from __future__ import annotations from dataclasses import dataclass import json import logging import re from strands import Agent from strands.models.openai import OpenAIModel logger = logging.getLogger("personal-agent.strands") ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "VEILLE_TECHNO", "OTHER"} SYSTEM_PROMPT = """You classify incoming emails into exactly one label: - LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters). - ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns. Do not label as ADVERTISING if the email is purely transactional (e.g. order confirmation, password reset) even if it contains some marketing language. Also do not label as ADVERTISING if the sender is Cybernetica. But if the sender is Cybernetica and the content is clearly promotional (e.g. "Check out our new product"), then label as ADVERTISING. And if the sender is Castorama and the content is about Communauté d'entraide, the label should be ADVERTISING. - VEILLE_TECHNO: Cybernetica emails that are clearly about technology watch, sharing interesting articles, insights, trends, etc. without a promotional angle. - OTHER: anything else. Rules: 1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional. 2) If uncertain between ADVERTISING and OTHER, choose OTHER. 3) Return only JSON with this schema: {"label":"LINKEDIN|ADVERTISING|VEILLE_TECHNO|OTHER","confidence":0.0-1.0,"reason":"short reason"}""" @dataclass(frozen=True) class EmailClassification: label: str confidence: float reason: str class StrandsEmailClassifier: def __init__( self, *, api_key: str, model_id: str, base_url: str | None = None, timeout_seconds: float = 20.0, temperature: float = 0.0, ) -> None: if not api_key: raise ValueError("Strands/OpenAI API key is required for classification.") self._api_key = api_key self._model_id = model_id self._base_url = base_url self._timeout_seconds = timeout_seconds self._temperature = temperature self._temperature_enabled = True self.agent = self._build_agent(include_temperature=True) def classify( self, *, sender: str, subject: str, snippet: str, list_unsubscribe: str, precedence: str, message_label_ids: set[str], ) -> EmailClassification: prompt_payload: dict[str, object] = { "sender": sender, "subject": subject, "snippet": snippet, "list_unsubscribe_present": bool(list_unsubscribe.strip()), "precedence": precedence, "gmail_label_ids": sorted(message_label_ids), "output_json_only": True, } response = self._invoke_agent_with_temperature_fallback(prompt_payload) parsed: dict[str, object] = _parse_json(str(response)) label = str(parsed.get("label", "OTHER")).upper().strip() if label not in ALLOWED_LABELS: logger.warning("Unexpected Strands label '%s', falling back to OTHER.", label) label = "OTHER" confidence = _to_confidence(parsed.get("confidence", 0)) reason = str(parsed.get("reason", "")).strip() return EmailClassification(label=label, confidence=confidence, reason=reason) def _invoke_agent_with_temperature_fallback(self, prompt_payload: dict[str, object]) -> object: prompt = json.dumps(prompt_payload, ensure_ascii=True) try: return self.agent(prompt) except Exception as exc: if self._temperature_enabled and _is_temperature_unsupported(exc): logger.warning( "Model '%s' rejected temperature=%s; retrying without temperature.", self._model_id, self._temperature, ) self._temperature_enabled = False self.agent = self._build_agent(include_temperature=False) return self.agent(prompt) raise def _build_agent(self, *, include_temperature: bool) -> Agent: client_args: dict[str, object] = {"api_key": self._api_key, "timeout": self._timeout_seconds} if self._base_url: client_args["base_url"] = self._base_url params: dict[str, float] | None = None if include_temperature: params = {"temperature": self._temperature} model_kwargs: dict[str, object] = { "client_args": client_args, "model_id": self._model_id, } if params is not None: model_kwargs["params"] = params model = OpenAIModel(**model_kwargs) # type: ignore return Agent(model=model, system_prompt=SYSTEM_PROMPT) def _parse_json(content: str) -> dict[str, object]: if not content: return {} try: return json.loads(content) except json.JSONDecodeError: match = re.search(r"\{.*\}", content, re.DOTALL) if not match: return {} try: return json.loads(match.group(0)) except json.JSONDecodeError: return {} def _to_confidence(raw_value: object) -> float: try: confidence = float(raw_value) # type: ignore except (TypeError, ValueError): return 0.0 if confidence < 0: return 0.0 if confidence > 1: return 1.0 return confidence def _is_temperature_unsupported(exc: Exception) -> bool: message = str(exc).lower() return "temperature" in message and "unsupported" in message