from __future__ import annotations import base64 from dataclasses import dataclass from datetime import datetime, timezone import hashlib import hmac import json from typing import Any from urllib import error as urllib_error from urllib import parse as urllib_parse from urllib import request as urllib_request from fastapi import HTTPException, status from app.config import Settings @dataclass(frozen=True) class AuthContext: subject: str auth_type: str scopes: set[str] class AuthBackend: def __init__( self, settings: Settings, *, oauth_introspection_url: str | None = None, oauth_client_id: str | None = None, oauth_client_secret: str | None = None, oauth_issuer: str | None = None, oauth_audience: str | None = None, oauth_timeout_seconds: float = 8.0, ) -> None: self.settings = settings self.oauth_introspection_url = oauth_introspection_url self.oauth_client_id = oauth_client_id self.oauth_client_secret = oauth_client_secret self.oauth_issuer = oauth_issuer self.oauth_audience = oauth_audience self.oauth_timeout_seconds = oauth_timeout_seconds def is_enabled(self) -> bool: if self.settings.auth_mode == "api_key": return bool(self.settings.agent_api_key.strip()) if self.settings.auth_mode == "jwt": return bool(self.settings.auth_jwt_secret.strip()) if self.settings.auth_mode == "hybrid": return bool( self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip() ) if self.settings.auth_mode == "oauth": return bool((self.oauth_introspection_url or "").strip()) return bool(self.settings.agent_api_key.strip()) def authenticate( self, *, x_api_key: str | None, authorization: str | None, required_scopes: set[str], ) -> AuthContext: mode = self.settings.auth_mode if mode == "api_key": return self._authenticate_api_key( x_api_key=x_api_key, authorization=authorization, required_scopes=required_scopes, ) if mode == "jwt": return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes) if mode == "hybrid": return self._authenticate_hybrid( x_api_key=x_api_key, authorization=authorization, required_scopes=required_scopes, ) if mode == "oauth": return self._authenticate_oauth( authorization=authorization, required_scopes=required_scopes, ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Unsupported AUTH_MODE '{mode}'.", ) def _authenticate_hybrid( self, *, x_api_key: str | None, authorization: str | None, required_scopes: set[str], ) -> AuthContext: if self.settings.agent_api_key: api_key = self._resolve_api_key( x_api_key=x_api_key, authorization=authorization, ) if api_key == self.settings.agent_api_key: return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"}) if self.settings.auth_jwt_secret: return self._authenticate_jwt( authorization=authorization, required_scopes=required_scopes, ) return self._authenticate_api_key( x_api_key=x_api_key, authorization=authorization, required_scopes=required_scopes, ) def _authenticate_api_key( self, *, x_api_key: str | None, authorization: str | None, required_scopes: set[str], ) -> AuthContext: expected = self.settings.agent_api_key if not expected: return AuthContext(subject="anonymous", auth_type="none", scopes={"*"}) provided = self._resolve_api_key( x_api_key=x_api_key, authorization=authorization, ) if provided != expected: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.", ) # API key access is treated as admin-level compatibility mode. return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"}) def _authenticate_jwt( self, *, authorization: str | None, required_scopes: set[str], ) -> AuthContext: secret = self.settings.auth_jwt_secret if not secret: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="AUTH_JWT_SECRET is required in jwt auth mode.", ) token = self._resolve_bearer(authorization) if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Bearer token.", ) claims = _decode_hs256_jwt(token=token, secret=secret) _validate_jwt_claims( claims=claims, expected_issuer=self.settings.auth_jwt_issuer, expected_audience=self.settings.auth_jwt_audience, ) scope_values = _extract_scopes(claims) if required_scopes and not required_scopes.issubset(scope_values): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Missing required scope.", ) subject = str(claims.get("sub") or "jwt-subject") return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values) def _authenticate_oauth( self, *, authorization: str | None, required_scopes: set[str], ) -> AuthContext: if not self.oauth_introspection_url: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="OAuth introspection URL is required in oauth auth mode.", ) token = self._resolve_bearer(authorization) if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Bearer token.", ) claims = self._introspect_oauth_token(token) if not bool(claims.get("active")): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive OAuth token.", ) _validate_jwt_claims( claims=claims, expected_issuer=self.oauth_issuer, expected_audience=self.oauth_audience, ) scope_values = _extract_scopes(claims) if required_scopes and not required_scopes.issubset(scope_values): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Missing required scope.", ) subject = str(claims.get("sub") or claims.get("client_id") or "oauth-subject") return AuthContext(subject=subject, auth_type="oauth", scopes=scope_values) def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None: if x_api_key: return x_api_key token = self._resolve_bearer(authorization) return token def _resolve_bearer(self, authorization: str | None) -> str | None: if not authorization: return None parts = authorization.split(" ", 1) if len(parts) == 2 and parts[0].lower() == "bearer": return parts[1] return None def _introspect_oauth_token(self, token: str) -> dict[str, Any]: if self.oauth_client_secret and not self.oauth_client_id: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="OAuth client_id is required when client_secret is configured.", ) request_body: dict[str, str] = {"token": token} if self.oauth_client_id and not self.oauth_client_secret: request_body["client_id"] = self.oauth_client_id headers = { "Content-Type": "application/x-www-form-urlencoded", "Accept": "application/json", } if self.oauth_client_id and self.oauth_client_secret: basic_secret = base64.b64encode( f"{self.oauth_client_id}:{self.oauth_client_secret}".encode("utf-8") ).decode("ascii") headers["Authorization"] = f"Basic {basic_secret}" body_bytes = urllib_parse.urlencode(request_body).encode("utf-8") request = urllib_request.Request( self.oauth_introspection_url, # type: ignore[arg-type] data=body_bytes, headers=headers, method="POST", ) try: with urllib_request.urlopen(request, timeout=self.oauth_timeout_seconds) as response: payload = json.loads(response.read().decode("utf-8")) except (urllib_error.HTTPError, urllib_error.URLError) as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail="OAuth introspection request failed.", ) from exc except Exception as exc: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid OAuth introspection response.", ) from exc if not isinstance(payload, dict): raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid OAuth introspection response.", ) return payload def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]: try: header_segment, payload_segment, signature_segment = token.split(".") except ValueError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Malformed JWT.", ) from exc signing_input = f"{header_segment}.{payload_segment}".encode("utf-8") expected_signature = hmac.new( secret.encode("utf-8"), signing_input, hashlib.sha256, ).digest() actual_signature = _urlsafe_b64decode(signature_segment) if not hmac.compare_digest(expected_signature, actual_signature): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT signature.", ) header = _decode_jwt_json_segment(header_segment) algorithm = str(header.get("alg", "")).upper() if algorithm != "HS256": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Unsupported JWT algorithm.", ) return _decode_jwt_json_segment(payload_segment) def _decode_jwt_json_segment(segment: str) -> dict[str, Any]: try: decoded = _urlsafe_b64decode(segment).decode("utf-8") payload = json.loads(decoded) except Exception as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT payload.", ) from exc if not isinstance(payload, dict): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT object.", ) return payload def _urlsafe_b64decode(segment: str) -> bytes: padded = segment + "=" * (-len(segment) % 4) return base64.urlsafe_b64decode(padded.encode("utf-8")) def _validate_jwt_claims( *, claims: dict[str, Any], expected_issuer: str | None, expected_audience: str | None, ) -> None: now_ts = int(datetime.now(tz=timezone.utc).timestamp()) exp = claims.get("exp") if exp is not None: try: exp_ts = int(exp) except Exception as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid exp claim.", ) from exc if exp_ts < now_ts: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="JWT has expired.", ) nbf = claims.get("nbf") if nbf is not None: try: nbf_ts = int(nbf) except Exception as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid nbf claim.", ) from exc if nbf_ts > now_ts: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="JWT not valid yet.", ) if expected_issuer: issuer = str(claims.get("iss", "")) if issuer != expected_issuer: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT issuer.", ) if expected_audience: audience = claims.get("aud") if isinstance(audience, list): audience_values = {str(value) for value in audience} elif audience is None: audience_values = set() else: audience_values = {str(audience)} if expected_audience not in audience_values: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT audience.", ) def _extract_scopes(claims: dict[str, Any]) -> set[str]: raw_scope = claims.get("scope", "") scope_values: set[str] = set() if isinstance(raw_scope, str): scope_values.update(value for value in raw_scope.strip().split(" ") if value) elif isinstance(raw_scope, list): scope_values.update(str(value) for value in raw_scope if str(value).strip()) raw_scp = claims.get("scp", "") if isinstance(raw_scp, str): scope_values.update(value for value in raw_scp.strip().split(" ") if value) elif isinstance(raw_scp, list): scope_values.update(str(value) for value in raw_scp if str(value).strip()) return scope_values