You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

184 lines
6.1 KiB
Rust

//! Generation handlers: trigger generation and stream progress via SSE.
//!
//! - `POST /api/v1/syntheses/generate` — start async generation
//! - `GET /api/v1/syntheses/generate/:job_id/progress` — SSE progress stream
use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::Json;
use serde::Serialize;
use tokio_stream::wrappers::WatchStream;
use tokio_stream::StreamExt;
use uuid::Uuid;
use serde::Deserialize;
use crate::app_state::AppState;
use crate::errors::AppError;
use crate::middleware::auth::AuthUser;
use crate::services::job_store::ProgressEvent;
use crate::services::synthesis;
/// Response body for `POST /api/v1/syntheses/generate`.
#[derive(Debug, Serialize)]
pub struct GenerateResponse {
pub job_id: Uuid,
pub message: String,
}
/// Request body for `POST /api/v1/syntheses/generate`.
#[derive(Debug, Deserialize)]
pub struct GenerateRequest {
pub theme_id: Uuid,
}
/// `POST /api/v1/syntheses/generate`
///
/// Triggers an asynchronous synthesis generation. Returns immediately
/// with a 202 Accepted status and a `job_id` that can be used to
/// subscribe to progress events via SSE.
///
/// Rejects the request if the user already has a generation in progress.
pub async fn trigger_generate(
auth_user: AuthUser,
State(state): State<AppState>,
Json(body): Json<GenerateRequest>,
) -> Result<impl IntoResponse, AppError> {
// Check if user already has an active job
if let Some(existing_job_id) = state.job_store.has_active_job(auth_user.id) {
tracing::warn!(
user_id = %auth_user.id,
existing_job_id = %existing_job_id,
"User tried to start generation while one is already in progress"
);
return Err(AppError::BadRequest(
"Une generation est deja en cours. Veuillez attendre qu'elle se termine.".into(),
));
}
// Create the job in the store
let (job_id, tx, cancelled) = state
.job_store
.create_job(auth_user.id)
.ok_or_else(|| {
AppError::BadRequest(
"Une generation est deja en cours. Veuillez attendre qu'elle se termine.".into(),
)
})?;
tracing::info!(
user_id = %auth_user.id,
job_id = %job_id,
"Starting synthesis generation"
);
// Spawn the generation pipeline as a background task
let state_clone = state.clone();
let user_id = auth_user.id;
let theme_id = body.theme_id;
let tx_for_panic = Arc::clone(&tx);
let state_for_panic = state.clone();
let join_handle = tokio::spawn(async move {
let timeout_duration = std::time::Duration::from_secs(900);
match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, theme_id, tx.clone(), None, cancelled)).await {
Ok(()) => {}
Err(_) => {
tracing::error!(job_id = %job_id, user_id = %user_id, "Generation timed out after 15 minutes");
let _ = tx.send(ProgressEvent::Error {
message: "La generation a depasse le delai maximum de 15 minutes.".into(),
});
}
}
state_clone.job_store.release_user(user_id);
});
tokio::spawn(async move {
if let Err(e) = join_handle.await {
tracing::error!(job_id = %job_id, error = %e, "Generation task panicked");
let _ = tx_for_panic.send(ProgressEvent::Error {
message: "Erreur interne lors de la generation.".into(),
});
state_for_panic.job_store.release_user(user_id);
}
});
Ok((
StatusCode::ACCEPTED,
Json(GenerateResponse {
job_id,
message: "Generation demarree.".into(),
}),
))
}
/// `POST /api/v1/syntheses/generate/:job_id/stop`
///
/// Signals a running generation job to stop. The pipeline will save
/// whatever articles have been collected so far and complete gracefully.
pub async fn stop_generate(
auth_user: AuthUser,
State(state): State<AppState>,
Path(job_id): Path<Uuid>,
) -> Result<impl IntoResponse, AppError> {
let cancelled = state.job_store.cancel_job(job_id, auth_user.id);
if !cancelled {
return Err(AppError::NotFound(
"Generation introuvable.".into(),
));
}
Ok(StatusCode::OK)
}
/// `GET /api/v1/syntheses/generate/:job_id/progress`
///
/// Server-Sent Events (SSE) endpoint that streams generation progress.
///
/// Event types:
/// - `progress`: `{type: "progress", step: "...", message: "...", percent: N}`
/// - `complete`: `{type: "complete", synthesis_id: "..."}`
/// - `error`: `{type: "error", message: "..."}`
///
/// The stream includes a keepalive ping every 15 seconds to prevent
/// connection timeouts through reverse proxies.
pub async fn progress_stream(
auth_user: AuthUser,
State(state): State<AppState>,
Path(job_id): Path<Uuid>,
) -> Result<Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>>>, AppError> {
// Get the watch receiver, verifying ownership
let rx = state
.job_store
.subscribe(job_id, auth_user.id)
.ok_or_else(|| {
AppError::NotFound("Generation introuvable ou deja terminee.".into())
})?;
// Convert the watch stream to an SSE event stream.
// The watch channel immediately delivers the latest value on subscribe,
// so clients that reconnect get caught up instantly.
let stream = WatchStream::new(rx).map(|event| {
let event_type = match &event {
ProgressEvent::Progress { .. } => "progress",
ProgressEvent::Complete { .. } => "complete",
ProgressEvent::Error { .. } => "error",
};
let data = serde_json::to_string(&event).unwrap_or_default();
Ok(Event::default().event(event_type).data(data))
});
Ok(Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("ping"),
))
}