package security import ( "net" "strings" "github.com/gofiber/fiber/v3" "github.com/knowfoolery/backend/services/gateway-service/config" ) type Manager struct { config config.SecurityConfig ipWhitelist map[string]bool ipBlacklist map[string]bool } func NewManager(cfg config.SecurityConfig) *Manager { whitelist := make(map[string]bool) for _, ip := range cfg.IPWhitelist { whitelist[ip] = true } blacklist := make(map[string]bool) for _, ip := range cfg.IPBlacklist { blacklist[ip] = true } return &Manager{ config: cfg, ipWhitelist: whitelist, ipBlacklist: blacklist, } } func (sm *Manager) Handler() fiber.Handler { return func(c fiber.Ctx) error { if err := sm.checkIPRestrictions(c); err != nil { return err } if err := sm.checkRequestSize(c); err != nil { return err } if err := sm.validateContentType(c); err != nil { return err } sm.setSecurityHeaders(c) return c.Next() } } func (sm *Manager) checkIPRestrictions(c fiber.Ctx) error { clientIP := sm.getClientIP(c) if len(sm.ipBlacklist) > 0 && sm.ipBlacklist[clientIP] { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "IP address is blacklisted", }) } if len(sm.ipWhitelist) > 0 && !sm.ipWhitelist[clientIP] { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "IP address is not whitelisted", }) } return nil } func (sm *Manager) checkRequestSize(c fiber.Ctx) error { if sm.config.RequestSizeLimit > 0 { contentLength := len(c.Body()) if int64(contentLength) > sm.config.RequestSizeLimit { return c.Status(fiber.StatusRequestEntityTooLarge).JSON(fiber.Map{ "error": "Request entity too large", "max_size": sm.config.RequestSizeLimit, "actual_size": contentLength, }) } } return nil } func (sm *Manager) validateContentType(c fiber.Ctx) error { if !sm.config.ContentTypeValidation { return nil } if c.Method() == "GET" || c.Method() == "HEAD" || c.Method() == "OPTIONS" { return nil } contentType := c.Get("Content-Type") if contentType == "" { return nil } contentType = strings.ToLower(strings.Split(contentType, ";")[0]) for _, allowedType := range sm.config.AllowedContentTypes { if contentType == strings.ToLower(allowedType) { return nil } } return c.Status(fiber.StatusUnsupportedMediaType).JSON(fiber.Map{ "error": "Unsupported content type", "provided_type": contentType, "allowed_types": sm.config.AllowedContentTypes, }) } func (sm *Manager) setSecurityHeaders(c fiber.Ctx) { for header, value := range sm.config.SecurityHeaders { c.Set(header, value) } if sm.config.EnableXSSProtection { c.Set("X-XSS-Protection", "1; mode=block") } if sm.config.EnableCSRFProtection { c.Set("X-Content-Type-Options", "nosniff") } } func (sm *Manager) getClientIP(c fiber.Ctx) string { xForwardedFor := c.Get("X-Forwarded-For") if xForwardedFor != "" { ips := strings.Split(xForwardedFor, ",") if len(ips) > 0 { return strings.TrimSpace(ips[0]) } } xRealIP := c.Get("X-Real-IP") if xRealIP != "" { return xRealIP } remoteAddr := c.Context().RemoteAddr().String() if host, _, err := net.SplitHostPort(remoteAddr); err == nil { return host } return remoteAddr } func (sm *Manager) IsAllowedIP(ip string) bool { if len(sm.ipBlacklist) > 0 && sm.ipBlacklist[ip] { return false } if len(sm.ipWhitelist) > 0 { return sm.ipWhitelist[ip] } return true } func (sm *Manager) AddToWhitelist(ip string) { sm.ipWhitelist[ip] = true } func (sm *Manager) RemoveFromWhitelist(ip string) { delete(sm.ipWhitelist, ip) } func (sm *Manager) AddToBlacklist(ip string) { sm.ipBlacklist[ip] = true } func (sm *Manager) RemoveFromBlacklist(ip string) { delete(sm.ipBlacklist, ip) } func (sm *Manager) GetSecurityStats() map[string]interface{} { return map[string]interface{}{ "ip_whitelist_count": len(sm.ipWhitelist), "ip_blacklist_count": len(sm.ipBlacklist), "request_size_limit": sm.config.RequestSizeLimit, "csrf_protection_enabled": sm.config.EnableCSRFProtection, "xss_protection_enabled": sm.config.EnableXSSProtection, "content_type_validation": sm.config.ContentTypeValidation, "allowed_content_types": sm.config.AllowedContentTypes, "security_headers": len(sm.config.SecurityHeaders), } }