You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

301 lines
9.5 KiB
Python

from __future__ import annotations
import base64
from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib
import hmac
import json
from typing import Any
from fastapi import HTTPException, status
from app.config import Settings
@dataclass(frozen=True)
class AuthContext:
subject: str
auth_type: str
scopes: set[str]
class AuthBackend:
def __init__(self, settings: Settings) -> None:
self.settings = settings
def is_enabled(self) -> bool:
if self.settings.auth_mode == "api_key":
return bool(self.settings.agent_api_key.strip())
if self.settings.auth_mode == "jwt":
return bool(self.settings.auth_jwt_secret.strip())
if self.settings.auth_mode == "hybrid":
return bool(
self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip()
)
return bool(self.settings.agent_api_key.strip())
def authenticate(
self,
*,
x_api_key: str | None,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
mode = self.settings.auth_mode
if mode == "api_key":
return self._authenticate_api_key(
x_api_key=x_api_key,
authorization=authorization,
required_scopes=required_scopes,
)
if mode == "jwt":
return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes)
if mode == "hybrid":
return self._authenticate_hybrid(
x_api_key=x_api_key,
authorization=authorization,
required_scopes=required_scopes,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unsupported AUTH_MODE '{mode}'.",
)
def _authenticate_hybrid(
self,
*,
x_api_key: str | None,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
if self.settings.agent_api_key:
api_key = self._resolve_api_key(
x_api_key=x_api_key,
authorization=authorization,
)
if api_key == self.settings.agent_api_key:
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
if self.settings.auth_jwt_secret:
return self._authenticate_jwt(
authorization=authorization,
required_scopes=required_scopes,
)
return self._authenticate_api_key(
x_api_key=x_api_key,
authorization=authorization,
required_scopes=required_scopes,
)
def _authenticate_api_key(
self,
*,
x_api_key: str | None,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
expected = self.settings.agent_api_key
if not expected:
return AuthContext(subject="anonymous", auth_type="none", scopes={"*"})
provided = self._resolve_api_key(
x_api_key=x_api_key,
authorization=authorization,
)
if provided != expected:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key.",
)
# API key access is treated as admin-level compatibility mode.
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
def _authenticate_jwt(
self,
*,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
secret = self.settings.auth_jwt_secret
if not secret:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="AUTH_JWT_SECRET is required in jwt auth mode.",
)
token = self._resolve_bearer(authorization)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Bearer token.",
)
claims = _decode_hs256_jwt(token=token, secret=secret)
_validate_jwt_claims(
claims=claims,
expected_issuer=self.settings.auth_jwt_issuer,
expected_audience=self.settings.auth_jwt_audience,
)
scope_values = _extract_scopes(claims)
if required_scopes and not required_scopes.issubset(scope_values):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Missing required scope.",
)
subject = str(claims.get("sub") or "jwt-subject")
return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values)
def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None:
if x_api_key:
return x_api_key
token = self._resolve_bearer(authorization)
return token
def _resolve_bearer(self, authorization: str | None) -> str | None:
if not authorization:
return None
parts = authorization.split(" ", 1)
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1]
return None
def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]:
try:
header_segment, payload_segment, signature_segment = token.split(".")
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Malformed JWT.",
) from exc
signing_input = f"{header_segment}.{payload_segment}".encode("utf-8")
expected_signature = hmac.new(
secret.encode("utf-8"),
signing_input,
hashlib.sha256,
).digest()
actual_signature = _urlsafe_b64decode(signature_segment)
if not hmac.compare_digest(expected_signature, actual_signature):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT signature.",
)
header = _decode_jwt_json_segment(header_segment)
algorithm = str(header.get("alg", "")).upper()
if algorithm != "HS256":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unsupported JWT algorithm.",
)
return _decode_jwt_json_segment(payload_segment)
def _decode_jwt_json_segment(segment: str) -> dict[str, Any]:
try:
decoded = _urlsafe_b64decode(segment).decode("utf-8")
payload = json.loads(decoded)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT payload.",
) from exc
if not isinstance(payload, dict):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT object.",
)
return payload
def _urlsafe_b64decode(segment: str) -> bytes:
padded = segment + "=" * (-len(segment) % 4)
return base64.urlsafe_b64decode(padded.encode("utf-8"))
def _validate_jwt_claims(
*,
claims: dict[str, Any],
expected_issuer: str | None,
expected_audience: str | None,
) -> None:
now_ts = int(datetime.now(tz=timezone.utc).timestamp())
exp = claims.get("exp")
if exp is not None:
try:
exp_ts = int(exp)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid exp claim.",
) from exc
if exp_ts < now_ts:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="JWT has expired.",
)
nbf = claims.get("nbf")
if nbf is not None:
try:
nbf_ts = int(nbf)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid nbf claim.",
) from exc
if nbf_ts > now_ts:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="JWT not valid yet.",
)
if expected_issuer:
issuer = str(claims.get("iss", ""))
if issuer != expected_issuer:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT issuer.",
)
if expected_audience:
audience = claims.get("aud")
if isinstance(audience, list):
audience_values = {str(value) for value in audience}
elif audience is None:
audience_values = set()
else:
audience_values = {str(audience)}
if expected_audience not in audience_values:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT audience.",
)
def _extract_scopes(claims: dict[str, Any]) -> set[str]:
raw_scope = claims.get("scope", "")
scope_values: set[str] = set()
if isinstance(raw_scope, str):
scope_values.update(value for value in raw_scope.strip().split(" ") if value)
elif isinstance(raw_scope, list):
scope_values.update(str(value) for value in raw_scope if str(value).strip())
raw_scp = claims.get("scp", "")
if isinstance(raw_scp, str):
scope_values.update(value for value in raw_scp.strip().split(" ") if value)
elif isinstance(raw_scp, list):
scope_values.update(str(value) for value in raw_scp if str(value).strip())
return scope_values