diff --git a/backend/src/services/synthesis.rs b/backend/src/services/synthesis.rs index be01c52..1ff12b5 100644 --- a/backend/src/services/synthesis.rs +++ b/backend/src/services/synthesis.rs @@ -1064,6 +1064,104 @@ fn sanitize_error_message(msg: &str) -> String { } } +/// Parse the LLM classification response and assign articles to categories. +/// +/// Returns a HashMap of category_key → Vec. +/// Invalid indices are ignored. Unknown categories default to "Autre". +/// Case-insensitive category matching. +fn parse_classification_response( + response: &serde_json::Value, + articles: &[ScrapedNewsItem], + categories: &[String], + max_per_category: i32, + filled_counts: &mut HashMap, +) -> HashMap> { + let max = max_per_category as usize; + let mut result: HashMap> = HashMap::new(); + + // Build category name → key mapping (case-insensitive) + // "Autre" always maps to "category_autre" + // User categories map to "category_0", "category_1", etc. + // The index skips "Autre" — only user categories get numeric keys + let mut name_to_key: HashMap = HashMap::new(); + let mut user_cat_idx = 0; + for cat in categories { + let key = if cat == "Autre" { + "category_autre".to_string() + } else { + let key = format!("category_{}", user_cat_idx); + user_cat_idx += 1; + key + }; + name_to_key.insert(cat.to_lowercase(), key); + } + + let assignments = response + .get("assignments") + .and_then(|a| a.as_array()) + .cloned() + .unwrap_or_default(); + + let mut assigned_indices = std::collections::HashSet::new(); + + for assignment in &assignments { + let index = match assignment.get("index").and_then(|i| i.as_u64()) { + Some(i) => i as usize, + None => continue, + }; + if index >= articles.len() || assigned_indices.contains(&index) { + continue; + } + + let cat_name = assignment + .get("category") + .and_then(|c| c.as_str()) + .unwrap_or("Autre") + .to_string(); + + let cat_key = name_to_key + .get(&cat_name.to_lowercase()) + .cloned() + .unwrap_or_else(|| "category_autre".to_string()); + + // Resolve the display name for counting + let cat_display = categories + .iter() + .find(|c| c.to_lowercase() == cat_name.to_lowercase()) + .cloned() + .unwrap_or_else(|| "Autre".to_string()); + + let filled = filled_counts.get(&cat_display).copied().unwrap_or(0); + if filled >= max { + // Category full — assign to Autre if Autre has room + let autre_filled = filled_counts.get("Autre").copied().unwrap_or(0); + if autre_filled < max { + result.entry("category_autre".to_string()).or_default().push(articles[index].clone()); + *filled_counts.entry("Autre".to_string()).or_insert(0) += 1; + assigned_indices.insert(index); + } + continue; + } + + result.entry(cat_key).or_default().push(articles[index].clone()); + *filled_counts.entry(cat_display).or_insert(0) += 1; + assigned_indices.insert(index); + } + + // Unclassified articles → Autre + for (i, article) in articles.iter().enumerate() { + if !assigned_indices.contains(&i) { + let autre_filled = filled_counts.get("Autre").copied().unwrap_or(0); + if autre_filled < max { + result.entry("category_autre".to_string()).or_default().push(article.clone()); + *filled_counts.entry("Autre".to_string()).or_insert(0) += 1; + } + } + } + + result +} + #[cfg(test)] mod tests { use super::*; @@ -1823,4 +1921,89 @@ mod tests { let result = limit_articles_per_source(parsed, 3); assert_eq!(result[0].1.len(), 2, "Unparseable URLs should be kept"); } + + // ── parse_classification_response tests ───────────────────── + + #[test] + fn classification_assigns_articles_to_categories() { + use crate::models::synthesis::ScrapedNewsItem; + let articles = vec![ + ScrapedNewsItem { title: "A".into(), url: "https://a.com/1".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() }, + ScrapedNewsItem { title: "B".into(), url: "https://b.com/2".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() }, + ]; + let categories = vec!["AI News".to_string(), "Autre".to_string()]; + let response = serde_json::json!({ + "assignments": [ + {"index": 0, "category": "AI News"}, + {"index": 1, "category": "Autre"} + ] + }); + let mut filled = HashMap::new(); + let result = parse_classification_response(&response, &articles, &categories, 4, &mut filled); + assert_eq!(result.get("category_0").map(|v| v.len()), Some(1)); + assert_eq!(result.get("category_autre").map(|v| v.len()), Some(1)); + } + + #[test] + fn classification_unknown_category_goes_to_autre() { + use crate::models::synthesis::ScrapedNewsItem; + let articles = vec![ + ScrapedNewsItem { title: "A".into(), url: "https://a.com/1".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() }, + ]; + let categories = vec!["AI News".to_string(), "Autre".to_string()]; + let response = serde_json::json!({ + "assignments": [{"index": 0, "category": "Unknown Category"}] + }); + let mut filled = HashMap::new(); + let result = parse_classification_response(&response, &articles, &categories, 4, &mut filled); + assert_eq!(result.get("category_autre").map(|v| v.len()), Some(1)); + } + + #[test] + fn classification_respects_max_per_category() { + use crate::models::synthesis::ScrapedNewsItem; + let articles: Vec = (0..5).map(|i| ScrapedNewsItem { + title: format!("Art{}", i), url: format!("https://a.com/{}", i), + summary: "s".into(), original_title: "t".into(), scraped_content: "c".into(), + }).collect(); + let categories = vec!["AI News".to_string(), "Autre".to_string()]; + let response = serde_json::json!({ + "assignments": (0..5).map(|i| serde_json::json!({"index": i, "category": "AI News"})).collect::>() + }); + let mut filled = HashMap::new(); + let result = parse_classification_response(&response, &articles, &categories, 2, &mut filled); + assert_eq!(result.get("category_0").map(|v| v.len()), Some(2)); + assert!(result.get("category_autre").map(|v| v.len()).unwrap_or(0) > 0); + } + + #[test] + fn classification_invalid_index_ignored() { + use crate::models::synthesis::ScrapedNewsItem; + let articles = vec![ + ScrapedNewsItem { title: "A".into(), url: "https://a.com/1".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() }, + ]; + let categories = vec!["AI News".to_string(), "Autre".to_string()]; + let response = serde_json::json!({ + "assignments": [{"index": 99, "category": "AI News"}] + }); + let mut filled = HashMap::new(); + let result = parse_classification_response(&response, &articles, &categories, 4, &mut filled); + // Index 99 is invalid → article 0 is unclassified → goes to Autre + assert_eq!(result.get("category_autre").map(|v| v.len()), Some(1)); + } + + #[test] + fn classification_case_insensitive() { + use crate::models::synthesis::ScrapedNewsItem; + let articles = vec![ + ScrapedNewsItem { title: "A".into(), url: "https://a.com/1".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() }, + ]; + let categories = vec!["AI News".to_string(), "Autre".to_string()]; + let response = serde_json::json!({ + "assignments": [{"index": 0, "category": "ai news"}] + }); + let mut filled = HashMap::new(); + let result = parse_classification_response(&response, &articles, &categories, 4, &mut filled); + assert_eq!(result.get("category_0").map(|v| v.len()), Some(1)); + } }