You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
6.9 KiB
Rust
210 lines
6.9 KiB
Rust
//! Database queries for the `admin_providers` table.
|
|
//!
|
|
//! Provides CRUD operations for the admin-curated LLM provider catalog.
|
|
|
|
use sqlx::PgPool;
|
|
use uuid::Uuid;
|
|
|
|
use crate::errors::AppError;
|
|
use crate::models::provider::{AdminProvider, ProviderModel};
|
|
|
|
/// Row type returned by sqlx queries against the `admin_providers` table.
|
|
#[derive(Debug, sqlx::FromRow)]
|
|
struct ProviderRow {
|
|
id: Uuid,
|
|
provider_name: String,
|
|
display_name: String,
|
|
models_scraping: serde_json::Value,
|
|
models_websearch: serde_json::Value,
|
|
is_enabled: bool,
|
|
created_at: chrono::DateTime<chrono::Utc>,
|
|
updated_at: chrono::DateTime<chrono::Utc>,
|
|
}
|
|
|
|
impl TryFrom<ProviderRow> for AdminProvider {
|
|
type Error = AppError;
|
|
|
|
fn try_from(row: ProviderRow) -> Result<Self, Self::Error> {
|
|
let models_scraping: Vec<ProviderModel> =
|
|
serde_json::from_value(row.models_scraping).map_err(|e| {
|
|
AppError::Internal(anyhow::anyhow!("Failed to parse provider models_scraping JSON: {}", e))
|
|
})?;
|
|
let models_websearch: Vec<ProviderModel> =
|
|
serde_json::from_value(row.models_websearch).map_err(|e| {
|
|
AppError::Internal(anyhow::anyhow!("Failed to parse provider models_websearch JSON: {}", e))
|
|
})?;
|
|
|
|
Ok(Self {
|
|
id: row.id,
|
|
provider_name: row.provider_name,
|
|
display_name: row.display_name,
|
|
models_scraping,
|
|
models_websearch,
|
|
is_enabled: row.is_enabled,
|
|
created_at: row.created_at,
|
|
updated_at: row.updated_at,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// List all providers (admin view, includes disabled).
|
|
pub async fn list_all(pool: &PgPool) -> Result<Vec<AdminProvider>, AppError> {
|
|
let rows = sqlx::query_as::<_, ProviderRow>(
|
|
r#"
|
|
SELECT id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
|
|
FROM admin_providers
|
|
ORDER BY provider_name
|
|
"#,
|
|
)
|
|
.fetch_all(pool)
|
|
.await?;
|
|
|
|
rows.into_iter().map(AdminProvider::try_from).collect()
|
|
}
|
|
|
|
/// Get a provider by its UUID.
|
|
pub async fn get_by_id(pool: &PgPool, id: Uuid) -> Result<Option<AdminProvider>, AppError> {
|
|
let row = sqlx::query_as::<_, ProviderRow>(
|
|
r#"
|
|
SELECT id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
|
|
FROM admin_providers
|
|
WHERE id = $1
|
|
"#,
|
|
)
|
|
.bind(id)
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
|
|
row.map(AdminProvider::try_from).transpose()
|
|
}
|
|
|
|
/// Get a provider by its unique name (e.g., "gemini", "openai").
|
|
pub async fn get_by_name(pool: &PgPool, name: &str) -> Result<Option<AdminProvider>, AppError> {
|
|
let row = sqlx::query_as::<_, ProviderRow>(
|
|
r#"
|
|
SELECT id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
|
|
FROM admin_providers
|
|
WHERE provider_name = $1
|
|
"#,
|
|
)
|
|
.bind(name)
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
|
|
row.map(AdminProvider::try_from).transpose()
|
|
}
|
|
|
|
/// Create a new provider. Returns the created provider.
|
|
///
|
|
/// The caller is responsible for validating the request before calling this.
|
|
pub async fn create(
|
|
pool: &PgPool,
|
|
provider_name: &str,
|
|
display_name: &str,
|
|
models_scraping: &[ProviderModel],
|
|
models_websearch: &[ProviderModel],
|
|
is_enabled: bool,
|
|
) -> Result<AdminProvider, AppError> {
|
|
let models_scraping_json = serde_json::to_value(models_scraping).map_err(|e| {
|
|
AppError::Internal(anyhow::anyhow!("Failed to serialize models_scraping: {}", e))
|
|
})?;
|
|
let models_websearch_json = serde_json::to_value(models_websearch).map_err(|e| {
|
|
AppError::Internal(anyhow::anyhow!("Failed to serialize models_websearch: {}", e))
|
|
})?;
|
|
|
|
let row = sqlx::query_as::<_, ProviderRow>(
|
|
r#"
|
|
INSERT INTO admin_providers (provider_name, display_name, models_scraping, models_websearch, is_enabled)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
RETURNING id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
|
|
"#,
|
|
)
|
|
.bind(provider_name)
|
|
.bind(display_name)
|
|
.bind(&models_scraping_json)
|
|
.bind(&models_websearch_json)
|
|
.bind(is_enabled)
|
|
.fetch_one(pool)
|
|
.await?;
|
|
|
|
AdminProvider::try_from(row)
|
|
}
|
|
|
|
/// Update an existing provider by ID.
|
|
///
|
|
/// Only updates the fields that are `Some` in the arguments.
|
|
/// Returns the updated provider, or `None` if the ID was not found.
|
|
pub async fn update(
|
|
pool: &PgPool,
|
|
id: Uuid,
|
|
display_name: Option<&str>,
|
|
models_scraping: Option<&[ProviderModel]>,
|
|
models_websearch: Option<&[ProviderModel]>,
|
|
is_enabled: Option<bool>,
|
|
) -> Result<Option<AdminProvider>, AppError> {
|
|
let models_scraping_json = models_scraping
|
|
.map(|m| {
|
|
serde_json::to_value(m)
|
|
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to serialize models_scraping: {}", e)))
|
|
})
|
|
.transpose()?;
|
|
let models_websearch_json = models_websearch
|
|
.map(|m| {
|
|
serde_json::to_value(m)
|
|
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to serialize models_websearch: {}", e)))
|
|
})
|
|
.transpose()?;
|
|
|
|
let row = sqlx::query_as::<_, ProviderRow>(
|
|
r#"
|
|
UPDATE admin_providers SET
|
|
display_name = COALESCE($2, display_name),
|
|
models_scraping = COALESCE($3, models_scraping),
|
|
models_websearch = COALESCE($4, models_websearch),
|
|
is_enabled = COALESCE($5, is_enabled),
|
|
updated_at = now()
|
|
WHERE id = $1
|
|
RETURNING id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
|
|
"#,
|
|
)
|
|
.bind(id)
|
|
.bind(display_name)
|
|
.bind(models_scraping_json)
|
|
.bind(models_websearch_json)
|
|
.bind(is_enabled)
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
|
|
row.map(AdminProvider::try_from).transpose()
|
|
}
|
|
|
|
/// Delete a provider by ID.
|
|
///
|
|
/// Returns `true` if a row was deleted, `false` if the ID was not found.
|
|
pub async fn delete(pool: &PgPool, id: Uuid) -> Result<bool, AppError> {
|
|
let result = sqlx::query("DELETE FROM admin_providers WHERE id = $1")
|
|
.bind(id)
|
|
.execute(pool)
|
|
.await?;
|
|
|
|
Ok(result.rows_affected() > 0)
|
|
}
|
|
|
|
/// List only enabled providers (for the public config endpoint).
|
|
///
|
|
/// Returns providers where `is_enabled = true`, ordered by provider name.
|
|
pub async fn list_enabled(pool: &PgPool) -> Result<Vec<AdminProvider>, AppError> {
|
|
let rows = sqlx::query_as::<_, ProviderRow>(
|
|
r#"
|
|
SELECT id, provider_name, display_name, models_scraping, models_websearch, is_enabled, created_at, updated_at
|
|
FROM admin_providers
|
|
WHERE is_enabled = true
|
|
ORDER BY provider_name
|
|
"#,
|
|
)
|
|
.fetch_all(pool)
|
|
.await?;
|
|
|
|
rows.into_iter().map(AdminProvider::try_from).collect()
|
|
}
|