feat: add stop generation button — saves partial synthesis on cancel

Adds Arc<AtomicBool> cancellation flag to JobStore/JobEntry. The pipeline
checks the flag before each wave and after each batch, then saves whatever
articles have been collected. A new POST /syntheses/generate/:job_id/stop
endpoint sets the flag. The frontend shows a red stop button during generation
and POSTs to the stop endpoint on click.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
master
oabrivard 3 months ago
parent 7ba1da4d92
commit 6f3e6883c9

@ -62,7 +62,7 @@ pub async fn trigger_generate(
}
// Create the job in the store
let (job_id, tx) = state
let (job_id, tx, cancelled) = state
.job_store
.create_job(auth_user.id)
.ok_or_else(|| {
@ -86,7 +86,7 @@ pub async fn trigger_generate(
let join_handle = tokio::spawn(async move {
let timeout_duration = std::time::Duration::from_secs(900);
match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, theme_id, tx.clone(), None)).await {
match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, theme_id, tx.clone(), None, cancelled)).await {
Ok(()) => {}
Err(_) => {
tracing::error!(job_id = %job_id, user_id = %user_id, "Generation timed out after 15 minutes");
@ -117,6 +117,24 @@ pub async fn trigger_generate(
))
}
/// `POST /api/v1/syntheses/generate/:job_id/stop`
///
/// Signals a running generation job to stop. The pipeline will save
/// whatever articles have been collected so far and complete gracefully.
pub async fn stop_generate(
auth_user: AuthUser,
State(state): State<AppState>,
Path(job_id): Path<Uuid>,
) -> Result<impl IntoResponse, AppError> {
let cancelled = state.job_store.cancel_job(job_id, auth_user.id);
if !cancelled {
return Err(AppError::NotFound(
"Generation introuvable.".into(),
));
}
Ok(StatusCode::OK)
}
/// `GET /api/v1/syntheses/generate/:job_id/progress`
///
/// Server-Sent Events (SSE) endpoint that streams generation progress.

@ -59,6 +59,7 @@ pub fn build_router(state: AppState, config: &AppConfig) -> Router {
// to avoid ambiguity with path parameter matching
.route("/syntheses/generate", post(handlers::generation::trigger_generate))
.route("/syntheses/generate/{job_id}/progress", get(handlers::generation::progress_stream))
.route("/syntheses/generate/{job_id}/stop", post(handlers::generation::stop_generate))
// Article history & provenance routes (authenticated)
.route("/article-history", get(handlers::article_history::list_history).delete(handlers::article_history::clear_history))
.route("/syntheses/{id}/provenance", get(handlers::article_history::get_provenance))

@ -8,6 +8,7 @@
//! consumed by SSE endpoints for real-time client updates.
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
@ -74,6 +75,8 @@ struct JobEntry {
user_id: Uuid,
/// When the job was created (for TTL cleanup).
created_at: Instant,
/// Flag set to true when the user requests cancellation.
cancelled: Arc<AtomicBool>,
}
/// In-memory store for active generation jobs.
@ -104,11 +107,11 @@ impl JobStore {
}
}
/// Create a new job for a user, returning the job ID and the watch Sender.
/// Create a new job for a user, returning the job ID, the watch Sender, and a cancellation flag.
///
/// Returns `None` if the user already has an active job.
/// Uses an atomic DashSet insert to prevent race conditions on double-click.
pub fn create_job(&self, user_id: Uuid) -> Option<(Uuid, Arc<watch::Sender<ProgressEvent>>)> {
pub fn create_job(&self, user_id: Uuid) -> Option<(Uuid, Arc<watch::Sender<ProgressEvent>>, Arc<AtomicBool>)> {
if !self.generating_users.insert(user_id) {
return None;
}
@ -119,10 +122,12 @@ impl JobStore {
percent: 0,
});
let tx = Arc::new(tx);
let cancelled = Arc::new(AtomicBool::new(false));
self.inner.insert(job_id, JobEntry {
tx: Arc::clone(&tx), _rx: rx, user_id, created_at: Instant::now(),
cancelled: Arc::clone(&cancelled),
});
Some((job_id, tx))
Some((job_id, tx, cancelled))
}
/// Get a watch receiver for a job, if it exists and belongs to the given user.
@ -145,6 +150,17 @@ impl JobStore {
None
}
/// Signal a job to stop. Returns true if the job was found and belongs to the user.
pub fn cancel_job(&self, job_id: Uuid, user_id: Uuid) -> bool {
if let Some(entry) = self.inner.get(&job_id) {
if entry.value().user_id == user_id {
entry.value().cancelled.store(true, Ordering::Relaxed);
return true;
}
}
false
}
/// Release the generating lock for a user (called when job completes, errors, or times out).
pub fn release_user(&self, user_id: Uuid) {
self.generating_users.remove(&user_id);
@ -196,8 +212,9 @@ pub async fn run_generation(
theme_id: Uuid,
tx: Arc<watch::Sender<ProgressEvent>>,
provider_override: Option<Arc<dyn crate::services::llm::LlmProvider>>,
cancelled: Arc<AtomicBool>,
) {
let result = run_generation_inner(job_id, &state, user_id, theme_id, &tx, provider_override).await;
let result = run_generation_inner(job_id, &state, user_id, theme_id, &tx, provider_override, &cancelled).await;
match result {
Ok(synthesis_id) => {
@ -233,6 +250,7 @@ pub async fn run_generation_inner(
theme_id: Uuid,
tx: &watch::Sender<ProgressEvent>,
provider_override: Option<Arc<dyn crate::services::llm::LlmProvider>>,
cancelled: &AtomicBool,
) -> Result<Uuid, AppError> {
// Batch buffer for article history traces (flushed at logical boundaries)
let mut pending_traces: Vec<db::article_history::ArticleHistoryEntry> = Vec::new();
@ -309,6 +327,13 @@ pub async fn run_generation_inner(
let total_waves = source_chunks.len();
'wave_loop: for (wave_idx, wave_sources) in source_chunks.iter().enumerate() {
// Check cancellation before each wave
if cancelled.load(Ordering::Relaxed) {
tracing::info!(job_id = %job_id, "Generation cancelled by user (before wave)");
emit_progress(tx, "saving", "Generation arretee, sauvegarde...", 90);
break 'wave_loop;
}
let articles_so_far: usize = article_scraped.values().map(|v| v.len()).sum();
let pct = 5 + ((articles_so_far as u32 * 60) / max_total.max(1) as u32).min(60);
emit_progress(tx, "sources",
@ -569,6 +594,13 @@ pub async fn run_generation_inner(
processed += batch.len();
// Check cancellation after each batch
if cancelled.load(Ordering::Relaxed) {
tracing::info!(job_id = %job_id, "Generation cancelled by user (after batch)");
emit_progress(tx, "saving", "Generation arretee, sauvegarde...", 90);
break;
}
// Check if we've reached the maximum after this batch
let total: usize = article_scraped.values().map(|v| v.len()).sum();
if total >= max_total {
@ -593,13 +625,19 @@ pub async fn run_generation_inner(
}
// === PHASE 2: Web Search Fallback ===
// Skip Phase 2 if cancelled
let is_cancelled = cancelled.load(Ordering::Relaxed);
if is_cancelled {
tracing::info!(job_id = %job_id, "Skipping Phase 2 — generation cancelled by user");
}
let category_gaps: Vec<(String, i32)> = user_categories.iter().filter_map(|cat| {
let filled = filled_counts.get(cat).copied().unwrap_or(0);
let needed = (theme.max_items_per_category as usize).saturating_sub(filled);
if needed > 0 { Some((cat.clone(), needed as i32)) } else { None }
}).collect();
if !category_gaps.is_empty() {
if !category_gaps.is_empty() && !is_cancelled {
if settings.use_brave_search {
// === BRAVE SEARCH PATH ===
emit_progress(tx, "websearch", "Recherche Brave Search...", 70);
@ -907,11 +945,21 @@ pub async fn run_generation_inner(
}
// === SAVE ===
if article_scraped.values().all(|items| items.is_empty()) {
let is_cancelled = cancelled.load(Ordering::Relaxed);
let has_articles = article_scraped.values().any(|items| !items.is_empty());
if !has_articles {
if is_cancelled {
return Err(AppError::BadRequest("Generation arretee. Aucun article n'avait encore ete collecte.".into()));
}
return Err(AppError::BadRequest("Aucun article valide trouve. Verifiez vos sources et categories.".into()));
}
emit_progress(tx, "saving", "Sauvegarde de la synthese...", 90);
if is_cancelled {
emit_progress(tx, "saving", "Generation arretee, sauvegarde des articles collectes...", 90);
} else {
emit_progress(tx, "saving", "Sauvegarde de la synthese...", 90);
}
let mut final_sections: Vec<NewsSection> = Vec::new();
for (i, cat_name) in user_categories.iter().enumerate() {
@ -1483,7 +1531,7 @@ mod tests {
let store = JobStore::new();
let user_id = Uuid::new_v4();
let (job_id, tx) = store.create_job(user_id).unwrap();
let (job_id, tx, _cancelled) = store.create_job(user_id).unwrap();
assert_eq!(store.len(), 1);
// Subscribe
@ -1506,8 +1554,8 @@ mod tests {
let store = JobStore::new();
let user_id = Uuid::new_v4();
let result1 = store.create_job(user_id);
assert!(result1.is_some());
let _result1 = store.create_job(user_id);
assert!(_result1.is_some());
// Second job for same user should fail
let result2 = store.create_job(user_id);
@ -1524,7 +1572,7 @@ mod tests {
let store = JobStore::new();
let user_id = Uuid::new_v4();
let (_job_id, tx) = store.create_job(user_id).unwrap();
let (_job_id, tx, _cancelled) = store.create_job(user_id).unwrap();
// Complete the job and release the user lock (as the pipeline does)
tx.send(ProgressEvent::Complete {
@ -1543,7 +1591,7 @@ mod tests {
let store = JobStore::new();
let user_id = Uuid::new_v4();
let (_job_id, tx) = store.create_job(user_id).unwrap();
let (_job_id, tx, _cancelled) = store.create_job(user_id).unwrap();
// Fail the job and release the user lock (as the pipeline does)
tx.send(ProgressEvent::Error {
@ -1563,7 +1611,7 @@ mod tests {
let user_id = Uuid::new_v4();
// Create a job and manually set its created_at to the past
let (_job_id, _tx) = store.create_job(user_id).unwrap();
let (_job_id, _tx, _cancelled) = store.create_job(user_id).unwrap();
assert_eq!(store.len(), 1);
// Cleanup should not remove recent jobs
@ -1576,7 +1624,7 @@ mod tests {
let store = JobStore::new();
let user_id = Uuid::new_v4();
let (job_id, _tx) = store.create_job(user_id).unwrap();
let (job_id, _tx, _cancelled) = store.create_job(user_id).unwrap();
assert_eq!(store.len(), 1);
store.remove(&job_id);

@ -2,6 +2,7 @@ mod common;
use ai_synth_backend::services::llm::mock::MockLlmProvider;
use ai_synth_backend::services::synthesis;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use tokio::sync::watch;
use wiremock::matchers::{method, path};
@ -117,7 +118,7 @@ async fn phase1_heuristic_extraction_classifies_articles() {
);
let result = synthesis::run_generation_inner(
job_id, &state, user_id, theme_id, &tx, Some(mock_provider),
job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false),
).await;
assert!(result.is_ok(), "Generation should succeed: {:?}", result.err());
@ -177,7 +178,7 @@ async fn phase2_search_fills_gaps_when_no_sources() {
);
let result = synthesis::run_generation_inner(
job_id, &state, user_id, theme_id, &tx, Some(mock_provider),
job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false),
).await;
assert!(result.is_ok(), "Generation should succeed: {:?}", result.err());
@ -220,7 +221,7 @@ async fn category_overflow_spills_to_autre() {
);
let result = synthesis::run_generation_inner(
job_id, &state, user_id, theme_id, &tx, Some(mock_provider),
job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false),
).await;
assert!(result.is_ok(), "Generation should succeed");
@ -334,7 +335,7 @@ async fn source_diversity_limits_articles_per_source() {
);
let result = synthesis::run_generation_inner(
job_id, &state, user_id, theme_id, &tx, Some(mock_provider),
job_id, &state, user_id, theme_id, &tx, Some(mock_provider), &AtomicBool::new(false),
)
.await;
@ -448,7 +449,7 @@ async fn article_history_dedup_prevents_repeat_articles() {
);
let result1 = synthesis::run_generation_inner(
job1, &state1, user_id, theme_id, &tx1, Some(mock1),
job1, &state1, user_id, theme_id, &tx1, Some(mock1), &AtomicBool::new(false),
)
.await;
assert!(result1.is_ok(), "First generation should succeed: {:?}", result1.err());
@ -479,7 +480,7 @@ async fn article_history_dedup_prevents_repeat_articles() {
// The second run scrapes the same URLs, which are already in article_history.
// They should be filtered out as "filtered_history".
let _result2 = synthesis::run_generation_inner(
job2, &state2, user_id, theme_id, &tx2, Some(mock2),
job2, &state2, user_id, theme_id, &tx2, Some(mock2), &AtomicBool::new(false),
)
.await;
// The second run may succeed (empty synthesis) or fail (no valid articles).

@ -88,6 +88,8 @@ const fr = {
'generate.noWebSearchWarning':
'Le fournisseur selectionne ne supporte pas la recherche web. Les resultats seront bases uniquement sur les connaissances du modele.',
'generate.selectTheme': 'Theme a generer',
'generate.stop': 'Arreter la generation',
'generate.stopped': 'Generation arretee. Les articles collectes ont ete sauvegardes.',
'generate.noThemes': 'Aucun theme configure. Creez un theme pour pouvoir generer une synthese.',
'generate.createThemeLink': 'Creer un theme',

@ -60,6 +60,7 @@ const GenerateSynthesis: Component = () => {
const [error, setError] = createSignal<string | null>(null);
const [success, setSuccess] = createSignal(false);
const [sseConnection, setSSEConnection] = createSignal<SSEConnection | null>(null);
const [jobId, setJobId] = createSignal<string | null>(null);
onMount(async () => {
try {
@ -219,6 +220,7 @@ const GenerateSynthesis: Component = () => {
try {
const response = await synthesesApi.generate(selectedThemeId());
setJobId(response.job_id);
const url = synthesesApi.progressUrl(response.job_id);
const conn = createSSEConnection(url);
setSSEConnection(conn);
@ -232,6 +234,20 @@ const GenerateSynthesis: Component = () => {
}
};
const handleStop = async () => {
const id = jobId();
if (!id) return;
try {
await fetch(`/api/v1/syntheses/generate/${id}/stop`, {
method: 'POST',
headers: { 'X-Requested-With': 'XMLHttpRequest' },
credentials: 'same-origin',
});
} catch {
// ignore errors — the pipeline will stop on its own
}
};
const handleRetry = () => {
// Close existing connection
const conn = sseConnection();
@ -239,6 +255,7 @@ const GenerateSynthesis: Component = () => {
conn.close();
}
setSSEConnection(null);
setJobId(null);
setError(null);
setSuccess(false);
setGenerating(false);
@ -403,6 +420,15 @@ const GenerateSynthesis: Component = () => {
<p class="text-xs text-gray-400 mt-4 italic">
{t('generate.canLeave')}
</p>
{/* Stop generation button */}
<button
type="button"
onClick={handleStop}
class="mt-4 px-4 py-2 text-sm font-medium text-red-700 bg-red-50 border border-red-300 rounded-md hover:bg-red-100"
>
{t('generate.stop')}
</button>
</div>
</Show>

Loading…
Cancel
Save