You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

288 lines
8.7 KiB
Python

from __future__ import annotations
import json
import logging
import re
from typing import Annotated, Any, cast
from fastapi import APIRouter, Header, HTTPException, Request, Response
from app.a2a.agent_card import build_agent_card
from app.a2a.models import A2ARpcError, A2ARpcRequest, A2ARpcResponse
from app.config import get_settings
from app.core.service import CoreAgentService
from app.security import AuthBackend
settings = get_settings()
auth_backend = AuthBackend(settings=settings)
core_service = CoreAgentService(settings=settings, logger=logging.getLogger("personal-agent.a2a"))
router = APIRouter(tags=["a2a"])
SEND_MESSAGE_METHODS = {"SendMessage", "send_message", "messages.send"}
PING_METHODS = {"ping", "health.ping", "health/ping"}
@router.get("/.well-known/agent-card.json")
def get_agent_card(request: Request, response: Response) -> dict[str, Any]:
response.headers["A2A-Version"] = "1.0"
return build_agent_card(settings=settings, request=request)
@router.post("/a2a/rpc", response_model=A2ARpcResponse)
def a2a_rpc(
payload: A2ARpcRequest,
response: Response,
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
authorization: Annotated[str | None, Header()] = None,
) -> A2ARpcResponse:
response.headers["A2A-Version"] = "1.0"
if payload.jsonrpc != "2.0":
return _error_response(
request_id=payload.id,
code=-32600,
message="Invalid Request: jsonrpc must be '2.0'.",
)
if payload.method in PING_METHODS:
return A2ARpcResponse(
id=payload.id,
result={"status": "ok", "agent": settings.a2a_agent_name},
)
if payload.method in SEND_MESSAGE_METHODS:
auth_error = _check_availability_access(
x_api_key=x_api_key,
authorization=authorization,
request_id=payload.id,
)
if auth_error:
return auth_error
return _handle_send_message(payload)
return _error_response(
request_id=payload.id,
code=-32601,
message=f"Method '{payload.method}' is not implemented yet.",
)
def _error_response(request_id: str | int | None, code: int, message: str) -> A2ARpcResponse:
return A2ARpcResponse(
id=request_id,
error=A2ARpcError(code=code, message=message),
)
def _check_availability_access(
*,
x_api_key: str | None,
authorization: str | None,
request_id: str | int | None,
) -> A2ARpcResponse | None:
try:
auth_backend.authenticate(
x_api_key=x_api_key,
authorization=authorization,
required_scopes={"availability:read"},
)
except HTTPException as exc:
return A2ARpcResponse(
id=request_id,
error=A2ARpcError(
code=-32001,
message=str(exc.detail),
data={"http_status": exc.status_code},
),
)
return None
def _handle_send_message(payload: A2ARpcRequest) -> A2ARpcResponse:
try:
request_payload = _extract_availability_payload(payload.params)
start = _require_string(request_payload, "start")
end = _require_string(request_payload, "end")
calendar_ids = _parse_calendar_ids(request_payload.get("calendar_ids"))
except ValueError as exc:
return _error_response(
request_id=payload.id,
code=-32602,
message=str(exc),
)
try:
result = core_service.check_availability(start, end, calendar_ids)
except ValueError as exc:
return _error_response(
request_id=payload.id,
code=-32602,
message=str(exc),
)
except FileNotFoundError as exc:
return _error_response(
request_id=payload.id,
code=-32000,
message=str(exc),
)
except Exception as exc:
return _error_response(
request_id=payload.id,
code=-32000,
message=f"Availability lookup failed: {exc}",
)
availability = {
"start": result.start,
"end": result.end,
"available": result.available,
"busy_slots": [
{
"calendar_id": slot.calendar_id,
"start": slot.start,
"end": slot.end,
}
for slot in result.busy_slots
],
"checked_calendars": result.checked_calendars,
}
return A2ARpcResponse(
id=payload.id,
result={
"type": "availability.result",
"availability": availability,
},
)
def _extract_availability_payload(params: dict[str, Any]) -> dict[str, Any]:
direct = _dict_with_availability_fields(params)
if direct is not None:
return direct
for key in ("input", "arguments", "data"):
nested = params.get(key)
if isinstance(nested, dict):
extracted = _dict_with_availability_fields(nested)
if extracted is not None:
return extracted
elif isinstance(nested, str):
parsed = _parse_json_object(nested)
if parsed is not None:
extracted = _dict_with_availability_fields(parsed)
if extracted is not None:
return extracted
message = params.get("message")
if isinstance(message, dict):
extracted = _extract_from_message(message)
if extracted is not None:
return extracted
messages = params.get("messages")
if isinstance(messages, list):
for item in reversed(messages):
if isinstance(item, dict):
extracted = _extract_from_message(item)
if extracted is not None:
return extracted
raise ValueError(
"SendMessage requires availability input with 'start' and 'end'. "
"Supported shapes: params.start/end, params.input.start/end, or message content JSON."
)
def _extract_from_message(message: dict[str, Any]) -> dict[str, Any] | None:
direct = _dict_with_availability_fields(message)
if direct is not None:
return direct
text = message.get("text")
if isinstance(text, str):
parsed = _parse_json_object(text)
if parsed is not None:
extracted = _dict_with_availability_fields(parsed)
if extracted is not None:
return extracted
content = message.get("content")
return _extract_from_content(content)
def _extract_from_content(content: Any) -> dict[str, Any] | None:
if isinstance(content, dict):
direct = _dict_with_availability_fields(content)
if direct is not None:
return direct
if "text" in content and isinstance(content["text"], str):
parsed = _parse_json_object(content["text"])
if parsed is not None:
extracted = _dict_with_availability_fields(parsed)
if extracted is not None:
return extracted
nested = content.get("content")
if nested is not None:
return _extract_from_content(nested)
if isinstance(content, list):
for part in content:
extracted = _extract_from_content(part)
if extracted is not None:
return extracted
if isinstance(content, str):
parsed = _parse_json_object(content)
if parsed is not None:
return _dict_with_availability_fields(parsed)
return None
def _dict_with_availability_fields(value: dict[str, Any]) -> dict[str, Any] | None:
if "start" in value and "end" in value:
return value
return None
def _parse_json_object(raw_value: str) -> dict[str, Any] | None:
stripped = raw_value.strip()
if not stripped:
return None
try:
loaded = json.loads(stripped)
if isinstance(loaded, dict):
return cast(dict[str, Any], loaded)
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", stripped, flags=re.DOTALL)
if not match:
return None
try:
loaded = json.loads(match.group(0))
except json.JSONDecodeError:
return None
if isinstance(loaded, dict):
return cast(dict[str, Any], loaded)
return None
def _require_string(payload: dict[str, Any], key: str) -> str:
value = payload.get(key)
if not isinstance(value, str) or not value.strip():
raise ValueError(f"'{key}' must be a non-empty string.")
return value.strip()
def _parse_calendar_ids(value: Any) -> list[str] | None:
if value is None:
return None
if not isinstance(value, list):
raise ValueError("'calendar_ids' must be an array of strings.")
calendar_ids = [str(item).strip() for item in value if str(item).strip()]
return calendar_ids or None