You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

201 lines
4.3 KiB
Go

package middleware
import (
"fmt"
"sync"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/limiter"
"github.com/gofiber/storage/redis/v3"
"github.com/knowfoolery/backend/services/gateway-service/config"
)
type RateLimiter struct {
config config.RateLimitConfig
storage fiber.Storage
counters map[string]*Counter
mutex sync.RWMutex
}
type Counter struct {
Count int
ResetTime time.Time
mutex sync.Mutex
}
func NewRateLimiter(cfg config.RateLimitConfig) *RateLimiter {
var storage fiber.Storage
if cfg.Storage.Type == "redis" && cfg.Storage.RedisURL != "" {
storage = redis.New(redis.Config{
URL: cfg.Storage.RedisURL,
})
}
return &RateLimiter{
config: cfg,
storage: storage,
counters: make(map[string]*Counter),
}
}
func (rl *RateLimiter) Handler() fiber.Handler {
if !rl.config.Enabled {
return func(c fiber.Ctx) error {
return c.Next()
}
}
if rl.storage != nil {
return limiter.New(limiter.Config{
Max: rl.config.RequestsPerMinute,
Expiration: rl.config.WindowSize,
LimiterMiddleware: limiter.SlidingWindow{},
KeyGenerator: func(c fiber.Ctx) string {
return rl.generateKey(c)
},
LimitReached: rl.limitReached,
Storage: rl.storage,
})
}
return rl.memoryLimiter()
}
func (rl *RateLimiter) memoryLimiter() fiber.Handler {
return func(c fiber.Ctx) error {
key := rl.generateKey(c)
rl.mutex.RLock()
counter, exists := rl.counters[key]
rl.mutex.RUnlock()
now := time.Now()
if !exists {
rl.mutex.Lock()
counter = &Counter{
Count: 1,
ResetTime: now.Add(rl.config.WindowSize),
}
rl.counters[key] = counter
rl.mutex.Unlock()
} else {
counter.mutex.Lock()
if now.After(counter.ResetTime) {
counter.Count = 1
counter.ResetTime = now.Add(rl.config.WindowSize)
} else {
counter.Count++
}
if counter.Count > rl.config.RequestsPerMinute {
counter.mutex.Unlock()
return rl.limitReached(c)
}
counter.mutex.Unlock()
}
rl.setHeaders(c, counter)
go rl.cleanup()
return c.Next()
}
}
func (rl *RateLimiter) generateKey(c fiber.Ctx) string {
switch rl.config.KeyGenerator {
case "ip":
return c.IP()
case "user":
if userCtx := GetUserFromContext(c); userCtx != nil {
return userCtx.UserID
}
return c.IP()
case "ip_user":
if userCtx := GetUserFromContext(c); userCtx != nil {
return fmt.Sprintf("%s:%s", c.IP(), userCtx.UserID)
}
return c.IP()
default:
return c.IP()
}
}
func (rl *RateLimiter) limitReached(c fiber.Ctx) error {
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
"error": "Rate limit exceeded",
"retry_after": int(rl.config.WindowSize.Seconds()),
"limit": rl.config.RequestsPerMinute,
"window": rl.config.WindowSize.String(),
})
}
func (rl *RateLimiter) setHeaders(c fiber.Ctx, counter *Counter) {
remaining := rl.config.RequestsPerMinute - counter.Count
if remaining < 0 {
remaining = 0
}
resetTime := counter.ResetTime.Unix()
c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", rl.config.RequestsPerMinute))
c.Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime))
c.Set("X-RateLimit-Window", rl.config.WindowSize.String())
}
func (rl *RateLimiter) cleanup() {
rl.mutex.Lock()
defer rl.mutex.Unlock()
now := time.Now()
for key, counter := range rl.counters {
counter.mutex.Lock()
if now.After(counter.ResetTime.Add(time.Minute)) {
delete(rl.counters, key)
}
counter.mutex.Unlock()
}
}
func (rl *RateLimiter) GetStats() map[string]interface{} {
rl.mutex.RLock()
defer rl.mutex.RUnlock()
stats := make(map[string]interface{})
activeKeys := 0
totalRequests := 0
for key, counter := range rl.counters {
counter.mutex.Lock()
if time.Now().Before(counter.ResetTime) {
activeKeys++
totalRequests += counter.Count
}
counter.mutex.Unlock()
if activeKeys < 10 {
stats[fmt.Sprintf("key_%d", activeKeys)] = map[string]interface{}{
"key": key,
"count": counter.Count,
"reset_at": counter.ResetTime,
}
}
}
stats["summary"] = map[string]interface{}{
"active_keys": activeKeys,
"total_requests": totalRequests,
"limit_per_key": rl.config.RequestsPerMinute,
"window_size": rl.config.WindowSize.String(),
}
return stats
}