You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
4.3 KiB
Go

package security
import (
"net"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/knowfoolery/backend/services/gateway-service/config"
)
type Manager struct {
config config.SecurityConfig
ipWhitelist map[string]bool
ipBlacklist map[string]bool
}
func NewManager(cfg config.SecurityConfig) *Manager {
whitelist := make(map[string]bool)
for _, ip := range cfg.IPWhitelist {
whitelist[ip] = true
}
blacklist := make(map[string]bool)
for _, ip := range cfg.IPBlacklist {
blacklist[ip] = true
}
return &Manager{
config: cfg,
ipWhitelist: whitelist,
ipBlacklist: blacklist,
}
}
func (sm *Manager) Handler() fiber.Handler {
return func(c fiber.Ctx) error {
if err := sm.checkIPRestrictions(c); err != nil {
return err
}
if err := sm.checkRequestSize(c); err != nil {
return err
}
if err := sm.validateContentType(c); err != nil {
return err
}
sm.setSecurityHeaders(c)
return c.Next()
}
}
func (sm *Manager) checkIPRestrictions(c fiber.Ctx) error {
clientIP := sm.getClientIP(c)
if len(sm.ipBlacklist) > 0 && sm.ipBlacklist[clientIP] {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "IP address is blacklisted",
})
}
if len(sm.ipWhitelist) > 0 && !sm.ipWhitelist[clientIP] {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "IP address is not whitelisted",
})
}
return nil
}
func (sm *Manager) checkRequestSize(c fiber.Ctx) error {
if sm.config.RequestSizeLimit > 0 {
contentLength := len(c.Body())
if int64(contentLength) > sm.config.RequestSizeLimit {
return c.Status(fiber.StatusRequestEntityTooLarge).JSON(fiber.Map{
"error": "Request entity too large",
"max_size": sm.config.RequestSizeLimit,
"actual_size": contentLength,
})
}
}
return nil
}
func (sm *Manager) validateContentType(c fiber.Ctx) error {
if !sm.config.ContentTypeValidation {
return nil
}
if c.Method() == "GET" || c.Method() == "HEAD" || c.Method() == "OPTIONS" {
return nil
}
contentType := c.Get("Content-Type")
if contentType == "" {
return nil
}
contentType = strings.ToLower(strings.Split(contentType, ";")[0])
for _, allowedType := range sm.config.AllowedContentTypes {
if contentType == strings.ToLower(allowedType) {
return nil
}
}
return c.Status(fiber.StatusUnsupportedMediaType).JSON(fiber.Map{
"error": "Unsupported content type",
"provided_type": contentType,
"allowed_types": sm.config.AllowedContentTypes,
})
}
func (sm *Manager) setSecurityHeaders(c fiber.Ctx) {
for header, value := range sm.config.SecurityHeaders {
c.Set(header, value)
}
if sm.config.EnableXSSProtection {
c.Set("X-XSS-Protection", "1; mode=block")
}
if sm.config.EnableCSRFProtection {
c.Set("X-Content-Type-Options", "nosniff")
}
}
func (sm *Manager) getClientIP(c fiber.Ctx) string {
xForwardedFor := c.Get("X-Forwarded-For")
if xForwardedFor != "" {
ips := strings.Split(xForwardedFor, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
xRealIP := c.Get("X-Real-IP")
if xRealIP != "" {
return xRealIP
}
remoteAddr := c.Context().RemoteAddr().String()
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
return host
}
return remoteAddr
}
func (sm *Manager) IsAllowedIP(ip string) bool {
if len(sm.ipBlacklist) > 0 && sm.ipBlacklist[ip] {
return false
}
if len(sm.ipWhitelist) > 0 {
return sm.ipWhitelist[ip]
}
return true
}
func (sm *Manager) AddToWhitelist(ip string) {
sm.ipWhitelist[ip] = true
}
func (sm *Manager) RemoveFromWhitelist(ip string) {
delete(sm.ipWhitelist, ip)
}
func (sm *Manager) AddToBlacklist(ip string) {
sm.ipBlacklist[ip] = true
}
func (sm *Manager) RemoveFromBlacklist(ip string) {
delete(sm.ipBlacklist, ip)
}
func (sm *Manager) GetSecurityStats() map[string]interface{} {
return map[string]interface{}{
"ip_whitelist_count": len(sm.ipWhitelist),
"ip_blacklist_count": len(sm.ipBlacklist),
"request_size_limit": sm.config.RequestSizeLimit,
"csrf_protection_enabled": sm.config.EnableCSRFProtection,
"xss_protection_enabled": sm.config.EnableXSSProtection,
"content_type_validation": sm.config.ContentTypeValidation,
"allowed_content_types": sm.config.AllowedContentTypes,
"security_headers": len(sm.config.SecurityHeaders),
}
}