From 104b6a0d7b71ded6218eae2e5a1c862b323819d4 Mon Sep 17 00:00:00 2001 From: oabrivard Date: Tue, 24 Mar 2026 01:42:56 +0100 Subject: [PATCH] feat: add classification prompt and schema for article categorization Co-Authored-By: Claude Sonnet 4.6 --- backend/src/services/llm/schema.rs | 38 +++++++++++ backend/src/services/prompts.rs | 103 +++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) diff --git a/backend/src/services/llm/schema.rs b/backend/src/services/llm/schema.rs index 810b4a7..f137b92 100644 --- a/backend/src/services/llm/schema.rs +++ b/backend/src/services/llm/schema.rs @@ -82,6 +82,31 @@ pub fn build_category_schema(categories: &[String], max_items_per_category: i32) }) } +/// Build a JSON Schema for the article classification response. +/// +/// The LLM returns an array of assignments mapping article indices to category names. +pub fn build_classification_schema() -> Value { + serde_json::json!({ + "type": "object", + "properties": { + "assignments": { + "type": "array", + "items": { + "type": "object", + "properties": { + "index": { "type": "integer", "description": "Article index from the input list" }, + "category": { "type": "string", "description": "Category name to assign this article to" } + }, + "required": ["index", "category"], + "additionalProperties": false + } + } + }, + "required": ["assignments"], + "additionalProperties": false + }) +} + #[cfg(test)] mod tests { use super::*; @@ -261,4 +286,17 @@ mod tests { assert_eq!(props["category_0"]["description"], "AI & Machine Learning"); assert_eq!(props["category_1"]["description"], "R&D / Innovation"); } + + #[test] + fn classification_schema_has_assignments_array() { + let schema = build_classification_schema(); + assert_eq!(schema["type"], "object"); + let assignments = &schema["properties"]["assignments"]; + assert_eq!(assignments["type"], "array"); + let item_props = &assignments["items"]["properties"]; + assert!(item_props.get("index").is_some()); + assert!(item_props.get("category").is_some()); + assert_eq!(assignments["items"]["additionalProperties"], false); + assert_eq!(schema["additionalProperties"], false); + } } diff --git a/backend/src/services/prompts.rs b/backend/src/services/prompts.rs index 20b8a6a..10436ca 100644 --- a/backend/src/services/prompts.rs +++ b/backend/src/services/prompts.rs @@ -140,6 +140,64 @@ pub fn build_rewrite_prompt( (system_prompt, user_prompt) } +/// Build a prompt for classifying scraped articles into categories. +/// +/// # Arguments +/// * `articles` — scraped articles to classify (title + body snippet used) +/// * `categories` — user categories + "Autre" +/// * `max_per_category` — max items allowed per category +/// * `filled_counts` — how many items already fill each category (for Phase 2) +pub fn build_classification_prompt( + articles: &[ScrapedNewsItem], + categories: &[String], + max_per_category: i32, + filled_counts: &std::collections::HashMap, +) -> (String, String) { + let system_prompt = + "Tu es un assistant qui classe des articles dans des categories. \ + Reponds uniquement au format JSON demande." + .to_string(); + + let articles_json: Vec = articles + .iter() + .enumerate() + .map(|(i, a)| { + let snippet: String = a.scraped_content.chars().take(500).collect(); + serde_json::json!({ + "index": i, + "title": a.title, + "url": a.url, + "snippet": snippet + }) + }) + .collect(); + + let categories_info: Vec = categories + .iter() + .map(|cat| { + let filled = filled_counts.get(cat).copied().unwrap_or(0); + let remaining = (max_per_category as usize).saturating_sub(filled); + if remaining == 1 { + format!("- \"{}\" (encore 1 place)", cat) + } else { + format!("- \"{}\" (encore {} places)", cat, remaining) + } + }) + .collect(); + + let user_prompt = format!( + "Voici une liste d'articles :\n{articles}\n\n\ + Categories disponibles :\n{categories}\n\n\ + Classe chaque article dans la categorie la plus appropriee. \ + Si un article ne correspond a aucune categorie, classe-le dans \"Autre\".\n\ + Respecte le nombre de places restantes par categorie.", + articles = serde_json::to_string_pretty(&articles_json).unwrap_or_default(), + categories = categories_info.join("\n"), + ); + + (system_prompt, user_prompt) +} + #[cfg(test)] mod tests { use super::*; @@ -317,4 +375,49 @@ mod tests { let (_, user_prompt) = build_search_prompt(&settings, &sources, date, &[]); assert!(!user_prompt.contains("Evite si possible")); } + + #[test] + fn classification_prompt_includes_categories_and_articles() { + let articles = vec![ + ScrapedNewsItem { + title: "GPT-5 Released".into(), + url: "https://openai.com/blog/gpt5".into(), + summary: "s".into(), + original_title: "t".into(), + scraped_content: "OpenAI released GPT-5 today with major improvements".into(), + }, + ]; + let categories = vec!["AI News".to_string(), "Autre".to_string()]; + let filled = std::collections::HashMap::new(); + let (_, user_prompt) = build_classification_prompt(&articles, &categories, 4, &filled); + assert!(user_prompt.contains("GPT-5 Released")); + assert!(user_prompt.contains("AI News")); + assert!(user_prompt.contains("Autre")); + assert!(user_prompt.contains("encore 4 places")); + } + + #[test] + fn classification_prompt_shows_reduced_capacity() { + let articles = vec![ + ScrapedNewsItem { + title: "T".into(), url: "https://a.com/1".into(), + summary: "s".into(), original_title: "t".into(), + scraped_content: "Content".into(), + }, + ]; + let categories = vec!["AI News".to_string(), "Autre".to_string()]; + let mut filled = std::collections::HashMap::new(); + filled.insert("AI News".to_string(), 3); + let (_, user_prompt) = build_classification_prompt(&articles, &categories, 4, &filled); + assert!(user_prompt.contains("encore 1 place")); + } + + #[test] + fn classification_prompt_system_is_french() { + let articles = vec![]; + let categories = vec!["Autre".to_string()]; + let filled = std::collections::HashMap::new(); + let (system, _) = build_classification_prompt(&articles, &categories, 4, &filled); + assert!(system.contains("classe")); + } }