You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
421 lines
14 KiB
Rust
421 lines
14 KiB
Rust
//! In-memory sliding window rate limiters.
|
|
//!
|
|
//! Two rate limiter types:
|
|
//! - `RateLimiter`: simple per-key limiter (used for auth rate limiting)
|
|
//! - `ProviderRateLimiter`: per-provider limiter with DB-backed config and hot-reload
|
|
//!
|
|
//! Both use `DashMap` for lock-free concurrent access.
|
|
|
|
use std::collections::VecDeque;
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
|
|
use dashmap::DashMap;
|
|
use sqlx::PgPool;
|
|
|
|
use crate::errors::AppError;
|
|
|
|
// ===========================================================================
|
|
// RateLimiter — simple per-key sliding window (auth, etc.)
|
|
// ===========================================================================
|
|
|
|
/// A thread-safe, in-memory sliding window rate limiter.
|
|
///
|
|
/// Tracks request timestamps per key and enforces a maximum number
|
|
/// of requests within a sliding time window.
|
|
#[derive(Clone)]
|
|
pub struct RateLimiter {
|
|
inner: Arc<RateLimiterInner>,
|
|
}
|
|
|
|
struct RateLimiterInner {
|
|
/// Maximum allowed requests per key within the time window.
|
|
max_requests: usize,
|
|
/// Duration of the sliding window.
|
|
window: Duration,
|
|
/// Per-key request timestamps.
|
|
entries: DashMap<String, Vec<Instant>>,
|
|
}
|
|
|
|
impl RateLimiter {
|
|
/// Create a new rate limiter.
|
|
///
|
|
/// # Arguments
|
|
/// * `max_requests` — maximum number of requests allowed per key within the window.
|
|
/// * `window` — duration of the sliding time window.
|
|
pub fn new(max_requests: usize, window: Duration) -> Self {
|
|
Self {
|
|
inner: Arc::new(RateLimiterInner {
|
|
max_requests,
|
|
window,
|
|
entries: DashMap::new(),
|
|
}),
|
|
}
|
|
}
|
|
|
|
/// Check if a request for the given key is allowed.
|
|
///
|
|
/// Returns `true` if the request is within the rate limit,
|
|
/// `false` if the limit has been exceeded.
|
|
/// If allowed, the current timestamp is recorded.
|
|
pub fn check(&self, key: &str) -> bool {
|
|
let now = Instant::now();
|
|
let cutoff = now - self.inner.window;
|
|
|
|
let mut entry = self.inner.entries.entry(key.to_string()).or_default();
|
|
let timestamps = entry.value_mut();
|
|
|
|
// Evict timestamps outside the window
|
|
timestamps.retain(|t| *t > cutoff);
|
|
|
|
if timestamps.len() >= self.inner.max_requests {
|
|
false
|
|
} else {
|
|
timestamps.push(now);
|
|
true
|
|
}
|
|
}
|
|
|
|
/// Returns the number of remaining requests allowed for the given key.
|
|
pub fn remaining(&self, key: &str) -> usize {
|
|
let now = Instant::now();
|
|
let cutoff = now - self.inner.window;
|
|
|
|
match self.inner.entries.get(key) {
|
|
Some(entry) => {
|
|
let recent_count = entry.value().iter().filter(|t| **t > cutoff).count();
|
|
self.inner.max_requests.saturating_sub(recent_count)
|
|
}
|
|
None => self.inner.max_requests,
|
|
}
|
|
}
|
|
}
|
|
|
|
// ===========================================================================
|
|
// ProviderRateLimiter — per-provider, DB-backed, hot-reloadable
|
|
// ===========================================================================
|
|
|
|
/// Configuration for a single provider's rate limit bucket.
|
|
struct ProviderBucket {
|
|
/// Request timestamps within the current window.
|
|
timestamps: VecDeque<Instant>,
|
|
/// Maximum requests allowed within the time window.
|
|
max_requests: u32,
|
|
/// Duration of the sliding window.
|
|
time_window: Duration,
|
|
}
|
|
|
|
/// Per-provider rate limiter with DB-backed configuration and hot-reload.
|
|
///
|
|
/// Global buckets are loaded from the `admin_rate_limits` table at startup
|
|
/// and can be reloaded at runtime when an admin changes the configuration.
|
|
#[derive(Clone)]
|
|
pub struct ProviderRateLimiter {
|
|
inner: Arc<ProviderRateLimiterInner>,
|
|
}
|
|
|
|
struct ProviderRateLimiterInner {
|
|
/// Per-provider global buckets (admin-configured defaults).
|
|
global_buckets: DashMap<String, ProviderBucket>,
|
|
}
|
|
|
|
/// Default rate limit: 29 requests per 60 seconds (conservative for Gemini).
|
|
const DEFAULT_MAX_REQUESTS: u32 = 29;
|
|
const DEFAULT_TIME_WINDOW_SECS: u64 = 60;
|
|
|
|
impl ProviderRateLimiter {
|
|
/// Create a new provider rate limiter with empty buckets.
|
|
///
|
|
/// Call `reload_from_db` after creation to load admin-configured limits.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
inner: Arc::new(ProviderRateLimiterInner {
|
|
global_buckets: DashMap::new(),
|
|
}),
|
|
}
|
|
}
|
|
|
|
/// Load (or reload) all provider rate limits from the database.
|
|
///
|
|
/// Called at startup and after admin updates rate limits.
|
|
/// Resets the sliding window timestamps for reloaded providers.
|
|
pub async fn reload_from_db(&self, pool: &PgPool) -> Result<(), AppError> {
|
|
// Use a direct query to avoid circular dependency with db module
|
|
let rows = sqlx::query_as::<_, RateLimitConfigRow>(
|
|
r#"
|
|
SELECT provider_name, max_requests, time_window_seconds
|
|
FROM admin_rate_limits
|
|
"#,
|
|
)
|
|
.fetch_all(pool)
|
|
.await?;
|
|
|
|
for row in rows {
|
|
self.inner.global_buckets.insert(
|
|
row.provider_name,
|
|
ProviderBucket {
|
|
timestamps: VecDeque::new(),
|
|
max_requests: row.max_requests as u32,
|
|
time_window: Duration::from_secs(row.time_window_seconds as u64),
|
|
},
|
|
);
|
|
}
|
|
|
|
tracing::info!(
|
|
provider_count = self.inner.global_buckets.len(),
|
|
"Provider rate limits loaded from DB"
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Check if a request for the given provider is allowed.
|
|
///
|
|
/// Returns `true` if the request is within the rate limit,
|
|
/// `false` if the limit has been exceeded.
|
|
pub fn check(&self, provider: &str) -> bool {
|
|
let now = Instant::now();
|
|
|
|
let mut bucket = self.inner.global_buckets.entry(provider.to_string()).or_insert_with(|| {
|
|
ProviderBucket {
|
|
timestamps: VecDeque::new(),
|
|
max_requests: DEFAULT_MAX_REQUESTS,
|
|
time_window: Duration::from_secs(DEFAULT_TIME_WINDOW_SECS),
|
|
}
|
|
});
|
|
|
|
let bucket = bucket.value_mut();
|
|
|
|
// Evict timestamps outside the window
|
|
let cutoff = now - bucket.time_window;
|
|
while bucket.timestamps.front().is_some_and(|t| *t < cutoff) {
|
|
bucket.timestamps.pop_front();
|
|
}
|
|
|
|
if bucket.timestamps.len() >= bucket.max_requests as usize {
|
|
false
|
|
} else {
|
|
bucket.timestamps.push_back(now);
|
|
true
|
|
}
|
|
}
|
|
|
|
/// Returns the number of remaining requests allowed for a provider.
|
|
pub fn remaining(&self, provider: &str) -> usize {
|
|
let now = Instant::now();
|
|
|
|
match self.inner.global_buckets.get(provider) {
|
|
Some(bucket) => {
|
|
let cutoff = now - bucket.time_window;
|
|
let recent_count = bucket.timestamps.iter().filter(|t| **t >= cutoff).count();
|
|
(bucket.max_requests as usize).saturating_sub(recent_count)
|
|
}
|
|
None => DEFAULT_MAX_REQUESTS as usize,
|
|
}
|
|
}
|
|
|
|
/// Hot-reload a single provider's rate limit configuration.
|
|
///
|
|
/// Called after an admin updates the rate limit for a specific provider.
|
|
/// Resets the sliding window timestamps for the provider.
|
|
pub fn update_provider_limit(
|
|
&self,
|
|
provider: &str,
|
|
max_requests: u32,
|
|
time_window_secs: u64,
|
|
) {
|
|
self.inner.global_buckets.insert(
|
|
provider.to_string(),
|
|
ProviderBucket {
|
|
timestamps: VecDeque::new(),
|
|
max_requests,
|
|
time_window: Duration::from_secs(time_window_secs),
|
|
},
|
|
);
|
|
|
|
tracing::info!(
|
|
provider = provider,
|
|
max_requests = max_requests,
|
|
time_window_secs = time_window_secs,
|
|
"Provider rate limit updated (hot-reload)"
|
|
);
|
|
}
|
|
|
|
/// Returns the configured limit for a provider, or the default if not configured.
|
|
pub fn get_provider_config(&self, provider: &str) -> (u32, u64) {
|
|
match self.inner.global_buckets.get(provider) {
|
|
Some(bucket) => (bucket.max_requests, bucket.time_window.as_secs()),
|
|
None => (DEFAULT_MAX_REQUESTS, DEFAULT_TIME_WINDOW_SECS),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Row type for reading rate limit configuration from DB.
|
|
#[derive(sqlx::FromRow)]
|
|
struct RateLimitConfigRow {
|
|
provider_name: String,
|
|
max_requests: i32,
|
|
time_window_seconds: i32,
|
|
}
|
|
|
|
// ===========================================================================
|
|
// Tests
|
|
// ===========================================================================
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// --- RateLimiter tests ---
|
|
|
|
#[test]
|
|
fn test_allows_within_limit() {
|
|
let limiter = RateLimiter::new(3, Duration::from_secs(60));
|
|
assert!(limiter.check("key1"));
|
|
assert!(limiter.check("key1"));
|
|
assert!(limiter.check("key1"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_blocks_over_limit() {
|
|
let limiter = RateLimiter::new(2, Duration::from_secs(60));
|
|
assert!(limiter.check("key1"));
|
|
assert!(limiter.check("key1"));
|
|
assert!(!limiter.check("key1")); // 3rd request exceeds limit of 2
|
|
}
|
|
|
|
#[test]
|
|
fn test_independent_keys() {
|
|
let limiter = RateLimiter::new(1, Duration::from_secs(60));
|
|
assert!(limiter.check("key1"));
|
|
assert!(limiter.check("key2")); // different key, should be allowed
|
|
assert!(!limiter.check("key1")); // key1 is exhausted
|
|
}
|
|
|
|
#[test]
|
|
fn test_remaining_count() {
|
|
let limiter = RateLimiter::new(5, Duration::from_secs(60));
|
|
assert_eq!(limiter.remaining("key1"), 5);
|
|
limiter.check("key1");
|
|
assert_eq!(limiter.remaining("key1"), 4);
|
|
limiter.check("key1");
|
|
limiter.check("key1");
|
|
assert_eq!(limiter.remaining("key1"), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_window_expiry() {
|
|
let limiter = RateLimiter::new(1, Duration::from_millis(1));
|
|
assert!(limiter.check("key1"));
|
|
// Wait for the window to expire
|
|
std::thread::sleep(Duration::from_millis(10));
|
|
assert!(limiter.check("key1")); // should be allowed again
|
|
}
|
|
|
|
#[test]
|
|
fn test_clone_shares_state() {
|
|
let limiter = RateLimiter::new(2, Duration::from_secs(60));
|
|
let limiter2 = limiter.clone();
|
|
assert!(limiter.check("key1"));
|
|
assert!(limiter2.check("key1")); // consumes second request via clone
|
|
assert!(!limiter.check("key1")); // limit reached
|
|
}
|
|
|
|
// --- ProviderRateLimiter tests ---
|
|
|
|
#[test]
|
|
fn test_provider_limiter_allows_within_limit() {
|
|
let limiter = ProviderRateLimiter::new();
|
|
limiter.update_provider_limit("gemini", 3, 60);
|
|
assert!(limiter.check("gemini"));
|
|
assert!(limiter.check("gemini"));
|
|
assert!(limiter.check("gemini"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_limiter_blocks_over_limit() {
|
|
let limiter = ProviderRateLimiter::new();
|
|
limiter.update_provider_limit("openai", 2, 60);
|
|
assert!(limiter.check("openai"));
|
|
assert!(limiter.check("openai"));
|
|
assert!(!limiter.check("openai")); // 3rd exceeds limit of 2
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_limiter_independent_providers() {
|
|
let limiter = ProviderRateLimiter::new();
|
|
limiter.update_provider_limit("gemini", 1, 60);
|
|
limiter.update_provider_limit("openai", 1, 60);
|
|
assert!(limiter.check("gemini"));
|
|
assert!(limiter.check("openai")); // different provider
|
|
assert!(!limiter.check("gemini")); // gemini exhausted
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_limiter_uses_defaults_for_unknown() {
|
|
let limiter = ProviderRateLimiter::new();
|
|
// Unknown provider gets the default (29 requests per 60 seconds)
|
|
for _ in 0..29 {
|
|
assert!(limiter.check("new_provider"));
|
|
}
|
|
assert!(!limiter.check("new_provider"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_limiter_remaining() {
|
|
let limiter = ProviderRateLimiter::new();
|
|
limiter.update_provider_limit("anthropic", 5, 60);
|
|
assert_eq!(limiter.remaining("anthropic"), 5);
|
|
limiter.check("anthropic");
|
|
assert_eq!(limiter.remaining("anthropic"), 4);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_limiter_hot_reload_resets_window() {
|
|
let limiter = ProviderRateLimiter::new();
|
|
limiter.update_provider_limit("gemini", 2, 60);
|
|
assert!(limiter.check("gemini"));
|
|
assert!(limiter.check("gemini"));
|
|
assert!(!limiter.check("gemini")); // exhausted
|
|
|
|
// Hot-reload with a higher limit — resets the window
|
|
limiter.update_provider_limit("gemini", 5, 60);
|
|
assert!(limiter.check("gemini")); // allowed again
|
|
assert_eq!(limiter.remaining("gemini"), 4);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_limiter_get_config() {
|
|
let limiter = ProviderRateLimiter::new();
|
|
limiter.update_provider_limit("gemini", 30, 120);
|
|
let (max, window) = limiter.get_provider_config("gemini");
|
|
assert_eq!(max, 30);
|
|
assert_eq!(window, 120);
|
|
|
|
// Unknown provider returns defaults
|
|
let (max, window) = limiter.get_provider_config("unknown");
|
|
assert_eq!(max, DEFAULT_MAX_REQUESTS);
|
|
assert_eq!(window, DEFAULT_TIME_WINDOW_SECS);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_limiter_window_expiry() {
|
|
let limiter = ProviderRateLimiter::new();
|
|
limiter.update_provider_limit("gemini", 1, 0); // 0 second window (instant expiry)
|
|
// With a 0-second window, everything expires immediately on next check
|
|
assert!(limiter.check("gemini"));
|
|
std::thread::sleep(Duration::from_millis(10));
|
|
assert!(limiter.check("gemini")); // should be allowed again
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_limiter_clone_shares_state() {
|
|
let limiter = ProviderRateLimiter::new();
|
|
limiter.update_provider_limit("openai", 2, 60);
|
|
let limiter2 = limiter.clone();
|
|
assert!(limiter.check("openai"));
|
|
assert!(limiter2.check("openai"));
|
|
assert!(!limiter.check("openai")); // limit reached via shared state
|
|
}
|
|
}
|