You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
397 lines
12 KiB
Rust
397 lines
12 KiB
Rust
//! Google Gemini LLM provider implementation.
|
|
//!
|
|
//! Implements the `LlmProvider` trait using the Gemini REST API.
|
|
//! Supports both web search grounding (Pass 1) and plain structured
|
|
//! output (Pass 2) via the `generateContent` endpoint.
|
|
|
|
use async_trait::async_trait;
|
|
use serde_json::Value;
|
|
|
|
use super::LlmProvider;
|
|
use crate::errors::AppError;
|
|
|
|
/// Google Gemini provider.
|
|
///
|
|
/// Holds the API key and an HTTP client for making requests
|
|
/// to the Gemini `generateContent` API.
|
|
pub struct GeminiProvider {
|
|
api_key: String,
|
|
http_client: reqwest::Client,
|
|
}
|
|
|
|
impl GeminiProvider {
|
|
/// Create a new Gemini provider with the given API key and HTTP client.
|
|
pub fn new(api_key: String, http_client: reqwest::Client) -> Self {
|
|
Self {
|
|
api_key,
|
|
http_client,
|
|
}
|
|
}
|
|
|
|
/// Build the Gemini API URL for a given model.
|
|
fn api_url(&self, model: &str) -> String {
|
|
format!(
|
|
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
|
|
model, self.api_key
|
|
)
|
|
}
|
|
|
|
/// Execute a generateContent request and parse the response.
|
|
async fn generate_content(&self, model: &str, body: &Value) -> Result<Value, AppError> {
|
|
let url = self.api_url(model);
|
|
|
|
let response = self
|
|
.http_client
|
|
.post(&url)
|
|
.json(body)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
// Log only the error kind, NOT the full Display (which includes the URL containing the API key)
|
|
let kind = if e.is_timeout() { "timeout" } else if e.is_connect() { "connection error" } else { "network error" };
|
|
tracing::error!("Gemini API request failed: {}", kind);
|
|
AppError::Internal(anyhow::anyhow!("Failed to connect to Gemini API"))
|
|
})?;
|
|
|
|
let status = response.status();
|
|
let response_body: Value = response.json().await.map_err(|e| {
|
|
tracing::error!("Failed to parse Gemini response: {}", e);
|
|
AppError::Internal(anyhow::anyhow!("Failed to parse Gemini API response"))
|
|
})?;
|
|
|
|
// Handle error responses
|
|
if !status.is_success() {
|
|
return Err(map_gemini_error(status.as_u16(), &response_body));
|
|
}
|
|
|
|
// Extract the text content from the Gemini response structure
|
|
extract_content(&response_body)
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LlmProvider for GeminiProvider {
|
|
fn provider_id(&self) -> &str {
|
|
"gemini"
|
|
}
|
|
|
|
async fn generate_search_pass(
|
|
&self,
|
|
model: &str,
|
|
system_prompt: &str,
|
|
user_prompt: &str,
|
|
response_schema: &Value,
|
|
) -> Result<Value, AppError> {
|
|
let body = build_request_body(
|
|
system_prompt,
|
|
user_prompt,
|
|
response_schema,
|
|
true, // include googleSearch tool
|
|
);
|
|
|
|
self.generate_content(model, &body).await
|
|
}
|
|
|
|
async fn generate_rewrite_pass(
|
|
&self,
|
|
model: &str,
|
|
system_prompt: &str,
|
|
user_prompt: &str,
|
|
response_schema: &Value,
|
|
) -> Result<Value, AppError> {
|
|
let body = build_request_body(
|
|
system_prompt,
|
|
user_prompt,
|
|
response_schema,
|
|
false, // no tools for rewrite
|
|
);
|
|
|
|
self.generate_content(model, &body).await
|
|
}
|
|
|
|
fn supports_web_search(&self) -> bool {
|
|
true
|
|
}
|
|
}
|
|
|
|
/// Build the JSON request body for the Gemini `generateContent` endpoint.
|
|
///
|
|
/// When `include_search` is true, the `googleSearch` tool is included
|
|
/// to enable web search grounding (Pass 1).
|
|
fn build_request_body(
|
|
system_prompt: &str,
|
|
user_prompt: &str,
|
|
response_schema: &Value,
|
|
include_search: bool,
|
|
) -> Value {
|
|
let mut body = serde_json::json!({
|
|
"contents": [{
|
|
"role": "user",
|
|
"parts": [{
|
|
"text": user_prompt
|
|
}]
|
|
}],
|
|
"systemInstruction": {
|
|
"parts": [{
|
|
"text": system_prompt
|
|
}]
|
|
},
|
|
"generationConfig": {
|
|
"responseMimeType": "application/json",
|
|
"responseSchema": response_schema,
|
|
"maxOutputTokens": 16384
|
|
}
|
|
});
|
|
|
|
if include_search {
|
|
body["tools"] = serde_json::json!([{
|
|
"googleSearch": {}
|
|
}]);
|
|
}
|
|
|
|
body
|
|
}
|
|
|
|
/// Extract the text content from a Gemini API response.
|
|
///
|
|
/// The response structure is:
|
|
/// ```json
|
|
/// { "candidates": [{ "content": { "parts": [{ "text": "..." }] } }] }
|
|
/// ```
|
|
///
|
|
/// The text field contains a JSON string that we parse into a `Value`.
|
|
fn extract_content(response: &Value) -> Result<Value, AppError> {
|
|
let text = response
|
|
.get("candidates")
|
|
.and_then(|c| c.get(0))
|
|
.and_then(|c| c.get("content"))
|
|
.and_then(|c| c.get("parts"))
|
|
.and_then(|p| p.get(0))
|
|
.and_then(|p| p.get("text"))
|
|
.and_then(|t| t.as_str())
|
|
.ok_or_else(|| {
|
|
tracing::error!("Unexpected Gemini response structure: {:?}", response);
|
|
AppError::Internal(anyhow::anyhow!(
|
|
"Gemini API returned an unexpected response structure"
|
|
))
|
|
})?;
|
|
|
|
// The text content is a JSON string — parse it
|
|
serde_json::from_str(text).map_err(|e| {
|
|
tracing::error!("Failed to parse Gemini JSON output: {}", e);
|
|
AppError::Internal(anyhow::anyhow!(
|
|
"Gemini returned invalid JSON in structured output"
|
|
))
|
|
})
|
|
}
|
|
|
|
/// Map Gemini API error responses to appropriate `AppError` variants.
|
|
///
|
|
/// Handles common error codes without exposing internal details.
|
|
fn map_gemini_error(status: u16, body: &Value) -> AppError {
|
|
let error_message = body
|
|
.get("error")
|
|
.and_then(|e| e.get("message"))
|
|
.and_then(|m| m.as_str())
|
|
.unwrap_or("Unknown error");
|
|
|
|
let error_status = body
|
|
.get("error")
|
|
.and_then(|e| e.get("status"))
|
|
.and_then(|s| s.as_str())
|
|
.unwrap_or("");
|
|
|
|
tracing::error!(
|
|
"Gemini API error (HTTP {}): {} (status: {})",
|
|
status,
|
|
error_message,
|
|
error_status
|
|
);
|
|
|
|
match status {
|
|
400 => AppError::BadRequest("Invalid request to LLM provider".into()),
|
|
401 | 403 => AppError::BadRequest("Invalid or unauthorized API key".into()),
|
|
404 => AppError::BadRequest("Model not found or not available".into()),
|
|
429 => AppError::RateLimited("LLM provider rate limit exceeded. Please try again later.".into()),
|
|
_ => AppError::Internal(anyhow::anyhow!("LLM provider returned an error")),
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn build_request_body_with_search() {
|
|
let schema = serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"category_0": {
|
|
"type": "array",
|
|
"items": { "type": "object" }
|
|
}
|
|
}
|
|
});
|
|
|
|
let body = build_request_body("system prompt", "user prompt", &schema, true);
|
|
|
|
// Verify contents
|
|
assert_eq!(
|
|
body["contents"][0]["role"].as_str().unwrap(),
|
|
"user"
|
|
);
|
|
assert_eq!(
|
|
body["contents"][0]["parts"][0]["text"].as_str().unwrap(),
|
|
"user prompt"
|
|
);
|
|
|
|
// Verify system instruction
|
|
assert_eq!(
|
|
body["systemInstruction"]["parts"][0]["text"].as_str().unwrap(),
|
|
"system prompt"
|
|
);
|
|
|
|
// Verify tools (googleSearch present)
|
|
assert!(body["tools"].is_array());
|
|
assert!(body["tools"][0].get("googleSearch").is_some());
|
|
|
|
// Verify generation config
|
|
assert_eq!(
|
|
body["generationConfig"]["responseMimeType"].as_str().unwrap(),
|
|
"application/json"
|
|
);
|
|
assert!(body["generationConfig"]["responseSchema"].is_object());
|
|
}
|
|
|
|
#[test]
|
|
fn build_request_body_without_search() {
|
|
let schema = serde_json::json!({"type": "object"});
|
|
let body = build_request_body("sys", "user", &schema, false);
|
|
|
|
// No tools key when search is disabled
|
|
assert!(body.get("tools").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_content_valid_response() {
|
|
let response = serde_json::json!({
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [{
|
|
"text": "{\"category_0\": [{\"title\": \"Test\", \"url\": \"https://example.com\", \"summary\": \"A test article\"}]}"
|
|
}]
|
|
}
|
|
}]
|
|
});
|
|
|
|
let result = extract_content(&response).unwrap();
|
|
assert!(result["category_0"].is_array());
|
|
assert_eq!(result["category_0"][0]["title"].as_str().unwrap(), "Test");
|
|
}
|
|
|
|
#[test]
|
|
fn extract_content_missing_candidates() {
|
|
let response = serde_json::json!({});
|
|
assert!(extract_content(&response).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_content_empty_candidates() {
|
|
let response = serde_json::json!({"candidates": []});
|
|
assert!(extract_content(&response).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_content_invalid_json_text() {
|
|
let response = serde_json::json!({
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [{
|
|
"text": "this is not valid json"
|
|
}]
|
|
}
|
|
}]
|
|
});
|
|
|
|
assert!(extract_content(&response).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn map_gemini_error_invalid_key() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"code": 403,
|
|
"message": "API key not valid",
|
|
"status": "PERMISSION_DENIED"
|
|
}
|
|
});
|
|
|
|
let err = map_gemini_error(403, &body);
|
|
match err {
|
|
AppError::BadRequest(msg) => assert!(msg.contains("unauthorized")),
|
|
_ => panic!("Expected BadRequest for 403"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_gemini_error_rate_limited() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"code": 429,
|
|
"message": "Resource exhausted",
|
|
"status": "RESOURCE_EXHAUSTED"
|
|
}
|
|
});
|
|
|
|
let err = map_gemini_error(429, &body);
|
|
match err {
|
|
AppError::RateLimited(msg) => assert!(msg.contains("rate limit")),
|
|
_ => panic!("Expected RateLimited for 429"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_gemini_error_model_not_found() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"code": 404,
|
|
"message": "Model not found",
|
|
"status": "NOT_FOUND"
|
|
}
|
|
});
|
|
|
|
let err = map_gemini_error(404, &body);
|
|
match err {
|
|
AppError::BadRequest(msg) => assert!(msg.contains("not found")),
|
|
_ => panic!("Expected BadRequest for 404"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_gemini_error_server_error() {
|
|
let body = serde_json::json!({
|
|
"error": {
|
|
"code": 500,
|
|
"message": "Internal error",
|
|
"status": "INTERNAL"
|
|
}
|
|
});
|
|
|
|
let err = map_gemini_error(500, &body);
|
|
match err {
|
|
AppError::Internal(_) => {} // expected
|
|
_ => panic!("Expected Internal for 500"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn gemini_provider_supports_web_search() {
|
|
let provider = GeminiProvider::new(
|
|
"test-key".into(),
|
|
reqwest::Client::new(),
|
|
);
|
|
assert!(provider.supports_web_search());
|
|
assert_eq!(provider.provider_id(), "gemini");
|
|
}
|
|
}
|