feat: add stop generation button — saves partial synthesis on cancel

Adds Arc<AtomicBool> cancellation flag to JobStore/JobEntry. The pipeline
checks the flag before each wave and after each batch, then saves whatever
articles have been collected. A new POST /syntheses/generate/:job_id/stop
endpoint sets the flag. The frontend shows a red stop button during generation
and POSTs to the stop endpoint on click.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
master
oabrivard 3 months ago
parent 7ba1da4d92
commit 6f3e6883c9

@ -62,7 +62,7 @@ pub async fn trigger_generate(
} }
// Create the job in the store // Create the job in the store
let (job_id, tx) = state let (job_id, tx, cancelled) = state
.job_store .job_store
.create_job(auth_user.id) .create_job(auth_user.id)
.ok_or_else(|| { .ok_or_else(|| {
@ -86,7 +86,7 @@ pub async fn trigger_generate(
let join_handle = tokio::spawn(async move { let join_handle = tokio::spawn(async move {
let timeout_duration = std::time::Duration::from_secs(900); let timeout_duration = std::time::Duration::from_secs(900);
match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, theme_id, tx.clone(), None)).await { match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, theme_id, tx.clone(), None, cancelled)).await {
Ok(()) => {} Ok(()) => {}
Err(_) => { Err(_) => {
tracing::error!(job_id = %job_id, user_id = %user_id, "Generation timed out after 15 minutes"); tracing::error!(job_id = %job_id, user_id = %user_id, "Generation timed out after 15 minutes");
@ -117,6 +117,24 @@ pub async fn trigger_generate(
)) ))
} }
/// `POST /api/v1/syntheses/generate/:job_id/stop`
///
/// Signals a running generation job to stop. The pipeline will save
/// whatever articles have been collected so far and complete gracefully.
pub async fn stop_generate(
auth_user: AuthUser,
State(state): State<AppState>,
Path(job_id): Path<Uuid>,
) -> Result<impl IntoResponse, AppError> {
let cancelled = state.job_store.cancel_job(job_id, auth_user.id);
if !cancelled {
return Err(AppError::NotFound(
"Generation introuvable.".into(),
));
}
Ok(StatusCode::OK)
}
/// `GET /api/v1/syntheses/generate/:job_id/progress` /// `GET /api/v1/syntheses/generate/:job_id/progress`
/// ///
/// Server-Sent Events (SSE) endpoint that streams generation progress. /// Server-Sent Events (SSE) endpoint that streams generation progress.

@ -59,6 +59,7 @@ pub fn build_router(state: AppState, config: &AppConfig) -> Router {
// to avoid ambiguity with path parameter matching // to avoid ambiguity with path parameter matching
.route("/syntheses/generate", post(handlers::generation::trigger_generate)) .route("/syntheses/generate", post(handlers::generation::trigger_generate))
.route("/syntheses/generate/{job_id}/progress", get(handlers::generation::progress_stream)) .route("/syntheses/generate/{job_id}/progress", get(handlers::generation::progress_stream))
.route("/syntheses/generate/{job_id}/stop", post(handlers::generation::stop_generate))
// Article history & provenance routes (authenticated) // Article history & provenance routes (authenticated)
.route("/article-history", get(handlers::article_history::list_history).delete(handlers::article_history::clear_history)) .route("/article-history", get(handlers::article_history::list_history).delete(handlers::article_history::clear_history))
.route("/syntheses/{id}/provenance", get(handlers::article_history::get_provenance)) .route("/syntheses/{id}/provenance", get(handlers::article_history::get_provenance))

@ -8,6 +8,7 @@
//! consumed by SSE endpoints for real-time client updates. //! consumed by SSE endpoints for real-time client updates.
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -74,6 +75,8 @@ struct JobEntry {
user_id: Uuid, user_id: Uuid,
/// When the job was created (for TTL cleanup). /// When the job was created (for TTL cleanup).
created_at: Instant, created_at: Instant,
/// Flag set to true when the user requests cancellation.
cancelled: Arc<AtomicBool>,
} }
/// In-memory store for active generation jobs. /// In-memory store for active generation jobs.
@ -104,11 +107,11 @@ impl JobStore {
} }
} }
/// Create a new job for a user, returning the job ID and the watch Sender. /// Create a new job for a user, returning the job ID, the watch Sender, and a cancellation flag.
/// ///
/// Returns `None` if the user already has an active job. /// Returns `None` if the user already has an active job.
/// Uses an atomic DashSet insert to prevent race conditions on double-click. /// Uses an atomic DashSet insert to prevent race conditions on double-click.
pub fn create_job(&self, user_id: Uuid) -> Option<(Uuid, Arc<watch::Sender<ProgressEvent>>)> { pub fn create_job(&self, user_id: Uuid) -> Option<(Uuid, Arc<watch::Sender<ProgressEvent>>, Arc<AtomicBool>)> {
if !self.generating_users.insert(user_id) { if !self.generating_users.insert(user_id) {
return None; return None;
} }
@ -119,10 +122,12 @@ impl JobStore {
percent: 0, percent: 0,
}); });
let tx = Arc::new(tx); let tx = Arc::new(tx);
let cancelled = Arc::new(AtomicBool::new(false));
self.inner.insert(job_id, JobEntry { self.inner.insert(job_id, JobEntry {
tx: Arc::clone(&tx), _rx: rx, user_id, created_at: Instant::now(), tx: Arc::clone(&tx), _rx: rx, user_id, created_at: Instant::now(),
cancelled: Arc::clone(&cancelled),
}); });
Some((job_id, tx)) Some((job_id, tx, cancelled))
} }
/// Get a watch receiver for a job, if it exists and belongs to the given user. /// Get a watch receiver for a job, if it exists and belongs to the given user.
@ -145,6 +150,17 @@ impl JobStore {
None None
} }
/// Signal a job to stop. Returns true if the job was found and belongs to the user.
pub fn cancel_job(&self, job_id: Uuid, user_id: Uuid) -> bool {
if let Some(entry) = self.inner.get(&job_id) {
if entry.value().user_id == user_id {
entry.value().cancelled.store(true, Ordering::Relaxed);
return true;
}
}
false
}
/// Release the generating lock for a user (called when job completes, errors, or times out). /// Release the generating lock for a user (called when job completes, errors, or times out).
pub fn release_user(&self, user_id: Uuid) { pub fn release_user(&self, user_id: Uuid) {
self.generating_users.remove(&user_id); self.generating_users.remove(&user_id);
@ -196,8 +212,9 @@ pub async fn run_generation(
theme_id: Uuid, theme_id: Uuid,
tx: Arc<watch::Sender<ProgressEvent>>, tx: Arc<watch::Sender<ProgressEvent>>,
provider_override: Option<Arc<dyn crate::services::llm::LlmProvider>>, provider_override: Option<Arc<dyn crate::services::llm::LlmProvider>>,
cancelled: Arc<AtomicBool>,
) { ) {
let result = run_generation_inner(job_id, &state, user_id, theme_id, &tx, provider_override).await; let result = run_generation_inner(job_id, &state, user_id, theme_id, &tx, provider_override, &cancelled).await;
match result { match result {
Ok(synthesis_id) => { Ok(synthesis_id) => {
@ -233,6 +250,7 @@ pub async fn run_generation_inner(
theme_id: Uuid, theme_id: Uuid,
tx: &watch::Sender<ProgressEvent>, tx: &watch::Sender<ProgressEvent>,
provider_override: Option<Arc<dyn crate::services::llm::LlmProvider>>, provider_override: Option<Arc<dyn crate::services::llm::LlmProvider>>,
cancelled: &AtomicBool,
) -> Result<Uuid, AppError> { ) -> Result<Uuid, AppError> {
// Batch buffer for article history traces (flushed at logical boundaries) // Batch buffer for article history traces (flushed at logical boundaries)
let mut pending_traces: Vec<db::article_history::ArticleHistoryEntry> = Vec::new(); let mut pending_traces: Vec<db::article_history::ArticleHistoryEntry> = Vec::new();
@ -309,6 +327,13 @@ pub async fn run_generation_inner(
let total_waves = source_chunks.len(); let total_waves = source_chunks.len();
'wave_loop: for (wave_idx, wave_sources) in source_chunks.iter().enumerate() { 'wave_loop: for (wave_idx, wave_sources) in source_chunks.iter().enumerate() {
// Check cancellation before each wave
if cancelled.load(Ordering::Relaxed) {
tracing::info!(job_id = %job_id, "Generation cancelled by user (before wave)");
emit_progress(tx, "saving", "Generation arretee, sauvegarde...", 90);
break 'wave_loop;
}
let articles_so_far: usize = article_scraped.values().map(|v| v.len()).sum(); let articles_so_far: usize = article_scraped.values().map(|v| v.len()).sum();
let pct = 5 + ((articles_so_far as u32 * 60) / max_total.max(1) as u32).min(60); let pct = 5 + ((articles_so_far as u32 * 60) / max_total.max(1) as u32).min(60);
emit_progress(tx, "sources", emit_progress(tx, "sources",
@ -569,6 +594,13 @@ pub async fn run_generation_inner(
processed += batch.len(); processed += batch.len();
// Check cancellation after each batch
if cancelled.load(Ordering::Relaxed) {
tracing::info!(job_id = %job_id, "Generation cancelled by user (after batch)");
emit_progress(tx, "saving", "Generation arretee, sauvegarde...", 90);
break;
}
// Check if we've reached the maximum after this batch // Check if we've reached the maximum after this batch
let total: usize = article_scraped.values().map(|v| v.len()).sum(); let total: usize = article_scraped.values().map(|v| v.len()).sum();
if total >= max_total { if total >= max_total {
@ -593,13 +625,19 @@ pub async fn run_generation_inner(
} }
// === PHASE 2: Web Search Fallback === // === PHASE 2: Web Search Fallback ===
// Skip Phase 2 if cancelled
let is_cancelled = cancelled.load(Ordering::Relaxed);
if is_cancelled {
tracing::info!(job_id = %job_id, "Skipping Phase 2 — generation cancelled by user");
}
let category_gaps: Vec<(String, i32)> = user_categories.iter().filter_map(|cat| { let category_gaps: Vec<(String, i32)> = user_categories.iter().filter_map(|cat| {
let filled = filled_counts.get(cat).copied().unwrap_or(0); let filled = filled_counts.get(cat).copied().unwrap_or(0);
let needed = (theme.max_items_per_category as usize).saturating_sub(filled); let needed = (theme.max_items_per_category as usize).saturating_sub(filled);
if needed > 0 { Some((cat.clone(), needed as i32)) } else { None } if needed > 0 { Some((cat.clone(), needed as i32)) } else { None }
}).collect(); }).collect();
if !category_gaps.is_empty() { if !category_gaps.is_empty() && !is_cancelled {
if settings.use_brave_search { if settings.use_brave_search {
// === BRAVE SEARCH PATH === // === BRAVE SEARCH PATH ===
emit_progress(tx, "websearch", "Recherche Brave Search...", 70); emit_progress(tx, "websearch", "Recherche Brave Search...", 70);
@ -907,11 +945,21 @@ pub async fn run_generation_inner(
} }
// === SAVE === // === SAVE ===
if article_scraped.values().all(|items| items.is_empty()) { let is_cancelled = cancelled.load(Ordering::Relaxed);
let has_articles = article_scraped.values().any(|items| !items.is_empty());
if !has_articles {
if is_cancelled {
return Err(AppError::BadRequest("Generation arretee. Aucun article n'avait encore ete collecte.".into()));
}
return Err(AppError::BadRequest("Aucun article valide trouve. Verifiez vos sources et categories.".into())); return Err(AppError::BadRequest("Aucun article valide trouve. Verifiez vos sources et categories.".into()));
} }
emit_progress(tx, "saving", "Sauvegarde de la synthese...", 90); if is_cancelled {
emit_progress(tx, "saving", "Generation arretee, sauvegarde des articles collectes...", 90);
} else {
emit_progress(tx, "saving", "Sauvegarde de la synthese...", 90);
}
let mut final_sections: Vec<NewsSection> = Vec::new(); let mut final_sections: Vec<NewsSection> = Vec::new();
for (i, cat_name) in user_categories.iter().enumerate() { for (i, cat_name) in user_categories.iter().enumerate() {
@ -1483,7 +1531,7 @@ mod tests {
let store = JobStore::new(); let store = JobStore::new();
let user_id = Uuid::new_v4(); let user_id = Uuid::new_v4();
let (job_id, tx) = store.create_job(user_id).unwrap(); let (job_id, tx, _cancelled) = store.create_job(user_id).unwrap();
assert_eq!(store.len(), 1); assert_eq!(store.len(), 1);
// Subscribe // Subscribe
@ -1506,8 +1554,8 @@ mod tests {
let store = JobStore::new(); let store = JobStore::new();
let user_id = Uuid::new_v4(); let user_id = Uuid::new_v4();
let result1 = store.create_job(user_id); let _result1 = store.create_job(user_id);
assert!(result1.is_some()); assert!(_result1.is_some());
// Second job for same user should fail // Second job for same user should fail
let result2 = store.create_job(user_id); let result2 = store.create_job(user_id);
@ -1524,7 +1572,7 @@ mod tests {
let store = JobStore::new(); let store = JobStore::new();
let user_id = Uuid::new_v4(); let user_id = Uuid::new_v4();
let (_job_id, tx) = store.create_job(user_id).unwrap(); let (_job_id, tx, _cancelled) = store.create_job(user_id).unwrap();
// Complete the job and release the user lock (as the pipeline does) // Complete the job and release the user lock (as the pipeline does)
tx.send(ProgressEvent::Complete { tx.send(ProgressEvent::Complete {
@ -1543,7 +1591,7 @@ mod tests {
let store = JobStore::new(); let store = JobStore::new();
let user_id = Uuid::new_v4(); let user_id = Uuid::new_v4();
let (_job_id, tx) = store.create_job(user_id).unwrap(); let (_job_id, tx, _cancelled) = store.create_job(user_id).unwrap();
// Fail the job and release the user lock (as the pipeline does) // Fail the job and release the user lock (as the pipeline does)
tx.send(ProgressEvent::Error { tx.send(ProgressEvent::Error {
@ -1563,7 +1611,7 @@ mod tests {
let user_id = Uuid::new_v4(); let user_id = Uuid::new_v4();
// Create a job and manually set its created_at to the past // Create a job and manually set its created_at to the past
let (_job_id, _tx) = store.create_job(user_id).unwrap(); let (_job_id, _tx, _cancelled) = store.create_job(user_id).unwrap();
assert_eq!(store.len(), 1); assert_eq!(store.len(), 1);
// Cleanup should not remove recent jobs // Cleanup should not remove recent jobs
@ -1576,7 +1624,7 @@ mod tests {
let store = JobStore::new(); let store = JobStore::new();
let user_id = Uuid::new_v4(); let user_id = Uuid::new_v4();
let (job_id, _tx) = store.create_job(user_id).unwrap(); let (job_id, _tx, _cancelled) = store.create_job(user_id).unwrap();
assert_eq!(store.len(), 1); assert_eq!(store.len(), 1);
store.remove(&job_id); store.remove(&job_id);

@ -2,6 +2,7 @@ mod common;
use ai_synth_backend::services::llm::mock::MockLlmProvider; use ai_synth_backend::services::llm::mock::MockLlmProvider;
use ai_synth_backend::services::synthesis; use ai_synth_backend::services::synthesis;
use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::watch; use tokio::sync::watch;
use wiremock::matchers::{method, path}; use wiremock::matchers::{method, path};
@ -117,7 +118,7 @@ async fn phase1_heuristic_extraction_classifies_articles() {
); );
let result = synthesis::run_generation_inner( let result = synthesis::run_generation_inner(
job_id, &state, user_id, theme_id, &tx, Some(mock_provider), job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false),
).await; ).await;
assert!(result.is_ok(), "Generation should succeed: {:?}", result.err()); assert!(result.is_ok(), "Generation should succeed: {:?}", result.err());
@ -177,7 +178,7 @@ async fn phase2_search_fills_gaps_when_no_sources() {
); );
let result = synthesis::run_generation_inner( let result = synthesis::run_generation_inner(
job_id, &state, user_id, theme_id, &tx, Some(mock_provider), job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false),
).await; ).await;
assert!(result.is_ok(), "Generation should succeed: {:?}", result.err()); assert!(result.is_ok(), "Generation should succeed: {:?}", result.err());
@ -220,7 +221,7 @@ async fn category_overflow_spills_to_autre() {
); );
let result = synthesis::run_generation_inner( let result = synthesis::run_generation_inner(
job_id, &state, user_id, theme_id, &tx, Some(mock_provider), job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false),
).await; ).await;
assert!(result.is_ok(), "Generation should succeed"); assert!(result.is_ok(), "Generation should succeed");
@ -334,7 +335,7 @@ async fn source_diversity_limits_articles_per_source() {
); );
let result = synthesis::run_generation_inner( let result = synthesis::run_generation_inner(
job_id, &state, user_id, theme_id, &tx, Some(mock_provider), job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false),
) )
.await; .await;
@ -448,7 +449,7 @@ async fn article_history_dedup_prevents_repeat_articles() {
); );
let result1 = synthesis::run_generation_inner( let result1 = synthesis::run_generation_inner(
job1, &state1, user_id, theme_id, &tx1, Some(mock1), job1, &state1, user_id, theme_id, &tx1, Some(mock1), &AtomicBool::new(false),
) )
.await; .await;
assert!(result1.is_ok(), "First generation should succeed: {:?}", result1.err()); assert!(result1.is_ok(), "First generation should succeed: {:?}", result1.err());
@ -479,7 +480,7 @@ async fn article_history_dedup_prevents_repeat_articles() {
// The second run scrapes the same URLs, which are already in article_history. // The second run scrapes the same URLs, which are already in article_history.
// They should be filtered out as "filtered_history". // They should be filtered out as "filtered_history".
let _result2 = synthesis::run_generation_inner( let _result2 = synthesis::run_generation_inner(
job2, &state2, user_id, theme_id, &tx2, Some(mock2), job2, &state2, user_id, theme_id, &tx2, Some(mock2), &AtomicBool::new(false),
) )
.await; .await;
// The second run may succeed (empty synthesis) or fail (no valid articles). // The second run may succeed (empty synthesis) or fail (no valid articles).

@ -88,6 +88,8 @@ const fr = {
'generate.noWebSearchWarning': 'generate.noWebSearchWarning':
'Le fournisseur selectionne ne supporte pas la recherche web. Les resultats seront bases uniquement sur les connaissances du modele.', 'Le fournisseur selectionne ne supporte pas la recherche web. Les resultats seront bases uniquement sur les connaissances du modele.',
'generate.selectTheme': 'Theme a generer', 'generate.selectTheme': 'Theme a generer',
'generate.stop': 'Arreter la generation',
'generate.stopped': 'Generation arretee. Les articles collectes ont ete sauvegardes.',
'generate.noThemes': 'Aucun theme configure. Creez un theme pour pouvoir generer une synthese.', 'generate.noThemes': 'Aucun theme configure. Creez un theme pour pouvoir generer une synthese.',
'generate.createThemeLink': 'Creer un theme', 'generate.createThemeLink': 'Creer un theme',

@ -60,6 +60,7 @@ const GenerateSynthesis: Component = () => {
const [error, setError] = createSignal<string | null>(null); const [error, setError] = createSignal<string | null>(null);
const [success, setSuccess] = createSignal(false); const [success, setSuccess] = createSignal(false);
const [sseConnection, setSSEConnection] = createSignal<SSEConnection | null>(null); const [sseConnection, setSSEConnection] = createSignal<SSEConnection | null>(null);
const [jobId, setJobId] = createSignal<string | null>(null);
onMount(async () => { onMount(async () => {
try { try {
@ -219,6 +220,7 @@ const GenerateSynthesis: Component = () => {
try { try {
const response = await synthesesApi.generate(selectedThemeId()); const response = await synthesesApi.generate(selectedThemeId());
setJobId(response.job_id);
const url = synthesesApi.progressUrl(response.job_id); const url = synthesesApi.progressUrl(response.job_id);
const conn = createSSEConnection(url); const conn = createSSEConnection(url);
setSSEConnection(conn); setSSEConnection(conn);
@ -232,6 +234,20 @@ const GenerateSynthesis: Component = () => {
} }
}; };
const handleStop = async () => {
const id = jobId();
if (!id) return;
try {
await fetch(`/api/v1/syntheses/generate/${id}/stop`, {
method: 'POST',
headers: { 'X-Requested-With': 'XMLHttpRequest' },
credentials: 'same-origin',
});
} catch {
// ignore errors — the pipeline will stop on its own
}
};
const handleRetry = () => { const handleRetry = () => {
// Close existing connection // Close existing connection
const conn = sseConnection(); const conn = sseConnection();
@ -239,6 +255,7 @@ const GenerateSynthesis: Component = () => {
conn.close(); conn.close();
} }
setSSEConnection(null); setSSEConnection(null);
setJobId(null);
setError(null); setError(null);
setSuccess(false); setSuccess(false);
setGenerating(false); setGenerating(false);
@ -403,6 +420,15 @@ const GenerateSynthesis: Component = () => {
<p class="text-xs text-gray-400 mt-4 italic"> <p class="text-xs text-gray-400 mt-4 italic">
{t('generate.canLeave')} {t('generate.canLeave')}
</p> </p>
{/* Stop generation button */}
<button
type="button"
onClick={handleStop}
class="mt-4 px-4 py-2 text-sm font-medium text-red-700 bg-red-50 border border-red-300 rounded-md hover:bg-red-100"
>
{t('generate.stop')}
</button>
</div> </div>
</Show> </Show>

Loading…
Cancel
Save