feat: parallelize Phase 1 scrape+classify in batches of 5

master
oabrivard 3 months ago
parent a5f4239157
commit ed399e9a6e

@ -389,50 +389,121 @@ async fn run_generation_inner(
url_source.insert(url.clone(), source_url.clone()); url_source.insert(url.clone(), source_url.clone());
} }
// 1b. Scrape, classify, summarize each article // 1b. Scrape, classify, summarize in batches of 5
emit_progress(tx, "processing", "Traitement des articles...", 25); emit_progress(tx, "processing", "Traitement des articles...", 25);
let total_candidates = candidate_urls.len(); let total_candidates = candidate_urls.len();
let batch_size = 5;
for (idx, (url, source_url)) in candidate_urls.into_iter().enumerate() { let mut processed = 0usize;
let pct = 25 + ((idx as u32 * 40) / total_candidates.max(1) as u32).min(40); let mut candidates_iter = candidate_urls.into_iter();
emit_progress(tx, "processing", &format!("Article {}/{}...", idx + 1, total_candidates), pct as u8); let mut done = false;
// Check source limit while !done {
// Take next batch of candidates (up to 5), filtering source limits
let mut batch: Vec<(String, String)> = Vec::new();
while batch.len() < batch_size {
let Some((url, source_url)) = candidates_iter.next() else {
break;
};
let source_domain = extract_domain(&source_url).unwrap_or_default(); let source_domain = extract_domain(&source_url).unwrap_or_default();
let source_count = source_counts.get(&source_domain).copied().unwrap_or(0); let source_count = source_counts.get(&source_domain).copied().unwrap_or(0);
if source_count >= settings.max_articles_per_source as usize { if source_count >= settings.max_articles_per_source as usize {
trace_article(&state.pool, user_id, job_id, &url, "", "personalized_source", Some(&source_url), None, None, "filtered_diversity", false).await; trace_article(&state.pool, user_id, job_id, &url, "", "personalized_source", Some(&source_url), None, None, "filtered_diversity", false).await;
continue; continue;
} }
batch.push((url, source_url));
}
// Scrape if batch.is_empty() {
let (body_text, page_title, final_url, drop_reason) = scrape_single_article(&state.http_client, &url, settings.max_age_days as i64).await; break;
}
let pct = 25 + ((processed as u32 * 40) / total_candidates.max(1) as u32).min(40);
emit_progress(tx, "processing", &format!("Articles {}-{}/{}...", processed + 1, processed + batch.len(), total_candidates), pct as u8);
// Phase A: Scrape batch in parallel
let mut scrape_set = tokio::task::JoinSet::new();
for (url, source_url) in &batch {
let client = state.http_client.clone();
let u = url.clone();
let su = source_url.clone();
let mad = settings.max_age_days as i64;
scrape_set.spawn(async move {
let result = scrape_single_article(&client, &u, mad).await;
(u, su, result)
});
}
let mut scraped_articles: Vec<(String, String, String, String)> = Vec::new(); // (url, source_url, body_text, page_title)
while let Some(join_result) = scrape_set.join_next().await {
if let Ok((url, source_url, (body_text, page_title, final_url, drop_reason))) = join_result {
if let Some(reason) = drop_reason { if let Some(reason) = drop_reason {
trace_article(&state.pool, user_id, job_id, &final_url, &page_title, "personalized_source", Some(&source_url), None, None, reason, false).await; trace_article(&state.pool, user_id, job_id, &final_url, &page_title, "personalized_source", Some(&source_url), None, None, reason, false).await;
} else {
scraped_articles.push((final_url, source_url, body_text, page_title));
}
}
}
if scraped_articles.is_empty() {
processed += batch.len();
continue; continue;
} }
// LLM classify + summarize // Phase B: Classify/summarize batch in parallel
check_rate_limit(state, &user_rate_limiter, &provider_name)?; check_rate_limit(state, &user_rate_limiter, &provider_name)?;
let mut classify_set = tokio::task::JoinSet::new();
for (final_url, source_url, body_text, page_title) in &scraped_articles {
let provider_clone = std::sync::Arc::clone(&provider);
let model = model_research.clone();
let schema = classify_schema.clone();
let cats = classification_categories.clone();
let body_snippet: String = body_text.chars().take(500).collect(); let body_snippet: String = body_text.chars().take(500).collect();
let (class_sys, class_user) = crate::services::prompts::build_article_classify_prompt(&page_title, &body_snippet, &classification_categories); let title = page_title.clone();
let url = final_url.clone();
let su = source_url.clone();
let pool = state.pool.clone();
let uid = user_id;
let jid = job_id;
let (class_sys, class_user) = crate::services::prompts::build_article_classify_prompt(&title, &body_snippet, &cats);
let sys = class_sys.clone();
let usr = class_user.clone();
let mdl = model.clone();
classify_set.spawn(async move {
let llm_start = std::time::Instant::now(); let llm_start = std::time::Instant::now();
let class_response = provider.call_llm(&model_research, &class_sys, &class_user, &classify_schema).await?; let result = provider_clone.call_llm(&mdl, &sys, &usr, &schema).await;
let llm_duration = llm_start.elapsed().as_millis() as u64; let duration = llm_start.elapsed().as_millis() as u64;
log_llm_call(&state.pool, user_id, job_id, "classify_summarize", &model_research, &class_sys, &class_user, &class_response, llm_duration).await;
// Log the LLM call
if let Ok(ref resp) = result {
let resp_str = serde_json::to_string_pretty(resp).unwrap_or_default();
crate::db::llm_call_log::insert(&pool, uid, jid, "classify_summarize", &mdl, &sys, &usr, &resp_str, duration as i32).await.ok();
}
(url, su, title, result)
});
}
while let Some(join_result) = classify_set.join_next().await {
if let Ok((final_url, source_url, page_title, llm_result)) = join_result {
let class_response = match llm_result {
Ok(resp) => resp,
Err(e) => {
tracing::warn!(url = %final_url, error = %e, "LLM classify failed, skipping article");
continue;
}
};
let llm_title = class_response.get("title").and_then(|t| t.as_str()).unwrap_or(&page_title).to_string(); let llm_title = class_response.get("title").and_then(|t| t.as_str()).unwrap_or(&page_title).to_string();
let llm_summary = class_response.get("summary").and_then(|s| s.as_str()).unwrap_or("").to_string(); let llm_summary = class_response.get("summary").and_then(|s| s.as_str()).unwrap_or("").to_string();
let mut llm_category = class_response.get("category").and_then(|c| c.as_str()).unwrap_or("Autre").to_string(); let mut llm_category = class_response.get("category").and_then(|c| c.as_str()).unwrap_or("Autre").to_string();
// Validate category
if !classification_categories.iter().any(|c| c.to_lowercase() == llm_category.to_lowercase()) { if !classification_categories.iter().any(|c| c.to_lowercase() == llm_category.to_lowercase()) {
llm_category = "Autre".to_string(); llm_category = "Autre".to_string();
} }
// Map category to key
let cat_key = if llm_category.to_lowercase() == "autre" { let cat_key = if llm_category.to_lowercase() == "autre" {
"category_autre".to_string() "category_autre".to_string()
} else { } else {
@ -441,12 +512,11 @@ async fn run_generation_inner(
.unwrap_or_else(|| "category_autre".to_string()) .unwrap_or_else(|| "category_autre".to_string())
}; };
// Check if category is full -> overflow to "Autre"
let cat_filled = filled_counts.get(&llm_category).copied().unwrap_or(0); let cat_filled = filled_counts.get(&llm_category).copied().unwrap_or(0);
let (final_cat_key, final_cat_name) = if cat_filled >= settings.max_items_per_category as usize && llm_category.to_lowercase() != "autre" { let (final_cat_key, final_cat_name) = if cat_filled >= settings.max_items_per_category as usize && llm_category.to_lowercase() != "autre" {
let autre_filled = filled_counts.get("Autre").copied().unwrap_or(0); let autre_filled = filled_counts.get("Autre").copied().unwrap_or(0);
if autre_filled >= settings.max_items_per_category as usize { if autre_filled >= settings.max_items_per_category as usize {
continue; // Both full -- skip continue;
} }
("category_autre".to_string(), "Autre".to_string()) ("category_autre".to_string(), "Autre".to_string())
} else { } else {
@ -459,11 +529,18 @@ async fn run_generation_inner(
summary: llm_summary, summary: llm_summary,
}); });
*filled_counts.entry(final_cat_name).or_insert(0) += 1; *filled_counts.entry(final_cat_name).or_insert(0) += 1;
let source_domain = extract_domain(&source_url).unwrap_or_default();
*source_counts.entry(source_domain).or_insert(0) += 1; *source_counts.entry(source_domain).or_insert(0) += 1;
}
}
processed += batch.len();
// Check if we've reached the maximum after this batch
let total: usize = article_scraped.values().map(|v| v.len()).sum(); let total: usize = article_scraped.values().map(|v| v.len()).sum();
if total >= max_total { if total >= max_total {
break; done = true;
} }
} }
} }

Loading…
Cancel
Save