From d1238f186c3e841e56bc296a1a1e07a944c108c8 Mon Sep 17 00:00:00 2001 From: oabrivard Date: Mon, 9 Mar 2026 21:57:35 +0100 Subject: [PATCH] feat(auth): add unified auth backend with scopes --- .env.example | 4 + app/config.py | 15 ++ app/main.py | 47 +++--- app/security/__init__.py | 3 + app/security/auth.py | 300 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 345 insertions(+), 24 deletions(-) create mode 100644 app/security/__init__.py create mode 100644 app/security/auth.py diff --git a/.env.example b/.env.example index 2210900..adce7d4 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,10 @@ GOOGLE_CLIENT_SECRETS_FILE=credentials.json GOOGLE_TOKEN_FILE=token.json AGENT_API_KEY=change-me +AUTH_MODE=api_key +AUTH_JWT_SECRET= +AUTH_JWT_ISSUER= +AUTH_JWT_AUDIENCE= # Preferred Strands settings STRANDS_OPENAI_API_KEY= STRANDS_MODEL_ID=gpt-4.1-mini diff --git a/app/config.py b/app/config.py index 8edd274..365fd81 100644 --- a/app/config.py +++ b/app/config.py @@ -19,6 +19,10 @@ class Settings: gmail_scan_interval_minutes: int gmail_query: str agent_api_key: str + auth_mode: str + auth_jwt_secret: str + auth_jwt_issuer: str | None + auth_jwt_audience: str | None strands_api_key: str strands_model_id: str strands_base_url: str | None @@ -52,6 +56,10 @@ def get_settings() -> Settings: "GMAIL_QUERY", "in:inbox is:unread -label:AgentProcessed" ), agent_api_key=os.getenv("AGENT_API_KEY", ""), + auth_mode=_normalize_auth_mode(os.getenv("AUTH_MODE", "api_key")), + auth_jwt_secret=os.getenv("AUTH_JWT_SECRET", "").strip(), + auth_jwt_issuer=os.getenv("AUTH_JWT_ISSUER", "").strip() or None, + auth_jwt_audience=os.getenv("AUTH_JWT_AUDIENCE", "").strip() or None, strands_api_key=_first_set_env("STRANDS_OPENAI_API_KEY", "LLM_API_KEY"), strands_model_id=_first_set_env("STRANDS_MODEL_ID", "LLM_MODEL") or "gpt-4.1-mini", strands_base_url=strands_base_url or None, @@ -102,3 +110,10 @@ def _first_set_env(*names: str) -> str: if value: return value.strip() return "" + + +def _normalize_auth_mode(value: str) -> str: + normalized = value.strip().lower() + if normalized in {"api_key", "jwt", "hybrid"}: + return normalized + return "api_key" diff --git a/app/main.py b/app/main.py index 603a2ac..2ac770e 100644 --- a/app/main.py +++ b/app/main.py @@ -12,11 +12,13 @@ from pydantic import BaseModel, Field from app.config import get_settings from app.core.service import CoreAgentService +from app.security import AuthBackend settings = get_settings() logging.basicConfig(level=getattr(logging, settings.log_level.upper(), logging.INFO)) logger = logging.getLogger("personal-agent") core_service = CoreAgentService(settings=settings, logger=logger) +auth_backend = AuthBackend(settings=settings) scheduler: AsyncIOScheduler | None = None scan_lock: asyncio.Lock | None = None @@ -110,28 +112,24 @@ class UnsubscribeExecutionResponse(BaseModel): def _is_api_auth_enabled() -> bool: - return bool(settings.agent_api_key.strip()) + return auth_backend.is_enabled() -def verify_api_key( - x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None, - authorization: Annotated[str | None, Header()] = None, -) -> None: - expected = settings.agent_api_key - if not expected: - return - - provided = x_api_key - if not provided and authorization: - parts = authorization.split(" ", 1) - if len(parts) == 2 and parts[0].lower() == "bearer": - provided = parts[1] +def require_scope(*required_scopes: str): + scope_set = {scope.strip() for scope in required_scopes if scope.strip()} - if provided != expected: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key." + def _dependency( + x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None, + authorization: Annotated[str | None, Header()] = None, + ) -> None: + auth_backend.authenticate( + x_api_key=x_api_key, + authorization=authorization, + required_scopes=scope_set, ) + return _dependency + def _run_scan_once(max_results: int) -> ScanResponse: result = core_service.scan_mailbox(max_results=max_results) @@ -308,8 +306,9 @@ async def lifespan(app: FastAPI): _get_scan_lock() _get_unsubscribe_lock() logger.info( - "API authentication enabled=%s (header: X-API-Key or Bearer token)", + "API authentication enabled=%s mode=%s (header: X-API-Key or Bearer token)", _is_api_auth_enabled(), + settings.auth_mode, ) scheduler = AsyncIOScheduler() scheduler.add_job( # type: ignore @@ -362,7 +361,7 @@ def health() -> dict[str, object]: @app.post( "/scan", response_model=ScanResponse, - dependencies=[Depends(verify_api_key)], + dependencies=[Depends(require_scope("mail:scan"))], ) async def scan_now(max_results: int = Query(100, ge=1, le=500)) -> ScanResponse: async with _get_scan_lock(): @@ -383,7 +382,7 @@ async def scan_now(max_results: int = Query(100, ge=1, le=500)) -> ScanResponse: @app.post( "/availability", response_model=AvailabilityResponse, - dependencies=[Depends(verify_api_key)], + dependencies=[Depends(require_scope("availability:read"))], ) async def availability(request: AvailabilityRequest) -> AvailabilityResponse: try: @@ -417,7 +416,7 @@ async def availability(request: AvailabilityRequest) -> AvailabilityResponse: @app.post( "/unsubscribe-digest", response_model=UnsubscribeDigestResponse, - dependencies=[Depends(verify_api_key)], + dependencies=[Depends(require_scope("unsubscribe:digest"))], ) async def unsubscribe_digest_now( max_results: int = Query(default=settings.unsubscribe_max_results, ge=1, le=500), @@ -440,7 +439,7 @@ async def unsubscribe_digest_now( @app.post( "/unsubscribe/candidates", response_model=UnsubscribeCandidatesResponse, - dependencies=[Depends(verify_api_key)], + dependencies=[Depends(require_scope("unsubscribe:read"))], ) async def unsubscribe_candidates( max_results: int = Query(default=settings.unsubscribe_hil_max_results, ge=1, le=500), @@ -458,7 +457,7 @@ async def unsubscribe_candidates( @app.post( "/unsubscribe/execute", response_model=UnsubscribeExecutionResponse, - dependencies=[Depends(verify_api_key)], + dependencies=[Depends(require_scope("unsubscribe:execute"))], ) async def unsubscribe_execute(request: ExecuteUnsubscribeRequest) -> UnsubscribeExecutionResponse: max_results = request.max_results or settings.unsubscribe_hil_max_results @@ -480,7 +479,7 @@ async def unsubscribe_execute(request: ExecuteUnsubscribeRequest) -> Unsubscribe @app.post( "/unsubscribe/auto-run", response_model=UnsubscribeExecutionResponse, - dependencies=[Depends(verify_api_key)], + dependencies=[Depends(require_scope("unsubscribe:execute"))], ) async def unsubscribe_auto_run( max_results: int = Query(default=settings.unsubscribe_hil_max_results, ge=1, le=500), diff --git a/app/security/__init__.py b/app/security/__init__.py new file mode 100644 index 0000000..9176a36 --- /dev/null +++ b/app/security/__init__.py @@ -0,0 +1,3 @@ +from app.security.auth import AuthBackend, AuthContext + +__all__ = ["AuthBackend", "AuthContext"] diff --git a/app/security/auth.py b/app/security/auth.py new file mode 100644 index 0000000..6dcd0dc --- /dev/null +++ b/app/security/auth.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import base64 +from dataclasses import dataclass +from datetime import datetime, timezone +import hashlib +import hmac +import json +from typing import Any + +from fastapi import HTTPException, status + +from app.config import Settings + + +@dataclass(frozen=True) +class AuthContext: + subject: str + auth_type: str + scopes: set[str] + + +class AuthBackend: + def __init__(self, settings: Settings) -> None: + self.settings = settings + + def is_enabled(self) -> bool: + if self.settings.auth_mode == "api_key": + return bool(self.settings.agent_api_key.strip()) + if self.settings.auth_mode == "jwt": + return bool(self.settings.auth_jwt_secret.strip()) + if self.settings.auth_mode == "hybrid": + return bool( + self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip() + ) + return bool(self.settings.agent_api_key.strip()) + + def authenticate( + self, + *, + x_api_key: str | None, + authorization: str | None, + required_scopes: set[str], + ) -> AuthContext: + mode = self.settings.auth_mode + if mode == "api_key": + return self._authenticate_api_key( + x_api_key=x_api_key, + authorization=authorization, + required_scopes=required_scopes, + ) + if mode == "jwt": + return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes) + if mode == "hybrid": + return self._authenticate_hybrid( + x_api_key=x_api_key, + authorization=authorization, + required_scopes=required_scopes, + ) + + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Unsupported AUTH_MODE '{mode}'.", + ) + + def _authenticate_hybrid( + self, + *, + x_api_key: str | None, + authorization: str | None, + required_scopes: set[str], + ) -> AuthContext: + if self.settings.agent_api_key: + api_key = self._resolve_api_key( + x_api_key=x_api_key, + authorization=authorization, + ) + if api_key == self.settings.agent_api_key: + return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"}) + + if self.settings.auth_jwt_secret: + return self._authenticate_jwt( + authorization=authorization, + required_scopes=required_scopes, + ) + + return self._authenticate_api_key( + x_api_key=x_api_key, + authorization=authorization, + required_scopes=required_scopes, + ) + + def _authenticate_api_key( + self, + *, + x_api_key: str | None, + authorization: str | None, + required_scopes: set[str], + ) -> AuthContext: + expected = self.settings.agent_api_key + if not expected: + return AuthContext(subject="anonymous", auth_type="none", scopes={"*"}) + + provided = self._resolve_api_key( + x_api_key=x_api_key, + authorization=authorization, + ) + if provided != expected: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key.", + ) + + # API key access is treated as admin-level compatibility mode. + return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"}) + + def _authenticate_jwt( + self, + *, + authorization: str | None, + required_scopes: set[str], + ) -> AuthContext: + secret = self.settings.auth_jwt_secret + if not secret: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="AUTH_JWT_SECRET is required in jwt auth mode.", + ) + + token = self._resolve_bearer(authorization) + if not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing Bearer token.", + ) + + claims = _decode_hs256_jwt(token=token, secret=secret) + _validate_jwt_claims( + claims=claims, + expected_issuer=self.settings.auth_jwt_issuer, + expected_audience=self.settings.auth_jwt_audience, + ) + + scope_values = _extract_scopes(claims) + if required_scopes and not required_scopes.issubset(scope_values): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Missing required scope.", + ) + + subject = str(claims.get("sub") or "jwt-subject") + return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values) + + def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None: + if x_api_key: + return x_api_key + + token = self._resolve_bearer(authorization) + return token + + def _resolve_bearer(self, authorization: str | None) -> str | None: + if not authorization: + return None + parts = authorization.split(" ", 1) + if len(parts) == 2 and parts[0].lower() == "bearer": + return parts[1] + return None + + +def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]: + try: + header_segment, payload_segment, signature_segment = token.split(".") + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Malformed JWT.", + ) from exc + + signing_input = f"{header_segment}.{payload_segment}".encode("utf-8") + expected_signature = hmac.new( + secret.encode("utf-8"), + signing_input, + hashlib.sha256, + ).digest() + actual_signature = _urlsafe_b64decode(signature_segment) + if not hmac.compare_digest(expected_signature, actual_signature): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid JWT signature.", + ) + + header = _decode_jwt_json_segment(header_segment) + algorithm = str(header.get("alg", "")).upper() + if algorithm != "HS256": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Unsupported JWT algorithm.", + ) + + return _decode_jwt_json_segment(payload_segment) + + +def _decode_jwt_json_segment(segment: str) -> dict[str, Any]: + try: + decoded = _urlsafe_b64decode(segment).decode("utf-8") + payload = json.loads(decoded) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid JWT payload.", + ) from exc + + if not isinstance(payload, dict): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid JWT object.", + ) + return payload + + +def _urlsafe_b64decode(segment: str) -> bytes: + padded = segment + "=" * (-len(segment) % 4) + return base64.urlsafe_b64decode(padded.encode("utf-8")) + + +def _validate_jwt_claims( + *, + claims: dict[str, Any], + expected_issuer: str | None, + expected_audience: str | None, +) -> None: + now_ts = int(datetime.now(tz=timezone.utc).timestamp()) + exp = claims.get("exp") + if exp is not None: + try: + exp_ts = int(exp) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid exp claim.", + ) from exc + if exp_ts < now_ts: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="JWT has expired.", + ) + + nbf = claims.get("nbf") + if nbf is not None: + try: + nbf_ts = int(nbf) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid nbf claim.", + ) from exc + if nbf_ts > now_ts: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="JWT not valid yet.", + ) + + if expected_issuer: + issuer = str(claims.get("iss", "")) + if issuer != expected_issuer: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid JWT issuer.", + ) + + if expected_audience: + audience = claims.get("aud") + if isinstance(audience, list): + audience_values = {str(value) for value in audience} + elif audience is None: + audience_values = set() + else: + audience_values = {str(audience)} + if expected_audience not in audience_values: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid JWT audience.", + ) + + +def _extract_scopes(claims: dict[str, Any]) -> set[str]: + raw_scope = claims.get("scope", "") + scope_values: set[str] = set() + if isinstance(raw_scope, str): + scope_values.update(value for value in raw_scope.strip().split(" ") if value) + elif isinstance(raw_scope, list): + scope_values.update(str(value) for value in raw_scope if str(value).strip()) + + raw_scp = claims.get("scp", "") + if isinstance(raw_scp, str): + scope_values.update(value for value in raw_scp.strip().split(" ") if value) + elif isinstance(raw_scp, list): + scope_values.update(str(value) for value in raw_scp if str(value).strip()) + + return scope_values