You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
301 lines
9.5 KiB
Python
301 lines
9.5 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
from typing import Any
|
|
|
|
from fastapi import HTTPException, status
|
|
|
|
from app.config import Settings
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AuthContext:
|
|
subject: str
|
|
auth_type: str
|
|
scopes: set[str]
|
|
|
|
|
|
class AuthBackend:
|
|
def __init__(self, settings: Settings) -> None:
|
|
self.settings = settings
|
|
|
|
def is_enabled(self) -> bool:
|
|
if self.settings.auth_mode == "api_key":
|
|
return bool(self.settings.agent_api_key.strip())
|
|
if self.settings.auth_mode == "jwt":
|
|
return bool(self.settings.auth_jwt_secret.strip())
|
|
if self.settings.auth_mode == "hybrid":
|
|
return bool(
|
|
self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip()
|
|
)
|
|
return bool(self.settings.agent_api_key.strip())
|
|
|
|
def authenticate(
|
|
self,
|
|
*,
|
|
x_api_key: str | None,
|
|
authorization: str | None,
|
|
required_scopes: set[str],
|
|
) -> AuthContext:
|
|
mode = self.settings.auth_mode
|
|
if mode == "api_key":
|
|
return self._authenticate_api_key(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
required_scopes=required_scopes,
|
|
)
|
|
if mode == "jwt":
|
|
return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes)
|
|
if mode == "hybrid":
|
|
return self._authenticate_hybrid(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
required_scopes=required_scopes,
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Unsupported AUTH_MODE '{mode}'.",
|
|
)
|
|
|
|
def _authenticate_hybrid(
|
|
self,
|
|
*,
|
|
x_api_key: str | None,
|
|
authorization: str | None,
|
|
required_scopes: set[str],
|
|
) -> AuthContext:
|
|
if self.settings.agent_api_key:
|
|
api_key = self._resolve_api_key(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
)
|
|
if api_key == self.settings.agent_api_key:
|
|
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
|
|
|
|
if self.settings.auth_jwt_secret:
|
|
return self._authenticate_jwt(
|
|
authorization=authorization,
|
|
required_scopes=required_scopes,
|
|
)
|
|
|
|
return self._authenticate_api_key(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
required_scopes=required_scopes,
|
|
)
|
|
|
|
def _authenticate_api_key(
|
|
self,
|
|
*,
|
|
x_api_key: str | None,
|
|
authorization: str | None,
|
|
required_scopes: set[str],
|
|
) -> AuthContext:
|
|
expected = self.settings.agent_api_key
|
|
if not expected:
|
|
return AuthContext(subject="anonymous", auth_type="none", scopes={"*"})
|
|
|
|
provided = self._resolve_api_key(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
)
|
|
if provided != expected:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key.",
|
|
)
|
|
|
|
# API key access is treated as admin-level compatibility mode.
|
|
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
|
|
|
|
def _authenticate_jwt(
|
|
self,
|
|
*,
|
|
authorization: str | None,
|
|
required_scopes: set[str],
|
|
) -> AuthContext:
|
|
secret = self.settings.auth_jwt_secret
|
|
if not secret:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="AUTH_JWT_SECRET is required in jwt auth mode.",
|
|
)
|
|
|
|
token = self._resolve_bearer(authorization)
|
|
if not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing Bearer token.",
|
|
)
|
|
|
|
claims = _decode_hs256_jwt(token=token, secret=secret)
|
|
_validate_jwt_claims(
|
|
claims=claims,
|
|
expected_issuer=self.settings.auth_jwt_issuer,
|
|
expected_audience=self.settings.auth_jwt_audience,
|
|
)
|
|
|
|
scope_values = _extract_scopes(claims)
|
|
if required_scopes and not required_scopes.issubset(scope_values):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Missing required scope.",
|
|
)
|
|
|
|
subject = str(claims.get("sub") or "jwt-subject")
|
|
return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values)
|
|
|
|
def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None:
|
|
if x_api_key:
|
|
return x_api_key
|
|
|
|
token = self._resolve_bearer(authorization)
|
|
return token
|
|
|
|
def _resolve_bearer(self, authorization: str | None) -> str | None:
|
|
if not authorization:
|
|
return None
|
|
parts = authorization.split(" ", 1)
|
|
if len(parts) == 2 and parts[0].lower() == "bearer":
|
|
return parts[1]
|
|
return None
|
|
|
|
|
|
def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]:
|
|
try:
|
|
header_segment, payload_segment, signature_segment = token.split(".")
|
|
except ValueError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Malformed JWT.",
|
|
) from exc
|
|
|
|
signing_input = f"{header_segment}.{payload_segment}".encode("utf-8")
|
|
expected_signature = hmac.new(
|
|
secret.encode("utf-8"),
|
|
signing_input,
|
|
hashlib.sha256,
|
|
).digest()
|
|
actual_signature = _urlsafe_b64decode(signature_segment)
|
|
if not hmac.compare_digest(expected_signature, actual_signature):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT signature.",
|
|
)
|
|
|
|
header = _decode_jwt_json_segment(header_segment)
|
|
algorithm = str(header.get("alg", "")).upper()
|
|
if algorithm != "HS256":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Unsupported JWT algorithm.",
|
|
)
|
|
|
|
return _decode_jwt_json_segment(payload_segment)
|
|
|
|
|
|
def _decode_jwt_json_segment(segment: str) -> dict[str, Any]:
|
|
try:
|
|
decoded = _urlsafe_b64decode(segment).decode("utf-8")
|
|
payload = json.loads(decoded)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT payload.",
|
|
) from exc
|
|
|
|
if not isinstance(payload, dict):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT object.",
|
|
)
|
|
return payload
|
|
|
|
|
|
def _urlsafe_b64decode(segment: str) -> bytes:
|
|
padded = segment + "=" * (-len(segment) % 4)
|
|
return base64.urlsafe_b64decode(padded.encode("utf-8"))
|
|
|
|
|
|
def _validate_jwt_claims(
|
|
*,
|
|
claims: dict[str, Any],
|
|
expected_issuer: str | None,
|
|
expected_audience: str | None,
|
|
) -> None:
|
|
now_ts = int(datetime.now(tz=timezone.utc).timestamp())
|
|
exp = claims.get("exp")
|
|
if exp is not None:
|
|
try:
|
|
exp_ts = int(exp)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid exp claim.",
|
|
) from exc
|
|
if exp_ts < now_ts:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="JWT has expired.",
|
|
)
|
|
|
|
nbf = claims.get("nbf")
|
|
if nbf is not None:
|
|
try:
|
|
nbf_ts = int(nbf)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid nbf claim.",
|
|
) from exc
|
|
if nbf_ts > now_ts:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="JWT not valid yet.",
|
|
)
|
|
|
|
if expected_issuer:
|
|
issuer = str(claims.get("iss", ""))
|
|
if issuer != expected_issuer:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT issuer.",
|
|
)
|
|
|
|
if expected_audience:
|
|
audience = claims.get("aud")
|
|
if isinstance(audience, list):
|
|
audience_values = {str(value) for value in audience}
|
|
elif audience is None:
|
|
audience_values = set()
|
|
else:
|
|
audience_values = {str(audience)}
|
|
if expected_audience not in audience_values:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT audience.",
|
|
)
|
|
|
|
|
|
def _extract_scopes(claims: dict[str, Any]) -> set[str]:
|
|
raw_scope = claims.get("scope", "")
|
|
scope_values: set[str] = set()
|
|
if isinstance(raw_scope, str):
|
|
scope_values.update(value for value in raw_scope.strip().split(" ") if value)
|
|
elif isinstance(raw_scope, list):
|
|
scope_values.update(str(value) for value in raw_scope if str(value).strip())
|
|
|
|
raw_scp = claims.get("scp", "")
|
|
if isinstance(raw_scp, str):
|
|
scope_values.update(value for value in raw_scp.strip().split(" ") if value)
|
|
elif isinstance(raw_scp, list):
|
|
scope_values.update(str(value) for value in raw_scp if str(value).strip())
|
|
|
|
return scope_values
|