Added openai classifier
parent
a14b02ad3c
commit
d9c7497acb
@ -1,6 +1,11 @@
|
|||||||
GOOGLE_CLIENT_SECRETS_FILE=credentials.json
|
GOOGLE_CLIENT_SECRETS_FILE=credentials.json
|
||||||
GOOGLE_TOKEN_FILE=token.json
|
GOOGLE_TOKEN_FILE=token.json
|
||||||
AGENT_API_KEY=change-me
|
AGENT_API_KEY=change-me
|
||||||
|
LLM_API_KEY=
|
||||||
|
LLM_MODEL=gpt-4.1-mini
|
||||||
|
LLM_BASE_URL=
|
||||||
|
LLM_TIMEOUT_SECONDS=20
|
||||||
|
LLM_FALLBACK_TO_RULES=false
|
||||||
GMAIL_SCAN_INTERVAL_MINUTES=5
|
GMAIL_SCAN_INTERVAL_MINUTES=5
|
||||||
GMAIL_QUERY=in:inbox -label:AgentProcessed newer_than:7d
|
GMAIL_QUERY=in:inbox -label:AgentProcessed newer_than:7d
|
||||||
LOG_LEVEL=INFO
|
LOG_LEVEL=INFO
|
||||||
|
|||||||
@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger("personal-agent.llm")
|
||||||
|
|
||||||
|
ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "OTHER"}
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """You classify incoming emails into exactly one label:
|
||||||
|
- LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters).
|
||||||
|
- ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns.
|
||||||
|
- OTHER: anything else.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional.
|
||||||
|
2) If uncertain between ADVERTISING and OTHER, choose OTHER.
|
||||||
|
3) Return only JSON with this schema:
|
||||||
|
{"label":"LINKEDIN|ADVERTISING|OTHER","confidence":0.0-1.0,"reason":"short reason"}"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LLMClassification:
|
||||||
|
label: str
|
||||||
|
confidence: float
|
||||||
|
reason: str
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEmailClassifier:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
api_key: str,
|
||||||
|
model: str,
|
||||||
|
base_url: str | None = None,
|
||||||
|
timeout_seconds: float = 20.0,
|
||||||
|
) -> None:
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("LLM API key is required for LLM classification.")
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def classify(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sender: str,
|
||||||
|
subject: str,
|
||||||
|
snippet: str,
|
||||||
|
list_unsubscribe: str,
|
||||||
|
precedence: str,
|
||||||
|
message_label_ids: set[str],
|
||||||
|
) -> LLMClassification:
|
||||||
|
email_payload = {
|
||||||
|
"sender": sender,
|
||||||
|
"subject": subject,
|
||||||
|
"snippet": snippet,
|
||||||
|
"list_unsubscribe_present": bool(list_unsubscribe.strip()),
|
||||||
|
"precedence": precedence,
|
||||||
|
"gmail_label_ids": sorted(message_label_ids),
|
||||||
|
}
|
||||||
|
|
||||||
|
completion = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
temperature=0,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
max_tokens=120,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": json.dumps(email_payload, ensure_ascii=True)},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
content = completion.choices[0].message.content or "{}"
|
||||||
|
|
||||||
|
parsed = _parse_json(content)
|
||||||
|
label = str(parsed.get("label", "OTHER")).upper().strip()
|
||||||
|
if label not in ALLOWED_LABELS:
|
||||||
|
logger.warning("Unexpected LLM label '%s', falling back to OTHER.", label)
|
||||||
|
label = "OTHER"
|
||||||
|
|
||||||
|
confidence = _to_confidence(parsed.get("confidence", 0.0))
|
||||||
|
reason = str(parsed.get("reason", "")).strip()
|
||||||
|
return LLMClassification(label=label, confidence=confidence, reason=reason)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json(content: str) -> dict:
|
||||||
|
if not content:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
match = re.search(r"\{.*\}", content, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
return json.loads(match.group(0))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _to_confidence(raw_value: object) -> float:
|
||||||
|
try:
|
||||||
|
confidence = float(raw_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if confidence < 0:
|
||||||
|
return 0.0
|
||||||
|
if confidence > 1:
|
||||||
|
return 1.0
|
||||||
|
return confidence
|
||||||
Loading…
Reference in New Issue