You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
391 lines
12 KiB
Python
391 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from typing import Annotated, Any, cast
|
|
|
|
from fastapi import APIRouter, Header, HTTPException, Request, Response
|
|
|
|
from app.a2a.agent_card import build_agent_card
|
|
from app.a2a.models import A2ARpcError, A2ARpcRequest, A2ARpcResponse
|
|
from app.config import get_settings
|
|
from app.core.service import CoreAgentService
|
|
from app.security import AuthBackend
|
|
|
|
settings = get_settings()
|
|
auth_backend = AuthBackend(settings=settings)
|
|
core_service = CoreAgentService(settings=settings, logger=logging.getLogger("personal-agent.a2a"))
|
|
|
|
router = APIRouter(tags=["a2a"])
|
|
|
|
SEND_MESSAGE_METHODS = {"SendMessage", "send_message", "messages.send"}
|
|
PING_METHODS = {"ping", "health.ping", "health/ping"}
|
|
DEFAULT_SEND_MESSAGE_ACTION = "check_availability"
|
|
MEETING_INTERVALS_ACTION = "available_meeting_intervals"
|
|
_ACTION_SCOPE = {
|
|
DEFAULT_SEND_MESSAGE_ACTION: "availability:read",
|
|
MEETING_INTERVALS_ACTION: "available_meeting_intervals:read",
|
|
}
|
|
|
|
|
|
@router.get("/.well-known/agent-card.json")
|
|
def get_agent_card(request: Request, response: Response) -> dict[str, Any]:
|
|
response.headers["A2A-Version"] = "1.0"
|
|
return build_agent_card(settings=settings, request=request)
|
|
|
|
|
|
@router.post("/a2a/rpc", response_model=A2ARpcResponse)
|
|
def a2a_rpc(
|
|
payload: A2ARpcRequest,
|
|
response: Response,
|
|
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
|
|
authorization: Annotated[str | None, Header()] = None,
|
|
) -> A2ARpcResponse:
|
|
response.headers["A2A-Version"] = "1.0"
|
|
if payload.jsonrpc != "2.0":
|
|
return _error_response(
|
|
request_id=payload.id,
|
|
code=-32600,
|
|
message="Invalid Request: jsonrpc must be '2.0'.",
|
|
)
|
|
|
|
if payload.method in PING_METHODS:
|
|
return A2ARpcResponse(
|
|
id=payload.id,
|
|
result={"status": "ok", "agent": settings.a2a_agent_name},
|
|
)
|
|
|
|
if payload.method in SEND_MESSAGE_METHODS:
|
|
try:
|
|
action = _resolve_send_message_action(payload.params)
|
|
except ValueError as exc:
|
|
return _error_response(
|
|
request_id=payload.id,
|
|
code=-32602,
|
|
message=str(exc),
|
|
)
|
|
auth_error = _check_scope_access(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
request_id=payload.id,
|
|
required_scope=_ACTION_SCOPE[action],
|
|
)
|
|
if auth_error:
|
|
return auth_error
|
|
return _handle_send_message(payload, action=action)
|
|
|
|
return _error_response(
|
|
request_id=payload.id,
|
|
code=-32601,
|
|
message=f"Method '{payload.method}' is not implemented yet.",
|
|
)
|
|
|
|
|
|
def _error_response(request_id: str | int | None, code: int, message: str) -> A2ARpcResponse:
|
|
return A2ARpcResponse(
|
|
id=request_id,
|
|
error=A2ARpcError(code=code, message=message),
|
|
)
|
|
|
|
|
|
def _check_scope_access(
|
|
*,
|
|
x_api_key: str | None,
|
|
authorization: str | None,
|
|
request_id: str | int | None,
|
|
required_scope: str,
|
|
) -> A2ARpcResponse | None:
|
|
try:
|
|
auth_backend.authenticate(
|
|
x_api_key=x_api_key,
|
|
authorization=authorization,
|
|
required_scopes={required_scope},
|
|
)
|
|
except HTTPException as exc:
|
|
return A2ARpcResponse(
|
|
id=request_id,
|
|
error=A2ARpcError(
|
|
code=-32001,
|
|
message=str(exc.detail),
|
|
data={"http_status": exc.status_code},
|
|
),
|
|
)
|
|
return None
|
|
|
|
|
|
def _handle_send_message(payload: A2ARpcRequest, *, action: str) -> A2ARpcResponse:
|
|
try:
|
|
request_payload = _extract_schedule_payload(payload.params)
|
|
start = _require_string(request_payload, "start")
|
|
end = _require_string(request_payload, "end")
|
|
calendar_ids = _parse_calendar_ids(request_payload.get("calendar_ids"))
|
|
except ValueError as exc:
|
|
return _error_response(
|
|
request_id=payload.id,
|
|
code=-32602,
|
|
message=str(exc),
|
|
)
|
|
|
|
try:
|
|
if action == MEETING_INTERVALS_ACTION:
|
|
return _meeting_intervals_response(payload, start, end, calendar_ids)
|
|
result = core_service.check_availability(start, end, calendar_ids)
|
|
except ValueError as exc:
|
|
return _error_response(
|
|
request_id=payload.id,
|
|
code=-32602,
|
|
message=str(exc),
|
|
)
|
|
except FileNotFoundError as exc:
|
|
return _error_response(
|
|
request_id=payload.id,
|
|
code=-32000,
|
|
message=str(exc),
|
|
)
|
|
except Exception as exc:
|
|
failure_label = (
|
|
"Meeting interval lookup failed"
|
|
if action == MEETING_INTERVALS_ACTION
|
|
else "Availability lookup failed"
|
|
)
|
|
return _error_response(
|
|
request_id=payload.id,
|
|
code=-32000,
|
|
message=f"{failure_label}: {exc}",
|
|
)
|
|
|
|
availability = {
|
|
"start": result.start,
|
|
"end": result.end,
|
|
"available": result.available,
|
|
"busy_slots": [
|
|
{
|
|
"calendar_id": slot.calendar_id,
|
|
"start": slot.start,
|
|
"end": slot.end,
|
|
}
|
|
for slot in result.busy_slots
|
|
],
|
|
"checked_calendars": result.checked_calendars,
|
|
}
|
|
return A2ARpcResponse(
|
|
id=payload.id,
|
|
result={
|
|
"type": "availability.result",
|
|
"availability": availability,
|
|
},
|
|
)
|
|
|
|
|
|
def _meeting_intervals_response(
|
|
payload: A2ARpcRequest,
|
|
start: str,
|
|
end: str,
|
|
calendar_ids: list[str] | None,
|
|
) -> A2ARpcResponse:
|
|
result = core_service.available_meeting_intervals(start, end, calendar_ids)
|
|
meeting_intervals = {
|
|
"start": result.start,
|
|
"end": result.end,
|
|
"timezone": result.timezone,
|
|
"meeting_intervals": [
|
|
{
|
|
"start": interval.start,
|
|
"end": interval.end,
|
|
}
|
|
for interval in result.meeting_intervals
|
|
],
|
|
"checked_calendars": result.checked_calendars,
|
|
}
|
|
return A2ARpcResponse(
|
|
id=payload.id,
|
|
result={
|
|
"type": "available_meeting_intervals.result",
|
|
"meeting_intervals": meeting_intervals,
|
|
},
|
|
)
|
|
|
|
|
|
def _resolve_send_message_action(params: dict[str, Any]) -> str:
|
|
action = _extract_action(params)
|
|
if action is None:
|
|
return DEFAULT_SEND_MESSAGE_ACTION
|
|
normalized = action.strip()
|
|
if not normalized:
|
|
return DEFAULT_SEND_MESSAGE_ACTION
|
|
lowered = normalized.lower()
|
|
if lowered in {"check_availability", "availability"}:
|
|
return DEFAULT_SEND_MESSAGE_ACTION
|
|
if lowered == MEETING_INTERVALS_ACTION:
|
|
return MEETING_INTERVALS_ACTION
|
|
raise ValueError(
|
|
"Unsupported 'action'. Expected 'available_meeting_intervals' or omitted for availability."
|
|
)
|
|
|
|
|
|
def _extract_action(params: dict[str, Any]) -> str | None:
|
|
if "action" in params and isinstance(params["action"], str):
|
|
return params["action"]
|
|
|
|
for key in ("input", "arguments", "data"):
|
|
nested = params.get(key)
|
|
if isinstance(nested, dict):
|
|
action = _extract_action(nested)
|
|
if action is not None:
|
|
return action
|
|
elif isinstance(nested, str):
|
|
parsed = _parse_json_object(nested)
|
|
if parsed is not None:
|
|
action = _extract_action(parsed)
|
|
if action is not None:
|
|
return action
|
|
|
|
message = params.get("message")
|
|
if isinstance(message, dict):
|
|
action = _extract_action(message)
|
|
if action is not None:
|
|
return action
|
|
|
|
messages = params.get("messages")
|
|
if isinstance(messages, list):
|
|
for item in reversed(messages):
|
|
if isinstance(item, dict):
|
|
action = _extract_action(item)
|
|
if action is not None:
|
|
return action
|
|
|
|
return None
|
|
|
|
|
|
def _extract_schedule_payload(params: dict[str, Any]) -> dict[str, Any]:
|
|
direct = _dict_with_availability_fields(params)
|
|
if direct is not None:
|
|
return direct
|
|
|
|
for key in ("input", "arguments", "data"):
|
|
nested = params.get(key)
|
|
if isinstance(nested, dict):
|
|
extracted = _dict_with_availability_fields(nested)
|
|
if extracted is not None:
|
|
return extracted
|
|
elif isinstance(nested, str):
|
|
parsed = _parse_json_object(nested)
|
|
if parsed is not None:
|
|
extracted = _dict_with_availability_fields(parsed)
|
|
if extracted is not None:
|
|
return extracted
|
|
|
|
message = params.get("message")
|
|
if isinstance(message, dict):
|
|
extracted = _extract_from_message(message)
|
|
if extracted is not None:
|
|
return extracted
|
|
|
|
messages = params.get("messages")
|
|
if isinstance(messages, list):
|
|
for item in reversed(messages):
|
|
if isinstance(item, dict):
|
|
extracted = _extract_from_message(item)
|
|
if extracted is not None:
|
|
return extracted
|
|
|
|
raise ValueError(
|
|
"SendMessage requires scheduling input with 'start' and 'end'. "
|
|
"Supported shapes: params.start/end, params.input.start/end, or message content JSON."
|
|
)
|
|
|
|
|
|
def _extract_from_message(message: dict[str, Any]) -> dict[str, Any] | None:
|
|
direct = _dict_with_availability_fields(message)
|
|
if direct is not None:
|
|
return direct
|
|
|
|
text = message.get("text")
|
|
if isinstance(text, str):
|
|
parsed = _parse_json_object(text)
|
|
if parsed is not None:
|
|
extracted = _dict_with_availability_fields(parsed)
|
|
if extracted is not None:
|
|
return extracted
|
|
|
|
content = message.get("content")
|
|
return _extract_from_content(content)
|
|
|
|
|
|
def _extract_from_content(content: Any) -> dict[str, Any] | None:
|
|
if isinstance(content, dict):
|
|
direct = _dict_with_availability_fields(content)
|
|
if direct is not None:
|
|
return direct
|
|
|
|
if "text" in content and isinstance(content["text"], str):
|
|
parsed = _parse_json_object(content["text"])
|
|
if parsed is not None:
|
|
extracted = _dict_with_availability_fields(parsed)
|
|
if extracted is not None:
|
|
return extracted
|
|
|
|
nested = content.get("content")
|
|
if nested is not None:
|
|
return _extract_from_content(nested)
|
|
|
|
if isinstance(content, list):
|
|
for part in content:
|
|
extracted = _extract_from_content(part)
|
|
if extracted is not None:
|
|
return extracted
|
|
|
|
if isinstance(content, str):
|
|
parsed = _parse_json_object(content)
|
|
if parsed is not None:
|
|
return _dict_with_availability_fields(parsed)
|
|
|
|
return None
|
|
|
|
|
|
def _dict_with_availability_fields(value: dict[str, Any]) -> dict[str, Any] | None:
|
|
if "start" in value and "end" in value:
|
|
return value
|
|
return None
|
|
|
|
|
|
def _parse_json_object(raw_value: str) -> dict[str, Any] | None:
|
|
stripped = raw_value.strip()
|
|
if not stripped:
|
|
return None
|
|
try:
|
|
loaded = json.loads(stripped)
|
|
if isinstance(loaded, dict):
|
|
return cast(dict[str, Any], loaded)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
match = re.search(r"\{.*\}", stripped, flags=re.DOTALL)
|
|
if not match:
|
|
return None
|
|
try:
|
|
loaded = json.loads(match.group(0))
|
|
except json.JSONDecodeError:
|
|
return None
|
|
if isinstance(loaded, dict):
|
|
return cast(dict[str, Any], loaded)
|
|
return None
|
|
|
|
|
|
def _require_string(payload: dict[str, Any], key: str) -> str:
|
|
value = payload.get(key)
|
|
if not isinstance(value, str) or not value.strip():
|
|
raise ValueError(f"'{key}' must be a non-empty string.")
|
|
return value.strip()
|
|
|
|
|
|
def _parse_calendar_ids(value: Any) -> list[str] | None:
|
|
if value is None:
|
|
return None
|
|
if not isinstance(value, list):
|
|
raise ValueError("'calendar_ids' must be an array of strings.")
|
|
|
|
calendar_ids = [str(item).strip() for item in value if str(item).strip()]
|
|
return calendar_ids or None
|