You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

363 lines
11 KiB
Rust

//! Admin provider model and request/response types.
//!
//! Represents the admin-curated catalog of LLM providers and their models.
//! Users select from this catalog when configuring their settings.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A single model within a provider's model list (stored as JSONB).
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProviderModel {
pub model_id: String,
pub display_name: String,
#[serde(default)]
pub is_default: bool,
}
/// An admin provider record from the database.
#[derive(Debug, Clone, Serialize)]
pub struct AdminProvider {
pub id: Uuid,
pub provider_name: String,
pub display_name: String,
pub models_scraping: Vec<ProviderModel>,
pub models_websearch: Vec<ProviderModel>,
pub is_enabled: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// Request body for `POST /api/v1/admin/providers` (create or update).
#[derive(Debug, Deserialize)]
pub struct CreateProviderRequest {
pub provider_name: String,
pub display_name: String,
pub models_scraping: Vec<ProviderModel>,
pub models_websearch: Vec<ProviderModel>,
#[serde(default = "default_true")]
pub is_enabled: bool,
}
fn default_true() -> bool {
true
}
/// Known provider names.
///
/// Also used by `models::api_key` for validating user API key requests.
pub const VALID_PROVIDERS: &[&str] = &["gemini", "openai", "anthropic"];
/// Valid provider names for user API key storage (includes non-LLM services).
pub const VALID_API_KEY_PROVIDERS: &[&str] = &["gemini", "openai", "anthropic", "brave_search"];
impl CreateProviderRequest {
/// Validate the provider creation request.
///
/// Returns `Ok(())` if all fields are valid, or `Err(message)`
/// describing the first validation failure.
pub fn validate(&self) -> Result<(), String> {
let name = self.provider_name.trim();
if name.is_empty() {
return Err("Provider name cannot be empty".into());
}
if name.len() > 50 {
return Err("Provider name must be at most 50 characters".into());
}
if !VALID_PROVIDERS.contains(&name) {
return Err(format!(
"Invalid provider name '{}'. Must be one of: {}",
name,
VALID_PROVIDERS.join(", ")
));
}
validate_display_name(&self.display_name)?;
validate_models(&self.models_scraping)?;
validate_models(&self.models_websearch)?;
Ok(())
}
}
/// Request body for updating an existing provider.
#[derive(Debug, Deserialize)]
pub struct UpdateProviderRequest {
pub display_name: Option<String>,
pub models_scraping: Option<Vec<ProviderModel>>,
pub models_websearch: Option<Vec<ProviderModel>>,
pub is_enabled: Option<bool>,
}
impl UpdateProviderRequest {
/// Validate the provider update request.
pub fn validate(&self) -> Result<(), String> {
if let Some(ref display) = self.display_name {
validate_display_name(display)?;
}
if let Some(ref models) = self.models_scraping {
validate_models(models)?;
}
if let Some(ref models) = self.models_websearch {
validate_models(models)?;
}
Ok(())
}
}
/// Validate a provider display name.
fn validate_display_name(display_name: &str) -> Result<(), String> {
if display_name.trim().is_empty() {
return Err("Display name cannot be empty".into());
}
if display_name.len() > 100 {
return Err("Display name must be at most 100 characters".into());
}
Ok(())
}
/// Validate a list of provider models.
fn validate_models(models: &[ProviderModel]) -> Result<(), String> {
if models.is_empty() {
return Err("At least one model must be provided".into());
}
let default_count = models.iter().filter(|m| m.is_default).count();
if default_count > 1 {
return Err("At most one model can be marked as default".into());
}
for model in models {
if model.model_id.trim().is_empty() {
return Err("Model ID cannot be empty".into());
}
if model.model_id.len() > 100 {
return Err("Model ID must be at most 100 characters".into());
}
if model.display_name.trim().is_empty() {
return Err("Model display name cannot be empty".into());
}
if model.display_name.len() > 200 {
return Err("Model display name must be at most 200 characters".into());
}
}
Ok(())
}
/// Public response for enabled providers (no admin-only data).
///
/// Returned by `GET /api/v1/config/providers` for authenticated (non-admin) users.
#[derive(Debug, Serialize)]
pub struct ProviderConfigResponse {
pub provider_name: String,
pub display_name: String,
pub models_scraping: Vec<PublicModelInfo>,
pub models_websearch: Vec<PublicModelInfo>,
}
/// Public model info (subset of `ProviderModel`).
#[derive(Debug, Serialize)]
pub struct PublicModelInfo {
pub model_id: String,
pub display_name: String,
pub is_default: bool,
}
impl From<ProviderModel> for PublicModelInfo {
fn from(m: ProviderModel) -> Self {
Self {
model_id: m.model_id,
display_name: m.display_name,
is_default: m.is_default,
}
}
}
/// Full admin response for a provider (includes all fields).
#[derive(Debug, Serialize)]
pub struct AdminProviderResponse {
pub id: Uuid,
pub provider_name: String,
pub display_name: String,
pub models_scraping: Vec<ProviderModel>,
pub models_websearch: Vec<ProviderModel>,
pub is_enabled: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl From<AdminProvider> for AdminProviderResponse {
fn from(p: AdminProvider) -> Self {
Self {
id: p.id,
provider_name: p.provider_name,
display_name: p.display_name,
models_scraping: p.models_scraping,
models_websearch: p.models_websearch,
is_enabled: p.is_enabled,
created_at: p.created_at,
updated_at: p.updated_at,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Helper to create a sample model list for tests.
fn sample_models() -> Vec<ProviderModel> {
vec![ProviderModel {
model_id: "m1".into(),
display_name: "Model 1".into(),
is_default: true,
}]
}
#[test]
fn test_valid_create_request() {
let req = CreateProviderRequest {
provider_name: "gemini".into(),
display_name: "Google Gemini".into(),
models_scraping: vec![ProviderModel {
model_id: "gemini-2.5-pro".into(),
display_name: "Gemini 2.5 Pro".into(),
is_default: true,
}],
models_websearch: vec![ProviderModel {
model_id: "gemini-2.5-pro".into(),
display_name: "Gemini 2.5 Pro".into(),
is_default: true,
}],
is_enabled: true,
};
assert!(req.validate().is_ok());
}
#[test]
fn test_invalid_provider_name() {
let req = CreateProviderRequest {
provider_name: "unknown_provider".into(),
display_name: "Unknown".into(),
models_scraping: sample_models(),
models_websearch: sample_models(),
is_enabled: true,
};
let err = req.validate().unwrap_err();
assert!(err.contains("Invalid provider name"));
}
#[test]
fn test_empty_provider_name() {
let req = CreateProviderRequest {
provider_name: " ".into(),
display_name: "Some Provider".into(),
models_scraping: sample_models(),
models_websearch: sample_models(),
is_enabled: true,
};
let err = req.validate().unwrap_err();
assert!(err.contains("cannot be empty"));
}
#[test]
fn test_empty_models_scraping_list() {
let req = CreateProviderRequest {
provider_name: "openai".into(),
display_name: "OpenAI".into(),
models_scraping: vec![],
models_websearch: sample_models(),
is_enabled: true,
};
let err = req.validate().unwrap_err();
assert!(err.contains("At least one model"));
}
#[test]
fn test_empty_models_websearch_list() {
let req = CreateProviderRequest {
provider_name: "openai".into(),
display_name: "OpenAI".into(),
models_scraping: sample_models(),
models_websearch: vec![],
is_enabled: true,
};
let err = req.validate().unwrap_err();
assert!(err.contains("At least one model"));
}
#[test]
fn test_multiple_defaults_rejected() {
let req = CreateProviderRequest {
provider_name: "openai".into(),
display_name: "OpenAI".into(),
models_scraping: vec![
ProviderModel {
model_id: "gpt-4o".into(),
display_name: "GPT-4o".into(),
is_default: true,
},
ProviderModel {
model_id: "gpt-4o-mini".into(),
display_name: "GPT-4o Mini".into(),
is_default: true,
},
],
models_websearch: sample_models(),
is_enabled: true,
};
let err = req.validate().unwrap_err();
assert!(err.contains("At most one model"));
}
#[test]
fn test_empty_model_id_rejected() {
let req = CreateProviderRequest {
provider_name: "anthropic".into(),
display_name: "Anthropic".into(),
models_scraping: vec![ProviderModel {
model_id: "".into(),
display_name: "Claude".into(),
is_default: false,
}],
models_websearch: sample_models(),
is_enabled: true,
};
let err = req.validate().unwrap_err();
assert!(err.contains("Model ID cannot be empty"));
}
#[test]
fn test_update_request_all_none() {
let req = UpdateProviderRequest {
display_name: None,
models_scraping: None,
models_websearch: None,
is_enabled: None,
};
assert!(req.validate().is_ok());
}
#[test]
fn test_update_request_empty_display_name() {
let req = UpdateProviderRequest {
display_name: Some("".into()),
models_scraping: None,
models_websearch: None,
is_enabled: None,
};
let err = req.validate().unwrap_err();
assert!(err.contains("Display name cannot be empty"));
}
#[test]
fn test_provider_model_deserialization() {
let json = r#"{"model_id": "gpt-4o", "display_name": "GPT-4o"}"#;
let model: ProviderModel = serde_json::from_str(json).unwrap();
assert_eq!(model.model_id, "gpt-4o");
assert!(!model.is_default); // default is false
}
}