refactor: add provider_override for pipeline dependency injection

Adds an optional LlmProvider override to run_generation and
run_generation_inner, allowing tests to inject a mock provider without
touching real credentials or the provider-resolution path. Makes
run_generation_inner pub so integration tests can call it directly.
Production callers pass None and behaviour is unchanged.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
master
oabrivard 3 months ago
parent 17e054c257
commit ccecaa2d13

@ -76,7 +76,7 @@ pub async fn trigger_generate(
let join_handle = tokio::spawn(async move {
let timeout_duration = std::time::Duration::from_secs(900);
match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, tx.clone())).await {
match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, tx.clone(), None)).await {
Ok(()) => {}
Err(_) => {
tracing::error!(job_id = %job_id, user_id = %user_id, "Generation timed out after 15 minutes");

@ -194,8 +194,9 @@ pub async fn run_generation(
state: AppState,
user_id: Uuid,
tx: Arc<watch::Sender<ProgressEvent>>,
provider_override: Option<Arc<dyn crate::services::llm::LlmProvider>>,
) {
let result = run_generation_inner(job_id, &state, user_id, &tx).await;
let result = run_generation_inner(job_id, &state, user_id, &tx, provider_override).await;
match result {
Ok(synthesis_id) => {
@ -224,11 +225,12 @@ pub async fn run_generation(
}
/// Inner implementation of the generation pipeline, returning a Result.
async fn run_generation_inner(
pub async fn run_generation_inner(
job_id: Uuid,
state: &AppState,
user_id: Uuid,
tx: &watch::Sender<ProgressEvent>,
provider_override: Option<Arc<dyn crate::services::llm::LlmProvider>>,
) -> Result<Uuid, AppError> {
// Batch buffer for article history traces (flushed at logical boundaries)
let mut pending_traces: Vec<db::article_history::ArticleHistoryEntry> = Vec::new();
@ -254,10 +256,22 @@ async fn run_generation_inner(
let sources = db::sources::list_for_user(&state.pool, user_id).await?;
emit_progress(tx, "provider", "Configuration du fournisseur IA...", 12);
let (provider_name, api_key) = resolve_provider_and_key(state, user_id, &settings).await?;
let provider = create_provider(&provider_name, api_key)?;
let model_research = if !settings.ai_model.is_empty() { settings.ai_model.clone() } else { resolve_model(state, &provider_name).await? };
let model_websearch = if !settings.ai_model_websearch.is_empty() { settings.ai_model_websearch.clone() } else { model_research.clone() };
let (provider_name, provider) = if let Some(mock_provider) = provider_override {
("mock".to_string(), mock_provider)
} else {
let (pname, api_key) = resolve_provider_and_key(state, user_id, &settings).await?;
let p = create_provider(&pname, api_key)?;
(pname, p)
};
let (model_research, model_websearch) = if provider_name == "mock" {
let research = if settings.ai_model.is_empty() { "mock-model".to_string() } else { settings.ai_model.clone() };
let websearch = if settings.ai_model_websearch.is_empty() { "mock-model".to_string() } else { settings.ai_model_websearch.clone() };
(research, websearch)
} else {
let model_research = if !settings.ai_model.is_empty() { settings.ai_model.clone() } else { resolve_model(state, &provider_name).await? };
let model_websearch = if !settings.ai_model_websearch.is_empty() { settings.ai_model_websearch.clone() } else { model_research.clone() };
(model_research, model_websearch)
};
let user_rate_limiter = get_user_rate_limiter(state, &settings, user_id);
// Tracking structures

Loading…
Cancel
Save