from __future__ import annotations import json import logging import re from typing import Annotated, Any, cast from fastapi import APIRouter, Header, HTTPException, Request, Response from app.a2a.agent_card import build_agent_card from app.a2a.models import A2ARpcError, A2ARpcRequest, A2ARpcResponse from app.config import get_settings from app.core.service import CoreAgentService from app.security import AuthBackend settings = get_settings() auth_backend = AuthBackend(settings=settings) core_service = CoreAgentService(settings=settings, logger=logging.getLogger("personal-agent.a2a")) router = APIRouter(tags=["a2a"]) SEND_MESSAGE_METHODS = {"SendMessage", "send_message", "messages.send"} PING_METHODS = {"ping", "health.ping", "health/ping"} DEFAULT_SEND_MESSAGE_ACTION = "check_availability" MEETING_INTERVALS_ACTION = "available_meeting_intervals" _ACTION_SCOPE = { DEFAULT_SEND_MESSAGE_ACTION: "availability:read", MEETING_INTERVALS_ACTION: "available_meeting_intervals:read", } @router.get("/.well-known/agent-card.json") def get_agent_card(request: Request, response: Response) -> dict[str, Any]: response.headers["A2A-Version"] = "1.0" return build_agent_card(settings=settings, request=request) @router.post("/a2a/rpc", response_model=A2ARpcResponse) def a2a_rpc( payload: A2ARpcRequest, response: Response, x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None, authorization: Annotated[str | None, Header()] = None, ) -> A2ARpcResponse: response.headers["A2A-Version"] = "1.0" if payload.jsonrpc != "2.0": return _error_response( request_id=payload.id, code=-32600, message="Invalid Request: jsonrpc must be '2.0'.", ) if payload.method in PING_METHODS: return A2ARpcResponse( id=payload.id, result={"status": "ok", "agent": settings.a2a_agent_name}, ) if payload.method in SEND_MESSAGE_METHODS: try: action = _resolve_send_message_action(payload.params) except ValueError as exc: return _error_response( request_id=payload.id, code=-32602, message=str(exc), ) auth_error = _check_scope_access( x_api_key=x_api_key, authorization=authorization, request_id=payload.id, required_scope=_ACTION_SCOPE[action], ) if auth_error: return auth_error return _handle_send_message(payload, action=action) return _error_response( request_id=payload.id, code=-32601, message=f"Method '{payload.method}' is not implemented yet.", ) def _error_response(request_id: str | int | None, code: int, message: str) -> A2ARpcResponse: return A2ARpcResponse( id=request_id, error=A2ARpcError(code=code, message=message), ) def _check_scope_access( *, x_api_key: str | None, authorization: str | None, request_id: str | int | None, required_scope: str, ) -> A2ARpcResponse | None: try: auth_backend.authenticate( x_api_key=x_api_key, authorization=authorization, required_scopes={required_scope}, ) except HTTPException as exc: return A2ARpcResponse( id=request_id, error=A2ARpcError( code=-32001, message=str(exc.detail), data={"http_status": exc.status_code}, ), ) return None def _handle_send_message(payload: A2ARpcRequest, *, action: str) -> A2ARpcResponse: try: request_payload = _extract_schedule_payload(payload.params) start = _require_string(request_payload, "start") end = _require_string(request_payload, "end") calendar_ids = _parse_calendar_ids(request_payload.get("calendar_ids")) except ValueError as exc: return _error_response( request_id=payload.id, code=-32602, message=str(exc), ) try: if action == MEETING_INTERVALS_ACTION: return _meeting_intervals_response(payload, start, end, calendar_ids) result = core_service.check_availability(start, end, calendar_ids) except ValueError as exc: return _error_response( request_id=payload.id, code=-32602, message=str(exc), ) except FileNotFoundError as exc: return _error_response( request_id=payload.id, code=-32000, message=str(exc), ) except Exception as exc: failure_label = ( "Meeting interval lookup failed" if action == MEETING_INTERVALS_ACTION else "Availability lookup failed" ) return _error_response( request_id=payload.id, code=-32000, message=f"{failure_label}: {exc}", ) availability = { "start": result.start, "end": result.end, "available": result.available, "busy_slots": [ { "calendar_id": slot.calendar_id, "start": slot.start, "end": slot.end, } for slot in result.busy_slots ], "checked_calendars": result.checked_calendars, } return A2ARpcResponse( id=payload.id, result={ "type": "availability.result", "availability": availability, }, ) def _meeting_intervals_response( payload: A2ARpcRequest, start: str, end: str, calendar_ids: list[str] | None, ) -> A2ARpcResponse: result = core_service.available_meeting_intervals(start, end, calendar_ids) meeting_intervals = { "start": result.start, "end": result.end, "timezone": result.timezone, "meeting_intervals": [ { "start": interval.start, "end": interval.end, } for interval in result.meeting_intervals ], "checked_calendars": result.checked_calendars, } return A2ARpcResponse( id=payload.id, result={ "type": "available_meeting_intervals.result", "meeting_intervals": meeting_intervals, }, ) def _resolve_send_message_action(params: dict[str, Any]) -> str: action = _extract_action(params) if action is None: return DEFAULT_SEND_MESSAGE_ACTION normalized = action.strip() if not normalized: return DEFAULT_SEND_MESSAGE_ACTION lowered = normalized.lower() if lowered in {"check_availability", "availability"}: return DEFAULT_SEND_MESSAGE_ACTION if lowered == MEETING_INTERVALS_ACTION: return MEETING_INTERVALS_ACTION raise ValueError( "Unsupported 'action'. Expected 'available_meeting_intervals' or omitted for availability." ) def _extract_action(params: dict[str, Any]) -> str | None: if "action" in params and isinstance(params["action"], str): return params["action"] for key in ("input", "arguments", "data"): nested = params.get(key) if isinstance(nested, dict): action = _extract_action(nested) if action is not None: return action elif isinstance(nested, str): parsed = _parse_json_object(nested) if parsed is not None: action = _extract_action(parsed) if action is not None: return action message = params.get("message") if isinstance(message, dict): action = _extract_action(message) if action is not None: return action messages = params.get("messages") if isinstance(messages, list): for item in reversed(messages): if isinstance(item, dict): action = _extract_action(item) if action is not None: return action return None def _extract_schedule_payload(params: dict[str, Any]) -> dict[str, Any]: direct = _dict_with_availability_fields(params) if direct is not None: return direct for key in ("input", "arguments", "data"): nested = params.get(key) if isinstance(nested, dict): extracted = _dict_with_availability_fields(nested) if extracted is not None: return extracted elif isinstance(nested, str): parsed = _parse_json_object(nested) if parsed is not None: extracted = _dict_with_availability_fields(parsed) if extracted is not None: return extracted message = params.get("message") if isinstance(message, dict): extracted = _extract_from_message(message) if extracted is not None: return extracted messages = params.get("messages") if isinstance(messages, list): for item in reversed(messages): if isinstance(item, dict): extracted = _extract_from_message(item) if extracted is not None: return extracted raise ValueError( "SendMessage requires scheduling input with 'start' and 'end'. " "Supported shapes: params.start/end, params.input.start/end, or message content JSON." ) def _extract_from_message(message: dict[str, Any]) -> dict[str, Any] | None: direct = _dict_with_availability_fields(message) if direct is not None: return direct text = message.get("text") if isinstance(text, str): parsed = _parse_json_object(text) if parsed is not None: extracted = _dict_with_availability_fields(parsed) if extracted is not None: return extracted content = message.get("content") return _extract_from_content(content) def _extract_from_content(content: Any) -> dict[str, Any] | None: if isinstance(content, dict): direct = _dict_with_availability_fields(content) if direct is not None: return direct if "text" in content and isinstance(content["text"], str): parsed = _parse_json_object(content["text"]) if parsed is not None: extracted = _dict_with_availability_fields(parsed) if extracted is not None: return extracted nested = content.get("content") if nested is not None: return _extract_from_content(nested) if isinstance(content, list): for part in content: extracted = _extract_from_content(part) if extracted is not None: return extracted if isinstance(content, str): parsed = _parse_json_object(content) if parsed is not None: return _dict_with_availability_fields(parsed) return None def _dict_with_availability_fields(value: dict[str, Any]) -> dict[str, Any] | None: if "start" in value and "end" in value: return value return None def _parse_json_object(raw_value: str) -> dict[str, Any] | None: stripped = raw_value.strip() if not stripped: return None try: loaded = json.loads(stripped) if isinstance(loaded, dict): return cast(dict[str, Any], loaded) except json.JSONDecodeError: pass match = re.search(r"\{.*\}", stripped, flags=re.DOTALL) if not match: return None try: loaded = json.loads(match.group(0)) except json.JSONDecodeError: return None if isinstance(loaded, dict): return cast(dict[str, Any], loaded) return None def _require_string(payload: dict[str, Any], key: str) -> str: value = payload.get(key) if not isinstance(value, str) or not value.strip(): raise ValueError(f"'{key}' must be a non-empty string.") return value.strip() def _parse_calendar_ids(value: Any) -> list[str] | None: if value is None: return None if not isinstance(value, list): raise ValueError("'calendar_ids' must be an array of strings.") calendar_ids = [str(item).strip() for item in value if str(item).strip()] return calendar_ids or None