//! Google Gemini LLM provider implementation. //! //! Implements the `LlmProvider` trait using the Gemini REST API. //! Supports both web search grounding (Pass 1) and plain structured //! output (Pass 2) via the `generateContent` endpoint. use async_trait::async_trait; use serde_json::Value; use super::LlmProvider; use crate::errors::AppError; /// Google Gemini provider. /// /// Holds the API key and an HTTP client for making requests /// to the Gemini `generateContent` API. pub struct GeminiProvider { api_key: String, http_client: reqwest::Client, } impl GeminiProvider { /// Create a new Gemini provider with the given API key and HTTP client. pub fn new(api_key: String, http_client: reqwest::Client) -> Self { Self { api_key, http_client, } } /// Build the Gemini API URL for a given model. fn api_url(&self, model: &str) -> String { format!( "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", model, self.api_key ) } /// Execute a generateContent request and parse the response. async fn generate_content(&self, model: &str, body: &Value) -> Result { let url = self.api_url(model); let response = self .http_client .post(&url) .json(body) .send() .await .map_err(|e| { // Log only the error kind, NOT the full Display (which includes the URL containing the API key) let kind = if e.is_timeout() { "timeout" } else if e.is_connect() { "connection error" } else { "network error" }; tracing::error!("Gemini API request failed: {}", kind); AppError::Internal(anyhow::anyhow!("Failed to connect to Gemini API")) })?; let status = response.status(); let response_body: Value = response.json().await.map_err(|e| { tracing::error!("Failed to parse Gemini response: {}", e); AppError::Internal(anyhow::anyhow!("Failed to parse Gemini API response")) })?; // Handle error responses if !status.is_success() { return Err(map_gemini_error(status.as_u16(), &response_body)); } // Extract the text content from the Gemini response structure extract_content(&response_body) } } #[async_trait] impl LlmProvider for GeminiProvider { fn provider_id(&self) -> &str { "gemini" } async fn generate_search_pass( &self, model: &str, system_prompt: &str, user_prompt: &str, response_schema: &Value, ) -> Result { let body = build_request_body( system_prompt, user_prompt, response_schema, true, // include googleSearch tool ); self.generate_content(model, &body).await } async fn generate_rewrite_pass( &self, model: &str, system_prompt: &str, user_prompt: &str, response_schema: &Value, ) -> Result { let body = build_request_body( system_prompt, user_prompt, response_schema, false, // no tools for rewrite ); self.generate_content(model, &body).await } fn supports_web_search(&self) -> bool { true } } /// Build the JSON request body for the Gemini `generateContent` endpoint. /// /// When `include_search` is true, the `googleSearch` tool is included /// to enable web search grounding (Pass 1). fn build_request_body( system_prompt: &str, user_prompt: &str, response_schema: &Value, include_search: bool, ) -> Value { let mut body = serde_json::json!({ "contents": [{ "role": "user", "parts": [{ "text": user_prompt }] }], "systemInstruction": { "parts": [{ "text": system_prompt }] }, "generationConfig": { "responseMimeType": "application/json", "responseSchema": response_schema, "maxOutputTokens": 16384 } }); if include_search { body["tools"] = serde_json::json!([{ "googleSearch": {} }]); } body } /// Extract the text content from a Gemini API response. /// /// The response structure is: /// ```json /// { "candidates": [{ "content": { "parts": [{ "text": "..." }] } }] } /// ``` /// /// The text field contains a JSON string that we parse into a `Value`. fn extract_content(response: &Value) -> Result { let text = response .get("candidates") .and_then(|c| c.get(0)) .and_then(|c| c.get("content")) .and_then(|c| c.get("parts")) .and_then(|p| p.get(0)) .and_then(|p| p.get("text")) .and_then(|t| t.as_str()) .ok_or_else(|| { tracing::error!("Unexpected Gemini response structure: {:?}", response); AppError::Internal(anyhow::anyhow!( "Gemini API returned an unexpected response structure" )) })?; // The text content is a JSON string — parse it serde_json::from_str(text).map_err(|e| { tracing::error!("Failed to parse Gemini JSON output: {}", e); AppError::Internal(anyhow::anyhow!( "Gemini returned invalid JSON in structured output" )) }) } /// Map Gemini API error responses to appropriate `AppError` variants. /// /// Handles common error codes without exposing internal details. fn map_gemini_error(status: u16, body: &Value) -> AppError { let error_message = body .get("error") .and_then(|e| e.get("message")) .and_then(|m| m.as_str()) .unwrap_or("Unknown error"); let error_status = body .get("error") .and_then(|e| e.get("status")) .and_then(|s| s.as_str()) .unwrap_or(""); tracing::error!( "Gemini API error (HTTP {}): {} (status: {})", status, error_message, error_status ); match status { 400 => AppError::BadRequest("Invalid request to LLM provider".into()), 401 | 403 => AppError::BadRequest("Invalid or unauthorized API key".into()), 404 => AppError::BadRequest("Model not found or not available".into()), 429 => AppError::RateLimited("LLM provider rate limit exceeded. Please try again later.".into()), _ => AppError::Internal(anyhow::anyhow!("LLM provider returned an error")), } } #[cfg(test)] mod tests { use super::*; #[test] fn build_request_body_with_search() { let schema = serde_json::json!({ "type": "object", "properties": { "category_0": { "type": "array", "items": { "type": "object" } } } }); let body = build_request_body("system prompt", "user prompt", &schema, true); // Verify contents assert_eq!( body["contents"][0]["role"].as_str().unwrap(), "user" ); assert_eq!( body["contents"][0]["parts"][0]["text"].as_str().unwrap(), "user prompt" ); // Verify system instruction assert_eq!( body["systemInstruction"]["parts"][0]["text"].as_str().unwrap(), "system prompt" ); // Verify tools (googleSearch present) assert!(body["tools"].is_array()); assert!(body["tools"][0].get("googleSearch").is_some()); // Verify generation config assert_eq!( body["generationConfig"]["responseMimeType"].as_str().unwrap(), "application/json" ); assert!(body["generationConfig"]["responseSchema"].is_object()); } #[test] fn build_request_body_without_search() { let schema = serde_json::json!({"type": "object"}); let body = build_request_body("sys", "user", &schema, false); // No tools key when search is disabled assert!(body.get("tools").is_none()); } #[test] fn extract_content_valid_response() { let response = serde_json::json!({ "candidates": [{ "content": { "parts": [{ "text": "{\"category_0\": [{\"title\": \"Test\", \"url\": \"https://example.com\", \"summary\": \"A test article\"}]}" }] } }] }); let result = extract_content(&response).unwrap(); assert!(result["category_0"].is_array()); assert_eq!(result["category_0"][0]["title"].as_str().unwrap(), "Test"); } #[test] fn extract_content_missing_candidates() { let response = serde_json::json!({}); assert!(extract_content(&response).is_err()); } #[test] fn extract_content_empty_candidates() { let response = serde_json::json!({"candidates": []}); assert!(extract_content(&response).is_err()); } #[test] fn extract_content_invalid_json_text() { let response = serde_json::json!({ "candidates": [{ "content": { "parts": [{ "text": "this is not valid json" }] } }] }); assert!(extract_content(&response).is_err()); } #[test] fn map_gemini_error_invalid_key() { let body = serde_json::json!({ "error": { "code": 403, "message": "API key not valid", "status": "PERMISSION_DENIED" } }); let err = map_gemini_error(403, &body); match err { AppError::BadRequest(msg) => assert!(msg.contains("unauthorized")), _ => panic!("Expected BadRequest for 403"), } } #[test] fn map_gemini_error_rate_limited() { let body = serde_json::json!({ "error": { "code": 429, "message": "Resource exhausted", "status": "RESOURCE_EXHAUSTED" } }); let err = map_gemini_error(429, &body); match err { AppError::RateLimited(msg) => assert!(msg.contains("rate limit")), _ => panic!("Expected RateLimited for 429"), } } #[test] fn map_gemini_error_model_not_found() { let body = serde_json::json!({ "error": { "code": 404, "message": "Model not found", "status": "NOT_FOUND" } }); let err = map_gemini_error(404, &body); match err { AppError::BadRequest(msg) => assert!(msg.contains("not found")), _ => panic!("Expected BadRequest for 404"), } } #[test] fn map_gemini_error_server_error() { let body = serde_json::json!({ "error": { "code": 500, "message": "Internal error", "status": "INTERNAL" } }); let err = map_gemini_error(500, &body); match err { AppError::Internal(_) => {} // expected _ => panic!("Expected Internal for 500"), } } #[test] fn gemini_provider_supports_web_search() { let provider = GeminiProvider::new( "test-key".into(), reqwest::Client::new(), ); assert!(provider.supports_web_search()); assert_eq!(provider.provider_id(), "gemini"); } }