Implement MCP OAuth 2.1 resource-server authorization flow
parent
0b9886fc56
commit
eca444f04a
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
|
||||
import anyio
|
||||
from fastapi import HTTPException
|
||||
from mcp.server.auth.provider import AccessToken, TokenVerifier
|
||||
from mcp.server.auth.settings import AuthSettings
|
||||
|
||||
from app.config import Settings
|
||||
from app.security import AuthBackend
|
||||
|
||||
_MCP_BASE_SCOPE = "availability:read"
|
||||
|
||||
|
||||
def resolve_mcp_auth_mode(settings: Settings) -> str:
|
||||
if settings.mcp_auth_mode == "inherit":
|
||||
return settings.auth_mode
|
||||
return settings.mcp_auth_mode
|
||||
|
||||
|
||||
def mcp_supported_scopes(settings: Settings) -> list[str]:
|
||||
scopes = [_MCP_BASE_SCOPE]
|
||||
if settings.mcp_enable_mutation_tools:
|
||||
scopes.extend(["mail:scan", "unsubscribe:read", "unsubscribe:execute"])
|
||||
return scopes
|
||||
|
||||
|
||||
def build_mcp_oauth_auth_settings(settings: Settings) -> AuthSettings:
|
||||
if not settings.mcp_oauth_issuer:
|
||||
raise ValueError("MCP_OAUTH_ISSUER is required when MCP_AUTH_MODE=oauth.")
|
||||
if not settings.mcp_resource_server_url:
|
||||
raise ValueError("MCP_RESOURCE_SERVER_URL is required when MCP_AUTH_MODE=oauth.")
|
||||
return AuthSettings(
|
||||
issuer_url=settings.mcp_oauth_issuer,
|
||||
resource_server_url=settings.mcp_resource_server_url,
|
||||
required_scopes=[_MCP_BASE_SCOPE],
|
||||
)
|
||||
|
||||
|
||||
def build_mcp_oauth_token_verifier(settings: Settings) -> TokenVerifier:
|
||||
oauth_settings = replace(settings, auth_mode="oauth")
|
||||
auth_backend = AuthBackend(
|
||||
settings=oauth_settings,
|
||||
oauth_introspection_url=settings.mcp_oauth_introspection_url,
|
||||
oauth_client_id=settings.mcp_oauth_client_id,
|
||||
oauth_client_secret=settings.mcp_oauth_client_secret or None,
|
||||
oauth_issuer=settings.mcp_oauth_issuer,
|
||||
oauth_audience=settings.mcp_oauth_audience,
|
||||
oauth_timeout_seconds=settings.mcp_oauth_timeout_seconds,
|
||||
)
|
||||
return OAuthIntrospectionTokenVerifier(auth_backend)
|
||||
|
||||
|
||||
class OAuthIntrospectionTokenVerifier(TokenVerifier):
|
||||
"""FastMCP TokenVerifier backed by MCP OAuth introspection settings."""
|
||||
|
||||
def __init__(self, auth_backend: AuthBackend) -> None:
|
||||
self._auth_backend = auth_backend
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
auth_context = await anyio.to_thread.run_sync(self._authenticate_token, token)
|
||||
if auth_context is None:
|
||||
return None
|
||||
scopes = sorted(scope for scope in auth_context.scopes if scope and scope != "*")
|
||||
return AccessToken(
|
||||
token=token,
|
||||
client_id=auth_context.subject,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
def _authenticate_token(self, token: str):
|
||||
try:
|
||||
return self._auth_backend.authenticate(
|
||||
x_api_key=None,
|
||||
authorization=f"Bearer {token}",
|
||||
required_scopes=set(),
|
||||
)
|
||||
except HTTPException:
|
||||
return None
|
||||
@ -1,33 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount
|
||||
|
||||
from app.config import get_settings
|
||||
from app.mcp import mcp
|
||||
|
||||
settings = get_settings()
|
||||
logger = logging.getLogger("personal-agent.mcp")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: Starlette):
|
||||
effective_mode = settings.auth_mode if settings.mcp_auth_mode == "inherit" else settings.mcp_auth_mode
|
||||
logger.info(
|
||||
"MCP authentication mode=%s (base AUTH_MODE=%s)",
|
||||
effective_mode,
|
||||
settings.auth_mode,
|
||||
)
|
||||
async with mcp.session_manager.run():
|
||||
yield
|
||||
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Mount("/mcp", app=mcp.streamable_http_app()),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
effective_mode = settings.auth_mode if settings.mcp_auth_mode == "inherit" else settings.mcp_auth_mode
|
||||
logger.info(
|
||||
"MCP authentication mode=%s (base AUTH_MODE=%s)",
|
||||
effective_mode,
|
||||
settings.auth_mode,
|
||||
)
|
||||
app = mcp.streamable_http_app()
|
||||
|
||||
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from mcp.server.auth.provider import AccessToken, TokenVerifier
|
||||
|
||||
from app.config import get_settings
|
||||
from app.mcp.server import build_mcp_server
|
||||
|
||||
|
||||
class _StaticTokenVerifier(TokenVerifier):
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
if token != "valid-token":
|
||||
return None
|
||||
return AccessToken(
|
||||
token=token,
|
||||
client_id="oauth-client",
|
||||
scopes=["availability:read"],
|
||||
)
|
||||
|
||||
|
||||
def _oauth_settings():
|
||||
return replace(
|
||||
get_settings(),
|
||||
auth_mode="api_key",
|
||||
mcp_auth_mode="oauth",
|
||||
mcp_oauth_issuer="https://issuer.example",
|
||||
mcp_resource_server_url="https://mcp.example.com/mcp",
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_oauth_exposes_protected_resource_metadata() -> None:
|
||||
server = build_mcp_server(settings=_oauth_settings(), token_verifier=_StaticTokenVerifier())
|
||||
|
||||
with TestClient(server.streamable_http_app()) as client:
|
||||
response = client.get("/.well-known/oauth-protected-resource/mcp")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["resource"] == "https://mcp.example.com/mcp"
|
||||
assert [value.rstrip("/") for value in payload["authorization_servers"]] == [
|
||||
"https://issuer.example"
|
||||
]
|
||||
assert "availability:read" in payload["scopes_supported"]
|
||||
|
||||
|
||||
def test_mcp_oauth_requires_bearer_token_with_challenge() -> None:
|
||||
server = build_mcp_server(settings=_oauth_settings(), token_verifier=_StaticTokenVerifier())
|
||||
|
||||
with TestClient(server.streamable_http_app()) as client:
|
||||
response = client.post("/mcp", json={})
|
||||
|
||||
assert response.status_code == 401
|
||||
challenge = response.headers.get("www-authenticate", "")
|
||||
assert challenge.startswith("Bearer ")
|
||||
assert 'error="invalid_token"' in challenge
|
||||
assert "resource_metadata=" in challenge
|
||||
assert "/.well-known/oauth-protected-resource/mcp" in challenge
|
||||
|
||||
|
||||
def test_mcp_oauth_mode_requires_resource_server_url() -> None:
|
||||
settings = replace(
|
||||
get_settings(),
|
||||
mcp_auth_mode="oauth",
|
||||
mcp_oauth_issuer="https://issuer.example",
|
||||
mcp_resource_server_url=None,
|
||||
)
|
||||
with pytest.raises(ValueError, match="MCP_RESOURCE_SERVER_URL"):
|
||||
build_mcp_server(settings=settings, token_verifier=_StaticTokenVerifier())
|
||||
Loading…
Reference in New Issue