package middleware import ( "fmt" "sync" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/limiter" "github.com/gofiber/storage/redis/v3" "github.com/knowfoolery/backend/services/gateway-service/config" ) type RateLimiter struct { config config.RateLimitConfig storage fiber.Storage counters map[string]*Counter mutex sync.RWMutex } type Counter struct { Count int ResetTime time.Time mutex sync.Mutex } func NewRateLimiter(cfg config.RateLimitConfig) *RateLimiter { var storage fiber.Storage if cfg.Storage.Type == "redis" && cfg.Storage.RedisURL != "" { storage = redis.New(redis.Config{ URL: cfg.Storage.RedisURL, }) } return &RateLimiter{ config: cfg, storage: storage, counters: make(map[string]*Counter), } } func (rl *RateLimiter) Handler() fiber.Handler { if !rl.config.Enabled { return func(c fiber.Ctx) error { return c.Next() } } if rl.storage != nil { return limiter.New(limiter.Config{ Max: rl.config.RequestsPerMinute, Expiration: rl.config.WindowSize, LimiterMiddleware: limiter.SlidingWindow{}, KeyGenerator: func(c fiber.Ctx) string { return rl.generateKey(c) }, LimitReached: rl.limitReached, Storage: rl.storage, }) } return rl.memoryLimiter() } func (rl *RateLimiter) memoryLimiter() fiber.Handler { return func(c fiber.Ctx) error { key := rl.generateKey(c) rl.mutex.RLock() counter, exists := rl.counters[key] rl.mutex.RUnlock() now := time.Now() if !exists { rl.mutex.Lock() counter = &Counter{ Count: 1, ResetTime: now.Add(rl.config.WindowSize), } rl.counters[key] = counter rl.mutex.Unlock() } else { counter.mutex.Lock() if now.After(counter.ResetTime) { counter.Count = 1 counter.ResetTime = now.Add(rl.config.WindowSize) } else { counter.Count++ } if counter.Count > rl.config.RequestsPerMinute { counter.mutex.Unlock() return rl.limitReached(c) } counter.mutex.Unlock() } rl.setHeaders(c, counter) go rl.cleanup() return c.Next() } } func (rl *RateLimiter) generateKey(c fiber.Ctx) string { switch rl.config.KeyGenerator { case "ip": return c.IP() case "user": if userCtx := GetUserFromContext(c); userCtx != nil { return userCtx.UserID } return c.IP() case "ip_user": if userCtx := GetUserFromContext(c); userCtx != nil { return fmt.Sprintf("%s:%s", c.IP(), userCtx.UserID) } return c.IP() default: return c.IP() } } func (rl *RateLimiter) limitReached(c fiber.Ctx) error { return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{ "error": "Rate limit exceeded", "retry_after": int(rl.config.WindowSize.Seconds()), "limit": rl.config.RequestsPerMinute, "window": rl.config.WindowSize.String(), }) } func (rl *RateLimiter) setHeaders(c fiber.Ctx, counter *Counter) { remaining := rl.config.RequestsPerMinute - counter.Count if remaining < 0 { remaining = 0 } resetTime := counter.ResetTime.Unix() c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", rl.config.RequestsPerMinute)) c.Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime)) c.Set("X-RateLimit-Window", rl.config.WindowSize.String()) } func (rl *RateLimiter) cleanup() { rl.mutex.Lock() defer rl.mutex.Unlock() now := time.Now() for key, counter := range rl.counters { counter.mutex.Lock() if now.After(counter.ResetTime.Add(time.Minute)) { delete(rl.counters, key) } counter.mutex.Unlock() } } func (rl *RateLimiter) GetStats() map[string]interface{} { rl.mutex.RLock() defer rl.mutex.RUnlock() stats := make(map[string]interface{}) activeKeys := 0 totalRequests := 0 for key, counter := range rl.counters { counter.mutex.Lock() if time.Now().Before(counter.ResetTime) { activeKeys++ totalRequests += counter.Count } counter.mutex.Unlock() if activeKeys < 10 { stats[fmt.Sprintf("key_%d", activeKeys)] = map[string]interface{}{ "key": key, "count": counter.Count, "reset_at": counter.ResetTime, } } } stats["summary"] = map[string]interface{}{ "active_keys": activeKeys, "total_requests": totalRequests, "limit_per_key": rl.config.RequestsPerMinute, "window_size": rl.config.WindowSize.String(), } return stats }