You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
219 lines
5.2 KiB
Go
219 lines
5.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
redisv9 "github.com/redis/go-redis/v9"
|
|
|
|
gconfig "knowfoolery/backend/services/gateway-service/internal/infra/config"
|
|
"knowfoolery/backend/shared/infra/auth/zitadel"
|
|
"knowfoolery/backend/shared/infra/observability/logging"
|
|
"knowfoolery/backend/shared/infra/utils/httputil"
|
|
)
|
|
|
|
type rateTier struct {
|
|
Name string
|
|
Limit int
|
|
Window time.Duration
|
|
RetrySec int
|
|
}
|
|
|
|
// RateLimitMiddleware applies Redis-backed sliding window limits with fail-open fallback.
|
|
func RateLimitMiddleware(
|
|
client *redisv9.Client,
|
|
cfg gconfig.RateLimitConfig,
|
|
prefix string,
|
|
logger *logging.Logger,
|
|
) fiber.Handler {
|
|
tiers := map[string]rateTier{
|
|
"general": {Name: "general", Limit: maxInt(cfg.GeneralRequests, 1),
|
|
Window: positiveDuration(cfg.Window, time.Minute)},
|
|
"auth": {Name: "auth", Limit: maxInt(cfg.AuthRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)},
|
|
"api": {Name: "api", Limit: maxInt(cfg.APIRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)},
|
|
"admin": {Name: "admin", Limit: maxInt(cfg.AdminRequests, 1), Window: positiveDuration(cfg.Window, time.Minute)},
|
|
}
|
|
|
|
script := redisv9.NewScript(`
|
|
local key = KEYS[1]
|
|
local now = tonumber(ARGV[1])
|
|
local window = tonumber(ARGV[2])
|
|
local limit = tonumber(ARGV[3])
|
|
local member = ARGV[4]
|
|
|
|
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
|
|
local count = redis.call('ZCARD', key)
|
|
|
|
if count >= limit then
|
|
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
|
local reset = now + window
|
|
if oldest[2] ~= nil then
|
|
reset = tonumber(oldest[2]) + window
|
|
end
|
|
return {0, count, reset}
|
|
end
|
|
|
|
redis.call('ZADD', key, now, member)
|
|
redis.call('PEXPIRE', key, window)
|
|
count = count + 1
|
|
|
|
return {1, count, now + window}
|
|
`)
|
|
|
|
return func(c fiber.Ctx) error {
|
|
if c.Method() == fiber.MethodOptions {
|
|
return c.Next()
|
|
}
|
|
|
|
tier := selectTier(c.Path(), prefix, tiers)
|
|
if client == nil {
|
|
c.Set("X-RateLimit-Policy", "degraded")
|
|
c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit))
|
|
return c.Next()
|
|
}
|
|
|
|
now := time.Now()
|
|
nowMS := now.UnixMilli()
|
|
windowMS := tier.Window.Milliseconds()
|
|
|
|
member := fmt.Sprintf("%d:%s", now.UnixNano(), c.IP())
|
|
key := fmt.Sprintf("gateway:rate:%s:%s", tier.Name, identifyRequester(c))
|
|
|
|
result, err := script.Run(
|
|
c.Context(),
|
|
client,
|
|
[]string{key},
|
|
nowMS,
|
|
windowMS,
|
|
tier.Limit,
|
|
member,
|
|
).Result()
|
|
if err != nil {
|
|
if logger != nil {
|
|
logger.WithError(err).Warn("rate limiter degraded (redis unavailable)")
|
|
}
|
|
c.Set("X-RateLimit-Policy", "degraded")
|
|
c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit))
|
|
return c.Next()
|
|
}
|
|
|
|
allowed, count, resetMS, ok := parseLimiterResult(result)
|
|
if !ok {
|
|
if logger != nil {
|
|
logger.Warn("rate limiter degraded (unexpected redis response)")
|
|
}
|
|
c.Set("X-RateLimit-Policy", "degraded")
|
|
c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit))
|
|
return c.Next()
|
|
}
|
|
|
|
remaining := tier.Limit - int(count)
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
resetUnix := time.UnixMilli(resetMS).Unix()
|
|
c.Set("X-RateLimit-Limit", strconv.Itoa(tier.Limit))
|
|
c.Set("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
|
c.Set("X-RateLimit-Reset", strconv.FormatInt(resetUnix, 10))
|
|
|
|
if !allowed {
|
|
retry := int(resetUnix - now.Unix())
|
|
if retry <= 0 {
|
|
retry = 1
|
|
}
|
|
if logger != nil {
|
|
logger.SecurityEvent(
|
|
"rate_limit_exceeded",
|
|
identifyRequester(c),
|
|
c.IP(),
|
|
"medium",
|
|
map[string]interface{}{"tier": tier.Name, "path": c.Path()},
|
|
)
|
|
}
|
|
return httputil.TooManyRequests(c, "Rate limit exceeded", retry)
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func selectTier(path string, prefix string, tiers map[string]rateTier) rateTier {
|
|
authPaths := []string{
|
|
prefix + "/admin/auth",
|
|
prefix + "/users/register",
|
|
prefix + "/users/verify-email",
|
|
}
|
|
for _, authPath := range authPaths {
|
|
if strings.HasPrefix(path, authPath) {
|
|
return tiers["auth"]
|
|
}
|
|
}
|
|
|
|
if strings.HasPrefix(path, prefix+"/admin") {
|
|
return tiers["admin"]
|
|
}
|
|
if strings.HasPrefix(path, prefix) {
|
|
return tiers["api"]
|
|
}
|
|
return tiers["general"]
|
|
}
|
|
|
|
func identifyRequester(c fiber.Ctx) string {
|
|
if raw := c.Locals(string(zitadel.ContextKeyUserID)); raw != nil {
|
|
if userID, ok := raw.(string); ok && userID != "" {
|
|
return "user:" + userID
|
|
}
|
|
}
|
|
return "ip:" + c.IP()
|
|
}
|
|
|
|
func parseLimiterResult(v interface{}) (allowed bool, count int64, resetMS int64, ok bool) {
|
|
arr, isSlice := v.([]interface{})
|
|
if !isSlice || len(arr) < 3 {
|
|
return false, 0, 0, false
|
|
}
|
|
|
|
allowedVal, okA := toInt64(arr[0])
|
|
count, okC := toInt64(arr[1])
|
|
resetMS, okR := toInt64(arr[2])
|
|
if !okA || !okC || !okR {
|
|
return false, 0, 0, false
|
|
}
|
|
|
|
return allowedVal == 1, count, resetMS, true
|
|
}
|
|
|
|
func toInt64(v interface{}) (int64, bool) {
|
|
switch t := v.(type) {
|
|
case int64:
|
|
return t, true
|
|
case int:
|
|
return int64(t), true
|
|
case string:
|
|
parsed, err := strconv.ParseInt(t, 10, 64)
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
return parsed, true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func positiveDuration(v time.Duration, fallback time.Duration) time.Duration {
|
|
if v <= 0 {
|
|
return fallback
|
|
}
|
|
return v
|
|
}
|
|
|
|
func maxInt(v int, fallback int) int {
|
|
if v <= 0 {
|
|
return fallback
|
|
}
|
|
return v
|
|
}
|