You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
3.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import json
import logging
import re
from strands import Agent
from strands.models.openai import OpenAIModel
logger = logging.getLogger("personal-agent.strands")
ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "OTHER"}
SYSTEM_PROMPT = """You classify incoming emails into exactly one label:
- LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters).
- ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns.
- OTHER: anything else.
Rules:
1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional.
2) If uncertain between ADVERTISING and OTHER, choose OTHER.
3) Return only JSON with this schema:
{"label":"LINKEDIN|ADVERTISING|OTHER","confidence":0.0-1.0,"reason":"short reason"}"""
@dataclass(frozen=True)
class EmailClassification:
label: str
confidence: float
reason: str
class StrandsEmailClassifier:
def __init__(
self,
*,
api_key: str,
model_id: str,
base_url: str | None = None,
timeout_seconds: float = 20.0,
temperature: float = 0.0,
) -> None:
if not api_key:
raise ValueError("Strands/OpenAI API key is required for classification.")
client_args = {"api_key": api_key, "timeout": timeout_seconds}
if base_url:
client_args["base_url"] = base_url
model = OpenAIModel(
client_args=client_args,
model_id=model_id,
params={"temperature": temperature},
)
self.agent = Agent(model=model, system_prompt=SYSTEM_PROMPT)
def classify(
self,
*,
sender: str,
subject: str,
snippet: str,
list_unsubscribe: str,
precedence: str,
message_label_ids: set[str],
) -> EmailClassification:
prompt_payload = {
"sender": sender,
"subject": subject,
"snippet": snippet,
"list_unsubscribe_present": bool(list_unsubscribe.strip()),
"precedence": precedence,
"gmail_label_ids": sorted(message_label_ids),
"output_json_only": True,
}
response = self.agent(json.dumps(prompt_payload, ensure_ascii=True))
parsed = _parse_json(str(response))
label = str(parsed.get("label", "OTHER")).upper().strip()
if label not in ALLOWED_LABELS:
logger.warning("Unexpected Strands label '%s', falling back to OTHER.", label)
label = "OTHER"
confidence = _to_confidence(parsed.get("confidence", 0))
reason = str(parsed.get("reason", "")).strip()
return EmailClassification(label=label, confidence=confidence, reason=reason)
def _parse_json(content: str) -> dict:
if not content:
return {}
try:
return json.loads(content)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", content, re.DOTALL)
if not match:
return {}
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
return {}
def _to_confidence(raw_value: object) -> float:
try:
confidence = float(raw_value)
except (TypeError, ValueError):
return 0.0
if confidence < 0:
return 0.0
if confidence > 1:
return 1.0
return confidence