feat(auth): add unified auth backend with scopes

master
oabrivard 1 week ago
parent d25d429519
commit d1238f186c

@ -1,6 +1,10 @@
GOOGLE_CLIENT_SECRETS_FILE=credentials.json GOOGLE_CLIENT_SECRETS_FILE=credentials.json
GOOGLE_TOKEN_FILE=token.json GOOGLE_TOKEN_FILE=token.json
AGENT_API_KEY=change-me AGENT_API_KEY=change-me
AUTH_MODE=api_key
AUTH_JWT_SECRET=
AUTH_JWT_ISSUER=
AUTH_JWT_AUDIENCE=
# Preferred Strands settings # Preferred Strands settings
STRANDS_OPENAI_API_KEY= STRANDS_OPENAI_API_KEY=
STRANDS_MODEL_ID=gpt-4.1-mini STRANDS_MODEL_ID=gpt-4.1-mini

@ -19,6 +19,10 @@ class Settings:
gmail_scan_interval_minutes: int gmail_scan_interval_minutes: int
gmail_query: str gmail_query: str
agent_api_key: str agent_api_key: str
auth_mode: str
auth_jwt_secret: str
auth_jwt_issuer: str | None
auth_jwt_audience: str | None
strands_api_key: str strands_api_key: str
strands_model_id: str strands_model_id: str
strands_base_url: str | None strands_base_url: str | None
@ -52,6 +56,10 @@ def get_settings() -> Settings:
"GMAIL_QUERY", "in:inbox is:unread -label:AgentProcessed" "GMAIL_QUERY", "in:inbox is:unread -label:AgentProcessed"
), ),
agent_api_key=os.getenv("AGENT_API_KEY", ""), agent_api_key=os.getenv("AGENT_API_KEY", ""),
auth_mode=_normalize_auth_mode(os.getenv("AUTH_MODE", "api_key")),
auth_jwt_secret=os.getenv("AUTH_JWT_SECRET", "").strip(),
auth_jwt_issuer=os.getenv("AUTH_JWT_ISSUER", "").strip() or None,
auth_jwt_audience=os.getenv("AUTH_JWT_AUDIENCE", "").strip() or None,
strands_api_key=_first_set_env("STRANDS_OPENAI_API_KEY", "LLM_API_KEY"), strands_api_key=_first_set_env("STRANDS_OPENAI_API_KEY", "LLM_API_KEY"),
strands_model_id=_first_set_env("STRANDS_MODEL_ID", "LLM_MODEL") or "gpt-4.1-mini", strands_model_id=_first_set_env("STRANDS_MODEL_ID", "LLM_MODEL") or "gpt-4.1-mini",
strands_base_url=strands_base_url or None, strands_base_url=strands_base_url or None,
@ -102,3 +110,10 @@ def _first_set_env(*names: str) -> str:
if value: if value:
return value.strip() return value.strip()
return "" return ""
def _normalize_auth_mode(value: str) -> str:
normalized = value.strip().lower()
if normalized in {"api_key", "jwt", "hybrid"}:
return normalized
return "api_key"

