feat(auth): add unified auth backend with scopes
parent
d25d429519
commit
d1238f186c
@ -0,0 +1,3 @@
|
||||
from app.security.auth import AuthBackend, AuthContext
|
||||
|
||||
__all__ = ["AuthBackend", "AuthContext"]
|
||||
@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthContext:
|
||||
subject: str
|
||||
auth_type: str
|
||||
scopes: set[str]
|
||||
|
||||
|
||||
class AuthBackend:
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
if self.settings.auth_mode == "api_key":
|
||||
return bool(self.settings.agent_api_key.strip())
|
||||
if self.settings.auth_mode == "jwt":
|
||||
return bool(self.settings.auth_jwt_secret.strip())
|
||||
if self.settings.auth_mode == "hybrid":
|
||||
return bool(
|
||||
self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip()
|
||||
)
|
||||
return bool(self.settings.agent_api_key.strip())
|
||||
|
||||
def authenticate(
|
||||
self,
|
||||
*,
|
||||
x_api_key: str | None,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> AuthContext:
|
||||
mode = self.settings.auth_mode
|
||||
if mode == "api_key":
|
||||
return self._authenticate_api_key(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
if mode == "jwt":
|
||||
return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes)
|
||||
if mode == "hybrid":
|
||||
return self._authenticate_hybrid(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Unsupported AUTH_MODE '{mode}'.",
|
||||
)
|
||||
|
||||
def _authenticate_hybrid(
|
||||
self,
|
||||
*,
|
||||
x_api_key: str | None,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> AuthContext:
|
||||
if self.settings.agent_api_key:
|
||||
api_key = self._resolve_api_key(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
)
|
||||
if api_key == self.settings.agent_api_key:
|
||||
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
|
||||
|
||||
if self.settings.auth_jwt_secret:
|
||||
return self._authenticate_jwt(
|
||||
authorization=authorization,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
return self._authenticate_api_key(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
def _authenticate_api_key(
|
||||
self,
|
||||
*,
|
||||
x_api_key: str | None,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> AuthContext:
|
||||
expected = self.settings.agent_api_key
|
||||
if not expected:
|
||||
return AuthContext(subject="anonymous", auth_type="none", scopes={"*"})
|
||||
|
||||
provided = self._resolve_api_key(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
)
|
||||
if provided != expected:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key.",
|
||||
)
|
||||
|
||||
# API key access is treated as admin-level compatibility mode.
|
||||
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
|
||||
|
||||
def _authenticate_jwt(
|
||||
self,
|
||||
*,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> AuthContext:
|
||||
secret = self.settings.auth_jwt_secret
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="AUTH_JWT_SECRET is required in jwt auth mode.",
|
||||
)
|
||||
|
||||
token = self._resolve_bearer(authorization)
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Bearer token.",
|
||||
)
|
||||
|
||||
claims = _decode_hs256_jwt(token=token, secret=secret)
|
||||
_validate_jwt_claims(
|
||||
claims=claims,
|
||||
expected_issuer=self.settings.auth_jwt_issuer,
|
||||
expected_audience=self.settings.auth_jwt_audience,
|
||||
)
|
||||
|
||||
scope_values = _extract_scopes(claims)
|
||||
if required_scopes and not required_scopes.issubset(scope_values):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Missing required scope.",
|
||||
)
|
||||
|
||||
subject = str(claims.get("sub") or "jwt-subject")
|
||||
return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values)
|
||||
|
||||
def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None:
|
||||
if x_api_key:
|
||||
return x_api_key
|
||||
|
||||
token = self._resolve_bearer(authorization)
|
||||
return token
|
||||
|
||||
def _resolve_bearer(self, authorization: str | None) -> str | None:
|
||||
if not authorization:
|
||||
return None
|
||||
parts = authorization.split(" ", 1)
|
||||
if len(parts) == 2 and parts[0].lower() == "bearer":
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]:
|
||||
try:
|
||||
header_segment, payload_segment, signature_segment = token.split(".")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Malformed JWT.",
|
||||
) from exc
|
||||
|
||||
signing_input = f"{header_segment}.{payload_segment}".encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
signing_input,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
actual_signature = _urlsafe_b64decode(signature_segment)
|
||||
if not hmac.compare_digest(expected_signature, actual_signature):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT signature.",
|
||||
)
|
||||
|
||||
header = _decode_jwt_json_segment(header_segment)
|
||||
algorithm = str(header.get("alg", "")).upper()
|
||||
if algorithm != "HS256":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unsupported JWT algorithm.",
|
||||
)
|
||||
|
||||
return _decode_jwt_json_segment(payload_segment)
|
||||
|
||||
|
||||
def _decode_jwt_json_segment(segment: str) -> dict[str, Any]:
|
||||
try:
|
||||
decoded = _urlsafe_b64decode(segment).decode("utf-8")
|
||||
payload = json.loads(decoded)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT payload.",
|
||||
) from exc
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT object.",
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def _urlsafe_b64decode(segment: str) -> bytes:
|
||||
padded = segment + "=" * (-len(segment) % 4)
|
||||
return base64.urlsafe_b64decode(padded.encode("utf-8"))
|
||||
|
||||
|
||||
def _validate_jwt_claims(
|
||||
*,
|
||||
claims: dict[str, Any],
|
||||
expected_issuer: str | None,
|
||||
expected_audience: str | None,
|
||||
) -> None:
|
||||
now_ts = int(datetime.now(tz=timezone.utc).timestamp())
|
||||
exp = claims.get("exp")
|
||||
if exp is not None:
|
||||
try:
|
||||
exp_ts = int(exp)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid exp claim.",
|
||||
) from exc
|
||||
if exp_ts < now_ts:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="JWT has expired.",
|
||||
)
|
||||
|
||||
nbf = claims.get("nbf")
|
||||
if nbf is not None:
|
||||
try:
|
||||
nbf_ts = int(nbf)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid nbf claim.",
|
||||
) from exc
|
||||
if nbf_ts > now_ts:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="JWT not valid yet.",
|
||||
)
|
||||
|
||||
if expected_issuer:
|
||||
issuer = str(claims.get("iss", ""))
|
||||
if issuer != expected_issuer:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT issuer.",
|
||||
)
|
||||
|
||||
if expected_audience:
|
||||
audience = claims.get("aud")
|
||||
if isinstance(audience, list):
|
||||
audience_values = {str(value) for value in audience}
|
||||
elif audience is None:
|
||||
audience_values = set()
|
||||
else:
|
||||
audience_values = {str(audience)}
|
||||
if expected_audience not in audience_values:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT audience.",
|
||||
)
|
||||
|
||||
|
||||
def _extract_scopes(claims: dict[str, Any]) -> set[str]:
|
||||
raw_scope = claims.get("scope", "")
|
||||
scope_values: set[str] = set()
|
||||
if isinstance(raw_scope, str):
|
||||
scope_values.update(value for value in raw_scope.strip().split(" ") if value)
|
||||
elif isinstance(raw_scope, list):
|
||||
scope_values.update(str(value) for value in raw_scope if str(value).strip())
|
||||
|
||||
raw_scp = claims.get("scp", "")
|
||||
if isinstance(raw_scp, str):
|
||||
scope_values.update(value for value in raw_scp.strip().split(" ") if value)
|
||||
elif isinstance(raw_scp, list):
|
||||
scope_values.update(str(value) for value in raw_scp if str(value).strip())
|
||||
|
||||
return scope_values
|
||||
Loading…
Reference in New Issue