You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
3.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import json
import logging
import re
from openai import OpenAI
logger = logging.getLogger("personal-agent.llm")
ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "OTHER"}
SYSTEM_PROMPT = """You classify incoming emails into exactly one label:
- LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters).
- ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns.
- OTHER: anything else.
Rules:
1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional.
2) If uncertain between ADVERTISING and OTHER, choose OTHER.
3) Return only JSON with this schema:
{"label":"LINKEDIN|ADVERTISING|OTHER","confidence":0.0-1.0,"reason":"short reason"}"""
@dataclass(frozen=True)
class LLMClassification:
label: str
confidence: float
reason: str
class LLMEmailClassifier:
def __init__(
self,
*,
api_key: str,
model: str,
base_url: str | None = None,
timeout_seconds: float = 20.0,
) -> None:
if not api_key:
raise ValueError("LLM API key is required for LLM classification.")
self.model = model
self.client = OpenAI(
api_key=api_key,
base_url=base_url,
timeout=timeout_seconds,
)
def classify(
self,
*,
sender: str,
subject: str,
snippet: str,
list_unsubscribe: str,
precedence: str,
message_label_ids: set[str],
) -> LLMClassification:
email_payload = {
"sender": sender,
"subject": subject,
"snippet": snippet,
"list_unsubscribe_present": bool(list_unsubscribe.strip()),
"precedence": precedence,
"gmail_label_ids": sorted(message_label_ids),
}
completion = self.client.chat.completions.create(
model=self.model,
temperature=0,
response_format={"type": "json_object"},
max_tokens=120,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": json.dumps(email_payload, ensure_ascii=True)},
],
)
content = completion.choices[0].message.content or "{}"
parsed = _parse_json(content)
label = str(parsed.get("label", "OTHER")).upper().strip()
if label not in ALLOWED_LABELS:
logger.warning("Unexpected LLM label '%s', falling back to OTHER.", label)
label = "OTHER"
confidence = _to_confidence(parsed.get("confidence", 0.0))
reason = str(parsed.get("reason", "")).strip()
return LLMClassification(label=label, confidence=confidence, reason=reason)
def _parse_json(content: str) -> dict:
if not content:
return {}
try:
return json.loads(content)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", content, re.DOTALL)
if not match:
return {}
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
return {}
def _to_confidence(raw_value: object) -> float:
try:
confidence = float(raw_value)
except (TypeError, ValueError):
return 0.0
if confidence < 0:
return 0.0
if confidence > 1:
return 1.0
return confidence