from __future__ import annotations from dataclasses import dataclass import json import logging import re from strands import Agent from strands.models.openai import OpenAIModel logger = logging.getLogger("personal-agent.strands") ALLOWED_LABELS = {"LINKEDIN", "ADVERTISING", "OTHER"} SYSTEM_PROMPT = """You classify incoming emails into exactly one label: - LINKEDIN: official LinkedIn platform emails (job alerts, invites, network updates, LinkedIn newsletters). - ADVERTISING: marketing/promotional/sales emails, newsletters, coupons, deals, brand campaigns. - OTHER: anything else. Rules: 1) If sender/content clearly belongs to LinkedIn, choose LINKEDIN even if promotional. 2) If uncertain between ADVERTISING and OTHER, choose OTHER. 3) Return only JSON with this schema: {"label":"LINKEDIN|ADVERTISING|OTHER","confidence":0.0-1.0,"reason":"short reason"}""" @dataclass(frozen=True) class EmailClassification: label: str confidence: float reason: str class StrandsEmailClassifier: def __init__( self, *, api_key: str, model_id: str, base_url: str | None = None, timeout_seconds: float = 20.0, temperature: float = 0.0, ) -> None: if not api_key: raise ValueError("Strands/OpenAI API key is required for classification.") client_args = {"api_key": api_key, "timeout": timeout_seconds} if base_url: client_args["base_url"] = base_url model = OpenAIModel( client_args=client_args, model_id=model_id, params={"temperature": temperature}, ) self.agent = Agent(model=model, system_prompt=SYSTEM_PROMPT) def classify( self, *, sender: str, subject: str, snippet: str, list_unsubscribe: str, precedence: str, message_label_ids: set[str], ) -> EmailClassification: prompt_payload = { "sender": sender, "subject": subject, "snippet": snippet, "list_unsubscribe_present": bool(list_unsubscribe.strip()), "precedence": precedence, "gmail_label_ids": sorted(message_label_ids), "output_json_only": True, } response = self.agent(json.dumps(prompt_payload, ensure_ascii=True)) parsed = _parse_json(str(response)) label = str(parsed.get("label", "OTHER")).upper().strip() if label not in ALLOWED_LABELS: logger.warning("Unexpected Strands label '%s', falling back to OTHER.", label) label = "OTHER" confidence = _to_confidence(parsed.get("confidence", 0)) reason = str(parsed.get("reason", "")).strip() return EmailClassification(label=label, confidence=confidence, reason=reason) def _parse_json(content: str) -> dict: if not content: return {} try: return json.loads(content) except json.JSONDecodeError: match = re.search(r"\{.*\}", content, re.DOTALL) if not match: return {} try: return json.loads(match.group(0)) except json.JSONDecodeError: return {} def _to_confidence(raw_value: object) -> float: try: confidence = float(raw_value) except (TypeError, ValueError): return 0.0 if confidence < 0: return 0.0 if confidence > 1: return 1.0 return confidence