@ -12,11 +12,13 @@ from pydantic import BaseModel, Field
from app.config import get_settings from app.config import get_settings
from app.core.service import CoreAgentService from app.core.service import CoreAgentService
from app.security import AuthBackend
settings = get_settings() settings = get_settings()
logging.basicConfig(level=getattr(logging, settings.log_level.upper(), logging.INFO)) logging.basicConfig(level=getattr(logging, settings.log_level.upper(), logging.INFO))
logger = logging.getLogger("personal-agent") logger = logging.getLogger("personal-agent")
core_service = CoreAgentService(settings=settings, logger=logger) core_service = CoreAgentService(settings=settings, logger=logger)
auth_backend = AuthBackend(settings=settings)
scheduler: AsyncIOScheduler | None = None scheduler: AsyncIOScheduler | None = None
scan_lock: asyncio.Lock | None = None scan_lock: asyncio.Lock | None = None
@ -110,28 +112,24 @@ class UnsubscribeExecutionResponse(BaseModel):
def _is_api_auth_enabled() -> bool: def _is_api_auth_enabled() -> bool:
return bool(settings.agent_api_key.strip()) return auth_backend.is_enabled()
def verify_api_key( def require_scope(*required_scopes: str):
scope_set = {scope.strip() for scope in required_scopes if scope.strip()}
def _dependency(
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None, x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
authorization: Annotated[str | None, Header()] = None, authorization: Annotated[str | None, Header()] = None,
) -> None: ) -> None:
expected = settings.agent_api_key auth_backend.authenticate(
if not expected: x_api_key=x_api_key,
return authorization=authorization,
required_scopes=scope_set,
provided = x_api_key
if not provided and authorization:
parts = authorization.split(" ", 1)
if len(parts) == 2 and parts[0].lower() == "bearer":
provided = parts[1]
if provided != expected:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key."
) )
return _dependency
def _run_scan_once(max_results: int) -> ScanResponse: def _run_scan_once(max_results: int) -> ScanResponse:
result = core_service.scan_mailbox(max_results=max_results) result = core_service.scan_mailbox(max_results=max_results)
@ -308,8 +306,9 @@ async def lifespan(app: FastAPI):
_get_scan_lock() _get_scan_lock()
_get_unsubscribe_lock() _get_unsubscribe_lock()
logger.info( logger.info(
"API authentication enabled=%s (header: X-API-Key or Bearer token)", "API authentication enabled=%s mode=%s (header: X-API-Key or Bearer token)",
_is_api_auth_enabled(), _is_api_auth_enabled(),
settings.auth_mode,
) )
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
scheduler.add_job( # type: ignore scheduler.add_job( # type: ignore
@ -362,7 +361,7 @@ def health() -> dict[str, object]:
@app.post( @app.post(
"/scan", "/scan",
response_model=ScanResponse, response_model=ScanResponse,
dependencies=[Depends(verify_api_key)], dependencies=[Depends(require_scope("mail:scan"))],
) )
async def scan_now(max_results: int = Query(100, ge=1, le=500)) -> ScanResponse: async def scan_now(max_results: int = Query(100, ge=1, le=500)) -> ScanResponse:
async with _get_scan_lock(): async with _get_scan_lock():
@ -383,7 +382,7 @@ async def scan_now(max_results: int = Query(100, ge=1, le=500)) -> ScanResponse:
@app.post( @app.post(
"/availability", "/availability",
response_model=AvailabilityResponse, response_model=AvailabilityResponse,
dependencies=[Depends(verify_api_key)], dependencies=[Depends(require_scope("availability:read"))],
) )
async def availability(request: AvailabilityRequest) -> AvailabilityResponse: async def availability(request: AvailabilityRequest) -> AvailabilityResponse:
try: try:
@ -417,7 +416,7 @@ async def availability(request: AvailabilityRequest) -> AvailabilityResponse:
@app.post( @app.post(
"/unsubscribe-digest", "/unsubscribe-digest",
response_model=UnsubscribeDigestResponse, response_model=UnsubscribeDigestResponse,
dependencies=[Depends(verify_api_key)], dependencies=[Depends(require_scope("unsubscribe:digest"))],
) )
async def unsubscribe_digest_now( async def unsubscribe_digest_now(
max_results: int = Query(default=settings.unsubscribe_max_results, ge=1, le=500), max_results: int = Query(default=settings.unsubscribe_max_results, ge=1, le=500),
@ -440,7 +439,7 @@ async def unsubscribe_digest_now(
@app.post( @app.post(
"/unsubscribe/candidates", "/unsubscribe/candidates",
response_model=UnsubscribeCandidatesResponse, response_model=UnsubscribeCandidatesResponse,
dependencies=[Depends(verify_api_key)], dependencies=[Depends(require_scope("unsubscribe:read"))],
) )
async def unsubscribe_candidates( async def unsubscribe_candidates(
max_results: int = Query(default=settings.unsubscribe_hil_max_results, ge=1, le=500), max_results: int = Query(default=settings.unsubscribe_hil_max_results, ge=1, le=500),
@ -458,7 +457,7 @@ async def unsubscribe_candidates(
@app.post( @app.post(
"/unsubscribe/execute", "/unsubscribe/execute",
response_model=UnsubscribeExecutionResponse, response_model=UnsubscribeExecutionResponse,
dependencies=[Depends(verify_api_key)], dependencies=[Depends(require_scope("unsubscribe:execute"))],
) )
async def unsubscribe_execute(request: ExecuteUnsubscribeRequest) -> UnsubscribeExecutionResponse: async def unsubscribe_execute(request: ExecuteUnsubscribeRequest) -> UnsubscribeExecutionResponse:
max_results = request.max_results or settings.unsubscribe_hil_max_results max_results = request.max_results or settings.unsubscribe_hil_max_results
@ -480,7 +479,7 @@ async def unsubscribe_execute(request: ExecuteUnsubscribeRequest) -> Unsubscribe
@app.post( @app.post(
"/unsubscribe/auto-run", "/unsubscribe/auto-run",
response_model=UnsubscribeExecutionResponse, response_model=UnsubscribeExecutionResponse,
dependencies=[Depends(verify_api_key)], dependencies=[Depends(require_scope("unsubscribe:execute"))],
) )
async def unsubscribe_auto_run( async def unsubscribe_auto_run(
max_results: int = Query(default=settings.unsubscribe_hil_max_results, ge=1, le=500), max_results: int = Query(default=settings.unsubscribe_hil_max_results, ge=1, le=500),

@ -0,0 +1,3 @@
from app.security.auth import AuthBackend, AuthContext
__all__ = ["AuthBackend", "AuthContext"]

@ -0,0 +1,300 @@
from __future__ import annotations
import base64
from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib
import hmac
import json
from typing import Any
from fastapi import HTTPException, status
from app.config import Settings
@dataclass(frozen=True)
class AuthContext:
subject: str
auth_type: str
scopes: set[str]
class AuthBackend:
def __init__(self, settings: Settings) -> None:
self.settings = settings
def is_enabled(self) -> bool:
if self.settings.auth_mode == "api_key":
return bool(self.settings.agent_api_key.strip())
if self.settings.auth_mode == "jwt":
return bool(self.settings.auth_jwt_secret.strip())
if self.settings.auth_mode == "hybrid":
return bool(
self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip()
)
return bool(self.settings.agent_api_key.strip())
def authenticate(
self,
*,
x_api_key: str | None,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
mode = self.settings.auth_mode
if mode == "api_key":
return self._authenticate_api_key(
x_api_key=x_api_key,
authorization=authorization,
required_scopes=required_scopes,
)
if mode == "jwt":
return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes)
if mode == "hybrid":
return self._authenticate_hybrid(
x_api_key=x_api_key,
authorization=authorization,
required_scopes=required_scopes,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unsupported AUTH_MODE '{mode}'.",
)
def _authenticate_hybrid(
self,
*,
x_api_key: str | None,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
if self.settings.agent_api_key:
api_key = self._resolve_api_key(
x_api_key=x_api_key,
authorization=authorization,
)
if api_key == self.settings.agent_api_key:
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
if self.settings.auth_jwt_secret:
return self._authenticate_jwt(
authorization=authorization,
required_scopes=required_scopes,
)
return self._authenticate_api_key(
x_api_key=x_api_key,
authorization=authorization,
required_scopes=required_scopes,
)
def _authenticate_api_key(
self,
*,
x_api_key: str | None,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
expected = self.settings.agent_api_key
if not expected:
return AuthContext(subject="anonymous", auth_type="none", scopes={"*"})
provided = self._resolve_api_key(
x_api_key=x_api_key,
authorization=authorization,
)
if provided != expected:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key.",
)
# API key access is treated as admin-level compatibility mode.
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
def _authenticate_jwt(
self,
*,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
secret = self.settings.auth_jwt_secret
if not secret:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="AUTH_JWT_SECRET is required in jwt auth mode.",
)
token = self._resolve_bearer(authorization)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Bearer token.",
)
claims = _decode_hs256_jwt(token=token, secret=secret)
_validate_jwt_claims(
claims=claims,
expected_issuer=self.settings.auth_jwt_issuer,
expected_audience=self.settings.auth_jwt_audience,
)
scope_values = _extract_scopes(claims)
if required_scopes and not required_scopes.issubset(scope_values):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Missing required scope.",
)
subject = str(claims.get("sub") or "jwt-subject")
return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values)
def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None:
if x_api_key:
return x_api_key
token = self._resolve_bearer(authorization)
return token
def _resolve_bearer(self, authorization: str | None) -> str | None:
if not authorization:
return None
parts = authorization.split(" ", 1)
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1]
return None
def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]:
try:
header_segment, payload_segment, signature_segment = token.split(".")
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Malformed JWT.",
) from exc
signing_input = f"{header_segment}.{payload_segment}".encode("utf-8")
expected_signature = hmac.new(
secret.encode("utf-8"),
signing_input,
hashlib.sha256,
).digest()
actual_signature = _urlsafe_b64decode(signature_segment)
if not hmac.compare_digest(expected_signature, actual_signature):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT signature.",
)
header = _decode_jwt_json_segment(header_segment)
algorithm = str(header.get("alg", "")).upper()
if algorithm != "HS256":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unsupported JWT algorithm.",
)
return _decode_jwt_json_segment(payload_segment)
def _decode_jwt_json_segment(segment: str) -> dict[str, Any]:
try:
decoded = _urlsafe_b64decode(segment).decode("utf-8")
payload = json.loads(decoded)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT payload.",
) from exc
if not isinstance(payload, dict):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT object.",
)
return payload
def _urlsafe_b64decode(segment: str) -> bytes:
padded = segment + "=" * (-len(segment) % 4)
return base64.urlsafe_b64decode(padded.encode("utf-8"))
def _validate_jwt_claims(
*,
claims: dict[str, Any],
expected_issuer: str | None,
expected_audience: str | None,
) -> None:
now_ts = int(datetime.now(tz=timezone.utc).timestamp())
exp = claims.get("exp")
if exp is not None:
try:
exp_ts = int(exp)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid exp claim.",
) from exc
if exp_ts < now_ts:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="JWT has expired.",
)
nbf = claims.get("nbf")
if nbf is not None:
try:
nbf_ts = int(nbf)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid nbf claim.",
) from exc
if nbf_ts > now_ts:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="JWT not valid yet.",
)
if expected_issuer:
issuer = str(claims.get("iss", ""))
if issuer != expected_issuer:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT issuer.",
)
if expected_audience:
audience = claims.get("aud")
if isinstance(audience, list):
audience_values = {str(value) for value in audience}
elif audience is None:
audience_values = set()
else:
audience_values = {str(audience)}
if expected_audience not in audience_values:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT audience.",
)
def _extract_scopes(claims: dict[str, Any]) -> set[str]:
raw_scope = claims.get("scope", "")
scope_values: set[str] = set()
if isinstance(raw_scope, str):
scope_values.update(value for value in raw_scope.strip().split(" ") if value)
elif isinstance(raw_scope, list):
scope_values.update(str(value) for value in raw_scope if str(value).strip())
raw_scp = claims.get("scp", "")
if isinstance(raw_scp, str):
scope_values.update(value for value in raw_scp.strip().split(" ") if value)
elif isinstance(raw_scp, list):
scope_values.update(str(value) for value in raw_scp if str(value).strip())
return scope_values
Loading…
Cancel
Save