diff --git a/backend/src/services/synthesis.rs b/backend/src/services/synthesis.rs index 464a0fc..cc71fe3 100644 --- a/backend/src/services/synthesis.rs +++ b/backend/src/services/synthesis.rs @@ -314,6 +314,9 @@ async fn run_generation_inner( // Step 7b: Filter out homepage URLs (path == "/" or empty) let parsed = filter_homepage_urls(parsed); + // Step 7c: Limit articles per source for diversity + let parsed = limit_articles_per_source(parsed, settings.max_articles_per_source); + // Step 8: Scrape + rewrite pass // // Always run the full pipeline: the search pass URLs can be hallucinated @@ -486,6 +489,89 @@ fn filter_homepage_urls( result } +/// Limit the number of articles from the same domain across all categories. +/// +/// Spreads articles across categories first (at most 1 per domain per category), +/// then fills remaining slots from dropped articles in encounter order. +fn limit_articles_per_source( + parsed: Vec<(String, Vec)>, + max_per_source: i32, +) -> Vec<(String, Vec)> { + let max = max_per_source as usize; + + // Pass 1: keep at most 1 article per domain per category + let mut kept: Vec<(String, Vec)> = Vec::new(); + let mut dropped: Vec<(usize, NewsItem)> = Vec::new(); // (category_index, item) + let mut domain_counts: std::collections::HashMap = + std::collections::HashMap::new(); + + for (cat_idx, (cat_key, items)) in parsed.into_iter().enumerate() { + let mut cat_kept = Vec::new(); + let mut seen_in_cat: std::collections::HashSet = std::collections::HashSet::new(); + + for item in items { + let domain = extract_domain(&item.url); + if let Some(ref d) = domain { + if seen_in_cat.contains(d) { + dropped.push((cat_idx, item)); + continue; + } + seen_in_cat.insert(d.clone()); + } + cat_kept.push(item); + } + + kept.push((cat_key, cat_kept)); + } + + // Cap enforcement: if any domain exceeds max after pass 1 (when categories > max), + // keep the first max articles in category order, drop the rest. + let mut cap_counts: std::collections::HashMap = std::collections::HashMap::new(); + for (_, items) in &mut kept { + items.retain(|item| { + let domain = extract_domain(&item.url); + match domain { + Some(ref d) => { + let count = cap_counts.entry(d.clone()).or_insert(0); + if *count >= max { + false + } else { + *count += 1; + true + } + } + None => true, // keep unparseable URLs + } + }); + } + + // Use cap_counts as the authoritative domain counts going forward + let mut domain_counts = cap_counts; + + // Pass 2: fill from dropped articles, back into their original category + for (cat_idx, item) in dropped { + if let Some(d) = extract_domain(&item.url) { + let count = domain_counts.get(&d).copied().unwrap_or(0); + if count < max { + *domain_counts.entry(d).or_insert(0) += 1; + kept[cat_idx].1.push(item); + } + } else { + // Unparseable URL — keep it + kept[cat_idx].1.push(item); + } + } + + kept +} + +/// Extract the domain (host) from a URL, or None if unparseable. +fn extract_domain(url: &str) -> Option { + url::Url::parse(url) + .ok() + .and_then(|u| u.host_str().map(|h| h.to_lowercase())) +} + /// Resolve the LLM provider and decrypt the user's API key. /// /// If the user has a preferred provider in settings, looks for a key matching @@ -1297,4 +1383,110 @@ mod tests { let sanitized = sanitize_json_null_bytes(json.clone()); assert_eq!(sanitized, json); } + + // ── limit_articles_per_source tests ──────────────────────────── + + #[test] + fn source_limit_spreads_across_categories() { + let parsed = vec![ + ("category_0".into(), vec![ + NewsItem { title: "A1".into(), url: "https://openai.com/blog/a".into(), summary: "s".into() }, + NewsItem { title: "A2".into(), url: "https://openai.com/blog/b".into(), summary: "s".into() }, + NewsItem { title: "A3".into(), url: "https://openai.com/blog/c".into(), summary: "s".into() }, + NewsItem { title: "A4".into(), url: "https://techcrunch.com/x".into(), summary: "s".into() }, + ]), + ("category_1".into(), vec![ + NewsItem { title: "B1".into(), url: "https://openai.com/research/d".into(), summary: "s".into() }, + NewsItem { title: "B2".into(), url: "https://openai.com/research/e".into(), summary: "s".into() }, + NewsItem { title: "B3".into(), url: "https://theverge.com/y".into(), summary: "s".into() }, + ]), + ]; + + let result = limit_articles_per_source(parsed, 3); + + // Count openai.com articles across all categories + let openai_count: usize = result.iter() + .flat_map(|(_, items)| items) + .filter(|i| i.url.contains("openai.com")) + .count(); + assert_eq!(openai_count, 3, "Should keep exactly 3 openai.com articles"); + + // Both categories should have at least 1 openai article (spread) + let cat0_openai = result[0].1.iter().filter(|i| i.url.contains("openai.com")).count(); + let cat1_openai = result[1].1.iter().filter(|i| i.url.contains("openai.com")).count(); + assert!(cat0_openai >= 1, "Category 0 should have at least 1 openai article"); + assert!(cat1_openai >= 1, "Category 1 should have at least 1 openai article"); + + // techcrunch and theverge should be untouched + let tc_count: usize = result.iter().flat_map(|(_, items)| items).filter(|i| i.url.contains("techcrunch")).count(); + assert_eq!(tc_count, 1); + } + + #[test] + fn source_limit_all_different_domains() { + let parsed = vec![ + ("category_0".into(), vec![ + NewsItem { title: "A".into(), url: "https://a.com/1".into(), summary: "s".into() }, + NewsItem { title: "B".into(), url: "https://b.com/2".into(), summary: "s".into() }, + ]), + ]; + + let result = limit_articles_per_source(parsed, 3); + assert_eq!(result[0].1.len(), 2, "Nothing dropped when all domains are unique"); + } + + #[test] + fn source_limit_max_one() { + let parsed = vec![ + ("category_0".into(), vec![ + NewsItem { title: "A".into(), url: "https://openai.com/a".into(), summary: "s".into() }, + NewsItem { title: "B".into(), url: "https://openai.com/b".into(), summary: "s".into() }, + ]), + ("category_1".into(), vec![ + NewsItem { title: "C".into(), url: "https://openai.com/c".into(), summary: "s".into() }, + ]), + ]; + + let result = limit_articles_per_source(parsed, 1); + let total: usize = result.iter().flat_map(|(_, items)| items).filter(|i| i.url.contains("openai.com")).count(); + assert_eq!(total, 1, "max=1 should keep exactly 1 openai article"); + } + + #[test] + fn source_limit_more_categories_than_max() { + // 5 categories, each with 1 openai article, max=2 + let parsed: Vec<(String, Vec)> = (0..5) + .map(|i| ( + format!("category_{}", i), + vec![NewsItem { + title: format!("Art{}", i), + url: format!("https://openai.com/{}", i), + summary: "s".into(), + }], + )) + .collect(); + + let result = limit_articles_per_source(parsed, 2); + let total: usize = result.iter().flat_map(|(_, items)| items).count(); + assert_eq!(total, 2, "Should cap at max_per_source even with more categories"); + } + + #[test] + fn source_limit_empty_input() { + let result = limit_articles_per_source(vec![], 3); + assert!(result.is_empty()); + } + + #[test] + fn source_limit_unparseable_urls_kept() { + let parsed = vec![ + ("category_0".into(), vec![ + NewsItem { title: "Good".into(), url: "https://openai.com/a".into(), summary: "s".into() }, + NewsItem { title: "Bad".into(), url: "not-a-url".into(), summary: "s".into() }, + ]), + ]; + + let result = limit_articles_per_source(parsed, 3); + assert_eq!(result[0].1.len(), 2, "Unparseable URLs should be kept"); + } }