|
|
|
@ -7,6 +7,9 @@ import hashlib
|
|
|
|
import hmac
|
|
|
|
import hmac
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
from typing import Any
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
from urllib import error as urllib_error
|
|
|
|
|
|
|
|
from urllib import parse as urllib_parse
|
|
|
|
|
|
|
|
from urllib import request as urllib_request
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi import HTTPException, status
|
|
|
|
from fastapi import HTTPException, status
|
|
|
|
|
|
|
|
|
|
|
|
@ -21,8 +24,24 @@ class AuthContext:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AuthBackend:
|
|
|
|
class AuthBackend:
|
|
|
|
def __init__(self, settings: Settings) -> None:
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
settings: Settings,
|
|
|
|
|
|
|
|
*,
|
|
|
|
|
|
|
|
oauth_introspection_url: str | None = None,
|
|
|
|
|
|
|
|
oauth_client_id: str | None = None,
|
|
|
|
|
|
|
|
oauth_client_secret: str | None = None,
|
|
|
|
|
|
|
|
oauth_issuer: str | None = None,
|
|
|
|
|
|
|
|
oauth_audience: str | None = None,
|
|
|
|
|
|
|
|
oauth_timeout_seconds: float = 8.0,
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
self.settings = settings
|
|
|
|
self.settings = settings
|
|
|
|
|
|
|
|
self.oauth_introspection_url = oauth_introspection_url
|
|
|
|
|
|
|
|
self.oauth_client_id = oauth_client_id
|
|
|
|
|
|
|
|
self.oauth_client_secret = oauth_client_secret
|
|
|
|
|
|
|
|
self.oauth_issuer = oauth_issuer
|
|
|
|
|
|
|
|
self.oauth_audience = oauth_audience
|
|
|
|
|
|
|
|
self.oauth_timeout_seconds = oauth_timeout_seconds
|
|
|
|
|
|
|
|
|
|
|
|
def is_enabled(self) -> bool:
|
|
|
|
def is_enabled(self) -> bool:
|
|
|
|
if self.settings.auth_mode == "api_key":
|
|
|
|
if self.settings.auth_mode == "api_key":
|
|
|
|
@ -33,6 +52,8 @@ class AuthBackend:
|
|
|
|
return bool(
|
|
|
|
return bool(
|
|
|
|
self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip()
|
|
|
|
self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.settings.auth_mode == "oauth":
|
|
|
|
|
|
|
|
return bool((self.oauth_introspection_url or "").strip())
|
|
|
|
return bool(self.settings.agent_api_key.strip())
|
|
|
|
return bool(self.settings.agent_api_key.strip())
|
|
|
|
|
|
|
|
|
|
|
|
def authenticate(
|
|
|
|
def authenticate(
|
|
|
|
@ -57,6 +78,11 @@ class AuthBackend:
|
|
|
|
authorization=authorization,
|
|
|
|
authorization=authorization,
|
|
|
|
required_scopes=required_scopes,
|
|
|
|
required_scopes=required_scopes,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if mode == "oauth":
|
|
|
|
|
|
|
|
return self._authenticate_oauth(
|
|
|
|
|
|
|
|
authorization=authorization,
|
|
|
|
|
|
|
|
required_scopes=required_scopes,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
@ -151,6 +177,48 @@ class AuthBackend:
|
|
|
|
subject = str(claims.get("sub") or "jwt-subject")
|
|
|
|
subject = str(claims.get("sub") or "jwt-subject")
|
|
|
|
return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values)
|
|
|
|
return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _authenticate_oauth(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
*,
|
|
|
|
|
|
|
|
authorization: str | None,
|
|
|
|
|
|
|
|
required_scopes: set[str],
|
|
|
|
|
|
|
|
) -> AuthContext:
|
|
|
|
|
|
|
|
if not self.oauth_introspection_url:
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
|
|
|
detail="OAuth introspection URL is required in oauth auth mode.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token = self._resolve_bearer(authorization)
|
|
|
|
|
|
|
|
if not token:
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
|
|
|
detail="Missing Bearer token.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
claims = self._introspect_oauth_token(token)
|
|
|
|
|
|
|
|
if not bool(claims.get("active")):
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
|
|
|
detail="Inactive OAuth token.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_validate_jwt_claims(
|
|
|
|
|
|
|
|
claims=claims,
|
|
|
|
|
|
|
|
expected_issuer=self.oauth_issuer,
|
|
|
|
|
|
|
|
expected_audience=self.oauth_audience,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scope_values = _extract_scopes(claims)
|
|
|
|
|
|
|
|
if required_scopes and not required_scopes.issubset(scope_values):
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
|
|
|
|
|
|
detail="Missing required scope.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
subject = str(claims.get("sub") or claims.get("client_id") or "oauth-subject")
|
|
|
|
|
|
|
|
return AuthContext(subject=subject, auth_type="oauth", scopes=scope_values)
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None:
|
|
|
|
def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None:
|
|
|
|
if x_api_key:
|
|
|
|
if x_api_key:
|
|
|
|
return x_api_key
|
|
|
|
return x_api_key
|
|
|
|
@ -166,6 +234,55 @@ class AuthBackend:
|
|
|
|
return parts[1]
|
|
|
|
return parts[1]
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _introspect_oauth_token(self, token: str) -> dict[str, Any]:
|
|
|
|
|
|
|
|
if self.oauth_client_secret and not self.oauth_client_id:
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
|
|
|
detail="OAuth client_id is required when client_secret is configured.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
request_body: dict[str, str] = {"token": token}
|
|
|
|
|
|
|
|
if self.oauth_client_id and not self.oauth_client_secret:
|
|
|
|
|
|
|
|
request_body["client_id"] = self.oauth_client_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
headers = {
|
|
|
|
|
|
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
|
|
|
|
|
|
"Accept": "application/json",
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if self.oauth_client_id and self.oauth_client_secret:
|
|
|
|
|
|
|
|
basic_secret = base64.b64encode(
|
|
|
|
|
|
|
|
f"{self.oauth_client_id}:{self.oauth_client_secret}".encode("utf-8")
|
|
|
|
|
|
|
|
).decode("ascii")
|
|
|
|
|
|
|
|
headers["Authorization"] = f"Basic {basic_secret}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
body_bytes = urllib_parse.urlencode(request_body).encode("utf-8")
|
|
|
|
|
|
|
|
request = urllib_request.Request(
|
|
|
|
|
|
|
|
self.oauth_introspection_url, # type: ignore[arg-type]
|
|
|
|
|
|
|
|
data=body_bytes,
|
|
|
|
|
|
|
|
headers=headers,
|
|
|
|
|
|
|
|
method="POST",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
with urllib_request.urlopen(request, timeout=self.oauth_timeout_seconds) as response:
|
|
|
|
|
|
|
|
payload = json.loads(response.read().decode("utf-8"))
|
|
|
|
|
|
|
|
except (urllib_error.HTTPError, urllib_error.URLError) as exc:
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
|
|
|
|
|
|
detail="OAuth introspection request failed.",
|
|
|
|
|
|
|
|
) from exc
|
|
|
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
|
|
|
|
|
|
detail="Invalid OAuth introspection response.",
|
|
|
|
|
|
|
|
) from exc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(payload, dict):
|
|
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
|
|
|
|
|
|
detail="Invalid OAuth introspection response.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return payload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]:
|
|
|
|
def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
|