You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

153 lines
5.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import json
import logging
import re
from strands import Agent
from strands.models.openai import OpenAIModel
logger = logging.getLogger("personal-agent.strands")
ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "VEILLE_TECHNO", "OTHER"}
SYSTEM_PROMPT = """You classify incoming emails into exactly one label:
- LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters).
- ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns. Do not label as ADVERTISING if the email is purely transactional (e.g. order confirmation, password reset) even if it contains some marketing language. Also do not label as ADVERTISING if the sender is Cybernetica. But if the sender is Cybernetica and the content is clearly promotional (e.g. "Check out our new product"), then label as ADVERTISING. And if the sender is Castorama and the content is about Communauté d'entraide, the label should be ADVERTISING.
- VEILLE_TECHNO: Cybernetica emails that are clearly about technology watch, sharing interesting articles, insights, trends, etc. without a promotional angle.
- OTHER: anything else.
Rules:
1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional.
2) If uncertain between ADVERTISING and OTHER, choose OTHER.
3) Return only JSON with this schema:
{"label":"LINKEDIN|ADVERTISING|VEILLE_TECHNO|OTHER","confidence":0.0-1.0,"reason":"short reason"}"""
@dataclass(frozen=True)
class EmailClassification:
label: str
confidence: float
reason: str
class StrandsEmailClassifier:
def __init__(
self,
*,
api_key: str,
model_id: str,
base_url: str | None = None,
timeout_seconds: float = 20.0,
temperature: float = 0.0,
) -> None:
if not api_key:
raise ValueError("Strands/OpenAI API key is required for classification.")
self._api_key = api_key
self._model_id = model_id
self._base_url = base_url
self._timeout_seconds = timeout_seconds
self._temperature = temperature
self._temperature_enabled = True
self.agent = self._build_agent(include_temperature=True)
def classify(
self,
*,
sender: str,
subject: str,
snippet: str,
list_unsubscribe: str,
precedence: str,
message_label_ids: set[str],
) -> EmailClassification:
prompt_payload = {
"sender": sender,
"subject": subject,
"snippet": snippet,
"list_unsubscribe_present": bool(list_unsubscribe.strip()),
"precedence": precedence,
"gmail_label_ids": sorted(message_label_ids),
"output_json_only": True,
}
response = self._invoke_agent_with_temperature_fallback(prompt_payload)
parsed = _parse_json(str(response))
label = str(parsed.get("label", "OTHER")).upper().strip()
if label not in ALLOWED_LABELS:
logger.warning("Unexpected Strands label '%s', falling back to OTHER.", label)
label = "OTHER"
confidence = _to_confidence(parsed.get("confidence", 0))
reason = str(parsed.get("reason", "")).strip()
return EmailClassification(label=label, confidence=confidence, reason=reason)
def _invoke_agent_with_temperature_fallback(self, prompt_payload: dict) -> object:
prompt = json.dumps(prompt_payload, ensure_ascii=True)
try:
return self.agent(prompt)
except Exception as exc:
if self._temperature_enabled and _is_temperature_unsupported(exc):
logger.warning(
"Model '%s' rejected temperature=%s; retrying without temperature.",
self._model_id,
self._temperature,
)
self._temperature_enabled = False
self.agent = self._build_agent(include_temperature=False)
return self.agent(prompt)
raise
def _build_agent(self, *, include_temperature: bool) -> Agent:
client_args = {"api_key": self._api_key, "timeout": self._timeout_seconds}
if self._base_url:
client_args["base_url"] = self._base_url
params: dict[str, float] | None = None
if include_temperature:
params = {"temperature": self._temperature}
model_kwargs = {
"client_args": client_args,
"model_id": self._model_id,
}
if params is not None:
model_kwargs["params"] = params
model = OpenAIModel(**model_kwargs)
return Agent(model=model, system_prompt=SYSTEM_PROMPT)
def _parse_json(content: str) -> dict:
if not content:
return {}
try:
return json.loads(content)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", content, re.DOTALL)
if not match:
return {}
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
return {}
def _to_confidence(raw_value: object) -> float:
try:
confidence = float(raw_value)
except (TypeError, ValueError):
return 0.0
if confidence < 0:
return 0.0
if confidence > 1:
return 1.0
return confidence
def _is_temperature_unsupported(exc: Exception) -> bool:
message = str(exc).lower()
return "temperature" in message and "unsupported" in message