//! In-memory sliding window rate limiters. //! //! Two rate limiter types: //! - `RateLimiter`: simple per-key limiter (used for auth rate limiting) //! - `ProviderRateLimiter`: per-provider limiter with DB-backed config and hot-reload //! //! Both use `DashMap` for lock-free concurrent access. use std::collections::VecDeque; use std::sync::Arc; use std::time::{Duration, Instant}; use dashmap::DashMap; use sqlx::PgPool; use crate::errors::AppError; // =========================================================================== // RateLimiter — simple per-key sliding window (auth, etc.) // =========================================================================== /// A thread-safe, in-memory sliding window rate limiter. /// /// Tracks request timestamps per key and enforces a maximum number /// of requests within a sliding time window. #[derive(Clone)] pub struct RateLimiter { inner: Arc, } struct RateLimiterInner { /// Maximum allowed requests per key within the time window. max_requests: usize, /// Duration of the sliding window. window: Duration, /// Per-key request timestamps. entries: DashMap>, } impl RateLimiter { /// Create a new rate limiter. /// /// # Arguments /// * `max_requests` — maximum number of requests allowed per key within the window. /// * `window` — duration of the sliding time window. pub fn new(max_requests: usize, window: Duration) -> Self { Self { inner: Arc::new(RateLimiterInner { max_requests, window, entries: DashMap::new(), }), } } /// Check if a request for the given key is allowed. /// /// Returns `true` if the request is within the rate limit, /// `false` if the limit has been exceeded. /// If allowed, the current timestamp is recorded. pub fn check(&self, key: &str) -> bool { let now = Instant::now(); let cutoff = now - self.inner.window; let mut entry = self.inner.entries.entry(key.to_string()).or_default(); let timestamps = entry.value_mut(); // Evict timestamps outside the window timestamps.retain(|t| *t > cutoff); if timestamps.len() >= self.inner.max_requests { false } else { timestamps.push(now); true } } /// Returns the number of remaining requests allowed for the given key. pub fn remaining(&self, key: &str) -> usize { let now = Instant::now(); let cutoff = now - self.inner.window; match self.inner.entries.get(key) { Some(entry) => { let recent_count = entry.value().iter().filter(|t| **t > cutoff).count(); self.inner.max_requests.saturating_sub(recent_count) } None => self.inner.max_requests, } } } // =========================================================================== // ProviderRateLimiter — per-provider, DB-backed, hot-reloadable // =========================================================================== /// Configuration for a single provider's rate limit bucket. struct ProviderBucket { /// Request timestamps within the current window. timestamps: VecDeque, /// Maximum requests allowed within the time window. max_requests: u32, /// Duration of the sliding window. time_window: Duration, } /// Per-provider rate limiter with DB-backed configuration and hot-reload. /// /// Global buckets are loaded from the `admin_rate_limits` table at startup /// and can be reloaded at runtime when an admin changes the configuration. #[derive(Clone)] pub struct ProviderRateLimiter { inner: Arc, } struct ProviderRateLimiterInner { /// Per-provider global buckets (admin-configured defaults). global_buckets: DashMap, } /// Default rate limit: 29 requests per 60 seconds (conservative for Gemini). const DEFAULT_MAX_REQUESTS: u32 = 29; const DEFAULT_TIME_WINDOW_SECS: u64 = 60; impl ProviderRateLimiter { /// Create a new provider rate limiter with empty buckets. /// /// Call `reload_from_db` after creation to load admin-configured limits. pub fn new() -> Self { Self { inner: Arc::new(ProviderRateLimiterInner { global_buckets: DashMap::new(), }), } } /// Load (or reload) all provider rate limits from the database. /// /// Called at startup and after admin updates rate limits. /// Resets the sliding window timestamps for reloaded providers. pub async fn reload_from_db(&self, pool: &PgPool) -> Result<(), AppError> { // Use a direct query to avoid circular dependency with db module let rows = sqlx::query_as::<_, RateLimitConfigRow>( r#" SELECT provider_name, max_requests, time_window_seconds FROM admin_rate_limits "#, ) .fetch_all(pool) .await?; for row in rows { self.inner.global_buckets.insert( row.provider_name, ProviderBucket { timestamps: VecDeque::new(), max_requests: row.max_requests as u32, time_window: Duration::from_secs(row.time_window_seconds as u64), }, ); } tracing::info!( provider_count = self.inner.global_buckets.len(), "Provider rate limits loaded from DB" ); Ok(()) } /// Check if a request for the given provider is allowed. /// /// Returns `true` if the request is within the rate limit, /// `false` if the limit has been exceeded. pub fn check(&self, provider: &str) -> bool { let now = Instant::now(); let mut bucket = self.inner.global_buckets.entry(provider.to_string()).or_insert_with(|| { ProviderBucket { timestamps: VecDeque::new(), max_requests: DEFAULT_MAX_REQUESTS, time_window: Duration::from_secs(DEFAULT_TIME_WINDOW_SECS), } }); let bucket = bucket.value_mut(); // Evict timestamps outside the window let cutoff = now - bucket.time_window; while bucket.timestamps.front().is_some_and(|t| *t < cutoff) { bucket.timestamps.pop_front(); } if bucket.timestamps.len() >= bucket.max_requests as usize { false } else { bucket.timestamps.push_back(now); true } } /// Returns the number of remaining requests allowed for a provider. pub fn remaining(&self, provider: &str) -> usize { let now = Instant::now(); match self.inner.global_buckets.get(provider) { Some(bucket) => { let cutoff = now - bucket.time_window; let recent_count = bucket.timestamps.iter().filter(|t| **t >= cutoff).count(); (bucket.max_requests as usize).saturating_sub(recent_count) } None => DEFAULT_MAX_REQUESTS as usize, } } /// Hot-reload a single provider's rate limit configuration. /// /// Called after an admin updates the rate limit for a specific provider. /// Resets the sliding window timestamps for the provider. pub fn update_provider_limit( &self, provider: &str, max_requests: u32, time_window_secs: u64, ) { self.inner.global_buckets.insert( provider.to_string(), ProviderBucket { timestamps: VecDeque::new(), max_requests, time_window: Duration::from_secs(time_window_secs), }, ); tracing::info!( provider = provider, max_requests = max_requests, time_window_secs = time_window_secs, "Provider rate limit updated (hot-reload)" ); } /// Returns the configured limit for a provider, or the default if not configured. pub fn get_provider_config(&self, provider: &str) -> (u32, u64) { match self.inner.global_buckets.get(provider) { Some(bucket) => (bucket.max_requests, bucket.time_window.as_secs()), None => (DEFAULT_MAX_REQUESTS, DEFAULT_TIME_WINDOW_SECS), } } } /// Row type for reading rate limit configuration from DB. #[derive(sqlx::FromRow)] struct RateLimitConfigRow { provider_name: String, max_requests: i32, time_window_seconds: i32, } // =========================================================================== // Tests // =========================================================================== #[cfg(test)] mod tests { use super::*; // --- RateLimiter tests --- #[test] fn test_allows_within_limit() { let limiter = RateLimiter::new(3, Duration::from_secs(60)); assert!(limiter.check("key1")); assert!(limiter.check("key1")); assert!(limiter.check("key1")); } #[test] fn test_blocks_over_limit() { let limiter = RateLimiter::new(2, Duration::from_secs(60)); assert!(limiter.check("key1")); assert!(limiter.check("key1")); assert!(!limiter.check("key1")); // 3rd request exceeds limit of 2 } #[test] fn test_independent_keys() { let limiter = RateLimiter::new(1, Duration::from_secs(60)); assert!(limiter.check("key1")); assert!(limiter.check("key2")); // different key, should be allowed assert!(!limiter.check("key1")); // key1 is exhausted } #[test] fn test_remaining_count() { let limiter = RateLimiter::new(5, Duration::from_secs(60)); assert_eq!(limiter.remaining("key1"), 5); limiter.check("key1"); assert_eq!(limiter.remaining("key1"), 4); limiter.check("key1"); limiter.check("key1"); assert_eq!(limiter.remaining("key1"), 2); } #[test] fn test_window_expiry() { let limiter = RateLimiter::new(1, Duration::from_millis(1)); assert!(limiter.check("key1")); // Wait for the window to expire std::thread::sleep(Duration::from_millis(10)); assert!(limiter.check("key1")); // should be allowed again } #[test] fn test_clone_shares_state() { let limiter = RateLimiter::new(2, Duration::from_secs(60)); let limiter2 = limiter.clone(); assert!(limiter.check("key1")); assert!(limiter2.check("key1")); // consumes second request via clone assert!(!limiter.check("key1")); // limit reached } // --- ProviderRateLimiter tests --- #[test] fn test_provider_limiter_allows_within_limit() { let limiter = ProviderRateLimiter::new(); limiter.update_provider_limit("gemini", 3, 60); assert!(limiter.check("gemini")); assert!(limiter.check("gemini")); assert!(limiter.check("gemini")); } #[test] fn test_provider_limiter_blocks_over_limit() { let limiter = ProviderRateLimiter::new(); limiter.update_provider_limit("openai", 2, 60); assert!(limiter.check("openai")); assert!(limiter.check("openai")); assert!(!limiter.check("openai")); // 3rd exceeds limit of 2 } #[test] fn test_provider_limiter_independent_providers() { let limiter = ProviderRateLimiter::new(); limiter.update_provider_limit("gemini", 1, 60); limiter.update_provider_limit("openai", 1, 60); assert!(limiter.check("gemini")); assert!(limiter.check("openai")); // different provider assert!(!limiter.check("gemini")); // gemini exhausted } #[test] fn test_provider_limiter_uses_defaults_for_unknown() { let limiter = ProviderRateLimiter::new(); // Unknown provider gets the default (29 requests per 60 seconds) for _ in 0..29 { assert!(limiter.check("new_provider")); } assert!(!limiter.check("new_provider")); } #[test] fn test_provider_limiter_remaining() { let limiter = ProviderRateLimiter::new(); limiter.update_provider_limit("anthropic", 5, 60); assert_eq!(limiter.remaining("anthropic"), 5); limiter.check("anthropic"); assert_eq!(limiter.remaining("anthropic"), 4); } #[test] fn test_provider_limiter_hot_reload_resets_window() { let limiter = ProviderRateLimiter::new(); limiter.update_provider_limit("gemini", 2, 60); assert!(limiter.check("gemini")); assert!(limiter.check("gemini")); assert!(!limiter.check("gemini")); // exhausted // Hot-reload with a higher limit — resets the window limiter.update_provider_limit("gemini", 5, 60); assert!(limiter.check("gemini")); // allowed again assert_eq!(limiter.remaining("gemini"), 4); } #[test] fn test_provider_limiter_get_config() { let limiter = ProviderRateLimiter::new(); limiter.update_provider_limit("gemini", 30, 120); let (max, window) = limiter.get_provider_config("gemini"); assert_eq!(max, 30); assert_eq!(window, 120); // Unknown provider returns defaults let (max, window) = limiter.get_provider_config("unknown"); assert_eq!(max, DEFAULT_MAX_REQUESTS); assert_eq!(window, DEFAULT_TIME_WINDOW_SECS); } #[test] fn test_provider_limiter_window_expiry() { let limiter = ProviderRateLimiter::new(); limiter.update_provider_limit("gemini", 1, 0); // 0 second window (instant expiry) // With a 0-second window, everything expires immediately on next check assert!(limiter.check("gemini")); std::thread::sleep(Duration::from_millis(10)); assert!(limiter.check("gemini")); // should be allowed again } #[test] fn test_provider_limiter_clone_shares_state() { let limiter = ProviderRateLimiter::new(); limiter.update_provider_limit("openai", 2, 60); let limiter2 = limiter.clone(); assert!(limiter.check("openai")); assert!(limiter2.check("openai")); assert!(!limiter.check("openai")); // limit reached via shared state } }