from __future__ import annotations from dataclasses import dataclass import json import logging import re from openai import OpenAI logger = logging.getLogger("personal-agent.llm") ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "OTHER"} SYSTEM_PROMPT = """You classify incoming emails into exactly one label: - LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters). - ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns. - OTHER: anything else. Rules: 1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional. 2) If uncertain between ADVERTISING and OTHER, choose OTHER. 3) Return only JSON with this schema: {"label":"LINKEDIN|ADVERTISING|OTHER","confidence":0.0-1.0,"reason":"short reason"}""" @dataclass(frozen=True) class LLMClassification: label: str confidence: float reason: str class LLMEmailClassifier: def __init__( self, *, api_key: str, model: str, base_url: str | None = None, timeout_seconds: float = 20.0, ) -> None: if not api_key: raise ValueError("LLM API key is required for LLM classification.") self.model = model self.client = OpenAI( api_key=api_key, base_url=base_url, timeout=timeout_seconds, ) def classify( self, *, sender: str, subject: str, snippet: str, list_unsubscribe: str, precedence: str, message_label_ids: set[str], ) -> LLMClassification: email_payload = { "sender": sender, "subject": subject, "snippet": snippet, "list_unsubscribe_present": bool(list_unsubscribe.strip()), "precedence": precedence, "gmail_label_ids": sorted(message_label_ids), } completion = self.client.chat.completions.create( model=self.model, temperature=0, response_format={"type": "json_object"}, max_tokens=120, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": json.dumps(email_payload, ensure_ascii=True)}, ], ) content = completion.choices[0].message.content or "{}" parsed = _parse_json(content) label = str(parsed.get("label", "OTHER")).upper().strip() if label not in ALLOWED_LABELS: logger.warning("Unexpected LLM label '%s', falling back to OTHER.", label) label = "OTHER" confidence = _to_confidence(parsed.get("confidence", 0.0)) reason = str(parsed.get("reason", "")).strip() return LLMClassification(label=label, confidence=confidence, reason=reason) def _parse_json(content: str) -> dict: if not content: return {} try: return json.loads(content) except json.JSONDecodeError: match = re.search(r"\{.*\}", content, re.DOTALL) if not match: return {} try: return json.loads(match.group(0)) except json.JSONDecodeError: return {} def _to_confidence(raw_value: object) -> float: try: confidence = float(raw_value) except (TypeError, ValueError): return 0.0 if confidence < 0: return 0.0 if confidence > 1: return 1.0 return confidence