|
|
|
@ -44,16 +44,13 @@ class StrandsEmailClassifier:
|
|
|
|
if not api_key:
|
|
|
|
if not api_key:
|
|
|
|
raise ValueError("Strands/OpenAI API key is required for classification.")
|
|
|
|
raise ValueError("Strands/OpenAI API key is required for classification.")
|
|
|
|
|
|
|
|
|
|
|
|
client_args = {"api_key": api_key, "timeout": timeout_seconds}
|
|
|
|
self._api_key = api_key
|
|
|
|
if base_url:
|
|
|
|
self._model_id = model_id
|
|
|
|
client_args["base_url"] = base_url
|
|
|
|
self._base_url = base_url
|
|
|
|
|
|
|
|
self._timeout_seconds = timeout_seconds
|
|
|
|
model = OpenAIModel(
|
|
|
|
self._temperature = temperature
|
|
|
|
client_args=client_args,
|
|
|
|
self._temperature_enabled = True
|
|
|
|
model_id=model_id,
|
|
|
|
self.agent = self._build_agent(include_temperature=True)
|
|
|
|
params={"temperature": temperature},
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
self.agent = Agent(model=model, system_prompt=SYSTEM_PROMPT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify(
|
|
|
|
def classify(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
@ -75,7 +72,7 @@ class StrandsEmailClassifier:
|
|
|
|
"output_json_only": True,
|
|
|
|
"output_json_only": True,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
response = self.agent(json.dumps(prompt_payload, ensure_ascii=True))
|
|
|
|
response = self._invoke_agent_with_temperature_fallback(prompt_payload)
|
|
|
|
parsed = _parse_json(str(response))
|
|
|
|
parsed = _parse_json(str(response))
|
|
|
|
label = str(parsed.get("label", "OTHER")).upper().strip()
|
|
|
|
label = str(parsed.get("label", "OTHER")).upper().strip()
|
|
|
|
if label not in ALLOWED_LABELS:
|
|
|
|
if label not in ALLOWED_LABELS:
|
|
|
|
@ -86,6 +83,41 @@ class StrandsEmailClassifier:
|
|
|
|
reason = str(parsed.get("reason", "")).strip()
|
|
|
|
reason = str(parsed.get("reason", "")).strip()
|
|
|
|
return EmailClassification(label=label, confidence=confidence, reason=reason)
|
|
|
|
return EmailClassification(label=label, confidence=confidence, reason=reason)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _invoke_agent_with_temperature_fallback(self, prompt_payload: dict) -> object:
|
|
|
|
|
|
|
|
prompt = json.dumps(prompt_payload, ensure_ascii=True)
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
return self.agent(prompt)
|
|
|
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
|
|
|
if self._temperature_enabled and _is_temperature_unsupported(exc):
|
|
|
|
|
|
|
|
logger.warning(
|
|
|
|
|
|
|
|
"Model '%s' rejected temperature=%s; retrying without temperature.",
|
|
|
|
|
|
|
|
self._model_id,
|
|
|
|
|
|
|
|
self._temperature,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
self._temperature_enabled = False
|
|
|
|
|
|
|
|
self.agent = self._build_agent(include_temperature=False)
|
|
|
|
|
|
|
|
return self.agent(prompt)
|
|
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_agent(self, *, include_temperature: bool) -> Agent:
|
|
|
|
|
|
|
|
client_args = {"api_key": self._api_key, "timeout": self._timeout_seconds}
|
|
|
|
|
|
|
|
if self._base_url:
|
|
|
|
|
|
|
|
client_args["base_url"] = self._base_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params: dict[str, float] | None = None
|
|
|
|
|
|
|
|
if include_temperature:
|
|
|
|
|
|
|
|
params = {"temperature": self._temperature}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_kwargs = {
|
|
|
|
|
|
|
|
"client_args": client_args,
|
|
|
|
|
|
|
|
"model_id": self._model_id,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if params is not None:
|
|
|
|
|
|
|
|
model_kwargs["params"] = params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = OpenAIModel(**model_kwargs)
|
|
|
|
|
|
|
|
return Agent(model=model, system_prompt=SYSTEM_PROMPT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_json(content: str) -> dict:
|
|
|
|
def _parse_json(content: str) -> dict:
|
|
|
|
if not content:
|
|
|
|
if not content:
|
|
|
|
@ -112,3 +144,8 @@ def _to_confidence(raw_value: object) -> float:
|
|
|
|
if confidence > 1:
|
|
|
|
if confidence > 1:
|
|
|
|
return 1.0
|
|
|
|
return 1.0
|
|
|
|
return confidence
|
|
|
|
return confidence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_temperature_unsupported(exc: Exception) -> bool:
|
|
|
|
|
|
|
|
message = str(exc).lower()
|
|
|
|
|
|
|
|
return "temperature" in message and "unsupported" in message
|
|
|
|
|