You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
184 lines
6.1 KiB
Rust
184 lines
6.1 KiB
Rust
//! Generation handlers: trigger generation and stream progress via SSE.
|
|
//!
|
|
//! - `POST /api/v1/syntheses/generate` — start async generation
|
|
//! - `GET /api/v1/syntheses/generate/:job_id/progress` — SSE progress stream
|
|
|
|
use std::convert::Infallible;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use axum::extract::{Path, State};
|
|
use axum::http::StatusCode;
|
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
|
use axum::response::IntoResponse;
|
|
use axum::Json;
|
|
use serde::Serialize;
|
|
use tokio_stream::wrappers::WatchStream;
|
|
use tokio_stream::StreamExt;
|
|
use uuid::Uuid;
|
|
|
|
use serde::Deserialize;
|
|
|
|
use crate::app_state::AppState;
|
|
use crate::errors::AppError;
|
|
use crate::middleware::auth::AuthUser;
|
|
use crate::services::job_store::ProgressEvent;
|
|
use crate::services::synthesis;
|
|
|
|
/// Response body for `POST /api/v1/syntheses/generate`.
|
|
#[derive(Debug, Serialize)]
|
|
pub struct GenerateResponse {
|
|
pub job_id: Uuid,
|
|
pub message: String,
|
|
}
|
|
|
|
/// Request body for `POST /api/v1/syntheses/generate`.
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct GenerateRequest {
|
|
pub theme_id: Uuid,
|
|
}
|
|
|
|
/// `POST /api/v1/syntheses/generate`
|
|
///
|
|
/// Triggers an asynchronous synthesis generation. Returns immediately
|
|
/// with a 202 Accepted status and a `job_id` that can be used to
|
|
/// subscribe to progress events via SSE.
|
|
///
|
|
/// Rejects the request if the user already has a generation in progress.
|
|
pub async fn trigger_generate(
|
|
auth_user: AuthUser,
|
|
State(state): State<AppState>,
|
|
Json(body): Json<GenerateRequest>,
|
|
) -> Result<impl IntoResponse, AppError> {
|
|
// Check if user already has an active job
|
|
if let Some(existing_job_id) = state.job_store.has_active_job(auth_user.id) {
|
|
tracing::warn!(
|
|
user_id = %auth_user.id,
|
|
existing_job_id = %existing_job_id,
|
|
"User tried to start generation while one is already in progress"
|
|
);
|
|
return Err(AppError::BadRequest(
|
|
"Une generation est deja en cours. Veuillez attendre qu'elle se termine.".into(),
|
|
));
|
|
}
|
|
|
|
// Create the job in the store
|
|
let (job_id, tx, cancelled) = state
|
|
.job_store
|
|
.create_job(auth_user.id)
|
|
.ok_or_else(|| {
|
|
AppError::BadRequest(
|
|
"Une generation est deja en cours. Veuillez attendre qu'elle se termine.".into(),
|
|
)
|
|
})?;
|
|
|
|
tracing::info!(
|
|
user_id = %auth_user.id,
|
|
job_id = %job_id,
|
|
"Starting synthesis generation"
|
|
);
|
|
|
|
// Spawn the generation pipeline as a background task
|
|
let state_clone = state.clone();
|
|
let user_id = auth_user.id;
|
|
let theme_id = body.theme_id;
|
|
let tx_for_panic = Arc::clone(&tx);
|
|
let state_for_panic = state.clone();
|
|
|
|
let join_handle = tokio::spawn(async move {
|
|
let timeout_duration = std::time::Duration::from_secs(900);
|
|
match tokio::time::timeout(timeout_duration, synthesis::run_generation(job_id, state_clone.clone(), user_id, theme_id, tx.clone(), None, cancelled)).await {
|
|
Ok(()) => {}
|
|
Err(_) => {
|
|
tracing::error!(job_id = %job_id, user_id = %user_id, "Generation timed out after 15 minutes");
|
|
let _ = tx.send(ProgressEvent::Error {
|
|
message: "La generation a depasse le delai maximum de 15 minutes.".into(),
|
|
});
|
|
}
|
|
}
|
|
state_clone.job_store.release_user(user_id);
|
|
});
|
|
|
|
tokio::spawn(async move {
|
|
if let Err(e) = join_handle.await {
|
|
tracing::error!(job_id = %job_id, error = %e, "Generation task panicked");
|
|
let _ = tx_for_panic.send(ProgressEvent::Error {
|
|
message: "Erreur interne lors de la generation.".into(),
|
|
});
|
|
state_for_panic.job_store.release_user(user_id);
|
|
}
|
|
});
|
|
|
|
Ok((
|
|
StatusCode::ACCEPTED,
|
|
Json(GenerateResponse {
|
|
job_id,
|
|
message: "Generation demarree.".into(),
|
|
}),
|
|
))
|
|
}
|
|
|
|
/// `POST /api/v1/syntheses/generate/:job_id/stop`
|
|
///
|
|
/// Signals a running generation job to stop. The pipeline will save
|
|
/// whatever articles have been collected so far and complete gracefully.
|
|
pub async fn stop_generate(
|
|
auth_user: AuthUser,
|
|
State(state): State<AppState>,
|
|
Path(job_id): Path<Uuid>,
|
|
) -> Result<impl IntoResponse, AppError> {
|
|
let cancelled = state.job_store.cancel_job(job_id, auth_user.id);
|
|
if !cancelled {
|
|
return Err(AppError::NotFound(
|
|
"Generation introuvable.".into(),
|
|
));
|
|
}
|
|
Ok(StatusCode::OK)
|
|
}
|
|
|
|
/// `GET /api/v1/syntheses/generate/:job_id/progress`
|
|
///
|
|
/// Server-Sent Events (SSE) endpoint that streams generation progress.
|
|
///
|
|
/// Event types:
|
|
/// - `progress`: `{type: "progress", step: "...", message: "...", percent: N}`
|
|
/// - `complete`: `{type: "complete", synthesis_id: "..."}`
|
|
/// - `error`: `{type: "error", message: "..."}`
|
|
///
|
|
/// The stream includes a keepalive ping every 15 seconds to prevent
|
|
/// connection timeouts through reverse proxies.
|
|
pub async fn progress_stream(
|
|
auth_user: AuthUser,
|
|
State(state): State<AppState>,
|
|
Path(job_id): Path<Uuid>,
|
|
) -> Result<Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>>>, AppError> {
|
|
// Get the watch receiver, verifying ownership
|
|
let rx = state
|
|
.job_store
|
|
.subscribe(job_id, auth_user.id)
|
|
.ok_or_else(|| {
|
|
AppError::NotFound("Generation introuvable ou deja terminee.".into())
|
|
})?;
|
|
|
|
// Convert the watch stream to an SSE event stream.
|
|
// The watch channel immediately delivers the latest value on subscribe,
|
|
// so clients that reconnect get caught up instantly.
|
|
let stream = WatchStream::new(rx).map(|event| {
|
|
let event_type = match &event {
|
|
ProgressEvent::Progress { .. } => "progress",
|
|
ProgressEvent::Complete { .. } => "complete",
|
|
ProgressEvent::Error { .. } => "error",
|
|
};
|
|
|
|
let data = serde_json::to_string(&event).unwrap_or_default();
|
|
|
|
Ok(Event::default().event(event_type).data(data))
|
|
});
|
|
|
|
Ok(Sse::new(stream).keep_alive(
|
|
KeepAlive::new()
|
|
.interval(Duration::from_secs(15))
|
|
.text("ping"),
|
|
))
|
|
}
|