feat: add classification prompt and schema for article categorization

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
master
oabrivard 3 months ago
parent c06b5ba454
commit 104b6a0d7b

@ -82,6 +82,31 @@ pub fn build_category_schema(categories: &[String], max_items_per_category: i32)
}) })
} }
/// Build a JSON Schema for the article classification response.
///
/// The LLM returns an array of assignments mapping article indices to category names.
pub fn build_classification_schema() -> Value {
serde_json::json!({
"type": "object",
"properties": {
"assignments": {
"type": "array",
"items": {
"type": "object",
"properties": {
"index": { "type": "integer", "description": "Article index from the input list" },
"category": { "type": "string", "description": "Category name to assign this article to" }
},
"required": ["index", "category"],
"additionalProperties": false
}
}
},
"required": ["assignments"],
"additionalProperties": false
})
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -261,4 +286,17 @@ mod tests {
assert_eq!(props["category_0"]["description"], "AI & Machine Learning"); assert_eq!(props["category_0"]["description"], "AI & Machine Learning");
assert_eq!(props["category_1"]["description"], "R&D / Innovation"); assert_eq!(props["category_1"]["description"], "R&D / Innovation");
} }
#[test]
fn classification_schema_has_assignments_array() {
let schema = build_classification_schema();
assert_eq!(schema["type"], "object");
let assignments = &schema["properties"]["assignments"];
assert_eq!(assignments["type"], "array");
let item_props = &assignments["items"]["properties"];
assert!(item_props.get("index").is_some());
assert!(item_props.get("category").is_some());
assert_eq!(assignments["items"]["additionalProperties"], false);
assert_eq!(schema["additionalProperties"], false);
}
} }

@ -140,6 +140,64 @@ pub fn build_rewrite_prompt(
(system_prompt, user_prompt) (system_prompt, user_prompt)
} }
/// Build a prompt for classifying scraped articles into categories.
///
/// # Arguments
/// * `articles` — scraped articles to classify (title + body snippet used)
/// * `categories` — user categories + "Autre"
/// * `max_per_category` — max items allowed per category
/// * `filled_counts` — how many items already fill each category (for Phase 2)
pub fn build_classification_prompt(
articles: &[ScrapedNewsItem],
categories: &[String],
max_per_category: i32,
filled_counts: &std::collections::HashMap<String, usize>,
) -> (String, String) {
let system_prompt =
"Tu es un assistant qui classe des articles dans des categories. \
Reponds uniquement au format JSON demande."
.to_string();
let articles_json: Vec<serde_json::Value> = articles
.iter()
.enumerate()
.map(|(i, a)| {
let snippet: String = a.scraped_content.chars().take(500).collect();
serde_json::json!({
"index": i,
"title": a.title,
"url": a.url,
"snippet": snippet
})
})
.collect();
let categories_info: Vec<String> = categories
.iter()
.map(|cat| {
let filled = filled_counts.get(cat).copied().unwrap_or(0);
let remaining = (max_per_category as usize).saturating_sub(filled);
if remaining == 1 {
format!("- \"{}\" (encore 1 place)", cat)
} else {
format!("- \"{}\" (encore {} places)", cat, remaining)
}
})
.collect();
let user_prompt = format!(
"Voici une liste d'articles :\n{articles}\n\n\
Categories disponibles :\n{categories}\n\n\
Classe chaque article dans la categorie la plus appropriee. \
Si un article ne correspond a aucune categorie, classe-le dans \"Autre\".\n\
Respecte le nombre de places restantes par categorie.",
articles = serde_json::to_string_pretty(&articles_json).unwrap_or_default(),
categories = categories_info.join("\n"),
);
(system_prompt, user_prompt)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -317,4 +375,49 @@ mod tests {
let (_, user_prompt) = build_search_prompt(&settings, &sources, date, &[]); let (_, user_prompt) = build_search_prompt(&settings, &sources, date, &[]);
assert!(!user_prompt.contains("Evite si possible")); assert!(!user_prompt.contains("Evite si possible"));
} }
#[test]
fn classification_prompt_includes_categories_and_articles() {
let articles = vec![
ScrapedNewsItem {
title: "GPT-5 Released".into(),
url: "https://openai.com/blog/gpt5".into(),
summary: "s".into(),
original_title: "t".into(),
scraped_content: "OpenAI released GPT-5 today with major improvements".into(),
},
];
let categories = vec!["AI News".to_string(), "Autre".to_string()];
let filled = std::collections::HashMap::new();
let (_, user_prompt) = build_classification_prompt(&articles, &categories, 4, &filled);
assert!(user_prompt.contains("GPT-5 Released"));
assert!(user_prompt.contains("AI News"));
assert!(user_prompt.contains("Autre"));
assert!(user_prompt.contains("encore 4 places"));
}
#[test]
fn classification_prompt_shows_reduced_capacity() {
let articles = vec![
ScrapedNewsItem {
title: "T".into(), url: "https://a.com/1".into(),
summary: "s".into(), original_title: "t".into(),
scraped_content: "Content".into(),
},
];
let categories = vec!["AI News".to_string(), "Autre".to_string()];
let mut filled = std::collections::HashMap::new();
filled.insert("AI News".to_string(), 3);
let (_, user_prompt) = build_classification_prompt(&articles, &categories, 4, &filled);
assert!(user_prompt.contains("encore 1 place"));
}
#[test]
fn classification_prompt_system_is_french() {
let articles = vec![];
let categories = vec!["Autre".to_string()];
let filled = std::collections::HashMap::new();
let (system, _) = build_classification_prompt(&articles, &categories, 4, &filled);
assert!(system.contains("classe"));
}
} }

Loading…
Cancel
Save