You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
418 lines
14 KiB
Python
418 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
from typing import Any
|
|
from urllib import error as urllib_error
|
|
from urllib import parse as urllib_parse
|
|
from urllib import request as urllib_request
|
|
|
|
from fastapi import HTTPException, status
|
|
|
|
from app.config import Settings
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AuthContext:
|
|
subject: str
|
|
auth_type: str
|
|
scopes: set[str]
|
|
|
|
|
|
class AuthBackend:
|
|
def __init__(
|
|
self,
|
|
settings: Settings,
|
|
*,
|
|
oauth_introspection_url: str | None = None,
|
|
oauth_client_id: str | None = None,
|
|
oauth_client_secret: str | None = None,
|
|
oauth_issuer: str | None = None,
|
|
oauth_audience: str | None = None,
|
|
oauth_timeout_seconds: float = 8.0,
|
|
) -> None:
|
|
self.settings = settings
|
|
self.oauth_introspection_url = oauth_introspection_url
|
|
self.oauth_client_id = oauth_client_id
|
|
self.oauth_client_secret = oauth_client_secret
|
|
self.oauth_issuer = oauth_issuer
|
|
self.oauth_audience = oauth_audience
|
|
self.oauth_timeout_seconds = oauth_timeout_seconds
|
|
|
|
def is_enabled(self) -> bool:
|
|
if self.settings.auth_mode == "api_key":
|
|
return bool(self.settings.agent_api_key.strip())
|
|
if self.settings.auth_mode == "jwt":
|
|
return bool(self.settings.auth_jwt_secret.strip())
|
|
if self.settings.auth_mode == "hybrid":
|
|
return bool(
|
|
self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip()
|
|
)
|
|
if self.settings.auth_mode == "oauth":
|
|
return bool((self.oauth_introspection_url or "").strip())
|
|
return bool(self.settings.agent_api_key.strip())
|
|
|
|
def authenticate(
|
|
self,
|
|
*,
|
|
x_api_key: str | None,
|
|
authorization: str | None,
|
|
required_scopes: set[str],
|
|
) -> AuthContext:
|
|
mode = self.settings.auth_mode
|
|
if mode == "api_key":
|
|
return self._authenticate_api_key(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
required_scopes=required_scopes,
|
|
)
|
|
if mode == "jwt":
|
|
return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes)
|
|
if mode == "hybrid":
|
|
return self._authenticate_hybrid(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
required_scopes=required_scopes,
|
|
)
|
|
if mode == "oauth":
|
|
return self._authenticate_oauth(
|
|
authorization=authorization,
|
|
required_scopes=required_scopes,
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Unsupported AUTH_MODE '{mode}'.",
|
|
)
|
|
|
|
def _authenticate_hybrid(
|
|
self,
|
|
*,
|
|
x_api_key: str | None,
|
|
authorization: str | None,
|
|
required_scopes: set[str],
|
|
) -> AuthContext:
|
|
if self.settings.agent_api_key:
|
|
api_key = self._resolve_api_key(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
)
|
|
if api_key == self.settings.agent_api_key:
|
|
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
|
|
|
|
if self.settings.auth_jwt_secret:
|
|
return self._authenticate_jwt(
|
|
authorization=authorization,
|
|
required_scopes=required_scopes,
|
|
)
|
|
|
|
return self._authenticate_api_key(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
required_scopes=required_scopes,
|
|
)
|
|
|
|
def _authenticate_api_key(
|
|
self,
|
|
*,
|
|
x_api_key: str | None,
|
|
authorization: str | None,
|
|
required_scopes: set[str],
|
|
) -> AuthContext:
|
|
expected = self.settings.agent_api_key
|
|
if not expected:
|
|
return AuthContext(subject="anonymous", auth_type="none", scopes={"*"})
|
|
|
|
provided = self._resolve_api_key(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
)
|
|
if provided != expected:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key.",
|
|
)
|
|
|
|
# API key access is treated as admin-level compatibility mode.
|
|
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
|
|
|
|
def _authenticate_jwt(
|
|
self,
|
|
*,
|
|
authorization: str | None,
|
|
required_scopes: set[str],
|
|
) -> AuthContext:
|
|
secret = self.settings.auth_jwt_secret
|
|
if not secret:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="AUTH_JWT_SECRET is required in jwt auth mode.",
|
|
)
|
|
|
|
token = self._resolve_bearer(authorization)
|
|
if not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing Bearer token.",
|
|
)
|
|
|
|
claims = _decode_hs256_jwt(token=token, secret=secret)
|
|
_validate_jwt_claims(
|
|
claims=claims,
|
|
expected_issuer=self.settings.auth_jwt_issuer,
|
|
expected_audience=self.settings.auth_jwt_audience,
|
|
)
|
|
|
|
scope_values = _extract_scopes(claims)
|
|
if required_scopes and not required_scopes.issubset(scope_values):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Missing required scope.",
|
|
)
|
|
|
|
subject = str(claims.get("sub") or "jwt-subject")
|
|
return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values)
|
|
|
|
def _authenticate_oauth(
|
|
self,
|
|
*,
|
|
authorization: str | None,
|
|
required_scopes: set[str],
|
|
) -> AuthContext:
|
|
if not self.oauth_introspection_url:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="OAuth introspection URL is required in oauth auth mode.",
|
|
)
|
|
|
|
token = self._resolve_bearer(authorization)
|
|
if not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing Bearer token.",
|
|
)
|
|
|
|
claims = self._introspect_oauth_token(token)
|
|
if not bool(claims.get("active")):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Inactive OAuth token.",
|
|
)
|
|
|
|
_validate_jwt_claims(
|
|
claims=claims,
|
|
expected_issuer=self.oauth_issuer,
|
|
expected_audience=self.oauth_audience,
|
|
)
|
|
|
|
scope_values = _extract_scopes(claims)
|
|
if required_scopes and not required_scopes.issubset(scope_values):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Missing required scope.",
|
|
)
|
|
|
|
subject = str(claims.get("sub") or claims.get("client_id") or "oauth-subject")
|
|
return AuthContext(subject=subject, auth_type="oauth", scopes=scope_values)
|
|
|
|
def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None:
|
|
if x_api_key:
|
|
return x_api_key
|
|
|
|
token = self._resolve_bearer(authorization)
|
|
return token
|
|
|
|
def _resolve_bearer(self, authorization: str | None) -> str | None:
|
|
if not authorization:
|
|
return None
|
|
parts = authorization.split(" ", 1)
|
|
if len(parts) == 2 and parts[0].lower() == "bearer":
|
|
return parts[1]
|
|
return None
|
|
|
|
def _introspect_oauth_token(self, token: str) -> dict[str, Any]:
|
|
if self.oauth_client_secret and not self.oauth_client_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="OAuth client_id is required when client_secret is configured.",
|
|
)
|
|
|
|
request_body: dict[str, str] = {"token": token}
|
|
if self.oauth_client_id and not self.oauth_client_secret:
|
|
request_body["client_id"] = self.oauth_client_id
|
|
|
|
headers = {
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
"Accept": "application/json",
|
|
}
|
|
if self.oauth_client_id and self.oauth_client_secret:
|
|
basic_secret = base64.b64encode(
|
|
f"{self.oauth_client_id}:{self.oauth_client_secret}".encode("utf-8")
|
|
).decode("ascii")
|
|
headers["Authorization"] = f"Basic {basic_secret}"
|
|
|
|
body_bytes = urllib_parse.urlencode(request_body).encode("utf-8")
|
|
request = urllib_request.Request(
|
|
self.oauth_introspection_url, # type: ignore[arg-type]
|
|
data=body_bytes,
|
|
headers=headers,
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urllib_request.urlopen(request, timeout=self.oauth_timeout_seconds) as response:
|
|
payload = json.loads(response.read().decode("utf-8"))
|
|
except (urllib_error.HTTPError, urllib_error.URLError) as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail="OAuth introspection request failed.",
|
|
) from exc
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail="Invalid OAuth introspection response.",
|
|
) from exc
|
|
|
|
if not isinstance(payload, dict):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail="Invalid OAuth introspection response.",
|
|
)
|
|
return payload
|
|
|
|
|
|
def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]:
|
|
try:
|
|
header_segment, payload_segment, signature_segment = token.split(".")
|
|
except ValueError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Malformed JWT.",
|
|
) from exc
|
|
|
|
signing_input = f"{header_segment}.{payload_segment}".encode("utf-8")
|
|
expected_signature = hmac.new(
|
|
secret.encode("utf-8"),
|
|
signing_input,
|
|
hashlib.sha256,
|
|
).digest()
|
|
actual_signature = _urlsafe_b64decode(signature_segment)
|
|
if not hmac.compare_digest(expected_signature, actual_signature):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT signature.",
|
|
)
|
|
|
|
header = _decode_jwt_json_segment(header_segment)
|
|
algorithm = str(header.get("alg", "")).upper()
|
|
if algorithm != "HS256":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Unsupported JWT algorithm.",
|
|
)
|
|
|
|
return _decode_jwt_json_segment(payload_segment)
|
|
|
|
|
|
def _decode_jwt_json_segment(segment: str) -> dict[str, Any]:
|
|
try:
|
|
decoded = _urlsafe_b64decode(segment).decode("utf-8")
|
|
payload = json.loads(decoded)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT payload.",
|
|
) from exc
|
|
|
|
if not isinstance(payload, dict):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT object.",
|
|
)
|
|
return payload
|
|
|
|
|
|
def _urlsafe_b64decode(segment: str) -> bytes:
|
|
padded = segment + "=" * (-len(segment) % 4)
|
|
return base64.urlsafe_b64decode(padded.encode("utf-8"))
|
|
|
|
|
|
def _validate_jwt_claims(
|
|
*,
|
|
claims: dict[str, Any],
|
|
expected_issuer: str | None,
|
|
expected_audience: str | None,
|
|
) -> None:
|
|
now_ts = int(datetime.now(tz=timezone.utc).timestamp())
|
|
exp = claims.get("exp")
|
|
if exp is not None:
|
|
try:
|
|
exp_ts = int(exp)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid exp claim.",
|
|
) from exc
|
|
if exp_ts < now_ts:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="JWT has expired.",
|
|
)
|
|
|
|
nbf = claims.get("nbf")
|
|
if nbf is not None:
|
|
try:
|
|
nbf_ts = int(nbf)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid nbf claim.",
|
|
) from exc
|
|
if nbf_ts > now_ts:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="JWT not valid yet.",
|
|
)
|
|
|
|
if expected_issuer:
|
|
issuer = str(claims.get("iss", ""))
|
|
if issuer != expected_issuer:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT issuer.",
|
|
)
|
|
|
|
if expected_audience:
|
|
audience = claims.get("aud")
|
|
if isinstance(audience, list):
|
|
audience_values = {str(value) for value in audience}
|
|
elif audience is None:
|
|
audience_values = set()
|
|
else:
|
|
audience_values = {str(audience)}
|
|
if expected_audience not in audience_values:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid JWT audience.",
|
|
)
|
|
|
|
|
|
def _extract_scopes(claims: dict[str, Any]) -> set[str]:
|
|
raw_scope = claims.get("scope", "")
|
|
scope_values: set[str] = set()
|
|
if isinstance(raw_scope, str):
|
|
scope_values.update(value for value in raw_scope.strip().split(" ") if value)
|
|
elif isinstance(raw_scope, list):
|
|
scope_values.update(str(value) for value in raw_scope if str(value).strip())
|
|
|
|
raw_scp = claims.get("scp", "")
|
|
if isinstance(raw_scp, str):
|
|
scope_values.update(value for value in raw_scp.strip().split(" ") if value)
|
|
elif isinstance(raw_scp, list):
|
|
scope_values.update(str(value) for value in raw_scp if str(value).strip())
|
|
|
|
return scope_values
|