Compare commits
10 Commits
42189e972d
...
e74b29381c
| Author | SHA1 | Date |
|---|---|---|
|
|
e74b29381c | 1 week ago |
|
|
be9bbf4f83 | 1 week ago |
|
|
bfd752ac39 | 1 week ago |
|
|
8bfe6c518f | 1 week ago |
|
|
54da3efdc9 | 1 week ago |
|
|
1b23493167 | 1 week ago |
|
|
bb5ce5e71a | 1 week ago |
|
|
9603483648 | 1 week ago |
|
|
d1238f186c | 1 week ago |
|
|
d25d429519 | 1 week ago |
@ -0,0 +1,3 @@
|
||||
from app.a2a.router import router as a2a_router
|
||||
|
||||
__all__ = ["a2a_router"]
|
||||
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
|
||||
def build_agent_card(settings: Settings, request: Request) -> dict[str, Any]:
|
||||
base_url = _resolve_base_url(settings=settings, request=request)
|
||||
return {
|
||||
"name": settings.a2a_agent_name,
|
||||
"description": settings.a2a_agent_description,
|
||||
"url": f"{base_url}/a2a/rpc",
|
||||
"version": "0.1.0",
|
||||
"protocolVersion": "1.0",
|
||||
"defaultInputModes": ["application/json"],
|
||||
"defaultOutputModes": ["application/json"],
|
||||
"capabilities": {
|
||||
"streaming": False,
|
||||
"pushNotifications": False,
|
||||
"stateTransitionHistory": False,
|
||||
},
|
||||
"securitySchemes": {
|
||||
"bearerAuth": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"description": "Use the same Bearer/API key auth as the REST API.",
|
||||
}
|
||||
},
|
||||
"security": [{"bearerAuth": []}],
|
||||
"skills": [
|
||||
{
|
||||
"id": "availability.query",
|
||||
"name": "Check Availability",
|
||||
"description": "Checks Google Calendar availability for a given time range.",
|
||||
"tags": ["calendar", "availability", "scheduling"],
|
||||
"examples": [
|
||||
"SendMessage with params.start/end/calendar_ids to check free/busy for a time window."
|
||||
],
|
||||
"inputModes": ["application/json"],
|
||||
"outputModes": ["application/json"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _resolve_base_url(*, settings: Settings, request: Request) -> str:
|
||||
if settings.a2a_public_base_url:
|
||||
return settings.a2a_public_base_url.rstrip("/")
|
||||
return str(request.base_url).rstrip("/")
|
||||
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class A2ARpcRequest(BaseModel):
|
||||
jsonrpc: str = "2.0"
|
||||
id: str | int | None = None
|
||||
method: str
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class A2ARpcError(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2ARpcResponse(BaseModel):
|
||||
jsonrpc: str = "2.0"
|
||||
id: str | int | None = None
|
||||
result: dict[str, Any] | None = None
|
||||
error: A2ARpcError | None = None
|
||||
@ -0,0 +1,287 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Annotated, Any, cast
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, Response
|
||||
|
||||
from app.a2a.agent_card import build_agent_card
|
||||
from app.a2a.models import A2ARpcError, A2ARpcRequest, A2ARpcResponse
|
||||
from app.config import get_settings
|
||||
from app.core.service import CoreAgentService
|
||||
from app.security import AuthBackend
|
||||
|
||||
settings = get_settings()
|
||||
auth_backend = AuthBackend(settings=settings)
|
||||
core_service = CoreAgentService(settings=settings, logger=logging.getLogger("personal-agent.a2a"))
|
||||
|
||||
router = APIRouter(tags=["a2a"])
|
||||
|
||||
SEND_MESSAGE_METHODS = {"SendMessage", "send_message", "messages.send"}
|
||||
PING_METHODS = {"ping", "health.ping", "health/ping"}
|
||||
|
||||
|
||||
@router.get("/.well-known/agent-card.json")
|
||||
def get_agent_card(request: Request, response: Response) -> dict[str, Any]:
|
||||
response.headers["A2A-Version"] = "1.0"
|
||||
return build_agent_card(settings=settings, request=request)
|
||||
|
||||
|
||||
@router.post("/a2a/rpc", response_model=A2ARpcResponse)
|
||||
def a2a_rpc(
|
||||
payload: A2ARpcRequest,
|
||||
response: Response,
|
||||
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> A2ARpcResponse:
|
||||
response.headers["A2A-Version"] = "1.0"
|
||||
if payload.jsonrpc != "2.0":
|
||||
return _error_response(
|
||||
request_id=payload.id,
|
||||
code=-32600,
|
||||
message="Invalid Request: jsonrpc must be '2.0'.",
|
||||
)
|
||||
|
||||
if payload.method in PING_METHODS:
|
||||
return A2ARpcResponse(
|
||||
id=payload.id,
|
||||
result={"status": "ok", "agent": settings.a2a_agent_name},
|
||||
)
|
||||
|
||||
if payload.method in SEND_MESSAGE_METHODS:
|
||||
auth_error = _check_availability_access(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
request_id=payload.id,
|
||||
)
|
||||
if auth_error:
|
||||
return auth_error
|
||||
return _handle_send_message(payload)
|
||||
|
||||
return _error_response(
|
||||
request_id=payload.id,
|
||||
code=-32601,
|
||||
message=f"Method '{payload.method}' is not implemented yet.",
|
||||
)
|
||||
|
||||
|
||||
def _error_response(request_id: str | int | None, code: int, message: str) -> A2ARpcResponse:
|
||||
return A2ARpcResponse(
|
||||
id=request_id,
|
||||
error=A2ARpcError(code=code, message=message),
|
||||
)
|
||||
|
||||
|
||||
def _check_availability_access(
|
||||
*,
|
||||
x_api_key: str | None,
|
||||
authorization: str | None,
|
||||
request_id: str | int | None,
|
||||
) -> A2ARpcResponse | None:
|
||||
try:
|
||||
auth_backend.authenticate(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
required_scopes={"availability:read"},
|
||||
)
|
||||
except HTTPException as exc:
|
||||
return A2ARpcResponse(
|
||||
id=request_id,
|
||||
error=A2ARpcError(
|
||||
code=-32001,
|
||||
message=str(exc.detail),
|
||||
data={"http_status": exc.status_code},
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _handle_send_message(payload: A2ARpcRequest) -> A2ARpcResponse:
|
||||
try:
|
||||
request_payload = _extract_availability_payload(payload.params)
|
||||
start = _require_string(request_payload, "start")
|
||||
end = _require_string(request_payload, "end")
|
||||
calendar_ids = _parse_calendar_ids(request_payload.get("calendar_ids"))
|
||||
except ValueError as exc:
|
||||
return _error_response(
|
||||
request_id=payload.id,
|
||||
code=-32602,
|
||||
message=str(exc),
|
||||
)
|
||||
|
||||
try:
|
||||
result = core_service.check_availability(start, end, calendar_ids)
|
||||
except ValueError as exc:
|
||||
return _error_response(
|
||||
request_id=payload.id,
|
||||
code=-32602,
|
||||
message=str(exc),
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
return _error_response(
|
||||
request_id=payload.id,
|
||||
code=-32000,
|
||||
message=str(exc),
|
||||
)
|
||||
except Exception as exc:
|
||||
return _error_response(
|
||||
request_id=payload.id,
|
||||
code=-32000,
|
||||
message=f"Availability lookup failed: {exc}",
|
||||
)
|
||||
|
||||
availability = {
|
||||
"start": result.start,
|
||||
"end": result.end,
|
||||
"available": result.available,
|
||||
"busy_slots": [
|
||||
{
|
||||
"calendar_id": slot.calendar_id,
|
||||
"start": slot.start,
|
||||
"end": slot.end,
|
||||
}
|
||||
for slot in result.busy_slots
|
||||
],
|
||||
"checked_calendars": result.checked_calendars,
|
||||
}
|
||||
return A2ARpcResponse(
|
||||
id=payload.id,
|
||||
result={
|
||||
"type": "availability.result",
|
||||
"availability": availability,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _extract_availability_payload(params: dict[str, Any]) -> dict[str, Any]:
|
||||
direct = _dict_with_availability_fields(params)
|
||||
if direct is not None:
|
||||
return direct
|
||||
|
||||
for key in ("input", "arguments", "data"):
|
||||
nested = params.get(key)
|
||||
if isinstance(nested, dict):
|
||||
extracted = _dict_with_availability_fields(nested)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
elif isinstance(nested, str):
|
||||
parsed = _parse_json_object(nested)
|
||||
if parsed is not None:
|
||||
extracted = _dict_with_availability_fields(parsed)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
|
||||
message = params.get("message")
|
||||
if isinstance(message, dict):
|
||||
extracted = _extract_from_message(message)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
|
||||
messages = params.get("messages")
|
||||
if isinstance(messages, list):
|
||||
for item in reversed(messages):
|
||||
if isinstance(item, dict):
|
||||
extracted = _extract_from_message(item)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
|
||||
raise ValueError(
|
||||
"SendMessage requires availability input with 'start' and 'end'. "
|
||||
"Supported shapes: params.start/end, params.input.start/end, or message content JSON."
|
||||
)
|
||||
|
||||
|
||||
def _extract_from_message(message: dict[str, Any]) -> dict[str, Any] | None:
|
||||
direct = _dict_with_availability_fields(message)
|
||||
if direct is not None:
|
||||
return direct
|
||||
|
||||
text = message.get("text")
|
||||
if isinstance(text, str):
|
||||
parsed = _parse_json_object(text)
|
||||
if parsed is not None:
|
||||
extracted = _dict_with_availability_fields(parsed)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
|
||||
content = message.get("content")
|
||||
return _extract_from_content(content)
|
||||
|
||||
|
||||
def _extract_from_content(content: Any) -> dict[str, Any] | None:
|
||||
if isinstance(content, dict):
|
||||
direct = _dict_with_availability_fields(content)
|
||||
if direct is not None:
|
||||
return direct
|
||||
|
||||
if "text" in content and isinstance(content["text"], str):
|
||||
parsed = _parse_json_object(content["text"])
|
||||
if parsed is not None:
|
||||
extracted = _dict_with_availability_fields(parsed)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
|
||||
nested = content.get("content")
|
||||
if nested is not None:
|
||||
return _extract_from_content(nested)
|
||||
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
extracted = _extract_from_content(part)
|
||||
if extracted is not None:
|
||||
return extracted
|
||||
|
||||
if isinstance(content, str):
|
||||
parsed = _parse_json_object(content)
|
||||
if parsed is not None:
|
||||
return _dict_with_availability_fields(parsed)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _dict_with_availability_fields(value: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if "start" in value and "end" in value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _parse_json_object(raw_value: str) -> dict[str, Any] | None:
|
||||
stripped = raw_value.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
try:
|
||||
loaded = json.loads(stripped)
|
||||
if isinstance(loaded, dict):
|
||||
return cast(dict[str, Any], loaded)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
match = re.search(r"\{.*\}", stripped, flags=re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
loaded = json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
if isinstance(loaded, dict):
|
||||
return cast(dict[str, Any], loaded)
|
||||
return None
|
||||
|
||||
|
||||
def _require_string(payload: dict[str, Any], key: str) -> str:
|
||||
value = payload.get(key)
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
raise ValueError(f"'{key}' must be a non-empty string.")
|
||||
return value.strip()
|
||||
|
||||
|
||||
def _parse_calendar_ids(value: Any) -> list[str] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("'calendar_ids' must be an array of strings.")
|
||||
|
||||
calendar_ids = [str(item).strip() for item in value if str(item).strip()]
|
||||
return calendar_ids or None
|
||||
@ -0,0 +1,25 @@
|
||||
from app.core.models import (
|
||||
CoreAvailabilityResult,
|
||||
CoreBusySlot,
|
||||
CoreMailingListCandidate,
|
||||
CoreMethodExecution,
|
||||
CoreScanResult,
|
||||
CoreUnsubscribeCandidatesResult,
|
||||
CoreUnsubscribeDigestResult,
|
||||
CoreUnsubscribeExecutionResult,
|
||||
CoreUnsubscribeMethod,
|
||||
)
|
||||
from app.core.service import CoreAgentService
|
||||
|
||||
__all__ = [
|
||||
"CoreAgentService",
|
||||
"CoreScanResult",
|
||||
"CoreAvailabilityResult",
|
||||
"CoreBusySlot",
|
||||
"CoreUnsubscribeDigestResult",
|
||||
"CoreUnsubscribeMethod",
|
||||
"CoreMailingListCandidate",
|
||||
"CoreUnsubscribeCandidatesResult",
|
||||
"CoreMethodExecution",
|
||||
"CoreUnsubscribeExecutionResult",
|
||||
]
|
||||
@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoreScanResult:
|
||||
scanned: int
|
||||
linkedin: int
|
||||
advertising: int
|
||||
veille_techno: int
|
||||
skipped: int
|
||||
failed: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoreBusySlot:
|
||||
calendar_id: str
|
||||
start: str
|
||||
end: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoreAvailabilityResult:
|
||||
start: str
|
||||
end: str
|
||||
available: bool
|
||||
busy_slots: list[CoreBusySlot]
|
||||
checked_calendars: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoreUnsubscribeDigestResult:
|
||||
scanned_messages: int
|
||||
extracted_unique_links: int
|
||||
new_links: int
|
||||
sent_to: str | None
|
||||
email_sent: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoreUnsubscribeMethod:
|
||||
method_id: str
|
||||
method_type: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoreMailingListCandidate:
|
||||
candidate_id: str
|
||||
list_name: str
|
||||
sender_domain: str
|
||||
message_count: int
|
||||
sample_senders: list[str]
|
||||
sample_subjects: list[str]
|
||||
methods: list[CoreUnsubscribeMethod]
|
||||
approved: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoreUnsubscribeCandidatesResult:
|
||||
scanned_messages: int
|
||||
candidates: list[CoreMailingListCandidate]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoreMethodExecution:
|
||||
candidate_id: str
|
||||
list_name: str
|
||||
method_id: str
|
||||
method_type: str
|
||||
value: str
|
||||
success: bool
|
||||
detail: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CoreUnsubscribeExecutionResult:
|
||||
scanned_messages: int
|
||||
candidates_considered: int
|
||||
selected_candidates: int
|
||||
executed_methods: int
|
||||
skipped_already_executed: int
|
||||
failed_methods: int
|
||||
updated_approved_count: int
|
||||
results: list[CoreMethodExecution]
|
||||
@ -0,0 +1,202 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.calendar_agent import CalendarAvailabilityAgent
|
||||
from app.config import Settings
|
||||
from app.gmail_agent import GmailTriageAgent
|
||||
from app.google_clients import build_calendar_service, build_gmail_service
|
||||
from app.strands_classifier import StrandsEmailClassifier
|
||||
from app.unsubscribe_agent import UnsubscribeDigestAgent
|
||||
from app.unsubscribe_hil_agent import (
|
||||
CandidateSnapshot,
|
||||
UnsubscribeExecutionResult,
|
||||
UnsubscribeHumanLoopAgent,
|
||||
)
|
||||
|
||||
from app.core.models import (
|
||||
CoreAvailabilityResult,
|
||||
CoreBusySlot,
|
||||
CoreMailingListCandidate,
|
||||
CoreMethodExecution,
|
||||
CoreScanResult,
|
||||
CoreUnsubscribeCandidatesResult,
|
||||
CoreUnsubscribeDigestResult,
|
||||
CoreUnsubscribeExecutionResult,
|
||||
CoreUnsubscribeMethod,
|
||||
)
|
||||
|
||||
|
||||
class CoreAgentService:
|
||||
def __init__(self, settings: Settings, *, logger: logging.Logger | None = None) -> None:
|
||||
self.settings = settings
|
||||
self.logger = logger or logging.getLogger("personal-agent.core")
|
||||
self._strands_key_warning_logged = False
|
||||
|
||||
def scan_mailbox(self, max_results: int) -> CoreScanResult:
|
||||
gmail_service = build_gmail_service(self.settings)
|
||||
gmail_agent = GmailTriageAgent(
|
||||
gmail_service=gmail_service,
|
||||
query=self.settings.gmail_query,
|
||||
classifier=self._build_strands_classifier(),
|
||||
fallback_to_rules=self.settings.llm_fallback_to_rules,
|
||||
)
|
||||
result = gmail_agent.scan_and_route_messages(max_results=max_results)
|
||||
return CoreScanResult(
|
||||
scanned=result.scanned,
|
||||
linkedin=result.linkedin,
|
||||
advertising=result.advertising,
|
||||
veille_techno=result.veille_techno,
|
||||
skipped=result.skipped,
|
||||
failed=result.failed,
|
||||
)
|
||||
|
||||
def check_availability(
|
||||
self, start: str, end: str, calendar_ids: list[str] | None
|
||||
) -> CoreAvailabilityResult:
|
||||
calendar_service = build_calendar_service(self.settings)
|
||||
availability_agent = CalendarAvailabilityAgent(calendar_service=calendar_service)
|
||||
result = availability_agent.get_availability(start, end, calendar_ids)
|
||||
return CoreAvailabilityResult(
|
||||
start=result.start,
|
||||
end=result.end,
|
||||
available=result.available,
|
||||
busy_slots=[
|
||||
CoreBusySlot(
|
||||
calendar_id=slot["calendar_id"],
|
||||
start=slot["start"],
|
||||
end=slot["end"],
|
||||
)
|
||||
for slot in result.busy_slots
|
||||
],
|
||||
checked_calendars=result.checked_calendars,
|
||||
)
|
||||
|
||||
def scan_unsubscribe_digest(self, max_results: int) -> CoreUnsubscribeDigestResult:
|
||||
bounded_max_results = max(1, min(max_results, 500))
|
||||
gmail_service = build_gmail_service(self.settings)
|
||||
unsubscribe_agent = UnsubscribeDigestAgent(
|
||||
gmail_service=gmail_service,
|
||||
query=self.settings.unsubscribe_query,
|
||||
state_file=self.settings.unsubscribe_state_file,
|
||||
recipient_email=self.settings.unsubscribe_digest_recipient,
|
||||
send_empty_digest=self.settings.unsubscribe_send_empty_digest,
|
||||
)
|
||||
result = unsubscribe_agent.scan_and_send_digest(max_results=bounded_max_results)
|
||||
return CoreUnsubscribeDigestResult(
|
||||
scanned_messages=result.scanned_messages,
|
||||
extracted_unique_links=result.extracted_unique_links,
|
||||
new_links=result.new_links,
|
||||
sent_to=result.sent_to,
|
||||
email_sent=result.email_sent,
|
||||
)
|
||||
|
||||
def list_unsubscribe_candidates(self, max_results: int) -> CoreUnsubscribeCandidatesResult:
|
||||
snapshot = self._build_unsubscribe_hil_agent().discover_candidates(max_results=max_results)
|
||||
return self._snapshot_to_core(snapshot)
|
||||
|
||||
def execute_unsubscribe_selected(
|
||||
self,
|
||||
selected_candidate_ids: list[str],
|
||||
max_results: int,
|
||||
remember_selection: bool,
|
||||
) -> CoreUnsubscribeExecutionResult:
|
||||
result = self._build_unsubscribe_hil_agent().execute_selected(
|
||||
selected_candidate_ids=selected_candidate_ids,
|
||||
max_results=max_results,
|
||||
remember_selection=remember_selection,
|
||||
)
|
||||
return self._execution_to_core(result)
|
||||
|
||||
def run_unsubscribe_auto(self, max_results: int) -> CoreUnsubscribeExecutionResult:
|
||||
result = self._build_unsubscribe_hil_agent().execute_for_approved(max_results=max_results)
|
||||
return self._execution_to_core(result)
|
||||
|
||||
def _build_strands_classifier(self) -> StrandsEmailClassifier | None:
|
||||
if not self.settings.strands_api_key:
|
||||
if self.settings.llm_fallback_to_rules:
|
||||
if not self._strands_key_warning_logged:
|
||||
self.logger.warning(
|
||||
"Strands API key not set. Falling back to rules-based classification."
|
||||
)
|
||||
self._strands_key_warning_logged = True
|
||||
return None
|
||||
raise RuntimeError(
|
||||
"STRANDS_OPENAI_API_KEY (or LLM_API_KEY) is required when LLM_FALLBACK_TO_RULES is disabled."
|
||||
)
|
||||
|
||||
try:
|
||||
return StrandsEmailClassifier(
|
||||
api_key=self.settings.strands_api_key,
|
||||
model_id=self.settings.strands_model_id,
|
||||
base_url=self.settings.strands_base_url,
|
||||
timeout_seconds=self.settings.strands_timeout_seconds,
|
||||
temperature=self.settings.strands_temperature,
|
||||
)
|
||||
except Exception:
|
||||
if self.settings.llm_fallback_to_rules:
|
||||
self.logger.exception(
|
||||
"Could not initialize Strands classifier; using rules fallback."
|
||||
)
|
||||
return None
|
||||
raise
|
||||
|
||||
def _build_unsubscribe_hil_agent(self) -> UnsubscribeHumanLoopAgent:
|
||||
gmail_service = build_gmail_service(self.settings)
|
||||
return UnsubscribeHumanLoopAgent(
|
||||
gmail_service=gmail_service,
|
||||
query=self.settings.unsubscribe_hil_query,
|
||||
state_file=self.settings.unsubscribe_hil_state_file,
|
||||
http_timeout_seconds=self.settings.unsubscribe_http_timeout_seconds,
|
||||
user_agent=self.settings.unsubscribe_user_agent,
|
||||
)
|
||||
|
||||
def _snapshot_to_core(self, snapshot: CandidateSnapshot) -> CoreUnsubscribeCandidatesResult:
|
||||
return CoreUnsubscribeCandidatesResult(
|
||||
scanned_messages=snapshot.scanned_messages,
|
||||
candidates=[
|
||||
CoreMailingListCandidate(
|
||||
candidate_id=candidate.candidate_id,
|
||||
list_name=candidate.list_name,
|
||||
sender_domain=candidate.sender_domain,
|
||||
message_count=candidate.message_count,
|
||||
sample_senders=candidate.sample_senders,
|
||||
sample_subjects=candidate.sample_subjects,
|
||||
methods=[
|
||||
CoreUnsubscribeMethod(
|
||||
method_id=method.method_id,
|
||||
method_type=method.method_type,
|
||||
value=method.value,
|
||||
)
|
||||
for method in candidate.methods
|
||||
],
|
||||
approved=candidate.approved,
|
||||
)
|
||||
for candidate in snapshot.candidates
|
||||
],
|
||||
)
|
||||
|
||||
def _execution_to_core(
|
||||
self, result: UnsubscribeExecutionResult
|
||||
) -> CoreUnsubscribeExecutionResult:
|
||||
return CoreUnsubscribeExecutionResult(
|
||||
scanned_messages=result.scanned_messages,
|
||||
candidates_considered=result.candidates_considered,
|
||||
selected_candidates=result.selected_candidates,
|
||||
executed_methods=result.executed_methods,
|
||||
skipped_already_executed=result.skipped_already_executed,
|
||||
failed_methods=result.failed_methods,
|
||||
updated_approved_count=result.updated_approved_count,
|
||||
results=[
|
||||
CoreMethodExecution(
|
||||
candidate_id=item.candidate_id,
|
||||
list_name=item.list_name,
|
||||
method_id=item.method_id,
|
||||
method_type=item.method_type,
|
||||
value=item.value,
|
||||
success=item.success,
|
||||
detail=item.detail,
|
||||
)
|
||||
for item in result.results
|
||||
],
|
||||
)
|
||||
@ -0,0 +1,3 @@
|
||||
from app.mcp.server import mcp
|
||||
|
||||
__all__ = ["mcp"]
|
||||
@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
|
||||
from app.config import get_settings
|
||||
from app.mcp.tools import (
|
||||
check_availability as check_availability_impl,
|
||||
execute_unsubscribe as execute_unsubscribe_impl,
|
||||
list_unsubscribe_candidates as list_unsubscribe_candidates_impl,
|
||||
scan_mailbox as scan_mailbox_impl,
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
mcp = FastMCP(
|
||||
"Personal Agent MCP",
|
||||
streamable_http_path="/",
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool(description="Check Google Calendar availability for a time range.")
|
||||
def check_availability(
|
||||
start: str,
|
||||
end: str,
|
||||
calendar_ids: list[str] | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, object]:
|
||||
return check_availability_impl(
|
||||
start=start,
|
||||
end=end,
|
||||
calendar_ids=calendar_ids,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
|
||||
if settings.mcp_enable_mutation_tools:
|
||||
|
||||
@mcp.tool(
|
||||
description="Scan unread root-inbox Gmail messages and apply classification labels."
|
||||
)
|
||||
def scan_mailbox(max_results: int = 100, ctx: Context | None = None) -> dict[str, object]:
|
||||
return scan_mailbox_impl(max_results=max_results, ctx=ctx)
|
||||
|
||||
@mcp.tool(
|
||||
description="List unsubscribe candidates discovered from advertising emails."
|
||||
)
|
||||
def list_unsubscribe_candidates(
|
||||
max_results: int = 500,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, object]:
|
||||
return list_unsubscribe_candidates_impl(max_results=max_results, ctx=ctx)
|
||||
|
||||
@mcp.tool(
|
||||
description="Execute unsubscribe actions for selected candidate IDs."
|
||||
)
|
||||
def execute_unsubscribe(
|
||||
selected_candidate_ids: list[str],
|
||||
max_results: int = 500,
|
||||
remember_selection: bool = True,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, object]:
|
||||
return execute_unsubscribe_impl(
|
||||
selected_candidate_ids=selected_candidate_ids,
|
||||
max_results=max_results,
|
||||
remember_selection=remember_selection,
|
||||
ctx=ctx,
|
||||
)
|
||||
@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
from app.config import get_settings
|
||||
from app.core.service import CoreAgentService
|
||||
from app.security import AuthBackend
|
||||
|
||||
settings = get_settings()
|
||||
core_service = CoreAgentService(settings=settings, logger=logging.getLogger("personal-agent.mcp"))
|
||||
auth_backend = AuthBackend(settings=settings)
|
||||
|
||||
|
||||
def check_availability(
|
||||
start: str,
|
||||
end: str,
|
||||
calendar_ids: list[str] | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return free/busy availability for a time range on one or more calendars."""
|
||||
_require_scope(ctx, "availability:read")
|
||||
result = core_service.check_availability(start=start, end=end, calendar_ids=calendar_ids)
|
||||
return {
|
||||
"start": result.start,
|
||||
"end": result.end,
|
||||
"available": result.available,
|
||||
"busy_slots": [
|
||||
{
|
||||
"calendar_id": slot.calendar_id,
|
||||
"start": slot.start,
|
||||
"end": slot.end,
|
||||
}
|
||||
for slot in result.busy_slots
|
||||
],
|
||||
"checked_calendars": result.checked_calendars,
|
||||
}
|
||||
|
||||
|
||||
def scan_mailbox(max_results: int = 100, ctx: Context | None = None) -> dict[str, Any]:
|
||||
"""Scan inbox emails and classify/move them according to current routing rules."""
|
||||
_require_scope(ctx, "mail:scan")
|
||||
result = core_service.scan_mailbox(max_results=max_results)
|
||||
return {
|
||||
"scanned": result.scanned,
|
||||
"linkedin": result.linkedin,
|
||||
"advertising": result.advertising,
|
||||
"veille_techno": result.veille_techno,
|
||||
"skipped": result.skipped,
|
||||
"failed": result.failed,
|
||||
}
|
||||
|
||||
|
||||
def list_unsubscribe_candidates(
|
||||
max_results: int = 500, ctx: Context | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""List unsubscribe candidates discovered from advertising emails."""
|
||||
_require_scope(ctx, "unsubscribe:read")
|
||||
result = core_service.list_unsubscribe_candidates(max_results=max_results)
|
||||
return {
|
||||
"scanned_messages": result.scanned_messages,
|
||||
"candidates": [
|
||||
{
|
||||
"candidate_id": candidate.candidate_id,
|
||||
"list_name": candidate.list_name,
|
||||
"sender_domain": candidate.sender_domain,
|
||||
"message_count": candidate.message_count,
|
||||
"sample_senders": candidate.sample_senders,
|
||||
"sample_subjects": candidate.sample_subjects,
|
||||
"approved": candidate.approved,
|
||||
"methods": [
|
||||
{
|
||||
"method_id": method.method_id,
|
||||
"method_type": method.method_type,
|
||||
"value": method.value,
|
||||
}
|
||||
for method in candidate.methods
|
||||
],
|
||||
}
|
||||
for candidate in result.candidates
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def execute_unsubscribe(
|
||||
selected_candidate_ids: list[str],
|
||||
max_results: int = 500,
|
||||
remember_selection: bool = True,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute unsubscribe actions for selected mailing list candidate IDs."""
|
||||
_require_scope(ctx, "unsubscribe:execute")
|
||||
result = core_service.execute_unsubscribe_selected(
|
||||
selected_candidate_ids=selected_candidate_ids,
|
||||
max_results=max_results,
|
||||
remember_selection=remember_selection,
|
||||
)
|
||||
return {
|
||||
"scanned_messages": result.scanned_messages,
|
||||
"candidates_considered": result.candidates_considered,
|
||||
"selected_candidates": result.selected_candidates,
|
||||
"executed_methods": result.executed_methods,
|
||||
"skipped_already_executed": result.skipped_already_executed,
|
||||
"failed_methods": result.failed_methods,
|
||||
"updated_approved_count": result.updated_approved_count,
|
||||
"results": [
|
||||
{
|
||||
"candidate_id": item.candidate_id,
|
||||
"list_name": item.list_name,
|
||||
"method_id": item.method_id,
|
||||
"method_type": item.method_type,
|
||||
"value": item.value,
|
||||
"success": item.success,
|
||||
"detail": item.detail,
|
||||
}
|
||||
for item in result.results
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _require_scope(ctx: Context | None, scope: str) -> None:
|
||||
x_api_key, authorization = _extract_auth_headers(ctx)
|
||||
try:
|
||||
auth_backend.authenticate(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
required_scopes={scope},
|
||||
)
|
||||
except HTTPException as exc:
|
||||
raise PermissionError(f"Unauthorized for scope '{scope}': {exc.detail}") from exc
|
||||
|
||||
|
||||
def _extract_auth_headers(ctx: Context | None) -> tuple[str | None, str | None]:
|
||||
if ctx is None:
|
||||
return None, None
|
||||
|
||||
request = ctx.request_context.request
|
||||
headers = getattr(request, "headers", None)
|
||||
if headers is None:
|
||||
return None, None
|
||||
|
||||
x_api_key = headers.get("x-api-key")
|
||||
authorization = headers.get("authorization")
|
||||
return x_api_key, authorization
|
||||
@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount
|
||||
|
||||
from app.mcp import mcp
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: Starlette):
|
||||
async with mcp.session_manager.run():
|
||||
yield
|
||||
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Mount("/mcp", app=mcp.streamable_http_app()),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
@ -0,0 +1,3 @@
|
||||
from app.security.auth import AuthBackend, AuthContext
|
||||
|
||||
__all__ = ["AuthBackend", "AuthContext"]
|
||||
@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthContext:
|
||||
subject: str
|
||||
auth_type: str
|
||||
scopes: set[str]
|
||||
|
||||
|
||||
class AuthBackend:
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
if self.settings.auth_mode == "api_key":
|
||||
return bool(self.settings.agent_api_key.strip())
|
||||
if self.settings.auth_mode == "jwt":
|
||||
return bool(self.settings.auth_jwt_secret.strip())
|
||||
if self.settings.auth_mode == "hybrid":
|
||||
return bool(
|
||||
self.settings.agent_api_key.strip() or self.settings.auth_jwt_secret.strip()
|
||||
)
|
||||
return bool(self.settings.agent_api_key.strip())
|
||||
|
||||
def authenticate(
|
||||
self,
|
||||
*,
|
||||
x_api_key: str | None,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> AuthContext:
|
||||
mode = self.settings.auth_mode
|
||||
if mode == "api_key":
|
||||
return self._authenticate_api_key(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
if mode == "jwt":
|
||||
return self._authenticate_jwt(authorization=authorization, required_scopes=required_scopes)
|
||||
if mode == "hybrid":
|
||||
return self._authenticate_hybrid(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Unsupported AUTH_MODE '{mode}'.",
|
||||
)
|
||||
|
||||
def _authenticate_hybrid(
|
||||
self,
|
||||
*,
|
||||
x_api_key: str | None,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> AuthContext:
|
||||
if self.settings.agent_api_key:
|
||||
api_key = self._resolve_api_key(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
)
|
||||
if api_key == self.settings.agent_api_key:
|
||||
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
|
||||
|
||||
if self.settings.auth_jwt_secret:
|
||||
return self._authenticate_jwt(
|
||||
authorization=authorization,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
return self._authenticate_api_key(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
def _authenticate_api_key(
|
||||
self,
|
||||
*,
|
||||
x_api_key: str | None,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> AuthContext:
|
||||
expected = self.settings.agent_api_key
|
||||
if not expected:
|
||||
return AuthContext(subject="anonymous", auth_type="none", scopes={"*"})
|
||||
|
||||
provided = self._resolve_api_key(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
)
|
||||
if provided != expected:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key.",
|
||||
)
|
||||
|
||||
# API key access is treated as admin-level compatibility mode.
|
||||
return AuthContext(subject="api-key", auth_type="api_key", scopes={"*"})
|
||||
|
||||
def _authenticate_jwt(
|
||||
self,
|
||||
*,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> AuthContext:
|
||||
secret = self.settings.auth_jwt_secret
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="AUTH_JWT_SECRET is required in jwt auth mode.",
|
||||
)
|
||||
|
||||
token = self._resolve_bearer(authorization)
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Bearer token.",
|
||||
)
|
||||
|
||||
claims = _decode_hs256_jwt(token=token, secret=secret)
|
||||
_validate_jwt_claims(
|
||||
claims=claims,
|
||||
expected_issuer=self.settings.auth_jwt_issuer,
|
||||
expected_audience=self.settings.auth_jwt_audience,
|
||||
)
|
||||
|
||||
scope_values = _extract_scopes(claims)
|
||||
if required_scopes and not required_scopes.issubset(scope_values):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Missing required scope.",
|
||||
)
|
||||
|
||||
subject = str(claims.get("sub") or "jwt-subject")
|
||||
return AuthContext(subject=subject, auth_type="jwt", scopes=scope_values)
|
||||
|
||||
def _resolve_api_key(self, *, x_api_key: str | None, authorization: str | None) -> str | None:
|
||||
if x_api_key:
|
||||
return x_api_key
|
||||
|
||||
token = self._resolve_bearer(authorization)
|
||||
return token
|
||||
|
||||
def _resolve_bearer(self, authorization: str | None) -> str | None:
|
||||
if not authorization:
|
||||
return None
|
||||
parts = authorization.split(" ", 1)
|
||||
if len(parts) == 2 and parts[0].lower() == "bearer":
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def _decode_hs256_jwt(*, token: str, secret: str) -> dict[str, Any]:
|
||||
try:
|
||||
header_segment, payload_segment, signature_segment = token.split(".")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Malformed JWT.",
|
||||
) from exc
|
||||
|
||||
signing_input = f"{header_segment}.{payload_segment}".encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
signing_input,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
actual_signature = _urlsafe_b64decode(signature_segment)
|
||||
if not hmac.compare_digest(expected_signature, actual_signature):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT signature.",
|
||||
)
|
||||
|
||||
header = _decode_jwt_json_segment(header_segment)
|
||||
algorithm = str(header.get("alg", "")).upper()
|
||||
if algorithm != "HS256":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unsupported JWT algorithm.",
|
||||
)
|
||||
|
||||
return _decode_jwt_json_segment(payload_segment)
|
||||
|
||||
|
||||
def _decode_jwt_json_segment(segment: str) -> dict[str, Any]:
|
||||
try:
|
||||
decoded = _urlsafe_b64decode(segment).decode("utf-8")
|
||||
payload = json.loads(decoded)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT payload.",
|
||||
) from exc
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT object.",
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def _urlsafe_b64decode(segment: str) -> bytes:
|
||||
padded = segment + "=" * (-len(segment) % 4)
|
||||
return base64.urlsafe_b64decode(padded.encode("utf-8"))
|
||||
|
||||
|
||||
def _validate_jwt_claims(
|
||||
*,
|
||||
claims: dict[str, Any],
|
||||
expected_issuer: str | None,
|
||||
expected_audience: str | None,
|
||||
) -> None:
|
||||
now_ts = int(datetime.now(tz=timezone.utc).timestamp())
|
||||
exp = claims.get("exp")
|
||||
if exp is not None:
|
||||
try:
|
||||
exp_ts = int(exp)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid exp claim.",
|
||||
) from exc
|
||||
if exp_ts < now_ts:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="JWT has expired.",
|
||||
)
|
||||
|
||||
nbf = claims.get("nbf")
|
||||
if nbf is not None:
|
||||
try:
|
||||
nbf_ts = int(nbf)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid nbf claim.",
|
||||
) from exc
|
||||
if nbf_ts > now_ts:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="JWT not valid yet.",
|
||||
)
|
||||
|
||||
if expected_issuer:
|
||||
issuer = str(claims.get("iss", ""))
|
||||
if issuer != expected_issuer:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT issuer.",
|
||||
)
|
||||
|
||||
if expected_audience:
|
||||
audience = claims.get("aud")
|
||||
if isinstance(audience, list):
|
||||
audience_values = {str(value) for value in audience}
|
||||
elif audience is None:
|
||||
audience_values = set()
|
||||
else:
|
||||
audience_values = {str(audience)}
|
||||
if expected_audience not in audience_values:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT audience.",
|
||||
)
|
||||
|
||||
|
||||
def _extract_scopes(claims: dict[str, Any]) -> set[str]:
|
||||
raw_scope = claims.get("scope", "")
|
||||
scope_values: set[str] = set()
|
||||
if isinstance(raw_scope, str):
|
||||
scope_values.update(value for value in raw_scope.strip().split(" ") if value)
|
||||
elif isinstance(raw_scope, list):
|
||||
scope_values.update(str(value) for value in raw_scope if str(value).strip())
|
||||
|
||||
raw_scp = claims.get("scp", "")
|
||||
if isinstance(raw_scp, str):
|
||||
scope_values.update(value for value in raw_scp.strip().split(" ") if value)
|
||||
elif isinstance(raw_scp, list):
|
||||
scope_values.update(str(value) for value in raw_scp if str(value).strip())
|
||||
|
||||
return scope_values
|
||||
@ -0,0 +1,107 @@
|
||||
# A2A Runbook
|
||||
|
||||
## Scope
|
||||
|
||||
This runbook covers the Agent-to-Agent (A2A) adapter exposed by the REST service in `app.main`.
|
||||
|
||||
## Endpoints
|
||||
|
||||
- Agent Card: `GET /.well-known/agent-card.json`
|
||||
- RPC endpoint: `POST /a2a/rpc`
|
||||
|
||||
When using Docker Compose:
|
||||
|
||||
- Base URL: `http://127.0.0.1:8000`
|
||||
|
||||
## Protocol contract
|
||||
|
||||
- `protocolVersion` advertised in Agent Card: `1.0`
|
||||
- Response header on A2A routes: `A2A-Version: 1.0`
|
||||
- JSON-RPC version: `2.0`
|
||||
|
||||
Implemented methods:
|
||||
|
||||
- `ping` / `health.ping` / `health/ping`
|
||||
- `SendMessage` (availability only)
|
||||
|
||||
## Authentication
|
||||
|
||||
The A2A adapter uses the same auth backend as REST:
|
||||
|
||||
- `AUTH_MODE=api_key`: `X-API-Key` or `Authorization: Bearer <AGENT_API_KEY>`
|
||||
- `AUTH_MODE=jwt`: `Authorization: Bearer <JWT>`
|
||||
- `AUTH_MODE=hybrid`: API key first, then JWT
|
||||
|
||||
Required scope for `SendMessage`:
|
||||
|
||||
- `availability:read`
|
||||
|
||||
## Request shape for SendMessage
|
||||
|
||||
`SendMessage` accepts several input shapes; the request must contain at least:
|
||||
|
||||
- `start` ISO datetime with timezone offset
|
||||
- `end` ISO datetime with timezone offset
|
||||
|
||||
Accepted locations:
|
||||
|
||||
- `params.start` / `params.end`
|
||||
- `params.input.start` / `params.input.end`
|
||||
- `params.arguments.start` / `params.arguments.end`
|
||||
- JSON embedded in message content
|
||||
|
||||
Optional:
|
||||
|
||||
- `calendar_ids`: array of calendar ids (defaults to `["primary"]`)
|
||||
|
||||
## Smoke tests
|
||||
|
||||
Agent Card:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8000/.well-known/agent-card.json
|
||||
```
|
||||
|
||||
Availability:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://127.0.0.1:8000/a2a/rpc" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-API-Key: $AGENT_API_KEY" \
|
||||
-d '{
|
||||
"jsonrpc":"2.0",
|
||||
"id":"req-1",
|
||||
"method":"SendMessage",
|
||||
"params":{
|
||||
"start":"2026-03-10T09:00:00+01:00",
|
||||
"end":"2026-03-10T10:00:00+01:00",
|
||||
"calendar_ids":["primary"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## Error mapping
|
||||
|
||||
- `-32600`: invalid JSON-RPC request
|
||||
- `-32601`: unknown method
|
||||
- `-32602`: invalid params (including bad time window)
|
||||
- `-32001`: unauthorized
|
||||
- `-32000`: backend/runtime error
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you get `-32001`:
|
||||
|
||||
- Verify `AUTH_MODE`
|
||||
- Verify API key/JWT and scope `availability:read`
|
||||
|
||||
If you get `-32602`:
|
||||
|
||||
- Ensure `start` and `end` include timezone offsets
|
||||
- Ensure `end > start`
|
||||
|
||||
If you get `-32000` with OAuth file errors:
|
||||
|
||||
- Check `GOOGLE_CLIENT_SECRETS_FILE` path
|
||||
- Check `GOOGLE_TOKEN_FILE` path
|
||||
|
||||
@ -0,0 +1,102 @@
|
||||
# MCP Runbook
|
||||
|
||||
## Scope
|
||||
|
||||
This runbook covers the MCP adapter exposed by `app.mcp_main`.
|
||||
|
||||
## Endpoint
|
||||
|
||||
- Streamable HTTP endpoint: `POST /mcp` (mounted under `app.mcp_main`)
|
||||
|
||||
When using Docker Compose:
|
||||
|
||||
- Base URL: `http://127.0.0.1:8001/mcp`
|
||||
|
||||
## Runtime modes
|
||||
|
||||
Local:
|
||||
|
||||
```bash
|
||||
uv run uvicorn app.mcp_main:app --host 0.0.0.0 --port 8001
|
||||
```
|
||||
|
||||
Docker Compose service:
|
||||
|
||||
- `personal-agent-mcp`
|
||||
|
||||
## Tool surface
|
||||
|
||||
Always enabled:
|
||||
|
||||
- `check_availability`
|
||||
|
||||
Optional mutation tools (disabled by default):
|
||||
|
||||
- `scan_mailbox`
|
||||
- `list_unsubscribe_candidates`
|
||||
- `execute_unsubscribe`
|
||||
|
||||
Enable optional tools with:
|
||||
|
||||
```bash
|
||||
MCP_ENABLE_MUTATION_TOOLS=true
|
||||
```
|
||||
|
||||
## Authorization and scope gates
|
||||
|
||||
MCP tools call the shared auth backend and read auth headers from request context.
|
||||
|
||||
Supported auth headers:
|
||||
|
||||
- `X-API-Key`
|
||||
- `Authorization: Bearer ...`
|
||||
|
||||
Required scopes:
|
||||
|
||||
- `check_availability`: `availability:read`
|
||||
- `scan_mailbox`: `mail:scan`
|
||||
- `list_unsubscribe_candidates`: `unsubscribe:read`
|
||||
- `execute_unsubscribe`: `unsubscribe:execute`
|
||||
|
||||
## Tool verification
|
||||
|
||||
Verify tool list from Python:
|
||||
|
||||
```bash
|
||||
uv run python - <<'PY'
|
||||
import asyncio
|
||||
from app.mcp.server import mcp
|
||||
|
||||
async def main():
|
||||
tools = await mcp.list_tools()
|
||||
print([t.name for t in tools])
|
||||
|
||||
asyncio.run(main())
|
||||
PY
|
||||
```
|
||||
|
||||
Expected output by mode:
|
||||
|
||||
- default: `['check_availability']`
|
||||
- with `MCP_ENABLE_MUTATION_TOOLS=true`: all four tools
|
||||
|
||||
## Protocol notes
|
||||
|
||||
- The server uses FastMCP Streamable HTTP.
|
||||
- Basic GET to `/mcp` is not a health endpoint; MCP expects protocol-compliant requests.
|
||||
- In local development, FastMCP may enforce host/origin checks. If you see `421 Misdirected Request`, verify host/port and reverse-proxy headers.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If tools fail with auth errors:
|
||||
|
||||
- Check `AUTH_MODE` and credentials
|
||||
- Confirm JWT contains required scopes
|
||||
- For API key mode, verify `AGENT_API_KEY`
|
||||
|
||||
If tool calls fail with Google errors:
|
||||
|
||||
- Verify OAuth file mounts in Docker:
|
||||
- `GOOGLE_CLIENT_SECRETS_FILE`
|
||||
- `GOOGLE_TOKEN_FILE`
|
||||
|
||||
@ -0,0 +1,110 @@
|
||||
# Security Runbook
|
||||
|
||||
## Security model
|
||||
|
||||
The service supports three auth modes via `AUTH_MODE`:
|
||||
|
||||
- `api_key`: static shared secret (`AGENT_API_KEY`)
|
||||
- `jwt`: bearer JWT with scope checks
|
||||
- `hybrid`: API key accepted first, then JWT fallback
|
||||
|
||||
The same backend is used across:
|
||||
|
||||
- REST API
|
||||
- A2A adapter
|
||||
- MCP tools
|
||||
|
||||
## Recommended deployment posture
|
||||
|
||||
External traffic:
|
||||
|
||||
- Use `AUTH_MODE=jwt`
|
||||
- Require HTTPS at reverse proxy/gateway
|
||||
- Restrict exposed routes to required protocol endpoints
|
||||
|
||||
Internal traffic:
|
||||
|
||||
- `AUTH_MODE=hybrid` is acceptable during migration
|
||||
- Prefer mTLS/private network for service-to-service traffic
|
||||
|
||||
## Scope matrix
|
||||
|
||||
- `availability:read`: availability access
|
||||
- `mail:scan`: inbox scan and triage
|
||||
- `unsubscribe:read`: candidate discovery
|
||||
- `unsubscribe:execute`: unsubscribe execution
|
||||
- `unsubscribe:digest`: digest scan and send
|
||||
|
||||
## Secret and token handling
|
||||
|
||||
Never commit secrets:
|
||||
|
||||
- `.env`
|
||||
- `token.json`
|
||||
- Google OAuth client secret files
|
||||
|
||||
Always persist and back up:
|
||||
|
||||
- `token.json`
|
||||
- `data/sent_unsubscribe_links.json`
|
||||
- `data/unsubscribed_methods.json`
|
||||
|
||||
## Key and token rotation
|
||||
|
||||
### API key rotation (api_key/hybrid)
|
||||
|
||||
1. Generate new strong key.
|
||||
2. Update environment (`AGENT_API_KEY`) in deployment.
|
||||
3. Restart services.
|
||||
4. Update all clients.
|
||||
5. Remove old key from all stores.
|
||||
|
||||
### JWT secret rotation (jwt/hybrid)
|
||||
|
||||
1. Generate new signing secret.
|
||||
2. Roll out issuer/signing config first.
|
||||
3. Update server `AUTH_JWT_SECRET`.
|
||||
4. Restart services.
|
||||
5. Force token refresh for clients.
|
||||
|
||||
## Incident response checklist
|
||||
|
||||
If credential leak is suspected:
|
||||
|
||||
1. Revoke compromised key/secret immediately.
|
||||
2. Rotate API key and JWT secret.
|
||||
3. Invalidate active tokens (issuer-side).
|
||||
4. Review logs for unusual scans/unsubscribe operations.
|
||||
5. Disable mutation MCP tools (`MCP_ENABLE_MUTATION_TOOLS=false`) until investigation completes.
|
||||
6. Re-enable features after containment and verification.
|
||||
|
||||
## Release rollout checklist
|
||||
|
||||
Preflight:
|
||||
|
||||
1. `uv run pytest -q`
|
||||
2. `uv run python -c "import app.main, app.mcp_main; print('import_ok')"`
|
||||
3. `docker compose config --services`
|
||||
|
||||
Canary:
|
||||
|
||||
1. Deploy to one node/environment.
|
||||
2. Validate:
|
||||
- `GET /health`
|
||||
- `GET /.well-known/agent-card.json`
|
||||
- A2A `SendMessage`
|
||||
- MCP tool listing
|
||||
3. Monitor errors for 30-60 minutes.
|
||||
|
||||
Full rollout:
|
||||
|
||||
1. Deploy all nodes.
|
||||
2. Re-run smoke checks.
|
||||
3. Confirm scheduler jobs continue as expected.
|
||||
|
||||
Rollback:
|
||||
|
||||
1. Redeploy previous image/tag.
|
||||
2. Verify health and protocol smoke checks.
|
||||
3. Keep state files (`data/*.json`) unchanged during rollback.
|
||||
|
||||
@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastapi import Response
|
||||
|
||||
import app.a2a.router as a2a_module
|
||||
from app.a2a.models import A2ARpcRequest
|
||||
|
||||
|
||||
class _Slot(dict):
|
||||
def __getattr__(self, item: str) -> str:
|
||||
return self[item]
|
||||
|
||||
|
||||
class _DummyCoreService:
|
||||
def check_availability(
|
||||
self,
|
||||
start: str,
|
||||
end: str,
|
||||
calendar_ids: list[str] | None,
|
||||
) -> SimpleNamespace:
|
||||
checked = calendar_ids or ["primary"]
|
||||
busy_slots = [_Slot(calendar_id=checked[0], start=start, end=end)]
|
||||
return SimpleNamespace(
|
||||
start=start,
|
||||
end=end,
|
||||
available=False,
|
||||
busy_slots=busy_slots,
|
||||
checked_calendars=checked,
|
||||
)
|
||||
|
||||
|
||||
class _AllowAuthBackend:
|
||||
def authenticate(
|
||||
self,
|
||||
*,
|
||||
x_api_key: str | None,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_a2a_send_message_returns_availability(monkeypatch) -> None:
|
||||
monkeypatch.setattr(a2a_module, "core_service", _DummyCoreService())
|
||||
monkeypatch.setattr(a2a_module, "auth_backend", _AllowAuthBackend())
|
||||
|
||||
request = A2ARpcRequest(
|
||||
jsonrpc="2.0",
|
||||
id="req-1",
|
||||
method="SendMessage",
|
||||
params={
|
||||
"start": "2026-03-10T09:00:00+01:00",
|
||||
"end": "2026-03-10T10:00:00+01:00",
|
||||
"calendar_ids": ["primary"],
|
||||
},
|
||||
)
|
||||
response = a2a_module.a2a_rpc(request, Response())
|
||||
|
||||
assert response.error is None
|
||||
assert response.result is not None
|
||||
assert response.result["type"] == "availability.result"
|
||||
assert response.result["availability"]["available"] is False
|
||||
assert response.result["availability"]["checked_calendars"] == ["primary"]
|
||||
@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app.a2a.router as a2a_module
|
||||
from app.config import get_settings
|
||||
from app.security.auth import AuthBackend
|
||||
|
||||
|
||||
class _DummyCoreService:
|
||||
def check_availability(
|
||||
self,
|
||||
start: str,
|
||||
end: str,
|
||||
calendar_ids: list[str] | None,
|
||||
) -> SimpleNamespace:
|
||||
checked = calendar_ids or ["primary"]
|
||||
return SimpleNamespace(
|
||||
start=start,
|
||||
end=end,
|
||||
available=True,
|
||||
busy_slots=[],
|
||||
checked_calendars=checked,
|
||||
)
|
||||
|
||||
|
||||
def _build_test_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(a2a_module.router)
|
||||
return app
|
||||
|
||||
|
||||
def test_a2a_agent_card_endpoint(monkeypatch) -> None:
|
||||
monkeypatch.setattr(a2a_module, "core_service", _DummyCoreService())
|
||||
app = _build_test_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/.well-known/agent-card.json")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["A2A-Version"] == "1.0"
|
||||
payload = response.json()
|
||||
assert payload["protocolVersion"] == "1.0"
|
||||
assert payload["url"].endswith("/a2a/rpc")
|
||||
|
||||
|
||||
def test_a2a_send_message_requires_auth(monkeypatch) -> None:
|
||||
auth_settings = replace(
|
||||
get_settings(),
|
||||
auth_mode="api_key",
|
||||
agent_api_key="integration-key",
|
||||
auth_jwt_secret="",
|
||||
)
|
||||
monkeypatch.setattr(a2a_module, "auth_backend", AuthBackend(auth_settings))
|
||||
monkeypatch.setattr(a2a_module, "core_service", _DummyCoreService())
|
||||
app = _build_test_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/a2a/rpc",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": "r1",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"start": "2026-03-10T09:00:00+01:00",
|
||||
"end": "2026-03-10T10:00:00+01:00",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["error"]["code"] == -32001
|
||||
|
||||
|
||||
def test_a2a_send_message_with_api_key(monkeypatch) -> None:
|
||||
auth_settings = replace(
|
||||
get_settings(),
|
||||
auth_mode="api_key",
|
||||
agent_api_key="integration-key",
|
||||
auth_jwt_secret="",
|
||||
)
|
||||
monkeypatch.setattr(a2a_module, "auth_backend", AuthBackend(auth_settings))
|
||||
monkeypatch.setattr(a2a_module, "core_service", _DummyCoreService())
|
||||
app = _build_test_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/a2a/rpc",
|
||||
headers={"X-API-Key": "integration-key"},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": "r2",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"start": "2026-03-10T09:00:00+01:00",
|
||||
"end": "2026-03-10T10:00:00+01:00",
|
||||
"calendar_ids": ["primary"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["error"] is None
|
||||
assert payload["result"]["availability"]["available"] is True
|
||||
assert payload["result"]["availability"]["checked_calendars"] == ["primary"]
|
||||
@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from dataclasses import replace
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.config import get_settings
|
||||
from app.security.auth import AuthBackend
|
||||
|
||||
|
||||
def _make_settings(**overrides: object):
|
||||
return replace(get_settings(), **overrides)
|
||||
|
||||
|
||||
def _build_jwt(secret: str, claims: dict[str, object]) -> str:
|
||||
header = {"alg": "HS256", "typ": "JWT"}
|
||||
header_b64 = _b64url_json(header)
|
||||
payload_b64 = _b64url_json(claims)
|
||||
signing_input = f"{header_b64}.{payload_b64}".encode("utf-8")
|
||||
signature = hmac.new(secret.encode("utf-8"), signing_input, hashlib.sha256).digest()
|
||||
signature_b64 = base64.urlsafe_b64encode(signature).decode("utf-8").rstrip("=")
|
||||
return f"{header_b64}.{payload_b64}.{signature_b64}"
|
||||
|
||||
|
||||
def _b64url_json(value: dict[str, object]) -> str:
|
||||
raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||||
return base64.urlsafe_b64encode(raw).decode("utf-8").rstrip("=")
|
||||
|
||||
|
||||
def test_auth_backend_api_key_mode_accepts_x_api_key() -> None:
|
||||
settings = _make_settings(
|
||||
auth_mode="api_key",
|
||||
agent_api_key="test-api-key",
|
||||
auth_jwt_secret="",
|
||||
)
|
||||
backend = AuthBackend(settings=settings)
|
||||
|
||||
context = backend.authenticate(
|
||||
x_api_key="test-api-key",
|
||||
authorization=None,
|
||||
required_scopes={"availability:read"},
|
||||
)
|
||||
|
||||
assert context.auth_type == "api_key"
|
||||
assert context.subject == "api-key"
|
||||
assert "*" in context.scopes
|
||||
|
||||
|
||||
def test_auth_backend_api_key_mode_rejects_invalid_key() -> None:
|
||||
settings = _make_settings(
|
||||
auth_mode="api_key",
|
||||
agent_api_key="expected",
|
||||
auth_jwt_secret="",
|
||||
)
|
||||
backend = AuthBackend(settings=settings)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
backend.authenticate(
|
||||
x_api_key="wrong",
|
||||
authorization=None,
|
||||
required_scopes={"availability:read"},
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert str(exc_info.value.detail) == "Invalid API key."
|
||||
|
||||
|
||||
def test_auth_backend_jwt_mode_validates_scope_and_claims() -> None:
|
||||
secret = "jwt-secret"
|
||||
settings = _make_settings(
|
||||
auth_mode="jwt",
|
||||
auth_jwt_secret=secret,
|
||||
auth_jwt_issuer="https://issuer.example",
|
||||
auth_jwt_audience="personal-agent",
|
||||
agent_api_key="",
|
||||
)
|
||||
backend = AuthBackend(settings=settings)
|
||||
token = _build_jwt(
|
||||
secret=secret,
|
||||
claims={
|
||||
"sub": "agent-123",
|
||||
"iss": "https://issuer.example",
|
||||
"aud": "personal-agent",
|
||||
"scope": "availability:read unsubscribe:read",
|
||||
"exp": int(time.time()) + 3600,
|
||||
},
|
||||
)
|
||||
|
||||
context = backend.authenticate(
|
||||
x_api_key=None,
|
||||
authorization=f"Bearer {token}",
|
||||
required_scopes={"availability:read"},
|
||||
)
|
||||
|
||||
assert context.auth_type == "jwt"
|
||||
assert context.subject == "agent-123"
|
||||
assert "availability:read" in context.scopes
|
||||
|
||||
|
||||
def test_auth_backend_jwt_mode_rejects_missing_scope() -> None:
|
||||
secret = "jwt-secret"
|
||||
settings = _make_settings(
|
||||
auth_mode="jwt",
|
||||
auth_jwt_secret=secret,
|
||||
auth_jwt_issuer=None,
|
||||
auth_jwt_audience=None,
|
||||
agent_api_key="",
|
||||
)
|
||||
backend = AuthBackend(settings=settings)
|
||||
token = _build_jwt(
|
||||
secret=secret,
|
||||
claims={
|
||||
"sub": "agent-123",
|
||||
"scope": "unsubscribe:read",
|
||||
"exp": int(time.time()) + 3600,
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
backend.authenticate(
|
||||
x_api_key=None,
|
||||
authorization=f"Bearer {token}",
|
||||
required_scopes={"availability:read"},
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert str(exc_info.value.detail) == "Missing required scope."
|
||||
|
||||
|
||||
def test_auth_backend_hybrid_mode_uses_jwt_when_api_key_missing() -> None:
|
||||
secret = "jwt-secret"
|
||||
settings = _make_settings(
|
||||
auth_mode="hybrid",
|
||||
agent_api_key="expected-api-key",
|
||||
auth_jwt_secret=secret,
|
||||
auth_jwt_issuer=None,
|
||||
auth_jwt_audience=None,
|
||||
)
|
||||
backend = AuthBackend(settings=settings)
|
||||
token = _build_jwt(
|
||||
secret=secret,
|
||||
claims={
|
||||
"sub": "fallback-jwt",
|
||||
"scope": "availability:read",
|
||||
"exp": int(time.time()) + 3600,
|
||||
},
|
||||
)
|
||||
|
||||
context = backend.authenticate(
|
||||
x_api_key=None,
|
||||
authorization=f"Bearer {token}",
|
||||
required_scopes={"availability:read"},
|
||||
)
|
||||
|
||||
assert context.auth_type == "jwt"
|
||||
assert context.subject == "fallback-jwt"
|
||||
@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import app.core.service as core_module
|
||||
from app.config import get_settings
|
||||
from app.core.service import CoreAgentService
|
||||
|
||||
|
||||
class _FakeFreeBusy:
|
||||
def __init__(self, payload: dict[str, object]) -> None:
|
||||
self.payload = payload
|
||||
self.last_query_body: dict[str, object] | None = None
|
||||
|
||||
def query(self, body: dict[str, object]) -> _FakeFreeBusy:
|
||||
self.last_query_body = body
|
||||
return self
|
||||
|
||||
def execute(self) -> dict[str, object]:
|
||||
return self.payload
|
||||
|
||||
|
||||
class _FakeCalendarService:
|
||||
def __init__(self, payload: dict[str, object]) -> None:
|
||||
self._freebusy = _FakeFreeBusy(payload)
|
||||
|
||||
def freebusy(self) -> _FakeFreeBusy:
|
||||
return self._freebusy
|
||||
|
||||
|
||||
def test_core_availability_maps_busy_slots(monkeypatch) -> None:
|
||||
payload = {
|
||||
"calendars": {
|
||||
"primary": {
|
||||
"busy": [
|
||||
{
|
||||
"start": "2026-03-10T09:00:00+01:00",
|
||||
"end": "2026-03-10T10:00:00+01:00",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
fake_service = _FakeCalendarService(payload)
|
||||
monkeypatch.setattr(core_module, "build_calendar_service", lambda _: fake_service)
|
||||
|
||||
service = CoreAgentService(settings=get_settings())
|
||||
result = service.check_availability(
|
||||
start="2026-03-10T09:00:00+01:00",
|
||||
end="2026-03-10T10:00:00+01:00",
|
||||
calendar_ids=["primary"],
|
||||
)
|
||||
|
||||
assert result.available is False
|
||||
assert result.checked_calendars == ["primary"]
|
||||
assert len(result.busy_slots) == 1
|
||||
assert result.busy_slots[0].calendar_id == "primary"
|
||||
assert result.busy_slots[0].start == "2026-03-10T09:00:00+01:00"
|
||||
assert result.busy_slots[0].end == "2026-03-10T10:00:00+01:00"
|
||||
@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.gmail_agent import GmailTriageAgent
|
||||
|
||||
|
||||
class _FailingClassifier:
|
||||
def classify(self, **kwargs): # type: ignore[no-untyped-def]
|
||||
raise RuntimeError("model unavailable")
|
||||
|
||||
|
||||
def test_build_effective_query_enforces_inbox_and_unread() -> None:
|
||||
agent = GmailTriageAgent(gmail_service=object(), query="-label:AgentProcessed")
|
||||
assert (
|
||||
agent._build_effective_query()
|
||||
== "-label:AgentProcessed in:inbox is:unread"
|
||||
)
|
||||
|
||||
|
||||
def test_build_effective_query_keeps_existing_requirements() -> None:
|
||||
agent = GmailTriageAgent(
|
||||
gmail_service=object(),
|
||||
query="IN:INBOX is:unread -label:AgentProcessed",
|
||||
)
|
||||
assert agent._build_effective_query() == "IN:INBOX is:unread -label:AgentProcessed"
|
||||
|
||||
|
||||
def test_classify_email_returns_other_when_model_fails_and_no_rules_fallback() -> None:
|
||||
agent = GmailTriageAgent(
|
||||
gmail_service=object(),
|
||||
query="",
|
||||
classifier=_FailingClassifier(), # type: ignore[arg-type]
|
||||
fallback_to_rules=False,
|
||||
)
|
||||
|
||||
label = agent._classify_email(
|
||||
message_id="m1",
|
||||
sender="newsletter@example.com",
|
||||
subject="50% OFF today",
|
||||
snippet="promo content",
|
||||
list_unsubscribe="<https://example.com/unsubscribe>",
|
||||
precedence="bulk",
|
||||
message_label_ids={"CATEGORY_PROMOTIONS"},
|
||||
)
|
||||
|
||||
assert label == "OTHER"
|
||||
|
||||
|
||||
def test_classify_email_prioritizes_linkedin_over_advertising_signals() -> None:
|
||||
agent = GmailTriageAgent(
|
||||
gmail_service=object(),
|
||||
query="",
|
||||
classifier=None,
|
||||
fallback_to_rules=True,
|
||||
)
|
||||
|
||||
label = agent._classify_email(
|
||||
message_id="m2",
|
||||
sender="jobs-noreply@linkedin.com",
|
||||
subject="Limited time offer for your profile",
|
||||
snippet="promotional snippet",
|
||||
list_unsubscribe="<https://example.com/unsubscribe>",
|
||||
precedence="bulk",
|
||||
message_label_ids={"CATEGORY_PROMOTIONS"},
|
||||
)
|
||||
|
||||
assert label == "LINKEDIN"
|
||||
@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app.main as main_module
|
||||
from app.config import get_settings
|
||||
from app.security.auth import AuthBackend
|
||||
|
||||
|
||||
class _DummyCoreService:
|
||||
def scan_mailbox(self, max_results: int) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
scanned=max_results,
|
||||
linkedin=1,
|
||||
advertising=2,
|
||||
veille_techno=0,
|
||||
skipped=3,
|
||||
failed=0,
|
||||
)
|
||||
|
||||
def check_availability(
|
||||
self,
|
||||
start: str,
|
||||
end: str,
|
||||
calendar_ids: list[str] | None,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
start=start,
|
||||
end=end,
|
||||
available=False,
|
||||
busy_slots=[
|
||||
{
|
||||
"calendar_id": "primary",
|
||||
"start": start,
|
||||
"end": end,
|
||||
}
|
||||
],
|
||||
checked_calendars=calendar_ids or ["primary"],
|
||||
)
|
||||
|
||||
|
||||
async def _noop_task() -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _setup_main_test_context(monkeypatch) -> None:
|
||||
auth_settings = replace(
|
||||
get_settings(),
|
||||
auth_mode="api_key",
|
||||
agent_api_key="integration-key",
|
||||
auth_jwt_secret="",
|
||||
)
|
||||
monkeypatch.setattr(main_module, "auth_backend", AuthBackend(auth_settings))
|
||||
monkeypatch.setattr(main_module, "core_service", _DummyCoreService())
|
||||
# Prevent scheduler jobs from executing real background work during lifespan startup.
|
||||
monkeypatch.setattr(main_module, "_scheduled_scan", _noop_task)
|
||||
monkeypatch.setattr(main_module, "_scheduled_unsubscribe_digest", _noop_task)
|
||||
monkeypatch.setattr(main_module, "_scheduled_unsubscribe_auto", _noop_task)
|
||||
|
||||
|
||||
def test_main_scan_endpoint_with_api_key(monkeypatch) -> None:
|
||||
_setup_main_test_context(monkeypatch)
|
||||
|
||||
with TestClient(main_module.app) as client:
|
||||
response = client.post(
|
||||
"/scan?max_results=15",
|
||||
headers={"X-API-Key": "integration-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["scanned"] == 15
|
||||
assert payload["linkedin"] == 1
|
||||
assert payload["advertising"] == 2
|
||||
|
||||
|
||||
def test_main_availability_endpoint_rejects_missing_key(monkeypatch) -> None:
|
||||
_setup_main_test_context(monkeypatch)
|
||||
|
||||
with TestClient(main_module.app) as client:
|
||||
response = client.post(
|
||||
"/availability",
|
||||
json={
|
||||
"start": "2026-03-10T09:00:00+01:00",
|
||||
"end": "2026-03-10T10:00:00+01:00",
|
||||
"calendar_ids": ["primary"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastapi import Response
|
||||
|
||||
import app.a2a.router as a2a_module
|
||||
from app.a2a.models import A2ARpcRequest
|
||||
import app.main as main_module
|
||||
import app.mcp.server as mcp_server_module
|
||||
import app.mcp.tools as mcp_tools_module
|
||||
|
||||
|
||||
class _Slot(dict):
|
||||
def __getattr__(self, item: str) -> str:
|
||||
return self[item]
|
||||
|
||||
|
||||
class _DummyCoreService:
|
||||
def check_availability(
|
||||
self,
|
||||
start: str,
|
||||
end: str,
|
||||
calendar_ids: list[str] | None,
|
||||
) -> SimpleNamespace:
|
||||
checked = calendar_ids or ["primary"]
|
||||
busy_slots = [_Slot(calendar_id=checked[0], start=start, end=end)]
|
||||
return SimpleNamespace(
|
||||
start=start,
|
||||
end=end,
|
||||
available=False,
|
||||
busy_slots=busy_slots,
|
||||
checked_calendars=checked,
|
||||
)
|
||||
|
||||
|
||||
class _AllowAuthBackend:
|
||||
def authenticate(
|
||||
self,
|
||||
*,
|
||||
x_api_key: str | None,
|
||||
authorization: str | None,
|
||||
required_scopes: set[str],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_availability_parity_rest_a2a_mcp(monkeypatch) -> None:
|
||||
dummy_core = _DummyCoreService()
|
||||
allow_auth = _AllowAuthBackend()
|
||||
|
||||
monkeypatch.setattr(main_module, "core_service", dummy_core)
|
||||
monkeypatch.setattr(a2a_module, "core_service", dummy_core)
|
||||
monkeypatch.setattr(mcp_tools_module, "core_service", dummy_core)
|
||||
monkeypatch.setattr(a2a_module, "auth_backend", allow_auth)
|
||||
monkeypatch.setattr(mcp_tools_module, "auth_backend", allow_auth)
|
||||
|
||||
rest_response = asyncio.run(
|
||||
main_module.availability(
|
||||
main_module.AvailabilityRequest(
|
||||
start="2026-03-10T09:00:00+01:00",
|
||||
end="2026-03-10T10:00:00+01:00",
|
||||
calendar_ids=["primary"],
|
||||
)
|
||||
)
|
||||
).model_dump()
|
||||
|
||||
a2a_response = a2a_module.a2a_rpc(
|
||||
A2ARpcRequest(
|
||||
jsonrpc="2.0",
|
||||
id="req-1",
|
||||
method="SendMessage",
|
||||
params={
|
||||
"start": "2026-03-10T09:00:00+01:00",
|
||||
"end": "2026-03-10T10:00:00+01:00",
|
||||
"calendar_ids": ["primary"],
|
||||
},
|
||||
),
|
||||
Response(),
|
||||
)
|
||||
assert a2a_response.error is None
|
||||
assert a2a_response.result is not None
|
||||
a2a_payload = a2a_response.result["availability"]
|
||||
|
||||
mcp_payload = mcp_server_module.check_availability(
|
||||
start="2026-03-10T09:00:00+01:00",
|
||||
end="2026-03-10T10:00:00+01:00",
|
||||
calendar_ids=["primary"],
|
||||
ctx=None,
|
||||
)
|
||||
|
||||
assert rest_response == a2a_payload == mcp_payload
|
||||
@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import app.mcp.tools as mcp_tools_module
|
||||
from app.config import get_settings
|
||||
from app.security.auth import AuthBackend
|
||||
|
||||
|
||||
class _DummyCoreService:
|
||||
def check_availability(
|
||||
self,
|
||||
start: str,
|
||||
end: str,
|
||||
calendar_ids: list[str] | None,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
start=start,
|
||||
end=end,
|
||||
available=True,
|
||||
busy_slots=[],
|
||||
checked_calendars=calendar_ids or ["primary"],
|
||||
)
|
||||
|
||||
def scan_mailbox(self, max_results: int) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
scanned=max_results,
|
||||
linkedin=0,
|
||||
advertising=0,
|
||||
veille_techno=0,
|
||||
skipped=0,
|
||||
failed=0,
|
||||
)
|
||||
|
||||
|
||||
class _DummyCtx:
|
||||
def __init__(self, headers: dict[str, str]) -> None:
|
||||
self.request_context = SimpleNamespace(
|
||||
request=SimpleNamespace(headers=headers)
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_check_availability_requires_auth(monkeypatch) -> None:
|
||||
auth_settings = replace(
|
||||
get_settings(),
|
||||
auth_mode="api_key",
|
||||
agent_api_key="mcp-key",
|
||||
auth_jwt_secret="",
|
||||
)
|
||||
monkeypatch.setattr(mcp_tools_module, "auth_backend", AuthBackend(auth_settings))
|
||||
monkeypatch.setattr(mcp_tools_module, "core_service", _DummyCoreService())
|
||||
|
||||
with pytest.raises(PermissionError):
|
||||
mcp_tools_module.check_availability(
|
||||
start="2026-03-10T09:00:00+01:00",
|
||||
end="2026-03-10T10:00:00+01:00",
|
||||
calendar_ids=["primary"],
|
||||
ctx=_DummyCtx(headers={}),
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_check_availability_with_api_key(monkeypatch) -> None:
|
||||
auth_settings = replace(
|
||||
get_settings(),
|
||||
auth_mode="api_key",
|
||||
agent_api_key="mcp-key",
|
||||
auth_jwt_secret="",
|
||||
)
|
||||
monkeypatch.setattr(mcp_tools_module, "auth_backend", AuthBackend(auth_settings))
|
||||
monkeypatch.setattr(mcp_tools_module, "core_service", _DummyCoreService())
|
||||
|
||||
payload = mcp_tools_module.check_availability(
|
||||
start="2026-03-10T09:00:00+01:00",
|
||||
end="2026-03-10T10:00:00+01:00",
|
||||
calendar_ids=["primary"],
|
||||
ctx=_DummyCtx(headers={"x-api-key": "mcp-key"}),
|
||||
)
|
||||
|
||||
assert payload["available"] is True
|
||||
assert payload["checked_calendars"] == ["primary"]
|
||||
|
||||
|
||||
def test_mcp_scan_mailbox_requires_mail_scan_scope(monkeypatch) -> None:
|
||||
auth_settings = replace(
|
||||
get_settings(),
|
||||
auth_mode="api_key",
|
||||
agent_api_key="mcp-key",
|
||||
auth_jwt_secret="",
|
||||
)
|
||||
monkeypatch.setattr(mcp_tools_module, "auth_backend", AuthBackend(auth_settings))
|
||||
monkeypatch.setattr(mcp_tools_module, "core_service", _DummyCoreService())
|
||||
|
||||
payload = mcp_tools_module.scan_mailbox(
|
||||
max_results=10,
|
||||
ctx=_DummyCtx(headers={"x-api-key": "mcp-key"}),
|
||||
)
|
||||
assert payload["scanned"] == 10
|
||||
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import app.main as main_module
|
||||
|
||||
|
||||
class _Slot(dict):
|
||||
def __getattr__(self, item: str) -> str:
|
||||
return self[item]
|
||||
|
||||
|
||||
class _DummyCoreService:
|
||||
def check_availability(
|
||||
self,
|
||||
start: str,
|
||||
end: str,
|
||||
calendar_ids: list[str] | None,
|
||||
) -> SimpleNamespace:
|
||||
checked = calendar_ids or ["primary"]
|
||||
busy_slots = [
|
||||
_Slot(
|
||||
calendar_id=checked[0],
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
]
|
||||
return SimpleNamespace(
|
||||
start=start,
|
||||
end=end,
|
||||
available=False,
|
||||
busy_slots=busy_slots,
|
||||
checked_calendars=checked,
|
||||
)
|
||||
|
||||
|
||||
def test_rest_availability_adapter_returns_expected_payload(monkeypatch) -> None:
|
||||
monkeypatch.setattr(main_module, "core_service", _DummyCoreService())
|
||||
|
||||
response = asyncio.run(
|
||||
main_module.availability(
|
||||
main_module.AvailabilityRequest(
|
||||
start="2026-03-10T09:00:00+01:00",
|
||||
end="2026-03-10T10:00:00+01:00",
|
||||
calendar_ids=["primary"],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
payload = response.model_dump()
|
||||
assert payload["available"] is False
|
||||
assert payload["checked_calendars"] == ["primary"]
|
||||
assert payload["busy_slots"] == [
|
||||
{
|
||||
"calendar_id": "primary",
|
||||
"start": "2026-03-10T09:00:00+01:00",
|
||||
"end": "2026-03-10T10:00:00+01:00",
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.unsubscribe_agent import UnsubscribeDigestAgent
|
||||
|
||||
|
||||
def _b64url_text(value: str) -> str:
|
||||
return base64.urlsafe_b64encode(value.encode("utf-8")).decode("utf-8").rstrip("=")
|
||||
|
||||
|
||||
class _Executable:
|
||||
def __init__(self, callback):
|
||||
self._callback = callback
|
||||
|
||||
def execute(self): # type: ignore[no-untyped-def]
|
||||
return self._callback()
|
||||
|
||||
|
||||
class _FakeMessagesApi:
|
||||
def __init__(self, message_payload_by_id: dict[str, dict[str, Any]]) -> None:
|
||||
self._message_payload_by_id = message_payload_by_id
|
||||
self.sent_messages: list[dict[str, Any]] = []
|
||||
|
||||
def list(self, userId: str, q: str, maxResults: int): # type: ignore[no-untyped-def]
|
||||
message_ids = [{"id": key} for key in self._message_payload_by_id.keys()]
|
||||
return _Executable(lambda: {"messages": message_ids[:maxResults]})
|
||||
|
||||
def get(self, userId: str, id: str, format: str): # type: ignore[no-untyped-def]
|
||||
return _Executable(lambda: self._message_payload_by_id[id])
|
||||
|
||||
def send(self, userId: str, body: dict[str, Any]): # type: ignore[no-untyped-def]
|
||||
self.sent_messages.append(body)
|
||||
return _Executable(lambda: {"id": "sent-1"})
|
||||
|
||||
|
||||
class _FakeUsersApi:
|
||||
def __init__(self, messages_api: _FakeMessagesApi) -> None:
|
||||
self._messages_api = messages_api
|
||||
|
||||
def messages(self) -> _FakeMessagesApi:
|
||||
return self._messages_api
|
||||
|
||||
def getProfile(self, userId: str): # type: ignore[no-untyped-def]
|
||||
return _Executable(lambda: {"emailAddress": "owner@example.com"})
|
||||
|
||||
|
||||
class _FakeGmailService:
|
||||
def __init__(self, payload_by_id: dict[str, dict[str, Any]]) -> None:
|
||||
self.messages_api = _FakeMessagesApi(payload_by_id)
|
||||
self.users_api = _FakeUsersApi(self.messages_api)
|
||||
|
||||
def users(self) -> _FakeUsersApi:
|
||||
return self.users_api
|
||||
|
||||
|
||||
def test_unsubscribe_digest_deduplicates_and_persists_state(tmp_path: Path) -> None:
|
||||
unsubscribe_url_1 = "https://example.com/unsubscribe?u=abc&utm_source=mail"
|
||||
unsubscribe_url_2 = "https://example.com/unsubscribe?fbclid=tracking&u=abc"
|
||||
|
||||
message_payloads = {
|
||||
"m1": {
|
||||
"payload": {
|
||||
"headers": [
|
||||
{"name": "List-Unsubscribe", "value": f"<{unsubscribe_url_1}>"},
|
||||
],
|
||||
"mimeType": "text/plain",
|
||||
"body": {"data": _b64url_text(f"Unsubscribe here: {unsubscribe_url_1}")},
|
||||
}
|
||||
},
|
||||
"m2": {
|
||||
"payload": {
|
||||
"headers": [],
|
||||
"mimeType": "text/plain",
|
||||
"body": {"data": _b64url_text(f"Click to unsubscribe: {unsubscribe_url_2}")},
|
||||
}
|
||||
},
|
||||
}
|
||||
state_file = tmp_path / "sent_links.json"
|
||||
service = _FakeGmailService(message_payloads)
|
||||
agent = UnsubscribeDigestAgent(
|
||||
gmail_service=service,
|
||||
query="label:Advertising",
|
||||
state_file=str(state_file),
|
||||
recipient_email="owner@example.com",
|
||||
send_empty_digest=False,
|
||||
)
|
||||
|
||||
first = agent.scan_and_send_digest(max_results=50)
|
||||
second = agent.scan_and_send_digest(max_results=50)
|
||||
|
||||
assert first.scanned_messages == 2
|
||||
assert first.extracted_unique_links == 1
|
||||
assert first.new_links == 1
|
||||
assert first.email_sent is True
|
||||
|
||||
assert second.scanned_messages == 2
|
||||
assert second.extracted_unique_links == 1
|
||||
assert second.new_links == 0
|
||||
assert second.email_sent is False
|
||||
|
||||
assert len(service.messages_api.sent_messages) == 1
|
||||
persisted = json.loads(state_file.read_text(encoding="utf-8"))
|
||||
assert persisted["sent_links"] == ["https://example.com/unsubscribe?u=abc"]
|
||||
Loading…
Reference in New Issue