You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
578 lines
18 KiB
Rust
578 lines
18 KiB
Rust
//! OpenAI LLM provider implementation.
|
|
//!
|
|
//! Implements the `LlmProvider` trait using two OpenAI APIs:
|
|
//! - **Pass 1 (search)**: Responses API (`/v1/responses`) with `web_search_preview` tool
|
|
//! - **Pass 2 (rewrite)**: Chat Completions API (`/v1/chat/completions`) with structured output
|
|
|
|
use async_trait::async_trait;
|
|
use serde_json::Value;
|
|
|
|
use super::LlmProvider;
|
|
use crate::errors::AppError;
|
|
|
|
/// OpenAI provider.
|
|
///
|
|
/// Holds the API key and an HTTP client for making requests
|
|
/// to the OpenAI Responses and Chat Completions APIs.
|
|
pub struct OpenAiProvider {
|
|
api_key: String,
|
|
http_client: reqwest::Client,
|
|
}
|
|
|
|
impl OpenAiProvider {
|
|
/// Create a new OpenAI provider with the given API key and HTTP client.
|
|
pub fn new(api_key: String, http_client: reqwest::Client) -> Self {
|
|
Self {
|
|
api_key,
|
|
http_client,
|
|
}
|
|
}
|
|
|
|
/// Execute a request to the OpenAI Responses API (Pass 1).
|
|
///
|
|
/// Uses the Responses API with `web_search_preview` tool for grounded search results
|
|
/// and structured output via `json_schema` text format.
|
|
async fn call_responses_api(
|
|
&self,
|
|
model: &str,
|
|
system_prompt: &str,
|
|
user_prompt: &str,
|
|
response_schema: &Value,
|
|
include_web_search: bool,
|
|
) -> Result<Value, AppError> {
|
|
let mut body = serde_json::json!({
|
|
"model": model,
|
|
"instructions": system_prompt,
|
|
"input": user_prompt,
|
|
"text": {
|
|
"format": {
|
|
"type": "json_schema",
|
|
"name": "synthesis",
|
|
"strict": true,
|
|
"schema": response_schema
|
|
}
|
|
}
|
|
});
|
|
|
|
if include_web_search {
|
|
body["tools"] = serde_json::json!([{
|
|
"type": "web_search_preview"
|
|
}]);
|
|
}
|
|
|
|
let response = self
|
|
.http_client
|
|
.post("https://api.openai.com/v1/responses")
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("Content-Type", "application/json")
|
|
.json(&body)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
let kind = if e.is_timeout() {
|
|
"timeout"
|
|
} else if e.is_connect() {
|
|
"connection error"
|
|
} else {
|
|
"network error"
|
|
};
|
|
tracing::error!("OpenAI Responses API request failed: {}", kind);
|
|
AppError::Internal(anyhow::anyhow!("Failed to connect to OpenAI API"))
|
|
})?;
|
|
|
|
let status = response.status();
|
|
let response_body: Value = response.json().await.map_err(|e| {
|
|
tracing::error!("Failed to parse OpenAI response body: {}", e);
|
|
AppError::Internal(anyhow::anyhow!("Failed to parse OpenAI API response"))
|
|
})?;
|
|
|
|
if !status.is_success() {
|
|
return Err(map_openai_error(status.as_u16(), &response_body));
|
|
}
|
|
|
|
extract_responses_api_content(&response_body)
|
|
}
|
|
|
|
/// Execute a request to the OpenAI Chat Completions API (Pass 2).
|
|
///
|
|
/// Uses the Chat Completions API with `json_schema` response format
|
|
/// for structured output without web search.
|
|
async fn call_chat_completions_api(
|
|
&self,
|
|
model: &str,
|
|
system_prompt: &str,
|
|
user_prompt: &str,
|
|
response_schema: &Value,
|
|
) -> Result<Value, AppError> {
|
|
let body = serde_json::json!({
|
|
"model": model,
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": user_prompt
|
|
}
|
|
],
|
|
"response_format": {
|
|
"type": "json_schema",
|
|
"json_schema": {
|
|
"name": "synthesis",
|
|
"strict": true,
|
|
"schema": response_schema
|
|
}
|
|
}
|
|
});
|
|
|
|
let response = self
|
|
.http_client
|
|
.post("https://api.openai.com/v1/chat/completions")
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("Content-Type", "application/json")
|
|
.json(&body)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
let kind = if e.is_timeout() {
|
|
"timeout"
|
|
} else if e.is_connect() {
|
|
"connection error"
|
|
} else {
|
|
"network error"
|
|
};
|
|
tracing::error!("OpenAI Chat Completions API request failed: {}", kind);
|
|
AppError::Internal(anyhow::anyhow!("Failed to connect to OpenAI API"))
|
|
})?;
|
|
|
|
let status = response.status();
|
|
let response_body: Value = response.json().await.map_err(|e| {
|
|
tracing::error!("Failed to parse OpenAI response body: {}", e);
|
|
AppError::Internal(anyhow::anyhow!("Failed to parse OpenAI API response"))
|
|
})?;
|
|
|
|
if !status.is_success() {
|
|
return Err(map_openai_error(status.as_u16(), &response_body));
|
|
}
|
|
|
|
extract_chat_completions_content(&response_body)
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LlmProvider for OpenAiProvider {
|
|
fn provider_id(&self) -> &str {
|
|
"openai"
|
|
}
|
|
|
|
async fn generate_search_pass(
|
|
&self,
|
|
model: &str,
|
|
system_prompt: &str,
|
|
user_prompt: &str,
|
|
response_schema: &Value,
|
|
) -> Result<Value, AppError> {
|
|
self.call_responses_api(model, system_prompt, user_prompt, response_schema, true)
|
|
.await
|
|
}
|
|
|
|
async fn generate_rewrite_pass(
|
|
&self,
|
|
model: &str,
|
|
system_prompt: &str,
|
|
user_prompt: &str,
|
|
response_schema: &Value,
|
|
) -> Result<Value, AppError> {
|
|
self.call_chat_completions_api(model, system_prompt, user_prompt, response_schema)
|
|
.await
|
|
}
|
|
|
|
fn supports_web_search(&self) -> bool {
|
|
true
|
|
}
|
|
}
|
|
|
|
/// Extract the text content from an OpenAI Responses API response.
|
|
///
|
|
/// The Responses API returns:
|
|
/// ```json
|
|
/// {
|
|
/// "output": [
|
|
/// { "type": "web_search_call", ... },
|
|
/// { "type": "message", "content": [{ "type": "output_text", "text": "..." }] }
|
|
/// ]
|
|
/// }
|
|
/// ```
|
|
///
|
|
/// We scan for `output_text` blocks and parse the first one as JSON.
|
|
fn extract_responses_api_content(response: &Value) -> Result<Value, AppError> {
|
|
let output = response
|
|
.get("output")
|
|
.and_then(|o| o.as_array())
|
|
.ok_or_else(|| {
|
|
tracing::error!(
|
|
"Unexpected OpenAI Responses API structure: missing 'output' array"
|
|
);
|
|
AppError::Internal(anyhow::anyhow!(
|
|
"OpenAI Responses API returned an unexpected response structure"
|
|
))
|
|
})?;
|
|
|
|
// Scan output items for type "message" with content containing "output_text"
|
|
for item in output {
|
|
if item.get("type").and_then(|t| t.as_str()) != Some("message") {
|
|
continue;
|
|
}
|
|
|
|
let content = match item.get("content").and_then(|c| c.as_array()) {
|
|
Some(c) => c,
|
|
None => continue,
|
|
};
|
|
|
|
for block in content {
|
|
if block.get("type").and_then(|t| t.as_str()) == Some("output_text") {
|
|
let text = block.get("text").and_then(|t| t.as_str()).ok_or_else(|| {
|
|
tracing::error!("OpenAI output_text block missing 'text' field");
|
|
AppError::Internal(anyhow::anyhow!(
|
|
"OpenAI Responses API returned output_text without text"
|
|
))
|
|
})?;
|
|
|
|
return serde_json::from_str(text).map_err(|e| {
|
|
tracing::error!("Failed to parse OpenAI JSON output: {}", e);
|
|
AppError::Internal(anyhow::anyhow!(
|
|
"OpenAI returned invalid JSON in structured output"
|
|
))
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
tracing::error!("No output_text found in OpenAI Responses API response");
|
|
Err(AppError::Internal(anyhow::anyhow!(
|
|
"OpenAI Responses API returned no text output"
|
|
)))
|
|
}
|
|
|
|
/// Extract the text content from an OpenAI Chat Completions API response.
|
|
///
|
|
/// The response structure is:
|
|
/// ```json
|
|
/// { "choices": [{ "message": { "content": "..." } }] }
|
|
/// ```
|
|
fn extract_chat_completions_content(response: &Value) -> Result<Value, AppError> {
|
|
let text = response
|
|
.get("choices")
|
|
.and_then(|c| c.get(0))
|
|
.and_then(|c| c.get("message"))
|
|
.and_then(|m| m.get("content"))
|
|
.and_then(|t| t.as_str())
|
|
.ok_or_else(|| {
|
|
tracing::error!("Unexpected OpenAI Chat Completions response structure");
|
|
AppError::Internal(anyhow::anyhow!(
|
|
"OpenAI Chat Completions API returned an unexpected response structure"
|
|
))
|
|
})?;
|
|
|
|
serde_json::from_str(text).map_err(|e| {
|
|
tracing::error!("Failed to parse OpenAI Chat Completions JSON output: {}", e);
|
|
AppError::Internal(anyhow::anyhow!(
|
|
"OpenAI returned invalid JSON in structured output"
|
|
))
|
|
})
|
|
}
|
|
|
|
/// Map OpenAI API error responses to appropriate `AppError` variants.
|
|
///
|
|
/// Handles common error codes without exposing internal details.
|
|
fn map_openai_error(status: u16, body: &Value) -> AppError {
|
|
let error_message = body
|
|
.get("error")
|
|
.and_then(|e| e.get("message"))
|
|
.and_then(|m| m.as_str())
|
|
.unwrap_or("Unknown error");
|
|
|
|
let error_type = body
|
|
.get("error")
|
|
.and_then(|e| e.get("type"))
|
|
.and_then(|t| t.as_str())
|
|
.unwrap_or("");
|
|
|
|
// Log error details but NEVER the API key
|
|
tracing::error!(
|
|
"OpenAI API error (HTTP {}): {} (type: {})",
|
|
status,
|
|
error_message,
|
|
error_type
|
|
);
|
|
|
|
match status {
|
|
400 => AppError::BadRequest("Invalid request to LLM provider".into()),
|
|
401 => AppError::BadRequest("Invalid or unauthorized API key".into()),
|
|
403 => AppError::BadRequest("Access denied by LLM provider".into()),
|
|
404 => AppError::BadRequest("Model not found or not available".into()),
|
|
429 => AppError::RateLimited(
|
|
"LLM provider rate limit exceeded. Please try again later.".into(),
|
|
),
|
|
_ => AppError::Internal(anyhow::anyhow!("LLM provider returned an error")),
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// ── Request body tests ──────────────────────────────────────
|
|
|
|
#[test]
|
|
fn openai_provider_metadata() {
|
|
let provider = OpenAiProvider::new("test-key".into(), reqwest::Client::new());
|
|
assert_eq!(provider.provider_id(), "openai");
|
|
assert!(provider.supports_web_search());
|
|
}
|
|
|
|
// ── Responses API response parsing ──────────────────────────
|
|
|
|
#[test]
|
|
fn extract_responses_api_content_valid() {
|
|
let response = serde_json::json!({
|
|
"id": "resp_123",
|
|
"output": [
|
|
{
|
|
"type": "web_search_call",
|
|
"id": "ws_1",
|
|
"status": "completed"
|
|
},
|
|
{
|
|
"type": "message",
|
|
"content": [
|
|
{
|
|
"type": "output_text",
|
|
"text": "{\"category_0\": [{\"title\": \"Test\", \"url\": \"https://example.com\", \"summary\": \"A test article\"}]}"
|
|
}
|
|
]
|
|
}
|
|
]
|
|
});
|
|
|
|
let result = extract_responses_api_content(&response).unwrap();
|
|
assert!(result["category_0"].is_array());
|
|
assert_eq!(result["category_0"][0]["title"].as_str().unwrap(), "Test");
|
|
}
|
|
|
|
#[test]
|
|
fn extract_responses_api_content_multiple_output_items() {
|
|
// Multiple web_search_call items before the message
|
|
let response = serde_json::json!({
|
|
"output": [
|
|
{ "type": "web_search_call", "id": "ws_1", "status": "completed" },
|
|
{ "type": "web_search_call", "id": "ws_2", "status": "completed" },
|
|
{
|
|
"type": "message",
|
|
"content": [
|
|
{
|
|
"type": "output_text",
|
|
"text": "{\"category_0\": []}"
|
|
}
|
|
]
|
|
}
|
|
]
|
|
});
|
|
|
|
let result = extract_responses_api_content(&response).unwrap();
|
|
assert!(result["category_0"].is_array());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_responses_api_content_missing_output() {
|
|
let response = serde_json::json!({});
|
|
assert!(extract_responses_api_content(&response).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_responses_api_content_no_message_item() {
|
|
let response = serde_json::json!({
|
|
"output": [
|
|
{ "type": "web_search_call", "id": "ws_1", "status": "completed" }
|
|
]
|
|
});
|
|
assert!(extract_responses_api_content(&response).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_responses_api_content_invalid_json_text() {
|
|
let response = serde_json::json!({
|
|
"output": [
|
|
{
|
|
"type": "message",
|
|
"content": [
|
|
{
|
|
"type": "output_text",
|
|
"text": "not valid json"
|
|
}
|
|
]
|
|
}
|
|
]
|
|
});
|
|
assert!(extract_responses_api_content(&response).is_err());
|
|
}
|
|
|
|
// ── Chat Completions response parsing ───────────────────────
|
|
|
|
#[test]
|
|
fn extract_chat_completions_content_valid() {
|
|
let response = serde_json::json!({
|
|
"choices": [{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "{\"category_0\": [{\"title\": \"Rewritten\", \"url\": \"https://example.com\", \"summary\": \"Rewritten summary\"}]}"
|
|
},
|
|
"finish_reason": "stop"
|
|
}]
|
|
});
|
|
|
|
let result = extract_chat_completions_content(&response).unwrap();
|
|
assert!(result["category_0"].is_array());
|
|
assert_eq!(
|
|
result["category_0"][0]["title"].as_str().unwrap(),
|
|
"Rewritten"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn extract_chat_completions_content_missing_choices() {
|
|
let response = serde_json::json!({});
|
|
assert!(extract_chat_completions_content(&response).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_chat_completions_content_empty_choices() {
|
|
let response = serde_json::json!({"choices": []});
|
|
assert!(extract_chat_completions_content(&response).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_chat_completions_content_invalid_json() {
|
|
let response = serde_json::json!({
|
|
"choices": [{
|
|
"message": {
|
|
"content": "this is not json"
|
|
}
|
|
}]
|
|
});
|
|
assert!(extract_chat_completions_content(&response).is_err());
|
|
}
|
|
|
|
// ── Error mapping tests ─────────────────────────────────────
|
|
|
|
#[test]
|
|
fn map_openai_error_invalid_key() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"message": "Incorrect API key provided",
|
|
"type": "invalid_request_error",
|
|
"code": "invalid_api_key"
|
|
}
|
|
});
|
|
|
|
let err = map_openai_error(401, &body);
|
|
match err {
|
|
AppError::BadRequest(msg) => assert!(msg.contains("unauthorized")),
|
|
_ => panic!("Expected BadRequest for 401"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_openai_error_rate_limited() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"message": "Rate limit reached for model gpt-4o",
|
|
"type": "rate_limit_error",
|
|
"code": "rate_limit_exceeded"
|
|
}
|
|
});
|
|
|
|
let err = map_openai_error(429, &body);
|
|
match err {
|
|
AppError::RateLimited(msg) => assert!(msg.contains("rate limit")),
|
|
_ => panic!("Expected RateLimited for 429"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_openai_error_bad_request() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"message": "Invalid model specified",
|
|
"type": "invalid_request_error"
|
|
}
|
|
});
|
|
|
|
let err = map_openai_error(400, &body);
|
|
match err {
|
|
AppError::BadRequest(msg) => assert!(msg.contains("Invalid request")),
|
|
_ => panic!("Expected BadRequest for 400"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_openai_error_model_not_found() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"message": "The model does not exist",
|
|
"type": "invalid_request_error"
|
|
}
|
|
});
|
|
|
|
let err = map_openai_error(404, &body);
|
|
match err {
|
|
AppError::BadRequest(msg) => assert!(msg.contains("not found")),
|
|
_ => panic!("Expected BadRequest for 404"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_openai_error_server_error() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"message": "Internal server error",
|
|
"type": "server_error"
|
|
}
|
|
});
|
|
|
|
let err = map_openai_error(500, &body);
|
|
match err {
|
|
AppError::Internal(_) => {} // expected
|
|
_ => panic!("Expected Internal for 500"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_openai_error_forbidden() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"message": "Access denied",
|
|
"type": "forbidden"
|
|
}
|
|
});
|
|
|
|
let err = map_openai_error(403, &body);
|
|
match err {
|
|
AppError::BadRequest(msg) => assert!(msg.contains("Access denied")),
|
|
_ => panic!("Expected BadRequest for 403"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_openai_error_unknown_body() {
|
|
// Sometimes the body may lack the standard error structure
|
|
let body = serde_json::json!({});
|
|
let err = map_openai_error(502, &body);
|
|
match err {
|
|
AppError::Internal(_) => {} // expected
|
|
_ => panic!("Expected Internal for 502"),
|
|
}
|
|
}
|
|
}
|