You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

210 lines
6.9 KiB
Rust

//! Database queries for the `admin_providers` table.
//!
//! Provides CRUD operations for the admin-curated LLM provider catalog.
use sqlx::PgPool;
use uuid::Uuid;
use crate::errors::AppError;
use crate::models::provider::{AdminProvider, ProviderModel};
/// Row type returned by sqlx queries against the `admin_providers` table.
#[derive(Debug, sqlx::FromRow)]
struct ProviderRow {
id: Uuid,
provider_name: String,
display_name: String,
models_scraping: serde_json::Value,
models_websearch: serde_json::Value,
is_enabled: bool,
created_at: chrono::DateTime<chrono::Utc>,
updated_at: chrono::DateTime<chrono::Utc>,
}
impl TryFrom<ProviderRow> for AdminProvider {
type Error = AppError;
fn try_from(row: ProviderRow) -> Result<Self, Self::Error> {
let models_scraping: Vec<ProviderModel> =
serde_json::from_value(row.models_scraping).map_err(|e| {
AppError::Internal(anyhow::anyhow!("Failed to parse provider models_scraping JSON: {}", e))
})?;
let models_websearch: Vec<ProviderModel> =
serde_json::from_value(row.models_websearch).map_err(|e| {
AppError::Internal(anyhow::anyhow!("Failed to parse provider models_websearch JSON: {}", e))
})?;
Ok(Self {
id: row.id,
provider_name: row.provider_name,
display_name: row.display_name,
models_scraping,
models_websearch,
is_enabled: row.is_enabled,
created_at: row.created_at,
updated_at: row.updated_at,
})
}
}
/// List all providers (admin view, includes disabled).
pub async fn list_all(pool: &PgPool) -> Result<Vec<AdminProvider>, AppError> {
let rows = sqlx::query_as::<_, ProviderRow>(
r#"
SELECT id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
FROM admin_providers
ORDER BY provider_name
"#,
)
.fetch_all(pool)
.await?;
rows.into_iter().map(AdminProvider::try_from).collect()
}
/// Get a provider by its UUID.
pub async fn get_by_id(pool: &PgPool, id: Uuid) -> Result<Option<AdminProvider>, AppError> {
let row = sqlx::query_as::<_, ProviderRow>(
r#"
SELECT id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
FROM admin_providers
WHERE id = $1
"#,
)
.bind(id)
.fetch_optional(pool)
.await?;
row.map(AdminProvider::try_from).transpose()
}
/// Get a provider by its unique name (e.g., "gemini", "openai").
pub async fn get_by_name(pool: &PgPool, name: &str) -> Result<Option<AdminProvider>, AppError> {
let row = sqlx::query_as::<_, ProviderRow>(
r#"
SELECT id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
FROM admin_providers
WHERE provider_name = $1
"#,
)
.bind(name)
.fetch_optional(pool)
.await?;
row.map(AdminProvider::try_from).transpose()
}
/// Create a new provider. Returns the created provider.
///
/// The caller is responsible for validating the request before calling this.
pub async fn create(
pool: &PgPool,
provider_name: &str,
display_name: &str,
models_scraping: &[ProviderModel],
models_websearch: &[ProviderModel],
is_enabled: bool,
) -> Result<AdminProvider, AppError> {
let models_scraping_json = serde_json::to_value(models_scraping).map_err(|e| {
AppError::Internal(anyhow::anyhow!("Failed to serialize models_scraping: {}", e))
})?;
let models_websearch_json = serde_json::to_value(models_websearch).map_err(|e| {
AppError::Internal(anyhow::anyhow!("Failed to serialize models_websearch: {}", e))
})?;
let row = sqlx::query_as::<_, ProviderRow>(
r#"
INSERT INTO admin_providers (provider_name, display_name, models_scraping, models_websearch, is_enabled)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
"#,
)
.bind(provider_name)
.bind(display_name)
.bind(&models_scraping_json)
.bind(&models_websearch_json)
.bind(is_enabled)
.fetch_one(pool)
.await?;
AdminProvider::try_from(row)
}
/// Update an existing provider by ID.
///
/// Only updates the fields that are `Some` in the arguments.
/// Returns the updated provider, or `None` if the ID was not found.
pub async fn update(
pool: &PgPool,
id: Uuid,
display_name: Option<&str>,
models_scraping: Option<&[ProviderModel]>,
models_websearch: Option<&[ProviderModel]>,
is_enabled: Option<bool>,
) -> Result<Option<AdminProvider>, AppError> {
let models_scraping_json = models_scraping
.map(|m| {
serde_json::to_value(m)
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to serialize models_scraping: {}", e)))
})
.transpose()?;
let models_websearch_json = models_websearch
.map(|m| {
serde_json::to_value(m)
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to serialize models_websearch: {}", e)))
})
.transpose()?;
let row = sqlx::query_as::<_, ProviderRow>(
r#"
UPDATE admin_providers SET
display_name = COALESCE($2, display_name),
models_scraping = COALESCE($3, models_scraping),
models_websearch = COALESCE($4, models_websearch),
is_enabled = COALESCE($5, is_enabled),
updated_at = now()
WHERE id = $1
RETURNING id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
"#,
)
.bind(id)
.bind(display_name)
.bind(models_scraping_json)
.bind(models_websearch_json)
.bind(is_enabled)
.fetch_optional(pool)
.await?;
row.map(AdminProvider::try_from).transpose()
}
/// Delete a provider by ID.
///
/// Returns `true` if a row was deleted, `false` if the ID was not found.
pub async fn delete(pool: &PgPool, id: Uuid) -> Result<bool, AppError> {
let result = sqlx::query("DELETE FROM admin_providers WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
/// List only enabled providers (for the public config endpoint).
///
/// Returns providers where `is_enabled = true`, ordered by provider name.
pub async fn list_enabled(pool: &PgPool) -> Result<Vec<AdminProvider>, AppError> {
let rows = sqlx::query_as::<_, ProviderRow>(
r#"
SELECT id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
FROM admin_providers
WHERE is_enabled = true
ORDER BY provider_name
"#,
)
.fetch_all(pool)
.await?;
rows.into_iter().map(AdminProvider::try_from).collect()
}