You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
201 lines
4.3 KiB
Go
201 lines
4.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/gofiber/fiber/v3/middleware/limiter"
|
|
"github.com/gofiber/storage/redis/v3"
|
|
|
|
"github.com/knowfoolery/backend/services/gateway-service/config"
|
|
)
|
|
|
|
type RateLimiter struct {
|
|
config config.RateLimitConfig
|
|
storage fiber.Storage
|
|
counters map[string]*Counter
|
|
mutex sync.RWMutex
|
|
}
|
|
|
|
type Counter struct {
|
|
Count int
|
|
ResetTime time.Time
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
func NewRateLimiter(cfg config.RateLimitConfig) *RateLimiter {
|
|
var storage fiber.Storage
|
|
|
|
if cfg.Storage.Type == "redis" && cfg.Storage.RedisURL != "" {
|
|
storage = redis.New(redis.Config{
|
|
URL: cfg.Storage.RedisURL,
|
|
})
|
|
}
|
|
|
|
return &RateLimiter{
|
|
config: cfg,
|
|
storage: storage,
|
|
counters: make(map[string]*Counter),
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) Handler() fiber.Handler {
|
|
if !rl.config.Enabled {
|
|
return func(c fiber.Ctx) error {
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
if rl.storage != nil {
|
|
return limiter.New(limiter.Config{
|
|
Max: rl.config.RequestsPerMinute,
|
|
Expiration: rl.config.WindowSize,
|
|
LimiterMiddleware: limiter.SlidingWindow{},
|
|
KeyGenerator: func(c fiber.Ctx) string {
|
|
return rl.generateKey(c)
|
|
},
|
|
LimitReached: rl.limitReached,
|
|
Storage: rl.storage,
|
|
})
|
|
}
|
|
|
|
return rl.memoryLimiter()
|
|
}
|
|
|
|
func (rl *RateLimiter) memoryLimiter() fiber.Handler {
|
|
return func(c fiber.Ctx) error {
|
|
key := rl.generateKey(c)
|
|
|
|
rl.mutex.RLock()
|
|
counter, exists := rl.counters[key]
|
|
rl.mutex.RUnlock()
|
|
|
|
now := time.Now()
|
|
|
|
if !exists {
|
|
rl.mutex.Lock()
|
|
counter = &Counter{
|
|
Count: 1,
|
|
ResetTime: now.Add(rl.config.WindowSize),
|
|
}
|
|
rl.counters[key] = counter
|
|
rl.mutex.Unlock()
|
|
} else {
|
|
counter.mutex.Lock()
|
|
|
|
if now.After(counter.ResetTime) {
|
|
counter.Count = 1
|
|
counter.ResetTime = now.Add(rl.config.WindowSize)
|
|
} else {
|
|
counter.Count++
|
|
}
|
|
|
|
if counter.Count > rl.config.RequestsPerMinute {
|
|
counter.mutex.Unlock()
|
|
return rl.limitReached(c)
|
|
}
|
|
|
|
counter.mutex.Unlock()
|
|
}
|
|
|
|
rl.setHeaders(c, counter)
|
|
|
|
go rl.cleanup()
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) generateKey(c fiber.Ctx) string {
|
|
switch rl.config.KeyGenerator {
|
|
case "ip":
|
|
return c.IP()
|
|
case "user":
|
|
if userCtx := GetUserFromContext(c); userCtx != nil {
|
|
return userCtx.UserID
|
|
}
|
|
return c.IP()
|
|
case "ip_user":
|
|
if userCtx := GetUserFromContext(c); userCtx != nil {
|
|
return fmt.Sprintf("%s:%s", c.IP(), userCtx.UserID)
|
|
}
|
|
return c.IP()
|
|
default:
|
|
return c.IP()
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) limitReached(c fiber.Ctx) error {
|
|
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
|
|
"error": "Rate limit exceeded",
|
|
"retry_after": int(rl.config.WindowSize.Seconds()),
|
|
"limit": rl.config.RequestsPerMinute,
|
|
"window": rl.config.WindowSize.String(),
|
|
})
|
|
}
|
|
|
|
func (rl *RateLimiter) setHeaders(c fiber.Ctx, counter *Counter) {
|
|
remaining := rl.config.RequestsPerMinute - counter.Count
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
|
|
resetTime := counter.ResetTime.Unix()
|
|
|
|
c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", rl.config.RequestsPerMinute))
|
|
c.Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
|
c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime))
|
|
c.Set("X-RateLimit-Window", rl.config.WindowSize.String())
|
|
}
|
|
|
|
func (rl *RateLimiter) cleanup() {
|
|
rl.mutex.Lock()
|
|
defer rl.mutex.Unlock()
|
|
|
|
now := time.Now()
|
|
|
|
for key, counter := range rl.counters {
|
|
counter.mutex.Lock()
|
|
if now.After(counter.ResetTime.Add(time.Minute)) {
|
|
delete(rl.counters, key)
|
|
}
|
|
counter.mutex.Unlock()
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) GetStats() map[string]interface{} {
|
|
rl.mutex.RLock()
|
|
defer rl.mutex.RUnlock()
|
|
|
|
stats := make(map[string]interface{})
|
|
activeKeys := 0
|
|
totalRequests := 0
|
|
|
|
for key, counter := range rl.counters {
|
|
counter.mutex.Lock()
|
|
if time.Now().Before(counter.ResetTime) {
|
|
activeKeys++
|
|
totalRequests += counter.Count
|
|
}
|
|
counter.mutex.Unlock()
|
|
|
|
if activeKeys < 10 {
|
|
stats[fmt.Sprintf("key_%d", activeKeys)] = map[string]interface{}{
|
|
"key": key,
|
|
"count": counter.Count,
|
|
"reset_at": counter.ResetTime,
|
|
}
|
|
}
|
|
}
|
|
|
|
stats["summary"] = map[string]interface{}{
|
|
"active_keys": activeKeys,
|
|
"total_requests": totalRequests,
|
|
"limit_per_key": rl.config.RequestsPerMinute,
|
|
"window_size": rl.config.WindowSize.String(),
|
|
}
|
|
|
|
return stats
|
|
} |