You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

397 lines
12 KiB
Rust

//! Google Gemini LLM provider implementation.
//!
//! Implements the `LlmProvider` trait using the Gemini REST API.
//! Supports both web search grounding (Pass 1) and plain structured
//! output (Pass 2) via the `generateContent` endpoint.
use async_trait::async_trait;
use serde_json::Value;
use super::LlmProvider;
use crate::errors::AppError;
/// Google Gemini provider.
///
/// Holds the API key and an HTTP client for making requests
/// to the Gemini `generateContent` API.
pub struct GeminiProvider {
api_key: String,
http_client: reqwest::Client,
}
impl GeminiProvider {
/// Create a new Gemini provider with the given API key and HTTP client.
pub fn new(api_key: String, http_client: reqwest::Client) -> Self {
Self {
api_key,
http_client,
}
}
/// Build the Gemini API URL for a given model.
fn api_url(&self, model: &str) -> String {
format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
model, self.api_key
)
}
/// Execute a generateContent request and parse the response.
async fn generate_content(&self, model: &str, body: &Value) -> Result<Value, AppError> {
let url = self.api_url(model);
let response = self
.http_client
.post(&url)
.json(body)
.send()
.await
.map_err(|e| {
// Log only the error kind, NOT the full Display (which includes the URL containing the API key)
let kind = if e.is_timeout() { "timeout" } else if e.is_connect() { "connection error" } else { "network error" };
tracing::error!("Gemini API request failed: {}", kind);
AppError::Internal(anyhow::anyhow!("Failed to connect to Gemini API"))
})?;
let status = response.status();
let response_body: Value = response.json().await.map_err(|e| {
tracing::error!("Failed to parse Gemini response: {}", e);
AppError::Internal(anyhow::anyhow!("Failed to parse Gemini API response"))
})?;
// Handle error responses
if !status.is_success() {
return Err(map_gemini_error(status.as_u16(), &response_body));
}
// Extract the text content from the Gemini response structure
extract_content(&response_body)
}
}
#[async_trait]
impl LlmProvider for GeminiProvider {
fn provider_id(&self) -> &str {
"gemini"
}
async fn generate_search_pass(
&self,
model: &str,
system_prompt: &str,
user_prompt: &str,
response_schema: &Value,
) -> Result<Value, AppError> {
let body = build_request_body(
system_prompt,
user_prompt,
response_schema,
true, // include googleSearch tool
);
self.generate_content(model, &body).await
}
async fn generate_rewrite_pass(
&self,
model: &str,
system_prompt: &str,
user_prompt: &str,
response_schema: &Value,
) -> Result<Value, AppError> {
let body = build_request_body(
system_prompt,
user_prompt,
response_schema,
false, // no tools for rewrite
);
self.generate_content(model, &body).await
}
fn supports_web_search(&self) -> bool {
true
}
}
/// Build the JSON request body for the Gemini `generateContent` endpoint.
///
/// When `include_search` is true, the `googleSearch` tool is included
/// to enable web search grounding (Pass 1).
fn build_request_body(
system_prompt: &str,
user_prompt: &str,
response_schema: &Value,
include_search: bool,
) -> Value {
let mut body = serde_json::json!({
"contents": [{
"role": "user",
"parts": [{
"text": user_prompt
}]
}],
"systemInstruction": {
"parts": [{
"text": system_prompt
}]
},
"generationConfig": {
"responseMimeType": "application/json",
"responseSchema": response_schema,
"maxOutputTokens": 16384
}
});
if include_search {
body["tools"] = serde_json::json!([{
"googleSearch": {}
}]);
}
body
}
/// Extract the text content from a Gemini API response.
///
/// The response structure is:
/// ```json
/// { "candidates": [{ "content": { "parts": [{ "text": "..." }] } }] }
/// ```
///
/// The text field contains a JSON string that we parse into a `Value`.
fn extract_content(response: &Value) -> Result<Value, AppError> {
let text = response
.get("candidates")
.and_then(|c| c.get(0))
.and_then(|c| c.get("content"))
.and_then(|c| c.get("parts"))
.and_then(|p| p.get(0))
.and_then(|p| p.get("text"))
.and_then(|t| t.as_str())
.ok_or_else(|| {
tracing::error!("Unexpected Gemini response structure: {:?}", response);
AppError::Internal(anyhow::anyhow!(
"Gemini API returned an unexpected response structure"
))
})?;
// The text content is a JSON string — parse it
serde_json::from_str(text).map_err(|e| {
tracing::error!("Failed to parse Gemini JSON output: {}", e);
AppError::Internal(anyhow::anyhow!(
"Gemini returned invalid JSON in structured output"
))
})
}
/// Map Gemini API error responses to appropriate `AppError` variants.
///
/// Handles common error codes without exposing internal details.
fn map_gemini_error(status: u16, body: &Value) -> AppError {
let error_message = body
.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
let error_status = body
.get("error")
.and_then(|e| e.get("status"))
.and_then(|s| s.as_str())
.unwrap_or("");
tracing::error!(
"Gemini API error (HTTP {}): {} (status: {})",
status,
error_message,
error_status
);
match status {
400 => AppError::BadRequest("Invalid request to LLM provider".into()),
401 | 403 => AppError::BadRequest("Invalid or unauthorized API key".into()),
404 => AppError::BadRequest("Model not found or not available".into()),
429 => AppError::RateLimited("LLM provider rate limit exceeded. Please try again later.".into()),
_ => AppError::Internal(anyhow::anyhow!("LLM provider returned an error")),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_request_body_with_search() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"category_0": {
"type": "array",
"items": { "type": "object" }
}
}
});
let body = build_request_body("system prompt", "user prompt", &schema, true);
// Verify contents
assert_eq!(
body["contents"][0]["role"].as_str().unwrap(),
"user"
);
assert_eq!(
body["contents"][0]["parts"][0]["text"].as_str().unwrap(),
"user prompt"
);
// Verify system instruction
assert_eq!(
body["systemInstruction"]["parts"][0]["text"].as_str().unwrap(),
"system prompt"
);
// Verify tools (googleSearch present)
assert!(body["tools"].is_array());
assert!(body["tools"][0].get("googleSearch").is_some());
// Verify generation config
assert_eq!(
body["generationConfig"]["responseMimeType"].as_str().unwrap(),
"application/json"
);
assert!(body["generationConfig"]["responseSchema"].is_object());
}
#[test]
fn build_request_body_without_search() {
let schema = serde_json::json!({"type": "object"});
let body = build_request_body("sys", "user", &schema, false);
// No tools key when search is disabled
assert!(body.get("tools").is_none());
}
#[test]
fn extract_content_valid_response() {
let response = serde_json::json!({
"candidates": [{
"content": {
"parts": [{
"text": "{\"category_0\": [{\"title\": \"Test\", \"url\": \"https://example.com\", \"summary\": \"A test article\"}]}"
}]
}
}]
});
let result = extract_content(&response).unwrap();
assert!(result["category_0"].is_array());
assert_eq!(result["category_0"][0]["title"].as_str().unwrap(), "Test");
}
#[test]
fn extract_content_missing_candidates() {
let response = serde_json::json!({});
assert!(extract_content(&response).is_err());
}
#[test]
fn extract_content_empty_candidates() {
let response = serde_json::json!({"candidates": []});
assert!(extract_content(&response).is_err());
}
#[test]
fn extract_content_invalid_json_text() {
let response = serde_json::json!({
"candidates": [{
"content": {
"parts": [{
"text": "this is not valid json"
}]
}
}]
});
assert!(extract_content(&response).is_err());
}
#[test]
fn map_gemini_error_invalid_key() {
let body = serde_json::json!({
"error": {
"code": 403,
"message": "API key not valid",
"status": "PERMISSION_DENIED"
}
});
let err = map_gemini_error(403, &body);
match err {
AppError::BadRequest(msg) => assert!(msg.contains("unauthorized")),
_ => panic!("Expected BadRequest for 403"),
}
}
#[test]
fn map_gemini_error_rate_limited() {
let body = serde_json::json!({
"error": {
"code": 429,
"message": "Resource exhausted",
"status": "RESOURCE_EXHAUSTED"
}
});
let err = map_gemini_error(429, &body);
match err {
AppError::RateLimited(msg) => assert!(msg.contains("rate limit")),
_ => panic!("Expected RateLimited for 429"),
}
}
#[test]
fn map_gemini_error_model_not_found() {
let body = serde_json::json!({
"error": {
"code": 404,
"message": "Model not found",
"status": "NOT_FOUND"
}
});
let err = map_gemini_error(404, &body);
match err {
AppError::BadRequest(msg) => assert!(msg.contains("not found")),
_ => panic!("Expected BadRequest for 404"),
}
}
#[test]
fn map_gemini_error_server_error() {
let body = serde_json::json!({
"error": {
"code": 500,
"message": "Internal error",
"status": "INTERNAL"
}
});
let err = map_gemini_error(500, &body);
match err {
AppError::Internal(_) => {} // expected
_ => panic!("Expected Internal for 500"),
}
}
#[test]
fn gemini_provider_supports_web_search() {
let provider = GeminiProvider::new(
"test-key".into(),
reqwest::Client::new(),
);
assert!(provider.supports_web_search());
assert_eq!(provider.provider_id(), "gemini");
}
}