You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

219 lines
5.2 KiB
Go

package middleware
import (
"fmt"
"strconv"
"strings"
"time"
"github.com/gofiber/fiber/v3"
redisv9 "github.com/redis/go-redis/v9"
gconfig "knowfoolery/backend/services/gateway-service/internal/infra/config"
"knowfoolery/backend/shared/infra/auth/zitadel"
"knowfoolery/backend/shared/infra/observability/logging"
"knowfoolery/backend/shared/infra/utils/httputil"
)
type rateTier struct {
Name string
Limit int
Window time.Duration
RetrySec int
}
// RateLimitMiddleware applies Redis-backed sliding window limits with fail-open fallback.
func RateLimitMiddleware(
client *redisv9.Client,
cfg gconfig.RateLimitConfig,
prefix string,
logger *logging.Logger,
) fiber.Handler {
tiers := map[string]rateTier{
"general": {Name: "general", Limit: maxInt(cfg.GeneralRequests, 1),
Window: positiveDuration(cfg.Window, time.Minute)},
"auth": {Name: "auth", Limit: maxInt(cfg.AuthRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)},
"api": {Name: "api", Limit: maxInt(cfg.APIRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)},
"admin": {Name: "admin", Limit: maxInt(cfg.AdminRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)},
}
script := redisv9.NewScript(`
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local member = ARGV[4]
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
local count = redis.call('ZCARD', key)
if count >= limit then
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local reset = now + window
if oldest[2] ~= nil then
reset = tonumber(oldest[2]) + window
end
return {0, count, reset}
end
redis.call('ZADD', key, now, member)
redis.call('PEXPIRE', key, window)
count = count + 1
return {1, count, now + window}
`)
return func(c fiber.Ctx) error {
if c.Method() == fiber.MethodOptions {
return c.Next()
}
tier := selectTier(c.Path(), prefix, tiers)
if client == nil {
c.Set("X-RateLimit-Policy", "degraded")
c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit))
return c.Next()
}
now := time.Now()
nowMS := now.UnixMilli()
windowMS := tier.Window.Milliseconds()
member := fmt.Sprintf("%d:%s", now.UnixNano(), c.IP())
key := fmt.Sprintf("gateway:rate:%s:%s", tier.Name, identifyRequester(c))
result, err := script.Run(
c.Context(),
client,
[]string{key},
nowMS,
windowMS,
tier.Limit,
member,
).Result()
if err != nil {
if logger != nil {
logger.WithError(err).Warn("rate limiter degraded (redis unavailable)")
}
c.Set("X-RateLimit-Policy", "degraded")
c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit))
return c.Next()
}
allowed, count, resetMS, ok := parseLimiterResult(result)
if !ok {
if logger != nil {
logger.Warn("rate limiter degraded (unexpected redis response)")
}
c.Set("X-RateLimit-Policy", "degraded")
c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit))
return c.Next()
}
remaining := tier.Limit - int(count)
if remaining < 0 {
remaining = 0
}
resetUnix := time.UnixMilli(resetMS).Unix()
c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit))
c.Set("X-RateLimit-Remaining", strconv.Itoa(remaining))
c.Set("X-RateLimit-Reset", strconv.FormatInt(resetUnix, 10))
if !allowed {
retry := int(resetUnix - now.Unix())
if retry <= 0 {
retry = 1
}
if logger != nil {
logger.SecurityEvent(
"rate_limit_exceeded",
identifyRequester(c),
c.IP(),
"medium",
map[string]interface{}{"tier": tier.Name, "path": c.Path()},
)
}
return httputil.TooManyRequests(c, "Rate limit exceeded", retry)
}
return c.Next()
}
}
func selectTier(path string, prefix string, tiers map[string]rateTier) rateTier {
authPaths := []string{
prefix + "/admin/auth",
prefix + "/users/register",
prefix + "/users/verify-email",
}
for _, authPath := range authPaths {
if strings.HasPrefix(path, authPath) {
return tiers["auth"]
}
}
if strings.HasPrefix(path, prefix+"/admin") {
return tiers["admin"]
}
if strings.HasPrefix(path, prefix) {
return tiers["api"]
}
return tiers["general"]
}
func identifyRequester(c fiber.Ctx) string {
if raw := c.Locals(string(zitadel.ContextKeyUserID)); raw != nil {
if userID, ok := raw.(string); ok && userID != "" {
return "user:" + userID
}
}
return "ip:" + c.IP()
}
func parseLimiterResult(v interface{}) (allowed bool, count int64, resetMS int64, ok bool) {
arr, isSlice := v.([]interface{})
if !isSlice || len(arr) < 3 {
return false, 0, 0, false
}
allowedVal, okA := toInt64(arr[0])
count, okC := toInt64(arr[1])
resetMS, okR := toInt64(arr[2])
if !okA || !okC || !okR {
return false, 0, 0, false
}
return allowedVal == 1, count, resetMS, true
}
func toInt64(v interface{}) (int64, bool) {
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
case string:
parsed, err := strconv.ParseInt(t, 10, 64)
if err != nil {
return 0, false
}
return parsed, true
default:
return 0, false
}
}
func positiveDuration(v time.Duration, fallback time.Duration) time.Duration {
if v <= 0 {
return fallback
}
return v
}
func maxInt(v int, fallback int) int {
if v <= 0 {
return fallback
}
return v
}