from __future__ import annotations import base64 from dataclasses import dataclass from datetime import datetime, timezone import hashlib import hmac import json from typing import Any from fastapi import HTTPException, status from app.config import Settings @dataclass(frozen=True) class AuthContext: subject: str auth_type: str scopes: set[str] class AuthBackend: def __init__(self, settings: Settings) -> None: self.settings = settings def is_enabled(self) -> bool: if self.settings.auth_mode == "api_key": return bool(self.settings.agent_api_key.strip()) if self.settings.auth_mode == "jwt": return bool(self.settings.auth_jwt_secret.strip()) if self.settings.auth_mode == "hybrid": return bool( self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip() ) return bool(self.settings.agent_api_key.strip()) def authenticate( self, *, x_api_key: str | None, authorization: str | None, required_scopes: set[str], ) -> AuthContext: mode = self.settings.auth_mode if mode == "api_key": return self._authenticate_api_key( x_api_key=x_api_key, authorization=authorization, required_scopes=required_scopes, ) if mode == "jwt": return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes) if mode == "hybrid": return self._authenticate_hybrid( x_api_key=x_api_key, authorization=authorization, required_scopes=required_scopes, ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Unsupported AUTH_MODE '{mode}'.", ) def _authenticate_hybrid( self, *, x_api_key: str | None, authorization: str | None, required_scopes: set[str], ) -> AuthContext: if self.settings.agent_api_key: api_key = self._resolve_api_key( x_api_key=x_api_key, authorization=authorization, ) if api_key == self.settings.agent_api_key: return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"}) if self.settings.auth_jwt_secret: return self._authenticate_jwt( authorization=authorization, required_scopes=required_scopes, ) return self._authenticate_api_key( x_api_key=x_api_key, authorization=authorization, required_scopes=required_scopes, ) def _authenticate_api_key( self, *, x_api_key: str | None, authorization: str | None, required_scopes: set[str], ) -> AuthContext: expected = self.settings.agent_api_key if not expected: return AuthContext(subject="anonymous", auth_type="none", scopes={"*"}) provided = self._resolve_api_key( x_api_key=x_api_key, authorization=authorization, ) if provided != expected: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.", ) # API key access is treated as admin-level compatibility mode. return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"}) def _authenticate_jwt( self, *, authorization: str | None, required_scopes: set[str], ) -> AuthContext: secret = self.settings.auth_jwt_secret if not secret: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="AUTH_JWT_SECRET is required in jwt auth mode.", ) token = self._resolve_bearer(authorization) if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Bearer token.", ) claims = _decode_hs256_jwt(token=token, secret=secret) _validate_jwt_claims( claims=claims, expected_issuer=self.settings.auth_jwt_issuer, expected_audience=self.settings.auth_jwt_audience, ) scope_values = _extract_scopes(claims) if required_scopes and not required_scopes.issubset(scope_values): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Missing required scope.", ) subject = str(claims.get("sub") or "jwt-subject") return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values) def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None: if x_api_key: return x_api_key token = self._resolve_bearer(authorization) return token def _resolve_bearer(self, authorization: str | None) -> str | None: if not authorization: return None parts = authorization.split(" ", 1) if len(parts) == 2 and parts[0].lower() == "bearer": return parts[1] return None def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]: try: header_segment, payload_segment, signature_segment = token.split(".") except ValueError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Malformed JWT.", ) from exc signing_input = f"{header_segment}.{payload_segment}".encode("utf-8") expected_signature = hmac.new( secret.encode("utf-8"), signing_input, hashlib.sha256, ).digest() actual_signature = _urlsafe_b64decode(signature_segment) if not hmac.compare_digest(expected_signature, actual_signature): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT signature.", ) header = _decode_jwt_json_segment(header_segment) algorithm = str(header.get("alg", "")).upper() if algorithm != "HS256": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Unsupported JWT algorithm.", ) return _decode_jwt_json_segment(payload_segment) def _decode_jwt_json_segment(segment: str) -> dict[str, Any]: try: decoded = _urlsafe_b64decode(segment).decode("utf-8") payload = json.loads(decoded) except Exception as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT payload.", ) from exc if not isinstance(payload, dict): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT object.", ) return payload def _urlsafe_b64decode(segment: str) -> bytes: padded = segment + "=" * (-len(segment) % 4) return base64.urlsafe_b64decode(padded.encode("utf-8")) def _validate_jwt_claims( *, claims: dict[str, Any], expected_issuer: str | None, expected_audience: str | None, ) -> None: now_ts = int(datetime.now(tz=timezone.utc).timestamp()) exp = claims.get("exp") if exp is not None: try: exp_ts = int(exp) except Exception as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid exp claim.", ) from exc if exp_ts < now_ts: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="JWT has expired.", ) nbf = claims.get("nbf") if nbf is not None: try: nbf_ts = int(nbf) except Exception as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid nbf claim.", ) from exc if nbf_ts > now_ts: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="JWT not valid yet.", ) if expected_issuer: issuer = str(claims.get("iss", "")) if issuer != expected_issuer: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT issuer.", ) if expected_audience: audience = claims.get("aud") if isinstance(audience, list): audience_values = {str(value) for value in audience} elif audience is None: audience_values = set() else: audience_values = {str(audience)} if expected_audience not in audience_values: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT audience.", ) def _extract_scopes(claims: dict[str, Any]) -> set[str]: raw_scope = claims.get("scope", "") scope_values: set[str] = set() if isinstance(raw_scope, str): scope_values.update(value for value in raw_scope.strip().split(" ") if value) elif isinstance(raw_scope, list): scope_values.update(str(value) for value in raw_scope if str(value).strip()) raw_scp = claims.get("scp", "") if isinstance(raw_scp, str): scope_values.update(value for value in raw_scp.strip().split(" ") if value) elif isinstance(raw_scp, list): scope_values.update(str(value) for value in raw_scp if str(value).strip()) return scope_values