feat: add classification response parsing with category filling and Autre fallback

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
master
oabrivard 3 months ago
parent 104b6a0d7b
commit ba7024e280

@ -1064,6 +1064,104 @@ fn sanitize_error_message(msg: &str) -> String {
}
}
/// Parse the LLM classification response and assign articles to categories.
///
/// Returns a HashMap of category_key → Vec<ScrapedNewsItem>.
/// Invalid indices are ignored. Unknown categories default to "Autre".
/// Case-insensitive category matching.
fn parse_classification_response(
response: &serde_json::Value,
articles: &[ScrapedNewsItem],
categories: &[String],
max_per_category: i32,
filled_counts: &mut HashMap<String, usize>,
) -> HashMap<String, Vec<ScrapedNewsItem>> {
let max = max_per_category as usize;
let mut result: HashMap<String, Vec<ScrapedNewsItem>> = HashMap::new();
// Build category name → key mapping (case-insensitive)
// "Autre" always maps to "category_autre"
// User categories map to "category_0", "category_1", etc.
// The index skips "Autre" — only user categories get numeric keys
let mut name_to_key: HashMap<String, String> = HashMap::new();
let mut user_cat_idx = 0;
for cat in categories {
let key = if cat == "Autre" {
"category_autre".to_string()
} else {
let key = format!("category_{}", user_cat_idx);
user_cat_idx += 1;
key
};
name_to_key.insert(cat.to_lowercase(), key);
}
let assignments = response
.get("assignments")
.and_then(|a| a.as_array())
.cloned()
.unwrap_or_default();
let mut assigned_indices = std::collections::HashSet::new();
for assignment in &assignments {
let index = match assignment.get("index").and_then(|i| i.as_u64()) {
Some(i) => i as usize,
None => continue,
};
if index >= articles.len() || assigned_indices.contains(&index) {
continue;
}
let cat_name = assignment
.get("category")
.and_then(|c| c.as_str())
.unwrap_or("Autre")
.to_string();
let cat_key = name_to_key
.get(&cat_name.to_lowercase())
.cloned()
.unwrap_or_else(|| "category_autre".to_string());
// Resolve the display name for counting
let cat_display = categories
.iter()
.find(|c| c.to_lowercase() == cat_name.to_lowercase())
.cloned()
.unwrap_or_else(|| "Autre".to_string());
let filled = filled_counts.get(&cat_display).copied().unwrap_or(0);
if filled >= max {
// Category full — assign to Autre if Autre has room
let autre_filled = filled_counts.get("Autre").copied().unwrap_or(0);
if autre_filled < max {
result.entry("category_autre".to_string()).or_default().push(articles[index].clone());
*filled_counts.entry("Autre".to_string()).or_insert(0) += 1;
assigned_indices.insert(index);
}
continue;
}
result.entry(cat_key).or_default().push(articles[index].clone());
*filled_counts.entry(cat_display).or_insert(0) += 1;
assigned_indices.insert(index);
}
// Unclassified articles → Autre
for (i, article) in articles.iter().enumerate() {
if !assigned_indices.contains(&i) {
let autre_filled = filled_counts.get("Autre").copied().unwrap_or(0);
if autre_filled < max {
result.entry("category_autre".to_string()).or_default().push(article.clone());
*filled_counts.entry("Autre".to_string()).or_insert(0) += 1;
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
@ -1823,4 +1921,89 @@ mod tests {
let result = limit_articles_per_source(parsed, 3);
assert_eq!(result[0].1.len(), 2, "Unparseable URLs should be kept");
}
// ── parse_classification_response tests ─────────────────────
#[test]
fn classification_assigns_articles_to_categories() {
use crate::models::synthesis::ScrapedNewsItem;
let articles = vec![
ScrapedNewsItem { title: "A".into(), url: "https://a.com/1".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() },
ScrapedNewsItem { title: "B".into(), url: "https://b.com/2".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() },
];
let categories = vec!["AI News".to_string(), "Autre".to_string()];
let response = serde_json::json!({
"assignments": [
{"index": 0, "category": "AI News"},
{"index": 1, "category": "Autre"}
]
});
let mut filled = HashMap::new();
let result = parse_classification_response(&response, &articles, &categories, 4, &mut filled);
assert_eq!(result.get("category_0").map(|v| v.len()), Some(1));
assert_eq!(result.get("category_autre").map(|v| v.len()), Some(1));
}
#[test]
fn classification_unknown_category_goes_to_autre() {
use crate::models::synthesis::ScrapedNewsItem;
let articles = vec![
ScrapedNewsItem { title: "A".into(), url: "https://a.com/1".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() },
];
let categories = vec!["AI News".to_string(), "Autre".to_string()];
let response = serde_json::json!({
"assignments": [{"index": 0, "category": "Unknown Category"}]
});
let mut filled = HashMap::new();
let result = parse_classification_response(&response, &articles, &categories, 4, &mut filled);
assert_eq!(result.get("category_autre").map(|v| v.len()), Some(1));
}
#[test]
fn classification_respects_max_per_category() {
use crate::models::synthesis::ScrapedNewsItem;
let articles: Vec<ScrapedNewsItem> = (0..5).map(|i| ScrapedNewsItem {
title: format!("Art{}", i), url: format!("https://a.com/{}", i),
summary: "s".into(), original_title: "t".into(), scraped_content: "c".into(),
}).collect();
let categories = vec!["AI News".to_string(), "Autre".to_string()];
let response = serde_json::json!({
"assignments": (0..5).map(|i| serde_json::json!({"index": i, "category": "AI News"})).collect::<Vec<_>>()
});
let mut filled = HashMap::new();
let result = parse_classification_response(&response, &articles, &categories, 2, &mut filled);
assert_eq!(result.get("category_0").map(|v| v.len()), Some(2));
assert!(result.get("category_autre").map(|v| v.len()).unwrap_or(0) > 0);
}
#[test]
fn classification_invalid_index_ignored() {
use crate::models::synthesis::ScrapedNewsItem;
let articles = vec![
ScrapedNewsItem { title: "A".into(), url: "https://a.com/1".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() },
];
let categories = vec!["AI News".to_string(), "Autre".to_string()];
let response = serde_json::json!({
"assignments": [{"index": 99, "category": "AI News"}]
});
let mut filled = HashMap::new();
let result = parse_classification_response(&response, &articles, &categories, 4, &mut filled);
// Index 99 is invalid → article 0 is unclassified → goes to Autre
assert_eq!(result.get("category_autre").map(|v| v.len()), Some(1));
}
#[test]
fn classification_case_insensitive() {
use crate::models::synthesis::ScrapedNewsItem;
let articles = vec![
ScrapedNewsItem { title: "A".into(), url: "https://a.com/1".into(), summary: "s".into(), original_title: "t".into(), scraped_content: "c".into() },
];
let categories = vec!["AI News".to_string(), "Autre".to_string()];
let response = serde_json::json!({
"assignments": [{"index": 0, "category": "ai news"}]
});
let mut filled = HashMap::new();
let result = parse_classification_response(&response, &articles, &categories, 4, &mut filled);
assert_eq!(result.get("category_0").map(|v| v.len()), Some(1));
}
}

Loading…
Cancel
Save