You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

421 lines
14 KiB
Rust

//! In-memory sliding window rate limiters.
//!
//! Two rate limiter types:
//! - `RateLimiter`: simple per-key limiter (used for auth rate limiting)
//! - `ProviderRateLimiter`: per-provider limiter with DB-backed config and hot-reload
//!
//! Both use `DashMap` for lock-free concurrent access.
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use sqlx::PgPool;
use crate::errors::AppError;
// ===========================================================================
// RateLimiter — simple per-key sliding window (auth, etc.)
// ===========================================================================
/// A thread-safe, in-memory sliding window rate limiter.
///
/// Tracks request timestamps per key and enforces a maximum number
/// of requests within a sliding time window.
#[derive(Clone)]
pub struct RateLimiter {
inner: Arc<RateLimiterInner>,
}
struct RateLimiterInner {
/// Maximum allowed requests per key within the time window.
max_requests: usize,
/// Duration of the sliding window.
window: Duration,
/// Per-key request timestamps.
entries: DashMap<String, Vec<Instant>>,
}
impl RateLimiter {
/// Create a new rate limiter.
///
/// # Arguments
/// * `max_requests` — maximum number of requests allowed per key within the window.
/// * `window` — duration of the sliding time window.
pub fn new(max_requests: usize, window: Duration) -> Self {
Self {
inner: Arc::new(RateLimiterInner {
max_requests,
window,
entries: DashMap::new(),
}),
}
}
/// Check if a request for the given key is allowed.
///
/// Returns `true` if the request is within the rate limit,
/// `false` if the limit has been exceeded.
/// If allowed, the current timestamp is recorded.
pub fn check(&self, key: &str) -> bool {
let now = Instant::now();
let cutoff = now - self.inner.window;
let mut entry = self.inner.entries.entry(key.to_string()).or_default();
let timestamps = entry.value_mut();
// Evict timestamps outside the window
timestamps.retain(|t| *t > cutoff);
if timestamps.len() >= self.inner.max_requests {
false
} else {
timestamps.push(now);
true
}
}
/// Returns the number of remaining requests allowed for the given key.
pub fn remaining(&self, key: &str) -> usize {
let now = Instant::now();
let cutoff = now - self.inner.window;
match self.inner.entries.get(key) {
Some(entry) => {
let recent_count = entry.value().iter().filter(|t| **t > cutoff).count();
self.inner.max_requests.saturating_sub(recent_count)
}
None => self.inner.max_requests,
}
}
}
// ===========================================================================
// ProviderRateLimiter — per-provider, DB-backed, hot-reloadable
// ===========================================================================
/// Configuration for a single provider's rate limit bucket.
struct ProviderBucket {
/// Request timestamps within the current window.
timestamps: VecDeque<Instant>,
/// Maximum requests allowed within the time window.
max_requests: u32,
/// Duration of the sliding window.
time_window: Duration,
}
/// Per-provider rate limiter with DB-backed configuration and hot-reload.
///
/// Global buckets are loaded from the `admin_rate_limits` table at startup
/// and can be reloaded at runtime when an admin changes the configuration.
#[derive(Clone)]
pub struct ProviderRateLimiter {
inner: Arc<ProviderRateLimiterInner>,
}
struct ProviderRateLimiterInner {
/// Per-provider global buckets (admin-configured defaults).
global_buckets: DashMap<String, ProviderBucket>,
}
/// Default rate limit: 29 requests per 60 seconds (conservative for Gemini).
const DEFAULT_MAX_REQUESTS: u32 = 29;
const DEFAULT_TIME_WINDOW_SECS: u64 = 60;
impl ProviderRateLimiter {
/// Create a new provider rate limiter with empty buckets.
///
/// Call `reload_from_db` after creation to load admin-configured limits.
pub fn new() -> Self {
Self {
inner: Arc::new(ProviderRateLimiterInner {
global_buckets: DashMap::new(),
}),
}
}
/// Load (or reload) all provider rate limits from the database.
///
/// Called at startup and after admin updates rate limits.
/// Resets the sliding window timestamps for reloaded providers.
pub async fn reload_from_db(&self, pool: &PgPool) -> Result<(), AppError> {
// Use a direct query to avoid circular dependency with db module
let rows = sqlx::query_as::<_, RateLimitConfigRow>(
r#"
SELECT provider_name, max_requests, time_window_seconds
FROM admin_rate_limits
"#,
)
.fetch_all(pool)
.await?;
for row in rows {
self.inner.global_buckets.insert(
row.provider_name,
ProviderBucket {
timestamps: VecDeque::new(),
max_requests: row.max_requests as u32,
time_window: Duration::from_secs(row.time_window_seconds as u64),
},
);
}
tracing::info!(
provider_count = self.inner.global_buckets.len(),
"Provider rate limits loaded from DB"
);
Ok(())
}
/// Check if a request for the given provider is allowed.
///
/// Returns `true` if the request is within the rate limit,
/// `false` if the limit has been exceeded.
pub fn check(&self, provider: &str) -> bool {
let now = Instant::now();
let mut bucket = self.inner.global_buckets.entry(provider.to_string()).or_insert_with(|| {
ProviderBucket {
timestamps: VecDeque::new(),
max_requests: DEFAULT_MAX_REQUESTS,
time_window: Duration::from_secs(DEFAULT_TIME_WINDOW_SECS),
}
});
let bucket = bucket.value_mut();
// Evict timestamps outside the window
let cutoff = now - bucket.time_window;
while bucket.timestamps.front().is_some_and(|t| *t < cutoff) {
bucket.timestamps.pop_front();
}
if bucket.timestamps.len() >= bucket.max_requests as usize {
false
} else {
bucket.timestamps.push_back(now);
true
}
}
/// Returns the number of remaining requests allowed for a provider.
pub fn remaining(&self, provider: &str) -> usize {
let now = Instant::now();
match self.inner.global_buckets.get(provider) {
Some(bucket) => {
let cutoff = now - bucket.time_window;
let recent_count = bucket.timestamps.iter().filter(|t| **t >= cutoff).count();
(bucket.max_requests as usize).saturating_sub(recent_count)
}
None => DEFAULT_MAX_REQUESTS as usize,
}
}
/// Hot-reload a single provider's rate limit configuration.
///
/// Called after an admin updates the rate limit for a specific provider.
/// Resets the sliding window timestamps for the provider.
pub fn update_provider_limit(
&self,
provider: &str,
max_requests: u32,
time_window_secs: u64,
) {
self.inner.global_buckets.insert(
provider.to_string(),
ProviderBucket {
timestamps: VecDeque::new(),
max_requests,
time_window: Duration::from_secs(time_window_secs),
},
);
tracing::info!(
provider = provider,
max_requests = max_requests,
time_window_secs = time_window_secs,
"Provider rate limit updated (hot-reload)"
);
}
/// Returns the configured limit for a provider, or the default if not configured.
pub fn get_provider_config(&self, provider: &str) -> (u32, u64) {
match self.inner.global_buckets.get(provider) {
Some(bucket) => (bucket.max_requests, bucket.time_window.as_secs()),
None => (DEFAULT_MAX_REQUESTS, DEFAULT_TIME_WINDOW_SECS),
}
}
}
/// Row type for reading rate limit configuration from DB.
#[derive(sqlx::FromRow)]
struct RateLimitConfigRow {
provider_name: String,
max_requests: i32,
time_window_seconds: i32,
}
// ===========================================================================
// Tests
// ===========================================================================
#[cfg(test)]
mod tests {
use super::*;
// --- RateLimiter tests ---
#[test]
fn test_allows_within_limit() {
let limiter = RateLimiter::new(3, Duration::from_secs(60));
assert!(limiter.check("key1"));
assert!(limiter.check("key1"));
assert!(limiter.check("key1"));
}
#[test]
fn test_blocks_over_limit() {
let limiter = RateLimiter::new(2, Duration::from_secs(60));
assert!(limiter.check("key1"));
assert!(limiter.check("key1"));
assert!(!limiter.check("key1")); // 3rd request exceeds limit of 2
}
#[test]
fn test_independent_keys() {
let limiter = RateLimiter::new(1, Duration::from_secs(60));
assert!(limiter.check("key1"));
assert!(limiter.check("key2")); // different key, should be allowed
assert!(!limiter.check("key1")); // key1 is exhausted
}
#[test]
fn test_remaining_count() {
let limiter = RateLimiter::new(5, Duration::from_secs(60));
assert_eq!(limiter.remaining("key1"), 5);
limiter.check("key1");
assert_eq!(limiter.remaining("key1"), 4);
limiter.check("key1");
limiter.check("key1");
assert_eq!(limiter.remaining("key1"), 2);
}
#[test]
fn test_window_expiry() {
let limiter = RateLimiter::new(1, Duration::from_millis(1));
assert!(limiter.check("key1"));
// Wait for the window to expire
std::thread::sleep(Duration::from_millis(10));
assert!(limiter.check("key1")); // should be allowed again
}
#[test]
fn test_clone_shares_state() {
let limiter = RateLimiter::new(2, Duration::from_secs(60));
let limiter2 = limiter.clone();
assert!(limiter.check("key1"));
assert!(limiter2.check("key1")); // consumes second request via clone
assert!(!limiter.check("key1")); // limit reached
}
// --- ProviderRateLimiter tests ---
#[test]
fn test_provider_limiter_allows_within_limit() {
let limiter = ProviderRateLimiter::new();
limiter.update_provider_limit("gemini", 3, 60);
assert!(limiter.check("gemini"));
assert!(limiter.check("gemini"));
assert!(limiter.check("gemini"));
}
#[test]
fn test_provider_limiter_blocks_over_limit() {
let limiter = ProviderRateLimiter::new();
limiter.update_provider_limit("openai", 2, 60);
assert!(limiter.check("openai"));
assert!(limiter.check("openai"));
assert!(!limiter.check("openai")); // 3rd exceeds limit of 2
}
#[test]
fn test_provider_limiter_independent_providers() {
let limiter = ProviderRateLimiter::new();
limiter.update_provider_limit("gemini", 1, 60);
limiter.update_provider_limit("openai", 1, 60);
assert!(limiter.check("gemini"));
assert!(limiter.check("openai")); // different provider
assert!(!limiter.check("gemini")); // gemini exhausted
}
#[test]
fn test_provider_limiter_uses_defaults_for_unknown() {
let limiter = ProviderRateLimiter::new();
// Unknown provider gets the default (29 requests per 60 seconds)
for _ in 0..29 {
assert!(limiter.check("new_provider"));
}
assert!(!limiter.check("new_provider"));
}
#[test]
fn test_provider_limiter_remaining() {
let limiter = ProviderRateLimiter::new();
limiter.update_provider_limit("anthropic", 5, 60);
assert_eq!(limiter.remaining("anthropic"), 5);
limiter.check("anthropic");
assert_eq!(limiter.remaining("anthropic"), 4);
}
#[test]
fn test_provider_limiter_hot_reload_resets_window() {
let limiter = ProviderRateLimiter::new();
limiter.update_provider_limit("gemini", 2, 60);
assert!(limiter.check("gemini"));
assert!(limiter.check("gemini"));
assert!(!limiter.check("gemini")); // exhausted
// Hot-reload with a higher limit — resets the window
limiter.update_provider_limit("gemini", 5, 60);
assert!(limiter.check("gemini")); // allowed again
assert_eq!(limiter.remaining("gemini"), 4);
}
#[test]
fn test_provider_limiter_get_config() {
let limiter = ProviderRateLimiter::new();
limiter.update_provider_limit("gemini", 30, 120);
let (max, window) = limiter.get_provider_config("gemini");
assert_eq!(max, 30);
assert_eq!(window, 120);
// Unknown provider returns defaults
let (max, window) = limiter.get_provider_config("unknown");
assert_eq!(max, DEFAULT_MAX_REQUESTS);
assert_eq!(window, DEFAULT_TIME_WINDOW_SECS);
}
#[test]
fn test_provider_limiter_window_expiry() {
let limiter = ProviderRateLimiter::new();
limiter.update_provider_limit("gemini", 1, 0); // 0 second window (instant expiry)
// With a 0-second window, everything expires immediately on next check
assert!(limiter.check("gemini"));
std::thread::sleep(Duration::from_millis(10));
assert!(limiter.check("gemini")); // should be allowed again
}
#[test]
fn test_provider_limiter_clone_shares_state() {
let limiter = ProviderRateLimiter::new();
limiter.update_provider_limit("openai", 2, 60);
let limiter2 = limiter.clone();
assert!(limiter.check("openai"));
assert!(limiter2.check("openai"));
assert!(!limiter.check("openai")); // limit reached via shared state
}
}