You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
363 lines
11 KiB
Rust
363 lines
11 KiB
Rust
//! Admin provider model and request/response types.
|
|
//!
|
|
//! Represents the admin-curated catalog of LLM providers and their models.
|
|
//! Users select from this catalog when configuring their settings.
|
|
|
|
use chrono::{DateTime, Utc};
|
|
use serde::{Deserialize, Serialize};
|
|
use uuid::Uuid;
|
|
|
|
/// A single model within a provider's model list (stored as JSONB).
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub struct ProviderModel {
|
|
pub model_id: String,
|
|
pub display_name: String,
|
|
#[serde(default)]
|
|
pub is_default: bool,
|
|
}
|
|
|
|
/// An admin provider record from the database.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct AdminProvider {
|
|
pub id: Uuid,
|
|
pub provider_name: String,
|
|
pub display_name: String,
|
|
pub models_scraping: Vec<ProviderModel>,
|
|
pub models_websearch: Vec<ProviderModel>,
|
|
pub is_enabled: bool,
|
|
pub created_at: DateTime<Utc>,
|
|
pub updated_at: DateTime<Utc>,
|
|
}
|
|
|
|
/// Request body for `POST /api/v1/admin/providers` (create or update).
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct CreateProviderRequest {
|
|
pub provider_name: String,
|
|
pub display_name: String,
|
|
pub models_scraping: Vec<ProviderModel>,
|
|
pub models_websearch: Vec<ProviderModel>,
|
|
#[serde(default = "default_true")]
|
|
pub is_enabled: bool,
|
|
}
|
|
|
|
fn default_true() -> bool {
|
|
true
|
|
}
|
|
|
|
/// Known provider names.
|
|
///
|
|
/// Also used by `models::api_key` for validating user API key requests.
|
|
pub const VALID_PROVIDERS: &[&str] = &["gemini", "openai", "anthropic"];
|
|
|
|
/// Valid provider names for user API key storage (includes non-LLM services).
|
|
pub const VALID_API_KEY_PROVIDERS: &[&str] = &["gemini", "openai", "anthropic", "brave_search"];
|
|
|
|
impl CreateProviderRequest {
|
|
/// Validate the provider creation request.
|
|
///
|
|
/// Returns `Ok(())` if all fields are valid, or `Err(message)`
|
|
/// describing the first validation failure.
|
|
pub fn validate(&self) -> Result<(), String> {
|
|
let name = self.provider_name.trim();
|
|
if name.is_empty() {
|
|
return Err("Provider name cannot be empty".into());
|
|
}
|
|
if name.len() > 50 {
|
|
return Err("Provider name must be at most 50 characters".into());
|
|
}
|
|
if !VALID_PROVIDERS.contains(&name) {
|
|
return Err(format!(
|
|
"Invalid provider name '{}'. Must be one of: {}",
|
|
name,
|
|
VALID_PROVIDERS.join(", ")
|
|
));
|
|
}
|
|
|
|
validate_display_name(&self.display_name)?;
|
|
validate_models(&self.models_scraping)?;
|
|
validate_models(&self.models_websearch)?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Request body for updating an existing provider.
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct UpdateProviderRequest {
|
|
pub display_name: Option<String>,
|
|
pub models_scraping: Option<Vec<ProviderModel>>,
|
|
pub models_websearch: Option<Vec<ProviderModel>>,
|
|
pub is_enabled: Option<bool>,
|
|
}
|
|
|
|
impl UpdateProviderRequest {
|
|
/// Validate the provider update request.
|
|
pub fn validate(&self) -> Result<(), String> {
|
|
if let Some(ref display) = self.display_name {
|
|
validate_display_name(display)?;
|
|
}
|
|
if let Some(ref models) = self.models_scraping {
|
|
validate_models(models)?;
|
|
}
|
|
if let Some(ref models) = self.models_websearch {
|
|
validate_models(models)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Validate a provider display name.
|
|
fn validate_display_name(display_name: &str) -> Result<(), String> {
|
|
if display_name.trim().is_empty() {
|
|
return Err("Display name cannot be empty".into());
|
|
}
|
|
if display_name.len() > 100 {
|
|
return Err("Display name must be at most 100 characters".into());
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Validate a list of provider models.
|
|
fn validate_models(models: &[ProviderModel]) -> Result<(), String> {
|
|
if models.is_empty() {
|
|
return Err("At least one model must be provided".into());
|
|
}
|
|
|
|
let default_count = models.iter().filter(|m| m.is_default).count();
|
|
if default_count > 1 {
|
|
return Err("At most one model can be marked as default".into());
|
|
}
|
|
|
|
for model in models {
|
|
if model.model_id.trim().is_empty() {
|
|
return Err("Model ID cannot be empty".into());
|
|
}
|
|
if model.model_id.len() > 100 {
|
|
return Err("Model ID must be at most 100 characters".into());
|
|
}
|
|
if model.display_name.trim().is_empty() {
|
|
return Err("Model display name cannot be empty".into());
|
|
}
|
|
if model.display_name.len() > 200 {
|
|
return Err("Model display name must be at most 200 characters".into());
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Public response for enabled providers (no admin-only data).
|
|
///
|
|
/// Returned by `GET /api/v1/config/providers` for authenticated (non-admin) users.
|
|
#[derive(Debug, Serialize)]
|
|
pub struct ProviderConfigResponse {
|
|
pub provider_name: String,
|
|
pub display_name: String,
|
|
pub models_scraping: Vec<PublicModelInfo>,
|
|
pub models_websearch: Vec<PublicModelInfo>,
|
|
}
|
|
|
|
/// Public model info (subset of `ProviderModel`).
|
|
#[derive(Debug, Serialize)]
|
|
pub struct PublicModelInfo {
|
|
pub model_id: String,
|
|
pub display_name: String,
|
|
pub is_default: bool,
|
|
}
|
|
|
|
impl From<ProviderModel> for PublicModelInfo {
|
|
fn from(m: ProviderModel) -> Self {
|
|
Self {
|
|
model_id: m.model_id,
|
|
display_name: m.display_name,
|
|
is_default: m.is_default,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Full admin response for a provider (includes all fields).
|
|
#[derive(Debug, Serialize)]
|
|
pub struct AdminProviderResponse {
|
|
pub id: Uuid,
|
|
pub provider_name: String,
|
|
pub display_name: String,
|
|
pub models_scraping: Vec<ProviderModel>,
|
|
pub models_websearch: Vec<ProviderModel>,
|
|
pub is_enabled: bool,
|
|
pub created_at: DateTime<Utc>,
|
|
pub updated_at: DateTime<Utc>,
|
|
}
|
|
|
|
impl From<AdminProvider> for AdminProviderResponse {
|
|
fn from(p: AdminProvider) -> Self {
|
|
Self {
|
|
id: p.id,
|
|
provider_name: p.provider_name,
|
|
display_name: p.display_name,
|
|
models_scraping: p.models_scraping,
|
|
models_websearch: p.models_websearch,
|
|
is_enabled: p.is_enabled,
|
|
created_at: p.created_at,
|
|
updated_at: p.updated_at,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
/// Helper to create a sample model list for tests.
|
|
fn sample_models() -> Vec<ProviderModel> {
|
|
vec![ProviderModel {
|
|
model_id: "m1".into(),
|
|
display_name: "Model 1".into(),
|
|
is_default: true,
|
|
}]
|
|
}
|
|
|
|
#[test]
|
|
fn test_valid_create_request() {
|
|
let req = CreateProviderRequest {
|
|
provider_name: "gemini".into(),
|
|
display_name: "Google Gemini".into(),
|
|
models_scraping: vec![ProviderModel {
|
|
model_id: "gemini-2.5-pro".into(),
|
|
display_name: "Gemini 2.5 Pro".into(),
|
|
is_default: true,
|
|
}],
|
|
models_websearch: vec![ProviderModel {
|
|
model_id: "gemini-2.5-pro".into(),
|
|
display_name: "Gemini 2.5 Pro".into(),
|
|
is_default: true,
|
|
}],
|
|
is_enabled: true,
|
|
};
|
|
assert!(req.validate().is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_invalid_provider_name() {
|
|
let req = CreateProviderRequest {
|
|
provider_name: "unknown_provider".into(),
|
|
display_name: "Unknown".into(),
|
|
models_scraping: sample_models(),
|
|
models_websearch: sample_models(),
|
|
is_enabled: true,
|
|
};
|
|
let err = req.validate().unwrap_err();
|
|
assert!(err.contains("Invalid provider name"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_empty_provider_name() {
|
|
let req = CreateProviderRequest {
|
|
provider_name: " ".into(),
|
|
display_name: "Some Provider".into(),
|
|
models_scraping: sample_models(),
|
|
models_websearch: sample_models(),
|
|
is_enabled: true,
|
|
};
|
|
let err = req.validate().unwrap_err();
|
|
assert!(err.contains("cannot be empty"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_empty_models_scraping_list() {
|
|
let req = CreateProviderRequest {
|
|
provider_name: "openai".into(),
|
|
display_name: "OpenAI".into(),
|
|
models_scraping: vec![],
|
|
models_websearch: sample_models(),
|
|
is_enabled: true,
|
|
};
|
|
let err = req.validate().unwrap_err();
|
|
assert!(err.contains("At least one model"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_empty_models_websearch_list() {
|
|
let req = CreateProviderRequest {
|
|
provider_name: "openai".into(),
|
|
display_name: "OpenAI".into(),
|
|
models_scraping: sample_models(),
|
|
models_websearch: vec![],
|
|
is_enabled: true,
|
|
};
|
|
let err = req.validate().unwrap_err();
|
|
assert!(err.contains("At least one model"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_defaults_rejected() {
|
|
let req = CreateProviderRequest {
|
|
provider_name: "openai".into(),
|
|
display_name: "OpenAI".into(),
|
|
models_scraping: vec![
|
|
ProviderModel {
|
|
model_id: "gpt-4o".into(),
|
|
display_name: "GPT-4o".into(),
|
|
is_default: true,
|
|
},
|
|
ProviderModel {
|
|
model_id: "gpt-4o-mini".into(),
|
|
display_name: "GPT-4o Mini".into(),
|
|
is_default: true,
|
|
},
|
|
],
|
|
models_websearch: sample_models(),
|
|
is_enabled: true,
|
|
};
|
|
let err = req.validate().unwrap_err();
|
|
assert!(err.contains("At most one model"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_empty_model_id_rejected() {
|
|
let req = CreateProviderRequest {
|
|
provider_name: "anthropic".into(),
|
|
display_name: "Anthropic".into(),
|
|
models_scraping: vec![ProviderModel {
|
|
model_id: "".into(),
|
|
display_name: "Claude".into(),
|
|
is_default: false,
|
|
}],
|
|
models_websearch: sample_models(),
|
|
is_enabled: true,
|
|
};
|
|
let err = req.validate().unwrap_err();
|
|
assert!(err.contains("Model ID cannot be empty"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_update_request_all_none() {
|
|
let req = UpdateProviderRequest {
|
|
display_name: None,
|
|
models_scraping: None,
|
|
models_websearch: None,
|
|
is_enabled: None,
|
|
};
|
|
assert!(req.validate().is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_update_request_empty_display_name() {
|
|
let req = UpdateProviderRequest {
|
|
display_name: Some("".into()),
|
|
models_scraping: None,
|
|
models_websearch: None,
|
|
is_enabled: None,
|
|
};
|
|
let err = req.validate().unwrap_err();
|
|
assert!(err.contains("Display name cannot be empty"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_model_deserialization() {
|
|
let json = r#"{"model_id": "gpt-4o", "display_name": "GPT-4o"}"#;
|
|
let model: ProviderModel = serde_json::from_str(json).unwrap();
|
|
assert_eq!(model.model_id, "gpt-4o");
|
|
assert!(!model.is_default); // default is false
|
|
}
|
|
}
|