package middleware import ( "fmt" "strconv" "strings" "time" "github.com/gofiber/fiber/v3" redisv9 "github.com/redis/go-redis/v9" gconfig "knowfoolery/backend/services/gateway-service/internal/infra/config" "knowfoolery/backend/shared/infra/auth/zitadel" "knowfoolery/backend/shared/infra/observability/logging" "knowfoolery/backend/shared/infra/utils/httputil" ) type rateTier struct { Name string Limit int Window time.Duration RetrySec int } // RateLimitMiddleware applies Redis-backed sliding window limits with fail-open fallback. func RateLimitMiddleware( client *redisv9.Client, cfg gconfig.RateLimitConfig, prefix string, logger *logging.Logger, ) fiber.Handler { tiers := map[string]rateTier{ "general": {Name: "general", Limit: maxInt(cfg.GeneralRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)}, "auth": {Name: "auth", Limit: maxInt(cfg.AuthRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)}, "api": {Name: "api", Limit: maxInt(cfg.APIRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)}, "admin": {Name: "admin", Limit: maxInt(cfg.AdminRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)}, } script := redisv9.NewScript(` local key = KEYS[1] local now = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local limit = tonumber(ARGV[3]) local member = ARGV[4] redis.call('ZREMRANGEBYSCORE', key, 0, now - window) local count = redis.call('ZCARD', key) if count >= limit then local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') local reset = now + window if oldest[2] ~= nil then reset = tonumber(oldest[2]) + window end return {0, count, reset} end redis.call('ZADD', key, now, member) redis.call('PEXPIRE', key, window) count = count + 1 return {1, count, now + window} `) return func(c fiber.Ctx) error { if c.Method() == fiber.MethodOptions { return c.Next() } tier := selectTier(c.Path(), prefix, tiers) if client == nil { c.Set("X-RateLimit-Policy", "degraded") c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit)) return c.Next() } now := time.Now() nowMS := now.UnixMilli() windowMS := tier.Window.Milliseconds() member := fmt.Sprintf("%d:%s", now.UnixNano(), c.IP()) key := fmt.Sprintf("gateway:rate:%s:%s", tier.Name, identifyRequester(c)) result, err := script.Run( c.Context(), client, []string{key}, nowMS, windowMS, tier.Limit, member, ).Result() if err != nil { if logger != nil { logger.WithError(err).Warn("rate limiter degraded (redis unavailable)") } c.Set("X-RateLimit-Policy", "degraded") c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit)) return c.Next() } allowed, count, resetMS, ok := parseLimiterResult(result) if !ok { if logger != nil { logger.Warn("rate limiter degraded (unexpected redis response)") } c.Set("X-RateLimit-Policy", "degraded") c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit)) return c.Next() } remaining := tier.Limit - int(count) if remaining < 0 { remaining = 0 } resetUnix := time.UnixMilli(resetMS).Unix() c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit)) c.Set("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Set("X-RateLimit-Reset", strconv.FormatInt(resetUnix, 10)) if !allowed { retry := int(resetUnix - now.Unix()) if retry <= 0 { retry = 1 } if logger != nil { logger.SecurityEvent( "rate_limit_exceeded", identifyRequester(c), c.IP(), "medium", map[string]interface{}{"tier": tier.Name, "path": c.Path()}, ) } return httputil.TooManyRequests(c, "Rate limit exceeded", retry) } return c.Next() } } func selectTier(path string, prefix string, tiers map[string]rateTier) rateTier { authPaths := []string{ prefix + "/admin/auth", prefix + "/users/register", prefix + "/users/verify-email", } for _, authPath := range authPaths { if strings.HasPrefix(path, authPath) { return tiers["auth"] } } if strings.HasPrefix(path, prefix+"/admin") { return tiers["admin"] } if strings.HasPrefix(path, prefix) { return tiers["api"] } return tiers["general"] } func identifyRequester(c fiber.Ctx) string { if raw := c.Locals(string(zitadel.ContextKeyUserID)); raw != nil { if userID, ok := raw.(string); ok && userID != "" { return "user:" + userID } } return "ip:" + c.IP() } func parseLimiterResult(v interface{}) (allowed bool, count int64, resetMS int64, ok bool) { arr, isSlice := v.([]interface{}) if !isSlice || len(arr) < 3 { return false, 0, 0, false } allowedVal, okA := toInt64(arr[0]) count, okC := toInt64(arr[1]) resetMS, okR := toInt64(arr[2]) if !okA || !okC || !okR { return false, 0, 0, false } return allowedVal == 1, count, resetMS, true } func toInt64(v interface{}) (int64, bool) { switch t := v.(type) { case int64: return t, true case int: return int64(t), true case string: parsed, err := strconv.ParseInt(t, 10, 64) if err != nil { return 0, false } return parsed, true default: return 0, false } } func positiveDuration(v time.Duration, fallback time.Duration) time.Duration { if v <= 0 { return fallback } return v } func maxInt(v int, fallback int) int { if v <= 0 { return fallback } return v }