You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
5.6 KiB
Python
153 lines
5.6 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import json
|
|
import logging
|
|
import re
|
|
|
|
from strands import Agent
|
|
from strands.models.openai import OpenAIModel
|
|
|
|
logger = logging.getLogger("personal-agent.strands")
|
|
|
|
ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "VEILLE_TECHNO", "OTHER"}
|
|
|
|
SYSTEM_PROMPT = """You classify incoming emails into exactly one label:
|
|
- LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters).
|
|
- ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns. Do not label as ADVERTISING if the email is purely transactional (e.g. order confirmation, password reset) even if it contains some marketing language. Also do not label as ADVERTISING if the sender is Cybernetica. But if the sender is Cybernetica and the content is clearly promotional (e.g. "Check out our new product"), then label as ADVERTISING. And if the sender is Castorama and the content is about Communauté d'entraide, the label should be ADVERTISING.
|
|
- VEILLE_TECHNO: Cybernetica emails that are clearly about technology watch, sharing interesting articles, insights, trends, etc. without a promotional angle.
|
|
- OTHER: anything else.
|
|
|
|
Rules:
|
|
1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional.
|
|
2) If uncertain between ADVERTISING and OTHER, choose OTHER.
|
|
3) Return only JSON with this schema:
|
|
{"label":"LINKEDIN|ADVERTISING|VEILLE_TECHNO|OTHER","confidence":0.0-1.0,"reason":"short reason"}"""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class EmailClassification:
|
|
label: str
|
|
confidence: float
|
|
reason: str
|
|
|
|
|
|
class StrandsEmailClassifier:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str,
|
|
model_id: str,
|
|
base_url: str | None = None,
|
|
timeout_seconds: float = 20.0,
|
|
temperature: float = 0.0,
|
|
) -> None:
|
|
if not api_key:
|
|
raise ValueError("Strands/OpenAI API key is required for classification.")
|
|
|
|
self._api_key = api_key
|
|
self._model_id = model_id
|
|
self._base_url = base_url
|
|
self._timeout_seconds = timeout_seconds
|
|
self._temperature = temperature
|
|
self._temperature_enabled = True
|
|
self.agent = self._build_agent(include_temperature=True)
|
|
|
|
def classify(
|
|
self,
|
|
*,
|
|
sender: str,
|
|
subject: str,
|
|
snippet: str,
|
|
list_unsubscribe: str,
|
|
precedence: str,
|
|
message_label_ids: set[str],
|
|
) -> EmailClassification:
|
|
prompt_payload: dict[str, object] = {
|
|
"sender": sender,
|
|
"subject": subject,
|
|
"snippet": snippet,
|
|
"list_unsubscribe_present": bool(list_unsubscribe.strip()),
|
|
"precedence": precedence,
|
|
"gmail_label_ids": sorted(message_label_ids),
|
|
"output_json_only": True,
|
|
}
|
|
|
|
response = self._invoke_agent_with_temperature_fallback(prompt_payload)
|
|
parsed: dict[str, object] = _parse_json(str(response))
|
|
label = str(parsed.get("label", "OTHER")).upper().strip()
|
|
if label not in ALLOWED_LABELS:
|
|
logger.warning("Unexpected Strands label '%s', falling back to OTHER.", label)
|
|
label = "OTHER"
|
|
|
|
confidence = _to_confidence(parsed.get("confidence", 0))
|
|
reason = str(parsed.get("reason", "")).strip()
|
|
return EmailClassification(label=label, confidence=confidence, reason=reason)
|
|
|
|
def _invoke_agent_with_temperature_fallback(self, prompt_payload: dict[str, object]) -> object:
|
|
prompt = json.dumps(prompt_payload, ensure_ascii=True)
|
|
try:
|
|
return self.agent(prompt)
|
|
except Exception as exc:
|
|
if self._temperature_enabled and _is_temperature_unsupported(exc):
|
|
logger.warning(
|
|
"Model '%s' rejected temperature=%s; retrying without temperature.",
|
|
self._model_id,
|
|
self._temperature,
|
|
)
|
|
self._temperature_enabled = False
|
|
self.agent = self._build_agent(include_temperature=False)
|
|
return self.agent(prompt)
|
|
raise
|
|
|
|
def _build_agent(self, *, include_temperature: bool) -> Agent:
|
|
client_args: dict[str, object] = {"api_key": self._api_key, "timeout": self._timeout_seconds}
|
|
if self._base_url:
|
|
client_args["base_url"] = self._base_url
|
|
|
|
params: dict[str, float] | None = None
|
|
if include_temperature:
|
|
params = {"temperature": self._temperature}
|
|
|
|
model_kwargs: dict[str, object] = {
|
|
"client_args": client_args,
|
|
"model_id": self._model_id,
|
|
}
|
|
if params is not None:
|
|
model_kwargs["params"] = params
|
|
|
|
model = OpenAIModel(**model_kwargs) # type: ignore
|
|
return Agent(model=model, system_prompt=SYSTEM_PROMPT)
|
|
|
|
|
|
def _parse_json(content: str) -> dict[str, object]:
|
|
if not content:
|
|
return {}
|
|
try:
|
|
return json.loads(content)
|
|
except json.JSONDecodeError:
|
|
match = re.search(r"\{.*\}", content, re.DOTALL)
|
|
if not match:
|
|
return {}
|
|
try:
|
|
return json.loads(match.group(0))
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
|
|
|
|
def _to_confidence(raw_value: object) -> float:
|
|
try:
|
|
confidence = float(raw_value) # type: ignore
|
|
except (TypeError, ValueError):
|
|
return 0.0
|
|
if confidence < 0:
|
|
return 0.0
|
|
if confidence > 1:
|
|
return 1.0
|
|
return confidence
|
|
|
|
|
|
def _is_temperature_unsupported(exc: Exception) -> bool:
|
|
message = str(exc).lower()
|
|
return "temperature" in message and "unsupported" in message
|