from __future__ import annotations from dataclasses import dataclass import json import logging import re from strands import Agent from strands.models.openai import OpenAIModel logger = logging.getLogger("personal-agent.strands") ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "OTHER"} SYSTEM_PROMPT = """You classify incoming emails into exactly one label: - LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters). - ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns. - OTHER: anything else. Rules: 1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional. 2) If uncertain between ADVERTISING and OTHER, choose OTHER. 3) Return only JSON with this schema: {"label":"LINKEDIN|ADVERTISING|OTHER","confidence":0.0-1.0,"reason":"short reason"}""" @dataclass(frozen=True) class EmailClassification: label: str confidence: float reason: str class StrandsEmailClassifier: def __init__( self, *, api_key: str, model_id: str, base_url: str | None = None, timeout_seconds: float = 20.0, temperature: float = 0.0, ) -> None: if not api_key: raise ValueError("Strands/OpenAI API key is required for classification.") self._api_key = api_key self._model_id = model_id self._base_url = base_url self._timeout_seconds = timeout_seconds self._temperature = temperature self._temperature_enabled = True self.agent = self._build_agent(include_temperature=True) def classify( self, *, sender: str, subject: str, snippet: str, list_unsubscribe: str, precedence: str, message_label_ids: set[str], ) -> EmailClassification: prompt_payload = { "sender": sender, "subject": subject, "snippet": snippet, "list_unsubscribe_present": bool(list_unsubscribe.strip()), "precedence": precedence, "gmail_label_ids": sorted(message_label_ids), "output_json_only": True, } response = self._invoke_agent_with_temperature_fallback(prompt_payload) parsed = _parse_json(str(response)) label = str(parsed.get("label", "OTHER")).upper().strip() if label not in ALLOWED_LABELS: logger.warning("Unexpected Strands label '%s', falling back to OTHER.", label) label = "OTHER" confidence = _to_confidence(parsed.get("confidence", 0)) reason = str(parsed.get("reason", "")).strip() return EmailClassification(label=label, confidence=confidence, reason=reason) def _invoke_agent_with_temperature_fallback(self, prompt_payload: dict) -> object: prompt = json.dumps(prompt_payload, ensure_ascii=True) try: return self.agent(prompt) except Exception as exc: if self._temperature_enabled and _is_temperature_unsupported(exc): logger.warning( "Model '%s' rejected temperature=%s; retrying without temperature.", self._model_id, self._temperature, ) self._temperature_enabled = False self.agent = self._build_agent(include_temperature=False) return self.agent(prompt) raise def _build_agent(self, *, include_temperature: bool) -> Agent: client_args = {"api_key": self._api_key, "timeout": self._timeout_seconds} if self._base_url: client_args["base_url"] = self._base_url params: dict[str, float] | None = None if include_temperature: params = {"temperature": self._temperature} model_kwargs = { "client_args": client_args, "model_id": self._model_id, } if params is not None: model_kwargs["params"] = params model = OpenAIModel(**model_kwargs) return Agent(model=model, system_prompt=SYSTEM_PROMPT) def _parse_json(content: str) -> dict: if not content: return {} try: return json.loads(content) except json.JSONDecodeError: match = re.search(r"\{.*\}", content, re.DOTALL) if not match: return {} try: return json.loads(match.group(0)) except json.JSONDecodeError: return {} def _to_confidence(raw_value: object) -> float: try: confidence = float(raw_value) except (TypeError, ValueError): return 0.0 if confidence < 0: return 0.0 if confidence > 1: return 1.0 return confidence def _is_temperature_unsupported(exc: Exception) -> bool: message = str(exc).lower() return "temperature" in message and "unsupported" in message