Implement MCP OAuth 2.1 resource-server authorization flow
parent
0b9886fc56
commit
eca444f04a
@ -0,0 +1,80 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from mcp.server.auth.provider import AccessToken, TokenVerifier
|
||||||
|
from mcp.server.auth.settings import AuthSettings
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
from app.security import AuthBackend
|
||||||
|
|
||||||
|
_MCP_BASE_SCOPE = "availability:read"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_mcp_auth_mode(settings: Settings) -> str:
|
||||||
|
if settings.mcp_auth_mode == "inherit":
|
||||||
|
return settings.auth_mode
|
||||||
|
return settings.mcp_auth_mode
|
||||||
|
|
||||||
|
|
||||||
|
def mcp_supported_scopes(settings: Settings) -> list[str]:
|
||||||
|
scopes = [_MCP_BASE_SCOPE]
|
||||||
|
if settings.mcp_enable_mutation_tools:
|
||||||
|
scopes.extend(["mail:scan", "unsubscribe:read", "unsubscribe:execute"])
|
||||||
|
return scopes
|
||||||
|
|
||||||
|
|
||||||
|
def build_mcp_oauth_auth_settings(settings: Settings) -> AuthSettings:
|
||||||
|
if not settings.mcp_oauth_issuer:
|
||||||
|
raise ValueError("MCP_OAUTH_ISSUER is required when MCP_AUTH_MODE=oauth.")
|
||||||
|
if not settings.mcp_resource_server_url:
|
||||||
|
raise ValueError("MCP_RESOURCE_SERVER_URL is required when MCP_AUTH_MODE=oauth.")
|
||||||
|
return AuthSettings(
|
||||||
|
issuer_url=settings.mcp_oauth_issuer,
|
||||||
|
resource_server_url=settings.mcp_resource_server_url,
|
||||||
|
required_scopes=[_MCP_BASE_SCOPE],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_mcp_oauth_token_verifier(settings: Settings) -> TokenVerifier:
|
||||||
|
oauth_settings = replace(settings, auth_mode="oauth")
|
||||||
|
auth_backend = AuthBackend(
|
||||||
|
settings=oauth_settings,
|
||||||
|
oauth_introspection_url=settings.mcp_oauth_introspection_url,
|
||||||
|
oauth_client_id=settings.mcp_oauth_client_id,
|
||||||
|
oauth_client_secret=settings.mcp_oauth_client_secret or None,
|
||||||
|
oauth_issuer=settings.mcp_oauth_issuer,
|
||||||
|
oauth_audience=settings.mcp_oauth_audience,
|
||||||
|
oauth_timeout_seconds=settings.mcp_oauth_timeout_seconds,
|
||||||
|
)
|
||||||
|
return OAuthIntrospectionTokenVerifier(auth_backend)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthIntrospectionTokenVerifier(TokenVerifier):
|
||||||
|
"""FastMCP TokenVerifier backed by MCP OAuth introspection settings."""
|
||||||
|
|
||||||
|
def __init__(self, auth_backend: AuthBackend) -> None:
|
||||||
|
self._auth_backend = auth_backend
|
||||||
|
|
||||||
|
async def verify_token(self, token: str) -> AccessToken | None:
|
||||||
|
auth_context = await anyio.to_thread.run_sync(self._authenticate_token, token)
|
||||||
|
if auth_context is None:
|
||||||
|
return None
|
||||||
|
scopes = sorted(scope for scope in auth_context.scopes if scope and scope != "*")
|
||||||
|
return AccessToken(
|
||||||
|
token=token,
|
||||||
|
client_id=auth_context.subject,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _authenticate_token(self, token: str):
|
||||||
|
try:
|
||||||
|
return self._auth_backend.authenticate(
|
||||||
|
x_api_key=None,
|
||||||
|
authorization=f"Bearer {token}",
|
||||||
|
required_scopes=set(),
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
return None
|
||||||
@ -1,33 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from starlette.applications import Starlette
|
|
||||||
from starlette.routing import Mount
|
|
||||||
|
|
||||||
from app.config import get_settings
|
from app.config import get_settings
|
||||||
from app.mcp import mcp
|
from app.mcp import mcp
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
logger = logging.getLogger("personal-agent.mcp")
|
logger = logging.getLogger("personal-agent.mcp")
|
||||||
|
effective_mode = settings.auth_mode if settings.mcp_auth_mode == "inherit" else settings.mcp_auth_mode
|
||||||
|
logger.info(
|
||||||
@asynccontextmanager
|
"MCP authentication mode=%s (base AUTH_MODE=%s)",
|
||||||
async def lifespan(_: Starlette):
|
effective_mode,
|
||||||
effective_mode = settings.auth_mode if settings.mcp_auth_mode == "inherit" else settings.mcp_auth_mode
|
settings.auth_mode,
|
||||||
logger.info(
|
|
||||||
"MCP authentication mode=%s (base AUTH_MODE=%s)",
|
|
||||||
effective_mode,
|
|
||||||
settings.auth_mode,
|
|
||||||
)
|
|
||||||
async with mcp.session_manager.run():
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
app = Starlette(
|
|
||||||
routes=[
|
|
||||||
Mount("/mcp", app=mcp.streamable_http_app()),
|
|
||||||
],
|
|
||||||
lifespan=lifespan,
|
|
||||||
)
|
)
|
||||||
|
app = mcp.streamable_http_app()
|
||||||
|
|||||||
@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from mcp.server.auth.provider import AccessToken, TokenVerifier
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
from app.mcp.server import build_mcp_server
|
||||||
|
|
||||||
|
|
||||||
|
class _StaticTokenVerifier(TokenVerifier):
|
||||||
|
async def verify_token(self, token: str) -> AccessToken | None:
|
||||||
|
if token != "valid-token":
|
||||||
|
return None
|
||||||
|
return AccessToken(
|
||||||
|
token=token,
|
||||||
|
client_id="oauth-client",
|
||||||
|
scopes=["availability:read"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _oauth_settings():
|
||||||
|
return replace(
|
||||||
|
get_settings(),
|
||||||
|
auth_mode="api_key",
|
||||||
|
mcp_auth_mode="oauth",
|
||||||
|
mcp_oauth_issuer="https://issuer.example",
|
||||||
|
mcp_resource_server_url="https://mcp.example.com/mcp",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_oauth_exposes_protected_resource_metadata() -> None:
|
||||||
|
server = build_mcp_server(settings=_oauth_settings(), token_verifier=_StaticTokenVerifier())
|
||||||
|
|
||||||
|
with TestClient(server.streamable_http_app()) as client:
|
||||||
|
response = client.get("/.well-known/oauth-protected-resource/mcp")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["resource"] == "https://mcp.example.com/mcp"
|
||||||
|
assert [value.rstrip("/") for value in payload["authorization_servers"]] == [
|
||||||
|
"https://issuer.example"
|
||||||
|
]
|
||||||
|
assert "availability:read" in payload["scopes_supported"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_oauth_requires_bearer_token_with_challenge() -> None:
|
||||||
|
server = build_mcp_server(settings=_oauth_settings(), token_verifier=_StaticTokenVerifier())
|
||||||
|
|
||||||
|
with TestClient(server.streamable_http_app()) as client:
|
||||||
|
response = client.post("/mcp", json={})
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
challenge = response.headers.get("www-authenticate", "")
|
||||||
|
assert challenge.startswith("Bearer ")
|
||||||
|
assert 'error="invalid_token"' in challenge
|
||||||
|
assert "resource_metadata=" in challenge
|
||||||
|
assert "/.well-known/oauth-protected-resource/mcp" in challenge
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_oauth_mode_requires_resource_server_url() -> None:
|
||||||
|
settings = replace(
|
||||||
|
get_settings(),
|
||||||
|
mcp_auth_mode="oauth",
|
||||||
|
mcp_oauth_issuer="https://issuer.example",
|
||||||
|
mcp_resource_server_url=None,
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="MCP_RESOURCE_SERVER_URL"):
|
||||||
|
build_mcp_server(settings=settings, token_verifier=_StaticTokenVerifier())
|
||||||
Loading…
Reference in New Issue