You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

316 lines
9.7 KiB
Python

from __future__ import annotations
import asyncio
import logging
from datetime import datetime
from typing import Annotated
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import Depends, FastAPI, Header, HTTPException, Query, status
from pydantic import BaseModel
from app.calendar_agent import CalendarAvailabilityAgent
from app.config import get_settings
from app.gmail_agent import GmailTriageAgent
from app.google_clients import build_calendar_service, build_gmail_service
from app.llm_classifier import LLMEmailClassifier
from app.unsubscribe_agent import UnsubscribeDigestAgent
settings = get_settings()
logging.basicConfig(level=getattr(logging, settings.log_level.upper(), logging.INFO))
logger = logging.getLogger("personal-agent")
app = FastAPI(title="Personal Agent", version="0.1.0")
scheduler: AsyncIOScheduler | None = None
scan_lock: asyncio.Lock | None = None
unsubscribe_lock: asyncio.Lock | None = None
llm_key_warning_logged = False
class ScanResponse(BaseModel):
scanned: int
linkedin: int
advertising: int
skipped: int
failed: int
class AvailabilityRequest(BaseModel):
start: str
end: str
calendar_ids: list[str] | None = None
class BusySlot(BaseModel):
calendar_id: str
start: str
end: str
class AvailabilityResponse(BaseModel):
start: str
end: str
available: bool
busy_slots: list[BusySlot]
checked_calendars: list[str]
class UnsubscribeDigestResponse(BaseModel):
scanned_messages: int
extracted_unique_links: int
new_links: int
sent_to: str | None
email_sent: bool
def verify_api_key(
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
authorization: Annotated[str | None, Header()] = None,
) -> None:
expected = settings.agent_api_key
if not expected:
return
provided = x_api_key
if not provided and authorization:
parts = authorization.split(" ", 1)
if len(parts) == 2 and parts[0].lower() == "bearer":
provided = parts[1]
if provided != expected:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key."
)
def _run_scan_once(max_results: int) -> ScanResponse:
gmail_service = build_gmail_service(settings)
gmail_agent = GmailTriageAgent(
gmail_service=gmail_service,
query=settings.gmail_query,
classifier=_build_llm_classifier(),
fallback_to_rules=settings.llm_fallback_to_rules,
)
result = gmail_agent.scan_and_route_messages(max_results=max_results)
return ScanResponse(
scanned=result.scanned,
linkedin=result.linkedin,
advertising=result.advertising,
skipped=result.skipped,
failed=result.failed,
)
def _run_unsubscribe_digest_once(max_results: int) -> UnsubscribeDigestResponse:
bounded_max_results = max(1, min(max_results, 500))
gmail_service = build_gmail_service(settings)
unsubscribe_agent = UnsubscribeDigestAgent(
gmail_service=gmail_service,
query=settings.unsubscribe_query,
state_file=settings.unsubscribe_state_file,
recipient_email=settings.unsubscribe_digest_recipient,
send_empty_digest=settings.unsubscribe_send_empty_digest,
)
result = unsubscribe_agent.scan_and_send_digest(max_results=bounded_max_results)
return UnsubscribeDigestResponse(
scanned_messages=result.scanned_messages,
extracted_unique_links=result.extracted_unique_links,
new_links=result.new_links,
sent_to=result.sent_to,
email_sent=result.email_sent,
)
def _build_llm_classifier() -> LLMEmailClassifier | None:
global llm_key_warning_logged
if not settings.llm_api_key:
if settings.llm_fallback_to_rules:
if not llm_key_warning_logged:
logger.warning(
"LLM_API_KEY not set. Falling back to rules-based classification."
)
llm_key_warning_logged = True
return None
raise RuntimeError(
"LLM_API_KEY is required when LLM_FALLBACK_TO_RULES is disabled."
)
try:
return LLMEmailClassifier(
api_key=settings.llm_api_key,
model=settings.llm_model,
base_url=settings.llm_base_url,
timeout_seconds=settings.llm_timeout_seconds,
)
except Exception:
if settings.llm_fallback_to_rules:
logger.exception("Could not initialize LLM classifier; using rules fallback.")
return None
raise
def _get_scan_lock() -> asyncio.Lock:
global scan_lock
if scan_lock is None:
scan_lock = asyncio.Lock()
return scan_lock
def _get_unsubscribe_lock() -> asyncio.Lock:
global unsubscribe_lock
if unsubscribe_lock is None:
unsubscribe_lock = asyncio.Lock()
return unsubscribe_lock
async def _scheduled_scan() -> None:
lock = _get_scan_lock()
if lock.locked():
logger.info("Previous scan still running, skipping this tick.")
return
async with lock:
try:
result = await asyncio.to_thread(_run_scan_once, 100)
logger.info("Scheduled scan complete: %s", result.model_dump())
except Exception:
logger.exception("Scheduled scan failed")
async def _scheduled_unsubscribe_digest() -> None:
lock = _get_unsubscribe_lock()
if lock.locked():
logger.info("Previous unsubscribe digest still running, skipping this tick.")
return
async with lock:
try:
result = await asyncio.to_thread(
_run_unsubscribe_digest_once, settings.unsubscribe_max_results
)
logger.info("Scheduled unsubscribe digest complete: %s", result.model_dump())
except Exception:
logger.exception("Scheduled unsubscribe digest failed")
@app.on_event("startup")
async def startup_event() -> None:
global scheduler
_get_scan_lock()
_get_unsubscribe_lock()
scheduler = AsyncIOScheduler()
scheduler.add_job(
_scheduled_scan,
"interval",
minutes=settings.gmail_scan_interval_minutes,
next_run_time=datetime.now(),
)
scheduler.add_job(
_scheduled_unsubscribe_digest,
"interval",
minutes=settings.unsubscribe_digest_interval_minutes,
next_run_time=datetime.now(),
)
scheduler.start()
logger.info(
"Scheduler started (scan interval=%s min, unsubscribe interval=%s min)",
settings.gmail_scan_interval_minutes,
settings.unsubscribe_digest_interval_minutes,
)
@app.on_event("shutdown")
async def shutdown_event() -> None:
if scheduler:
scheduler.shutdown(wait=False)
@app.get("/health")
def health() -> dict[str, object]:
return {
"status": "ok",
"scan_interval_minutes": settings.gmail_scan_interval_minutes,
"unsubscribe_digest_interval_minutes": settings.unsubscribe_digest_interval_minutes,
}
@app.post(
"/scan",
response_model=ScanResponse,
dependencies=[Depends(verify_api_key)],
)
async def scan_now(max_results: int = Query(100, ge=1, le=500)) -> ScanResponse:
async with _get_scan_lock():
try:
return await asyncio.to_thread(_run_scan_once, max_results)
except FileNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(exc),
) from exc
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Gmail scan failed: {exc}",
) from exc
@app.post(
"/availability",
response_model=AvailabilityResponse,
dependencies=[Depends(verify_api_key)],
)
async def availability(request: AvailabilityRequest) -> AvailabilityResponse:
try:
calendar_service = build_calendar_service(settings)
availability_agent = CalendarAvailabilityAgent(calendar_service=calendar_service)
result = await asyncio.to_thread(
availability_agent.get_availability,
request.start,
request.end,
request.calendar_ids,
)
return AvailabilityResponse(
start=result.start,
end=result.end,
available=result.available,
busy_slots=result.busy_slots,
checked_calendars=result.checked_calendars,
)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
except FileNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(exc),
) from exc
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Availability lookup failed: {exc}",
) from exc
@app.post(
"/unsubscribe-digest",
response_model=UnsubscribeDigestResponse,
dependencies=[Depends(verify_api_key)],
)
async def unsubscribe_digest_now(
max_results: int = Query(default=settings.unsubscribe_max_results, ge=1, le=500),
) -> UnsubscribeDigestResponse:
async with _get_unsubscribe_lock():
try:
return await asyncio.to_thread(_run_unsubscribe_digest_once, max_results)
except FileNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(exc),
) from exc
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unsubscribe digest failed: {exc}",
) from exc