diff --git a/backend/src/handlers/generation.rs b/backend/src/handlers/generation.rs index efd31d5..92252fe 100644 --- a/backend/src/handlers/generation.rs +++ b/backend/src/handlers/generation.rs @@ -62,7 +62,7 @@ pub async fn trigger_generate( } // Create the job in the store - let (job_id, tx) = state + let (job_id, tx, cancelled) = state .job_store .create_job(auth_user.id) .ok_or_else(|| { @@ -86,7 +86,7 @@ pub async fn trigger_generate( let join_handle = tokio::spawn(async move { let timeout_duration = std::time::Duration::from_secs(900); - match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, theme_id, tx.clone(), None)).await { + match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, theme_id, tx.clone(), None, cancelled)).await { Ok(()) => {} Err(_) => { tracing::error!(job_id = %job_id, user_id = %user_id, "Generation timed out after 15 minutes"); @@ -117,6 +117,24 @@ pub async fn trigger_generate( )) } +/// `POST /api/v1/syntheses/generate/:job_id/stop` +/// +/// Signals a running generation job to stop. The pipeline will save +/// whatever articles have been collected so far and complete gracefully. +pub async fn stop_generate( + auth_user: AuthUser, + State(state): State, + Path(job_id): Path, +) -> Result { + let cancelled = state.job_store.cancel_job(job_id, auth_user.id); + if !cancelled { + return Err(AppError::NotFound( + "Generation introuvable.".into(), + )); + } + Ok(StatusCode::OK) +} + /// `GET /api/v1/syntheses/generate/:job_id/progress` /// /// Server-Sent Events (SSE) endpoint that streams generation progress. diff --git a/backend/src/router.rs b/backend/src/router.rs index 12d8fbd..ab22e4a 100644 --- a/backend/src/router.rs +++ b/backend/src/router.rs @@ -59,6 +59,7 @@ pub fn build_router(state: AppState, config: &AppConfig) -> Router { // to avoid ambiguity with path parameter matching .route("/syntheses/generate", post(handlers::generation::trigger_generate)) .route("/syntheses/generate/{job_id}/progress", get(handlers::generation::progress_stream)) + .route("/syntheses/generate/{job_id}/stop", post(handlers::generation::stop_generate)) // Article history & provenance routes (authenticated) .route("/article-history", get(handlers::article_history::list_history).delete(handlers::article_history::clear_history)) .route("/syntheses/{id}/provenance", get(handlers::article_history::get_provenance)) diff --git a/backend/src/services/synthesis.rs b/backend/src/services/synthesis.rs index 4b349ef..94da595 100644 --- a/backend/src/services/synthesis.rs +++ b/backend/src/services/synthesis.rs @@ -8,6 +8,7 @@ //! consumed by SSE endpoints for real-time client updates. use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -74,6 +75,8 @@ struct JobEntry { user_id: Uuid, /// When the job was created (for TTL cleanup). created_at: Instant, + /// Flag set to true when the user requests cancellation. + cancelled: Arc, } /// In-memory store for active generation jobs. @@ -104,11 +107,11 @@ impl JobStore { } } - /// Create a new job for a user, returning the job ID and the watch Sender. + /// Create a new job for a user, returning the job ID, the watch Sender, and a cancellation flag. /// /// Returns `None` if the user already has an active job. /// Uses an atomic DashSet insert to prevent race conditions on double-click. - pub fn create_job(&self, user_id: Uuid) -> Option<(Uuid, Arc>)> { + pub fn create_job(&self, user_id: Uuid) -> Option<(Uuid, Arc>, Arc)> { if !self.generating_users.insert(user_id) { return None; } @@ -119,10 +122,12 @@ impl JobStore { percent: 0, }); let tx = Arc::new(tx); + let cancelled = Arc::new(AtomicBool::new(false)); self.inner.insert(job_id, JobEntry { tx: Arc::clone(&tx), _rx: rx, user_id, created_at: Instant::now(), + cancelled: Arc::clone(&cancelled), }); - Some((job_id, tx)) + Some((job_id, tx, cancelled)) } /// Get a watch receiver for a job, if it exists and belongs to the given user. @@ -145,6 +150,17 @@ impl JobStore { None } + /// Signal a job to stop. Returns true if the job was found and belongs to the user. + pub fn cancel_job(&self, job_id: Uuid, user_id: Uuid) -> bool { + if let Some(entry) = self.inner.get(&job_id) { + if entry.value().user_id == user_id { + entry.value().cancelled.store(true, Ordering::Relaxed); + return true; + } + } + false + } + /// Release the generating lock for a user (called when job completes, errors, or times out). pub fn release_user(&self, user_id: Uuid) { self.generating_users.remove(&user_id); @@ -196,8 +212,9 @@ pub async fn run_generation( theme_id: Uuid, tx: Arc>, provider_override: Option>, + cancelled: Arc, ) { - let result = run_generation_inner(job_id, &state, user_id, theme_id, &tx, provider_override).await; + let result = run_generation_inner(job_id, &state, user_id, theme_id, &tx, provider_override, &cancelled).await; match result { Ok(synthesis_id) => { @@ -233,6 +250,7 @@ pub async fn run_generation_inner( theme_id: Uuid, tx: &watch::Sender, provider_override: Option>, + cancelled: &AtomicBool, ) -> Result { // Batch buffer for article history traces (flushed at logical boundaries) let mut pending_traces: Vec = Vec::new(); @@ -309,6 +327,13 @@ pub async fn run_generation_inner( let total_waves = source_chunks.len(); 'wave_loop: for (wave_idx, wave_sources) in source_chunks.iter().enumerate() { + // Check cancellation before each wave + if cancelled.load(Ordering::Relaxed) { + tracing::info!(job_id = %job_id, "Generation cancelled by user (before wave)"); + emit_progress(tx, "saving", "Generation arretee, sauvegarde...", 90); + break 'wave_loop; + } + let articles_so_far: usize = article_scraped.values().map(|v| v.len()).sum(); let pct = 5 + ((articles_so_far as u32 * 60) / max_total.max(1) as u32).min(60); emit_progress(tx, "sources", @@ -569,6 +594,13 @@ pub async fn run_generation_inner( processed += batch.len(); + // Check cancellation after each batch + if cancelled.load(Ordering::Relaxed) { + tracing::info!(job_id = %job_id, "Generation cancelled by user (after batch)"); + emit_progress(tx, "saving", "Generation arretee, sauvegarde...", 90); + break; + } + // Check if we've reached the maximum after this batch let total: usize = article_scraped.values().map(|v| v.len()).sum(); if total >= max_total { @@ -593,13 +625,19 @@ pub async fn run_generation_inner( } // === PHASE 2: Web Search Fallback === + // Skip Phase 2 if cancelled + let is_cancelled = cancelled.load(Ordering::Relaxed); + if is_cancelled { + tracing::info!(job_id = %job_id, "Skipping Phase 2 — generation cancelled by user"); + } + let category_gaps: Vec<(String, i32)> = user_categories.iter().filter_map(|cat| { let filled = filled_counts.get(cat).copied().unwrap_or(0); let needed = (theme.max_items_per_category as usize).saturating_sub(filled); if needed > 0 { Some((cat.clone(), needed as i32)) } else { None } }).collect(); - if !category_gaps.is_empty() { + if !category_gaps.is_empty() && !is_cancelled { if settings.use_brave_search { // === BRAVE SEARCH PATH === emit_progress(tx, "websearch", "Recherche Brave Search...", 70); @@ -907,11 +945,21 @@ pub async fn run_generation_inner( } // === SAVE === - if article_scraped.values().all(|items| items.is_empty()) { + let is_cancelled = cancelled.load(Ordering::Relaxed); + let has_articles = article_scraped.values().any(|items| !items.is_empty()); + + if !has_articles { + if is_cancelled { + return Err(AppError::BadRequest("Generation arretee. Aucun article n'avait encore ete collecte.".into())); + } return Err(AppError::BadRequest("Aucun article valide trouve. Verifiez vos sources et categories.".into())); } - emit_progress(tx, "saving", "Sauvegarde de la synthese...", 90); + if is_cancelled { + emit_progress(tx, "saving", "Generation arretee, sauvegarde des articles collectes...", 90); + } else { + emit_progress(tx, "saving", "Sauvegarde de la synthese...", 90); + } let mut final_sections: Vec = Vec::new(); for (i, cat_name) in user_categories.iter().enumerate() { @@ -1483,7 +1531,7 @@ mod tests { let store = JobStore::new(); let user_id = Uuid::new_v4(); - let (job_id, tx) = store.create_job(user_id).unwrap(); + let (job_id, tx, _cancelled) = store.create_job(user_id).unwrap(); assert_eq!(store.len(), 1); // Subscribe @@ -1506,8 +1554,8 @@ mod tests { let store = JobStore::new(); let user_id = Uuid::new_v4(); - let result1 = store.create_job(user_id); - assert!(result1.is_some()); + let _result1 = store.create_job(user_id); + assert!(_result1.is_some()); // Second job for same user should fail let result2 = store.create_job(user_id); @@ -1524,7 +1572,7 @@ mod tests { let store = JobStore::new(); let user_id = Uuid::new_v4(); - let (_job_id, tx) = store.create_job(user_id).unwrap(); + let (_job_id, tx, _cancelled) = store.create_job(user_id).unwrap(); // Complete the job and release the user lock (as the pipeline does) tx.send(ProgressEvent::Complete { @@ -1543,7 +1591,7 @@ mod tests { let store = JobStore::new(); let user_id = Uuid::new_v4(); - let (_job_id, tx) = store.create_job(user_id).unwrap(); + let (_job_id, tx, _cancelled) = store.create_job(user_id).unwrap(); // Fail the job and release the user lock (as the pipeline does) tx.send(ProgressEvent::Error { @@ -1563,7 +1611,7 @@ mod tests { let user_id = Uuid::new_v4(); // Create a job and manually set its created_at to the past - let (_job_id, _tx) = store.create_job(user_id).unwrap(); + let (_job_id, _tx, _cancelled) = store.create_job(user_id).unwrap(); assert_eq!(store.len(), 1); // Cleanup should not remove recent jobs @@ -1576,7 +1624,7 @@ mod tests { let store = JobStore::new(); let user_id = Uuid::new_v4(); - let (job_id, _tx) = store.create_job(user_id).unwrap(); + let (job_id, _tx, _cancelled) = store.create_job(user_id).unwrap(); assert_eq!(store.len(), 1); store.remove(&job_id); diff --git a/backend/tests/pipeline_test.rs b/backend/tests/pipeline_test.rs index 3d5ff5c..7b68f8f 100644 --- a/backend/tests/pipeline_test.rs +++ b/backend/tests/pipeline_test.rs @@ -2,6 +2,7 @@ mod common; use ai_synth_backend::services::llm::mock::MockLlmProvider; use ai_synth_backend::services::synthesis; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use tokio::sync::watch; use wiremock::matchers::{method, path}; @@ -117,7 +118,7 @@ async fn phase1_heuristic_extraction_classifies_articles() { ); let result = synthesis::run_generation_inner( - job_id, &state, user_id, theme_id, &tx, Some(mock_provider), + job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false), ).await; assert!(result.is_ok(), "Generation should succeed: {:?}", result.err()); @@ -177,7 +178,7 @@ async fn phase2_search_fills_gaps_when_no_sources() { ); let result = synthesis::run_generation_inner( - job_id, &state, user_id, theme_id, &tx, Some(mock_provider), + job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false), ).await; assert!(result.is_ok(), "Generation should succeed: {:?}", result.err()); @@ -220,7 +221,7 @@ async fn category_overflow_spills_to_autre() { ); let result = synthesis::run_generation_inner( - job_id, &state, user_id, theme_id, &tx, Some(mock_provider), + job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false), ).await; assert!(result.is_ok(), "Generation should succeed"); @@ -334,7 +335,7 @@ async fn source_diversity_limits_articles_per_source() { ); let result = synthesis::run_generation_inner( - job_id, &state, user_id, theme_id, &tx, Some(mock_provider), + job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false), ) .await; @@ -448,7 +449,7 @@ async fn article_history_dedup_prevents_repeat_articles() { ); let result1 = synthesis::run_generation_inner( - job1, &state1, user_id, theme_id, &tx1, Some(mock1), + job1, &state1, user_id, theme_id, &tx1, Some(mock1), &AtomicBool::new(false), ) .await; assert!(result1.is_ok(), "First generation should succeed: {:?}", result1.err()); @@ -479,7 +480,7 @@ async fn article_history_dedup_prevents_repeat_articles() { // The second run scrapes the same URLs, which are already in article_history. // They should be filtered out as "filtered_history". let _result2 = synthesis::run_generation_inner( - job2, &state2, user_id, theme_id, &tx2, Some(mock2), + job2, &state2, user_id, theme_id, &tx2, Some(mock2), &AtomicBool::new(false), ) .await; // The second run may succeed (empty synthesis) or fail (no valid articles). diff --git a/frontend/src/i18n/fr.ts b/frontend/src/i18n/fr.ts index fd823a2..1087a3e 100644 --- a/frontend/src/i18n/fr.ts +++ b/frontend/src/i18n/fr.ts @@ -88,6 +88,8 @@ const fr = { 'generate.noWebSearchWarning': 'Le fournisseur selectionne ne supporte pas la recherche web. Les resultats seront bases uniquement sur les connaissances du modele.', 'generate.selectTheme': 'Theme a generer', + 'generate.stop': 'Arreter la generation', + 'generate.stopped': 'Generation arretee. Les articles collectes ont ete sauvegardes.', 'generate.noThemes': 'Aucun theme configure. Creez un theme pour pouvoir generer une synthese.', 'generate.createThemeLink': 'Creer un theme', diff --git a/frontend/src/pages/GenerateSynthesis.tsx b/frontend/src/pages/GenerateSynthesis.tsx index 7a28776..ab78967 100644 --- a/frontend/src/pages/GenerateSynthesis.tsx +++ b/frontend/src/pages/GenerateSynthesis.tsx @@ -60,6 +60,7 @@ const GenerateSynthesis: Component = () => { const [error, setError] = createSignal(null); const [success, setSuccess] = createSignal(false); const [sseConnection, setSSEConnection] = createSignal(null); + const [jobId, setJobId] = createSignal(null); onMount(async () => { try { @@ -219,6 +220,7 @@ const GenerateSynthesis: Component = () => { try { const response = await synthesesApi.generate(selectedThemeId()); + setJobId(response.job_id); const url = synthesesApi.progressUrl(response.job_id); const conn = createSSEConnection(url); setSSEConnection(conn); @@ -232,6 +234,20 @@ const GenerateSynthesis: Component = () => { } }; + const handleStop = async () => { + const id = jobId(); + if (!id) return; + try { + await fetch(`/api/v1/syntheses/generate/${id}/stop`, { + method: 'POST', + headers: { 'X-Requested-With': 'XMLHttpRequest' }, + credentials: 'same-origin', + }); + } catch { + // ignore errors — the pipeline will stop on its own + } + }; + const handleRetry = () => { // Close existing connection const conn = sseConnection(); @@ -239,6 +255,7 @@ const GenerateSynthesis: Component = () => { conn.close(); } setSSEConnection(null); + setJobId(null); setError(null); setSuccess(false); setGenerating(false); @@ -403,6 +420,15 @@ const GenerateSynthesis: Component = () => {

{t('generate.canLeave')}

+ + {/* Stop generation button */} +