You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import json
|
|
import logging
|
|
import re
|
|
|
|
from strands import Agent
|
|
from strands.models.openai import OpenAIModel
|
|
|
|
logger = logging.getLogger("personal-agent.strands")
|
|
|
|
ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "OTHER"}
|
|
|
|
SYSTEM_PROMPT = """You classify incoming emails into exactly one label:
|
|
- LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters).
|
|
- ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns.
|
|
- OTHER: anything else.
|
|
|
|
Rules:
|
|
1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional.
|
|
2) If uncertain between ADVERTISING and OTHER, choose OTHER.
|
|
3) Return only JSON with this schema:
|
|
{"label":"LINKEDIN|ADVERTISING|OTHER","confidence":0.0-1.0,"reason":"short reason"}"""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class EmailClassification:
|
|
label: str
|
|
confidence: float
|
|
reason: str
|
|
|
|
|
|
class StrandsEmailClassifier:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str,
|
|
model_id: str,
|
|
base_url: str | None = None,
|
|
timeout_seconds: float = 20.0,
|
|
temperature: float = 0.0,
|
|
) -> None:
|
|
if not api_key:
|
|
raise ValueError("Strands/OpenAI API key is required for classification.")
|
|
|
|
client_args = {"api_key": api_key, "timeout": timeout_seconds}
|
|
if base_url:
|
|
client_args["base_url"] = base_url
|
|
|
|
model = OpenAIModel(
|
|
client_args=client_args,
|
|
model_id=model_id,
|
|
params={"temperature": temperature},
|
|
)
|
|
self.agent = Agent(model=model, system_prompt=SYSTEM_PROMPT)
|
|
|
|
def classify(
|
|
self,
|
|
*,
|
|
sender: str,
|
|
subject: str,
|
|
snippet: str,
|
|
list_unsubscribe: str,
|
|
precedence: str,
|
|
message_label_ids: set[str],
|
|
) -> EmailClassification:
|
|
prompt_payload = {
|
|
"sender": sender,
|
|
"subject": subject,
|
|
"snippet": snippet,
|
|
"list_unsubscribe_present": bool(list_unsubscribe.strip()),
|
|
"precedence": precedence,
|
|
"gmail_label_ids": sorted(message_label_ids),
|
|
"output_json_only": True,
|
|
}
|
|
|
|
response = self.agent(json.dumps(prompt_payload, ensure_ascii=True))
|
|
parsed = _parse_json(str(response))
|
|
label = str(parsed.get("label", "OTHER")).upper().strip()
|
|
if label not in ALLOWED_LABELS:
|
|
logger.warning("Unexpected Strands label '%s', falling back to OTHER.", label)
|
|
label = "OTHER"
|
|
|
|
confidence = _to_confidence(parsed.get("confidence", 0))
|
|
reason = str(parsed.get("reason", "")).strip()
|
|
return EmailClassification(label=label, confidence=confidence, reason=reason)
|
|
|
|
|
|
def _parse_json(content: str) -> dict:
|
|
if not content:
|
|
return {}
|
|
try:
|
|
return json.loads(content)
|
|
except json.JSONDecodeError:
|
|
match = re.search(r"\{.*\}", content, re.DOTALL)
|
|
if not match:
|
|
return {}
|
|
try:
|
|
return json.loads(match.group(0))
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
|
|
|
|
def _to_confidence(raw_value: object) -> float:
|
|
try:
|
|
confidence = float(raw_value)
|
|
except (TypeError, ValueError):
|
|
return 0.0
|
|
if confidence < 0:
|
|
return 0.0
|
|
if confidence > 1:
|
|
return 1.0
|
|
return confidence
|