You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

418 lines
14 KiB
Python

from __future__ import annotations
import base64
from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib
import hmac
import json
from typing import Any
from urllib import error as urllib_error
from urllib import parse as urllib_parse
from urllib import request as urllib_request
from fastapi import HTTPException, status
from app.config import Settings
@dataclass(frozen=True)
class AuthContext:
subject: str
auth_type: str
scopes: set[str]
class AuthBackend:
def __init__(
self,
settings: Settings,
*,
oauth_introspection_url: str | None = None,
oauth_client_id: str | None = None,
oauth_client_secret: str | None = None,
oauth_issuer: str | None = None,
oauth_audience: str | None = None,
oauth_timeout_seconds: float = 8.0,
) -> None:
self.settings = settings
self.oauth_introspection_url = oauth_introspection_url
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.oauth_issuer = oauth_issuer
self.oauth_audience = oauth_audience
self.oauth_timeout_seconds = oauth_timeout_seconds
def is_enabled(self) -> bool:
if self.settings.auth_mode == "api_key":
return bool(self.settings.agent_api_key.strip())
if self.settings.auth_mode == "jwt":
return bool(self.settings.auth_jwt_secret.strip())
if self.settings.auth_mode == "hybrid":
return bool(
self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip()
)
if self.settings.auth_mode == "oauth":
return bool((self.oauth_introspection_url or "").strip())
return bool(self.settings.agent_api_key.strip())
def authenticate(
self,
*,
x_api_key: str | None,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
mode = self.settings.auth_mode
if mode == "api_key":
return self._authenticate_api_key(
x_api_key=x_api_key,
authorization=authorization,
required_scopes=required_scopes,
)
if mode == "jwt":
return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes)
if mode == "hybrid":
return self._authenticate_hybrid(
x_api_key=x_api_key,
authorization=authorization,
required_scopes=required_scopes,
)
if mode == "oauth":
return self._authenticate_oauth(
authorization=authorization,
required_scopes=required_scopes,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unsupported AUTH_MODE '{mode}'.",
)
def _authenticate_hybrid(
self,
*,
x_api_key: str | None,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
if self.settings.agent_api_key:
api_key = self._resolve_api_key(
x_api_key=x_api_key,
authorization=authorization,
)
if api_key == self.settings.agent_api_key:
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
if self.settings.auth_jwt_secret:
return self._authenticate_jwt(
authorization=authorization,
required_scopes=required_scopes,
)
return self._authenticate_api_key(
x_api_key=x_api_key,
authorization=authorization,
required_scopes=required_scopes,
)
def _authenticate_api_key(
self,
*,
x_api_key: str | None,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
expected = self.settings.agent_api_key
if not expected:
return AuthContext(subject="anonymous", auth_type="none", scopes={"*"})
provided = self._resolve_api_key(
x_api_key=x_api_key,
authorization=authorization,
)
if provided != expected:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key.",
)
# API key access is treated as admin-level compatibility mode.
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
def _authenticate_jwt(
self,
*,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
secret = self.settings.auth_jwt_secret
if not secret:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="AUTH_JWT_SECRET is required in jwt auth mode.",
)
token = self._resolve_bearer(authorization)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Bearer token.",
)
claims = _decode_hs256_jwt(token=token, secret=secret)
_validate_jwt_claims(
claims=claims,
expected_issuer=self.settings.auth_jwt_issuer,
expected_audience=self.settings.auth_jwt_audience,
)
scope_values = _extract_scopes(claims)
if required_scopes and not required_scopes.issubset(scope_values):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Missing required scope.",
)
subject = str(claims.get("sub") or "jwt-subject")
return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values)
def _authenticate_oauth(
self,
*,
authorization: str | None,
required_scopes: set[str],
) -> AuthContext:
if not self.oauth_introspection_url:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth introspection URL is required in oauth auth mode.",
)
token = self._resolve_bearer(authorization)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Bearer token.",
)
claims = self._introspect_oauth_token(token)
if not bool(claims.get("active")):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Inactive OAuth token.",
)
_validate_jwt_claims(
claims=claims,
expected_issuer=self.oauth_issuer,
expected_audience=self.oauth_audience,
)
scope_values = _extract_scopes(claims)
if required_scopes and not required_scopes.issubset(scope_values):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Missing required scope.",
)
subject = str(claims.get("sub") or claims.get("client_id") or "oauth-subject")
return AuthContext(subject=subject, auth_type="oauth", scopes=scope_values)
def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None:
if x_api_key:
return x_api_key
token = self._resolve_bearer(authorization)
return token
def _resolve_bearer(self, authorization: str | None) -> str | None:
if not authorization:
return None
parts = authorization.split(" ", 1)
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1]
return None
def _introspect_oauth_token(self, token: str) -> dict[str, Any]:
if self.oauth_client_secret and not self.oauth_client_id:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth client_id is required when client_secret is configured.",
)
request_body: dict[str, str] = {"token": token}
if self.oauth_client_id and not self.oauth_client_secret:
request_body["client_id"] = self.oauth_client_id
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
}
if self.oauth_client_id and self.oauth_client_secret:
basic_secret = base64.b64encode(
f"{self.oauth_client_id}:{self.oauth_client_secret}".encode("utf-8")
).decode("ascii")
headers["Authorization"] = f"Basic {basic_secret}"
body_bytes = urllib_parse.urlencode(request_body).encode("utf-8")
request = urllib_request.Request(
self.oauth_introspection_url, # type: ignore[arg-type]
data=body_bytes,
headers=headers,
method="POST",
)
try:
with urllib_request.urlopen(request, timeout=self.oauth_timeout_seconds) as response:
payload = json.loads(response.read().decode("utf-8"))
except (urllib_error.HTTPError, urllib_error.URLError) as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="OAuth introspection request failed.",
) from exc
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Invalid OAuth introspection response.",
) from exc
if not isinstance(payload, dict):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Invalid OAuth introspection response.",
)
return payload
def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]:
try:
header_segment, payload_segment, signature_segment = token.split(".")
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Malformed JWT.",
) from exc
signing_input = f"{header_segment}.{payload_segment}".encode("utf-8")
expected_signature = hmac.new(
secret.encode("utf-8"),
signing_input,
hashlib.sha256,
).digest()
actual_signature = _urlsafe_b64decode(signature_segment)
if not hmac.compare_digest(expected_signature, actual_signature):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT signature.",
)
header = _decode_jwt_json_segment(header_segment)
algorithm = str(header.get("alg", "")).upper()
if algorithm != "HS256":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unsupported JWT algorithm.",
)
return _decode_jwt_json_segment(payload_segment)
def _decode_jwt_json_segment(segment: str) -> dict[str, Any]:
try:
decoded = _urlsafe_b64decode(segment).decode("utf-8")
payload = json.loads(decoded)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT payload.",
) from exc
if not isinstance(payload, dict):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT object.",
)
return payload
def _urlsafe_b64decode(segment: str) -> bytes:
padded = segment + "=" * (-len(segment) % 4)
return base64.urlsafe_b64decode(padded.encode("utf-8"))
def _validate_jwt_claims(
*,
claims: dict[str, Any],
expected_issuer: str | None,
expected_audience: str | None,
) -> None:
now_ts = int(datetime.now(tz=timezone.utc).timestamp())
exp = claims.get("exp")
if exp is not None:
try:
exp_ts = int(exp)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid exp claim.",
) from exc
if exp_ts < now_ts:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="JWT has expired.",
)
nbf = claims.get("nbf")
if nbf is not None:
try:
nbf_ts = int(nbf)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid nbf claim.",
) from exc
if nbf_ts > now_ts:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="JWT not valid yet.",
)
if expected_issuer:
issuer = str(claims.get("iss", ""))
if issuer != expected_issuer:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT issuer.",
)
if expected_audience:
audience = claims.get("aud")
if isinstance(audience, list):
audience_values = {str(value) for value in audience}
elif audience is None:
audience_values = set()
else:
audience_values = {str(audience)}
if expected_audience not in audience_values:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT audience.",
)
def _extract_scopes(claims: dict[str, Any]) -> set[str]:
raw_scope = claims.get("scope", "")
scope_values: set[str] = set()
if isinstance(raw_scope, str):
scope_values.update(value for value in raw_scope.strip().split(" ") if value)
elif isinstance(raw_scope, list):
scope_values.update(str(value) for value in raw_scope if str(value).strip())
raw_scp = claims.get("scp", "")
if isinstance(raw_scp, str):
scope_values.update(value for value in raw_scp.strip().split(" ") if value)
elif isinstance(raw_scp, list):
scope_values.update(str(value) for value in raw_scp if str(value).strip())
return scope_values