You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

400 lines
12 KiB
Rust

//! OpenAI LLM provider implementation.
//!
//! Implements the `LlmProvider` trait using the OpenAI Responses API (`/v1/responses`)
//! with structured JSON output via `json_schema` text format.
use async_trait::async_trait;
use serde_json::Value;
use super::LlmProvider;
use crate::errors::AppError;
/// OpenAI provider.
///
/// Holds the API key and an HTTP client for making requests
/// to the OpenAI Responses API.
pub struct OpenAiProvider {
api_key: String,
http_client: reqwest::Client,
}
impl OpenAiProvider {
/// Create a new OpenAI provider with the given API key and HTTP client.
pub fn new(api_key: String, http_client: reqwest::Client) -> Self {
Self {
api_key,
http_client,
}
}
}
#[async_trait]
impl LlmProvider for OpenAiProvider {
fn provider_id(&self) -> &str {
"openai"
}
async fn call_llm(
&self,
model: &str,
system_prompt: &str,
user_prompt: &str,
response_schema: &Value,
) -> Result<Value, AppError> {
let body = serde_json::json!({
"model": model,
"instructions": system_prompt,
"input": user_prompt,
"max_output_tokens": 16384,
"text": {
"format": {
"type": "json_schema",
"name": "synthesis",
"strict": true,
"schema": response_schema
}
}
});
let response = self
.http_client
.post("https://api.openai.com/v1/responses")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| {
let kind = if e.is_timeout() {
"timeout"
} else if e.is_connect() {
"connection error"
} else {
"network error"
};
tracing::error!("OpenAI Responses API request failed: {}", kind);
AppError::Internal(anyhow::anyhow!("Failed to connect to OpenAI API"))
})?;
let status = response.status();
let response_body: Value = response.json().await.map_err(|e| {
tracing::error!("Failed to parse OpenAI response body: {}", e);
AppError::Internal(anyhow::anyhow!("Failed to parse OpenAI API response"))
})?;
if !status.is_success() {
return Err(map_openai_error(status.as_u16(), &response_body));
}
extract_responses_api_content(&response_body)
}
}
/// Extract the text content from an OpenAI Responses API response.
///
/// The Responses API returns:
/// ```json
/// {
/// "output": [
/// { "type": "web_search_call", ... },
/// { "type": "message", "content": [{ "type": "output_text", "text": "..." }] }
/// ]
/// }
/// ```
///
/// We scan for `output_text` blocks and parse the first one as JSON.
fn extract_responses_api_content(response: &Value) -> Result<Value, AppError> {
let output = response
.get("output")
.and_then(|o| o.as_array())
.ok_or_else(|| {
tracing::error!(
"Unexpected OpenAI Responses API structure: missing 'output' array"
);
AppError::Internal(anyhow::anyhow!(
"OpenAI Responses API returned an unexpected response structure"
))
})?;
// Scan output items for type "message" with content containing "output_text"
for item in output {
if item.get("type").and_then(|t| t.as_str()) != Some("message") {
continue;
}
let content = match item.get("content").and_then(|c| c.as_array()) {
Some(c) => c,
None => continue,
};
for block in content {
if block.get("type").and_then(|t| t.as_str()) == Some("output_text") {
let text = block.get("text").and_then(|t| t.as_str()).ok_or_else(|| {
tracing::error!("OpenAI output_text block missing 'text' field");
AppError::Internal(anyhow::anyhow!(
"OpenAI Responses API returned output_text without text"
))
})?;
return serde_json::from_str(text).map_err(|e| {
tracing::error!("Failed to parse OpenAI JSON output: {}", e);
AppError::Internal(anyhow::anyhow!(
"OpenAI returned invalid JSON in structured output"
))
});
}
}
}
tracing::error!("No output_text found in OpenAI Responses API response");
Err(AppError::Internal(anyhow::anyhow!(
"OpenAI Responses API returned no text output"
)))
}
/// Map OpenAI API error responses to appropriate `AppError` variants.
///
/// Handles common error codes without exposing internal details.
fn map_openai_error(status: u16, body: &Value) -> AppError {
let error_message = body
.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
let error_type = body
.get("error")
.and_then(|e| e.get("type"))
.and_then(|t| t.as_str())
.unwrap_or("");
// Log error details but NEVER the API key
tracing::error!(
"OpenAI API error (HTTP {}): {} (type: {})",
status,
error_message,
error_type
);
match status {
400 => AppError::BadRequest("Invalid request to LLM provider".into()),
401 => AppError::BadRequest("Invalid or unauthorized API key".into()),
403 => AppError::BadRequest("Access denied by LLM provider".into()),
404 => AppError::BadRequest("Model not found or not available".into()),
429 => AppError::RateLimited(
"LLM provider rate limit exceeded. Please try again later.".into(),
),
_ => AppError::Internal(anyhow::anyhow!("LLM provider returned an error")),
}
}
#[cfg(test)]
mod tests {
use super::*;
// ── Provider metadata ───────────────────────────────────────
#[test]
fn openai_provider_metadata() {
let provider = OpenAiProvider::new("test-key".into(), reqwest::Client::new());
assert_eq!(provider.provider_id(), "openai");
}
// ── Responses API response parsing ──────────────────────────
#[test]
fn extract_responses_api_content_valid() {
let response = serde_json::json!({
"id": "resp_123",
"output": [
{
"type": "web_search_call",
"id": "ws_1",
"status": "completed"
},
{
"type": "message",
"content": [
{
"type": "output_text",
"text": "{\"category_0\": [{\"title\": \"Test\", \"url\": \"https://example.com\", \"summary\": \"A test article\"}]}"
}
]
}
]
});
let result = extract_responses_api_content(&response).unwrap();
assert!(result["category_0"].is_array());
assert_eq!(result["category_0"][0]["title"].as_str().unwrap(), "Test");
}
#[test]
fn extract_responses_api_content_multiple_output_items() {
// Multiple web_search_call items before the message
let response = serde_json::json!({
"output": [
{ "type": "web_search_call", "id": "ws_1", "status": "completed" },
{ "type": "web_search_call", "id": "ws_2", "status": "completed" },
{
"type": "message",
"content": [
{
"type": "output_text",
"text": "{\"category_0\": []}"
}
]
}
]
});
let result = extract_responses_api_content(&response).unwrap();
assert!(result["category_0"].is_array());
}
#[test]
fn extract_responses_api_content_missing_output() {
let response = serde_json::json!({});
assert!(extract_responses_api_content(&response).is_err());
}
#[test]
fn extract_responses_api_content_no_message_item() {
let response = serde_json::json!({
"output": [
{ "type": "web_search_call", "id": "ws_1", "status": "completed" }
]
});
assert!(extract_responses_api_content(&response).is_err());
}
#[test]
fn extract_responses_api_content_invalid_json_text() {
let response = serde_json::json!({
"output": [
{
"type": "message",
"content": [
{
"type": "output_text",
"text": "not valid json"
}
]
}
]
});
assert!(extract_responses_api_content(&response).is_err());
}
// ── Error mapping tests ─────────────────────────────────────
#[test]
fn map_openai_error_invalid_key() {
let body = serde_json::json!({
"error": {
"message": "Incorrect API key provided",
"type": "invalid_request_error",
"code": "invalid_api_key"
}
});
let err = map_openai_error(401, &body);
match err {
AppError::BadRequest(msg) => assert!(msg.contains("unauthorized")),
_ => panic!("Expected BadRequest for 401"),
}
}
#[test]
fn map_openai_error_rate_limited() {
let body = serde_json::json!({
"error": {
"message": "Rate limit reached for model gpt-4o",
"type": "rate_limit_error",
"code": "rate_limit_exceeded"
}
});
let err = map_openai_error(429, &body);
match err {
AppError::RateLimited(msg) => assert!(msg.contains("rate limit")),
_ => panic!("Expected RateLimited for 429"),
}
}
#[test]
fn map_openai_error_bad_request() {
let body = serde_json::json!({
"error": {
"message": "Invalid model specified",
"type": "invalid_request_error"
}
});
let err = map_openai_error(400, &body);
match err {
AppError::BadRequest(msg) => assert!(msg.contains("Invalid request")),
_ => panic!("Expected BadRequest for 400"),
}
}
#[test]
fn map_openai_error_model_not_found() {
let body = serde_json::json!({
"error": {
"message": "The model does not exist",
"type": "invalid_request_error"
}
});
let err = map_openai_error(404, &body);
match err {
AppError::BadRequest(msg) => assert!(msg.contains("not found")),
_ => panic!("Expected BadRequest for 404"),
}
}
#[test]
fn map_openai_error_server_error() {
let body = serde_json::json!({
"error": {
"message": "Internal server error",
"type": "server_error"
}
});
let err = map_openai_error(500, &body);
match err {
AppError::Internal(_) => {} // expected
_ => panic!("Expected Internal for 500"),
}
}
#[test]
fn map_openai_error_forbidden() {
let body = serde_json::json!({
"error": {
"message": "Access denied",
"type": "forbidden"
}
});
let err = map_openai_error(403, &body);
match err {
AppError::BadRequest(msg) => assert!(msg.contains("Access denied")),
_ => panic!("Expected BadRequest for 403"),
}
}
#[test]
fn map_openai_error_unknown_body() {
// Sometimes the body may lack the standard error structure
let body = serde_json::json!({});
let err = map_openai_error(502, &body);
match err {
AppError::Internal(_) => {} // expected
_ => panic!("Expected Internal for 502"),
}
}
}