You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
4.8 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import json
import logging
import re
from strands import Agent
from strands.models.openai import OpenAIModel
logger = logging.getLogger("personal-agent.strands")
ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "OTHER"}
SYSTEM_PROMPT = """You classify incoming emails into exactly one label:
- LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters).
- ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns.
- OTHER: anything else.
Rules:
1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional.
2) If uncertain between ADVERTISING and OTHER, choose OTHER.
3) Return only JSON with this schema:
{"label":"LINKEDIN|ADVERTISING|OTHER","confidence":0.0-1.0,"reason":"short reason"}"""
@dataclass(frozen=True)
class EmailClassification:
label: str
confidence: float
reason: str
class StrandsEmailClassifier:
def __init__(
self,
*,
api_key: str,
model_id: str,
base_url: str | None = None,
timeout_seconds: float = 20.0,
temperature: float = 0.0,
) -> None:
if not api_key:
raise ValueError("Strands/OpenAI API key is required for classification.")
self._api_key = api_key
self._model_id = model_id
self._base_url = base_url
self._timeout_seconds = timeout_seconds
self._temperature = temperature
self._temperature_enabled = True
self.agent = self._build_agent(include_temperature=True)
def classify(
self,
*,
sender: str,
subject: str,
snippet: str,
list_unsubscribe: str,
precedence: str,
message_label_ids: set[str],
) -> EmailClassification:
prompt_payload = {
"sender": sender,
"subject": subject,
"snippet": snippet,
"list_unsubscribe_present": bool(list_unsubscribe.strip()),
"precedence": precedence,
"gmail_label_ids": sorted(message_label_ids),
"output_json_only": True,
}
response = self._invoke_agent_with_temperature_fallback(prompt_payload)
parsed = _parse_json(str(response))
label = str(parsed.get("label", "OTHER")).upper().strip()
if label not in ALLOWED_LABELS:
logger.warning("Unexpected Strands label '%s', falling back to OTHER.", label)
label = "OTHER"
confidence = _to_confidence(parsed.get("confidence", 0))
reason = str(parsed.get("reason", "")).strip()
return EmailClassification(label=label, confidence=confidence, reason=reason)
def _invoke_agent_with_temperature_fallback(self, prompt_payload: dict) -> object:
prompt = json.dumps(prompt_payload, ensure_ascii=True)
try:
return self.agent(prompt)
except Exception as exc:
if self._temperature_enabled and _is_temperature_unsupported(exc):
logger.warning(
"Model '%s' rejected temperature=%s; retrying without temperature.",
self._model_id,
self._temperature,
)
self._temperature_enabled = False
self.agent = self._build_agent(include_temperature=False)
return self.agent(prompt)
raise
def _build_agent(self, *, include_temperature: bool) -> Agent:
client_args = {"api_key": self._api_key, "timeout": self._timeout_seconds}
if self._base_url:
client_args["base_url"] = self._base_url
params: dict[str, float] | None = None
if include_temperature:
params = {"temperature": self._temperature}
model_kwargs = {
"client_args": client_args,
"model_id": self._model_id,
}
if params is not None:
model_kwargs["params"] = params
model = OpenAIModel(**model_kwargs)
return Agent(model=model, system_prompt=SYSTEM_PROMPT)
def _parse_json(content: str) -> dict:
if not content:
return {}
try:
return json.loads(content)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", content, re.DOTALL)
if not match:
return {}
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
return {}
def _to_confidence(raw_value: object) -> float:
try:
confidence = float(raw_value)
except (TypeError, ValueError):
return 0.0
if confidence < 0:
return 0.0
if confidence > 1:
return 1.0
return confidence
def _is_temperature_unsupported(exc: Exception) -> bool:
message = str(exc).lower()
return "temperature" in message and "unsupported" in message