You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
48 KiB
48 KiB
Know Foolery - Detailed Security Implementation Guidelines
Authentication & Authorization
OAuth 2.0/OIDC Implementation
Zitadel Authentication Flow
// Secure authentication implementation
package auth
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/argon2"
)
type SecureAuthService struct {
zitadelRepo ZitadelRepository
keyStore *KeyStore
sessionManager *SessionManager
auditLogger *AuditLogger
rateLimiter *RateLimiter
}
type AuthClaims struct {
jwt.RegisteredClaims
Email string `json:"email"`
Name string `json:"name"`
Roles []string `json:"urn:zitadel:iam:org:project:roles"`
MFAVerified bool `json:"amr"`
SessionID string `json:"sid"`
DeviceID string `json:"device_id,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
}
// Token validation with comprehensive security checks
func (s *SecureAuthService) ValidateToken(ctx context.Context, tokenString string, clientIP string) (*AuthClaims, error) {
// Rate limiting check
if !s.rateLimiter.Allow(fmt.Sprintf("token_validation:%s", clientIP)) {
s.auditLogger.LogSecurityEvent("rate_limit_exceeded", "", clientIP, "warning", map[string]interface{}{
"operation": "token_validation",
})
return nil, fmt.Errorf("rate limit exceeded")
}
// Parse and validate JWT structure
token, err := jwt.ParseWithClaims(tokenString, &AuthClaims{}, func(token *jwt.Token) (interface{}, error) {
// Verify signing algorithm
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// Get key ID from header
keyID, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("missing key ID in token header")
}
// Retrieve public key
return s.keyStore.GetPublicKey(keyID)
})
if err != nil {
s.auditLogger.LogSecurityEvent("invalid_token", "", clientIP, "warning", map[string]interface{}{
"error": err.Error(),
})
return nil, fmt.Errorf("token validation failed: %w", err)
}
if !token.Valid {
return nil, fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(*AuthClaims)
if !ok {
return nil, fmt.Errorf("invalid token claims")
}
// Additional security validations
if err := s.validateTokenClaims(claims, clientIP); err != nil {
s.auditLogger.LogSecurityEvent("token_validation_failed", claims.Subject, clientIP, "warning", map[string]interface{}{
"error": err.Error(),
})
return nil, err
}
// Check if session is still valid
if !s.sessionManager.IsValidSession(claims.SessionID, claims.Subject) {
s.auditLogger.LogSecurityEvent("invalid_session", claims.Subject, clientIP, "warning", map[string]interface{}{
"session_id": claims.SessionID,
})
return nil, fmt.Errorf("session no longer valid")
}
// Log successful validation
s.auditLogger.LogSecurityEvent("token_validated", claims.Subject, clientIP, "info", map[string]interface{}{
"session_id": claims.SessionID,
"roles": claims.Roles,
})
return claims, nil
}
func (s *SecureAuthService) validateTokenClaims(claims *AuthClaims, clientIP string) error {
now := time.Now()
// Check expiration
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(now) {
return fmt.Errorf("token expired")
}
// Check not before
if claims.NotBefore != nil && claims.NotBefore.Time.After(now) {
return fmt.Errorf("token not yet valid")
}
// Check issued at (prevent future tokens)
if claims.IssuedAt != nil && claims.IssuedAt.Time.After(now.Add(5*time.Minute)) {
return fmt.Errorf("token issued in the future")
}
// Validate audience
expectedAudience := "knowfoolery-quiz-game"
if !contains(claims.Audience, expectedAudience) {
return fmt.Errorf("invalid audience")
}
// Validate issuer
expectedIssuer := "https://auth.knowfoolery.com"
if claims.Issuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
// IP address validation (if configured)
if claims.IPAddress != "" && claims.IPAddress != clientIP {
return fmt.Errorf("IP address mismatch")
}
return nil
}
// Multi-Factor Authentication enforcement
func (s *SecureAuthService) RequireMFA(userID string, roles []string) bool {
// Admin users always require MFA
for _, role := range roles {
if role == "admin" {
return true
}
}
// Check if user has high-value permissions
return s.hasHighValuePermissions(roles)
}
func (s *SecureAuthService) hasHighValuePermissions(roles []string) bool {
highValueRoles := []string{"moderator", "question_manager", "user_manager"}
for _, userRole := range roles {
for _, hvRole := range highValueRoles {
if userRole == hvRole {
return true
}
}
}
return false
}
Session Management
// Secure session management
package session
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"time"
"github.com/go-redis/redis/v8"
)
type SessionManager struct {
redis *redis.Client
entropy int
maxAge time.Duration
secureCookie bool
}
type Session struct {
ID string `json:"id"`
UserID string `json:"user_id"`
CreatedAt time.Time `json:"created_at"`
LastSeen time.Time `json:"last_seen"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
DeviceID string `json:"device_id,omitempty"`
Roles []string `json:"roles"`
MFAVerified bool `json:"mfa_verified"`
}
func NewSessionManager(redis *redis.Client, secureCookie bool) *SessionManager {
return &SessionManager{
redis: redis,
entropy: 32, // 256 bits of entropy
maxAge: 24 * time.Hour,
secureCookie: secureCookie,
}
}
func (sm *SessionManager) CreateSession(ctx context.Context, userID, ipAddress, userAgent string, roles []string, mfaVerified bool) (*Session, error) {
// Generate cryptographically secure session ID
sessionID, err := sm.generateSecureSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
session := &Session{
ID: sessionID,
UserID: userID,
CreatedAt: time.Now(),
LastSeen: time.Now(),
IPAddress: ipAddress,
UserAgent: userAgent,
DeviceID: sm.generateDeviceFingerprint(ipAddress, userAgent),
Roles: roles,
MFAVerified: mfaVerified,
}
// Store session in Redis with expiration
sessionKey := fmt.Sprintf("session:%s", sessionID)
sessionData, err := json.Marshal(session)
if err != nil {
return nil, fmt.Errorf("failed to marshal session: %w", err)
}
err = sm.redis.SetEX(ctx, sessionKey, sessionData, sm.maxAge).Err()
if err != nil {
return nil, fmt.Errorf("failed to store session: %w", err)
}
// Store user session mapping for concurrent session limiting
userSessionKey := fmt.Sprintf("user_sessions:%s", userID)
sm.redis.SAdd(ctx, userSessionKey, sessionID)
sm.redis.Expire(ctx, userSessionKey, sm.maxAge)
return session, nil
}
func (sm *SessionManager) GetSession(ctx context.Context, sessionID string) (*Session, error) {
sessionKey := fmt.Sprintf("session:%s", sessionID)
sessionData, err := sm.redis.Get(ctx, sessionKey).Result()
if err == redis.Nil {
return nil, fmt.Errorf("session not found")
} else if err != nil {
return nil, fmt.Errorf("failed to retrieve session: %w", err)
}
var session Session
err = json.Unmarshal([]byte(sessionData), &session)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
}
return &session, nil
}
func (sm *SessionManager) IsValidSession(sessionID, userID string) bool {
ctx := context.Background()
session, err := sm.GetSession(ctx, sessionID)
if err != nil {
return false
}
// Check if session belongs to the correct user
if session.UserID != userID {
return false
}
// Update last seen timestamp
session.LastSeen = time.Now()
sm.updateSession(ctx, session)
return true
}
func (sm *SessionManager) InvalidateSession(ctx context.Context, sessionID string) error {
session, err := sm.GetSession(ctx, sessionID)
if err != nil {
return err
}
// Remove from Redis
sessionKey := fmt.Sprintf("session:%s", sessionID)
sm.redis.Del(ctx, sessionKey)
// Remove from user sessions set
userSessionKey := fmt.Sprintf("user_sessions:%s", session.UserID)
sm.redis.SRem(ctx, userSessionKey, sessionID)
return nil
}
func (sm *SessionManager) InvalidateAllUserSessions(ctx context.Context, userID string) error {
userSessionKey := fmt.Sprintf("user_sessions:%s", userID)
sessionIDs, err := sm.redis.SMembers(ctx, userSessionKey).Result()
if err != nil {
return err
}
// Remove all sessions
for _, sessionID := range sessionIDs {
sessionKey := fmt.Sprintf("session:%s", sessionID)
sm.redis.Del(ctx, sessionKey)
}
// Clear user sessions set
sm.redis.Del(ctx, userSessionKey)
return nil
}
func (sm *SessionManager) generateSecureSessionID() (string, error) {
bytes := make([]byte, sm.entropy)
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
// Hash the random bytes for additional security
hash := sha256.Sum256(bytes)
return hex.EncodeToString(hash[:]), nil
}
func (sm *SessionManager) generateDeviceFingerprint(ipAddress, userAgent string) string {
data := fmt.Sprintf("%s:%s", ipAddress, userAgent)
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:16]) // 128-bit fingerprint
}
func (sm *SessionManager) updateSession(ctx context.Context, session *Session) error {
sessionKey := fmt.Sprintf("session:%s", session.ID)
sessionData, err := json.Marshal(session)
if err != nil {
return err
}
return sm.redis.SetEX(ctx, sessionKey, sessionData, sm.maxAge).Err()
}
Input Validation & Sanitization
Comprehensive Input Validation
// Input validation and sanitization framework
package validation
import (
"fmt"
"html"
"regexp"
"strings"
"unicode"
"github.com/go-playground/validator/v10"
)
type InputValidator struct {
validator *validator.Validate
rules map[string]*ValidationRule
}
type ValidationRule struct {
MaxLength int
MinLength int
Pattern *regexp.Regexp
AllowedChars *regexp.Regexp
Sanitizer func(string) string
}
func NewInputValidator() *InputValidator {
v := validator.New()
// Register custom validations
v.RegisterValidation("alphanum_space", validateAlphanumSpace)
v.RegisterValidation("no_html", validateNoHTML)
v.RegisterValidation("safe_text", validateSafeText)
iv := &InputValidator{
validator: v,
rules: make(map[string]*ValidationRule),
}
iv.setupValidationRules()
return iv
}
func (iv *InputValidator) setupValidationRules() {
// Player name validation
iv.rules["player_name"] = &ValidationRule{
MaxLength: 50,
MinLength: 2,
AllowedChars: regexp.MustCompile(`^[a-zA-Z0-9\s\-_.]+$`),
Sanitizer: iv.sanitizePlayerName,
}
// Answer validation
iv.rules["answer"] = &ValidationRule{
MaxLength: 500,
MinLength: 1,
AllowedChars: regexp.MustCompile(`^[a-zA-Z0-9\s\-_.,'!?()]+$`),
Sanitizer: iv.sanitizeAnswer,
}
// Question text validation (admin only)
iv.rules["question_text"] = &ValidationRule{
MaxLength: 1000,
MinLength: 10,
Sanitizer: iv.sanitizeQuestionText,
}
// Theme validation
iv.rules["theme"] = &ValidationRule{
MaxLength: 100,
MinLength: 2,
AllowedChars: regexp.MustCompile(`^[a-zA-Z0-9\s\-_]+$`),
Sanitizer: iv.sanitizeTheme,
}
}
// Validate and sanitize input based on field type
func (iv *InputValidator) ValidateAndSanitize(fieldType, input string) (string, error) {
rule, exists := iv.rules[fieldType]
if !exists {
return "", fmt.Errorf("unknown field type: %s", fieldType)
}
// Basic length validation
if len(input) < rule.MinLength {
return "", fmt.Errorf("input too short: minimum %d characters", rule.MinLength)
}
if len(input) > rule.MaxLength {
return "", fmt.Errorf("input too long: maximum %d characters", rule.MaxLength)
}
// Character validation
if rule.AllowedChars != nil && !rule.AllowedChars.MatchString(input) {
return "", fmt.Errorf("input contains invalid characters")
}
// Sanitize input
sanitized := input
if rule.Sanitizer != nil {
sanitized = rule.Sanitizer(input)
}
return sanitized, nil
}
// Sanitization functions
func (iv *InputValidator) sanitizePlayerName(input string) string {
// Remove HTML entities and tags
sanitized := html.EscapeString(input)
// Trim whitespace
sanitized = strings.TrimSpace(sanitized)
// Remove multiple consecutive spaces
spaceRegex := regexp.MustCompile(`\s+`)
sanitized = spaceRegex.ReplaceAllString(sanitized, " ")
return sanitized
}
func (iv *InputValidator) sanitizeAnswer(input string) string {
// HTML escape
sanitized := html.EscapeString(input)
// Trim and normalize whitespace
sanitized = strings.TrimSpace(sanitized)
// Convert to lowercase for comparison
sanitized = strings.ToLower(sanitized)
// Remove extra punctuation but keep essential ones
punctRegex := regexp.MustCompile(`[^\w\s\-'.]`)
sanitized = punctRegex.ReplaceAllString(sanitized, "")
return sanitized
}
func (iv *InputValidator) sanitizeQuestionText(input string) string {
// More permissive sanitization for question text
sanitized := html.EscapeString(input)
sanitized = strings.TrimSpace(sanitized)
// Remove potential script content
scriptRegex := regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`)
sanitized = scriptRegex.ReplaceAllString(sanitized, "")
return sanitized
}
func (iv *InputValidator) sanitizeTheme(input string) string {
sanitized := html.EscapeString(input)
sanitized = strings.TrimSpace(sanitized)
// Capitalize first letter of each word
words := strings.Fields(sanitized)
for i, word := range words {
if len(word) > 0 {
words[i] = strings.ToUpper(string(word[0])) + strings.ToLower(word[1:])
}
}
return strings.Join(words, " ")
}
// Custom validation functions
func validateAlphanumSpace(fl validator.FieldLevel) bool {
str := fl.Field().String()
for _, r := range str {
if !unicode.IsLetter(r) && !unicode.IsNumber(r) && !unicode.IsSpace(r) && r != '-' && r != '_' && r != '.' {
return false
}
}
return true
}
func validateNoHTML(fl validator.FieldLevel) bool {
str := fl.Field().String()
return !strings.Contains(str, "<") && !strings.Contains(str, ">")
}
func validateSafeText(fl validator.FieldLevel) bool {
str := fl.Field().String()
// Check for potential XSS patterns
dangerousPatterns := []string{
"javascript:",
"data:",
"vbscript:",
"on\\w+\\s*=",
"<script",
"</script",
}
lowerStr := strings.ToLower(str)
for _, pattern := range dangerousPatterns {
matched, _ := regexp.MatchString(pattern, lowerStr)
if matched {
return false
}
}
return true
}
// Fiber middleware for input validation
func InputValidationMiddleware(validator *InputValidator) fiber.Handler {
return func(c *fiber.Ctx) error {
// Only validate POST, PUT, PATCH requests
if c.Method() != "POST" && c.Method() != "PUT" && c.Method() != "PATCH" {
return c.Next()
}
// Get request body
body := c.Body()
if len(body) == 0 {
return c.Next()
}
// Parse JSON body
var requestData map[string]interface{}
if err := json.Unmarshal(body, &requestData); err != nil {
return c.Status(400).JSON(fiber.Map{
"error": true,
"message": "Invalid JSON format",
})
}
// Validate and sanitize each field
sanitizedData := make(map[string]interface{})
for key, value := range requestData {
if strValue, ok := value.(string); ok {
// Determine field type from key name or endpoint
fieldType := determineFieldType(key, c.Path())
if fieldType != "" {
sanitized, err := validator.ValidateAndSanitize(fieldType, strValue)
if err != nil {
return c.Status(400).JSON(fiber.Map{
"error": true,
"message": fmt.Sprintf("Validation error for field %s: %s", key, err.Error()),
})
}
sanitizedData[key] = sanitized
} else {
sanitizedData[key] = value
}
} else {
sanitizedData[key] = value
}
}
// Replace body with sanitized data
sanitizedBody, _ := json.Marshal(sanitizedData)
c.Request().SetBody(sanitizedBody)
return c.Next()
}
}
func determineFieldType(fieldName, endpoint string) string {
fieldTypeMap := map[string]string{
"player_name": "player_name",
"name": "player_name",
"answer": "answer",
"question": "question_text",
"text": "question_text",
"theme": "theme",
"category": "theme",
}
return fieldTypeMap[fieldName]
}
Data Security & Encryption
Encryption Implementation
// Data encryption for sensitive fields
package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"golang.org/x/crypto/pbkdf2"
)
type EncryptionService struct {
key []byte
gcm cipher.AEAD
keyID string
}
type EncryptedData struct {
Data string `json:"data"`
Nonce string `json:"nonce"`
KeyID string `json:"key_id"`
}
func NewEncryptionService(masterKey string, keyID string) (*EncryptionService, error) {
// Derive key using PBKDF2
salt := []byte("knowfoolery-salt-2024") // In production, use a random salt
key := pbkdf2.Key([]byte(masterKey), salt, 10000, 32, sha256.New)
// Create AES cipher
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM mode
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
return &EncryptionService{
key: key,
gcm: gcm,
keyID: keyID,
}, nil
}
func (es *EncryptionService) Encrypt(plaintext string) (*EncryptedData, error) {
// Generate random nonce
nonce := make([]byte, es.gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt data
ciphertext := es.gcm.Seal(nil, nonce, []byte(plaintext), nil)
return &EncryptedData{
Data: base64.StdEncoding.EncodeToString(ciphertext),
Nonce: base64.StdEncoding.EncodeToString(nonce),
KeyID: es.keyID,
}, nil
}
func (es *EncryptionService) Decrypt(encData *EncryptedData) (string, error) {
// Decode base64 data
ciphertext, err := base64.StdEncoding.DecodeString(encData.Data)
if err != nil {
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
}
nonce, err := base64.StdEncoding.DecodeString(encData.Nonce)
if err != nil {
return "", fmt.Errorf("failed to decode nonce: %w", err)
}
// Decrypt data
plaintext, err := es.gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt data: %w", err)
}
return string(plaintext), nil
}
// Field-level encryption for Ent
type EncryptedField struct {
encryptionService *EncryptionService
encryptedData *EncryptedData
}
func NewEncryptedField(es *EncryptionService, plaintext string) (*EncryptedField, error) {
encData, err := es.Encrypt(plaintext)
if err != nil {
return nil, err
}
return &EncryptedField{
encryptionService: es,
encryptedData: encData,
}, nil
}
func (ef *EncryptedField) Decrypt() (string, error) {
return ef.encryptionService.Decrypt(ef.encryptedData)
}
func (ef *EncryptedField) MarshalJSON() ([]byte, error) {
return json.Marshal(ef.encryptedData)
}
func (ef *EncryptedField) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &ef.encryptedData)
}
// Database encryption hooks for Ent
func EncryptionHook(encryptionService *EncryptionService, sensitiveFields []string) ent.Hook {
return hook.On(
func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
// Encrypt sensitive fields before saving
for _, field := range sensitiveFields {
if value, exists := m.Field(field); exists {
if strValue, ok := value.(string); ok && strValue != "" {
encryptedField, err := NewEncryptedField(encryptionService, strValue)
if err != nil {
return nil, fmt.Errorf("failed to encrypt field %s: %w", field, err)
}
// Replace field value with encrypted data
if err := m.SetField(field, encryptedField); err != nil {
return nil, err
}
}
}
}
return next.Mutate(ctx, m)
})
},
ent.OpCreate|ent.OpUpdate|ent.OpUpdateOne,
)
}
Security Headers & HTTPS
HTTP Security Implementation
// HTTP security headers middleware
package security
import (
"fmt"
"strings"
"time"
"github.com/gofiber/fiber/v3"
)
type SecurityConfig struct {
ContentSecurityPolicy string
StrictTransportSecurity bool
FrameOptions string
ContentTypeOptions bool
ReferrerPolicy string
PermissionsPolicy string
}
func DefaultSecurityConfig() SecurityConfig {
return SecurityConfig{
ContentSecurityPolicy: "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' https:; connect-src 'self' https://auth.knowfoolery.com; frame-ancestors 'none';",
StrictTransportSecurity: true,
FrameOptions: "DENY",
ContentTypeOptions: true,
ReferrerPolicy: "strict-origin-when-cross-origin",
PermissionsPolicy: "geolocation=(), microphone=(), camera=(), payment=(), usb=(), screen-wake-lock=()",
}
}
func SecurityHeadersMiddleware(config SecurityConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// Content Security Policy
if config.ContentSecurityPolicy != "" {
c.Set("Content-Security-Policy", config.ContentSecurityPolicy)
}
// Strict Transport Security (HTTPS only)
if config.StrictTransportSecurity && c.Protocol() == "https" {
c.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
}
// X-Frame-Options
if config.FrameOptions != "" {
c.Set("X-Frame-Options", config.FrameOptions)
}
// X-Content-Type-Options
if config.ContentTypeOptions {
c.Set("X-Content-Type-Options", "nosniff")
}
// Referrer Policy
if config.ReferrerPolicy != "" {
c.Set("Referrer-Policy", config.ReferrerPolicy)
}
// Permissions Policy
if config.PermissionsPolicy != "" {
c.Set("Permissions-Policy", config.PermissionsPolicy)
}
// X-XSS-Protection (legacy but still useful)
c.Set("X-XSS-Protection", "1; mode=block")
// Remove server header
c.Set("Server", "")
return c.Next()
}
}
// CORS configuration with security focus
func SecureCORSConfig() fiber.Config {
return fiber.Config{
AllowOrigins: "https://app.knowfoolery.com,https://admin.knowfoolery.com",
AllowMethods: "GET,POST,PUT,DELETE,OPTIONS",
AllowHeaders: "Origin,Content-Type,Accept,Authorization,X-CSRF-Token",
AllowCredentials: true,
MaxAge: 300, // 5 minutes
}
}
// Rate limiting with different tiers
type RateLimitConfig struct {
General RateLimit
Auth RateLimit
API RateLimit
Admin RateLimit
}
type RateLimit struct {
Requests int
Window time.Duration
}
func SecurityRateLimitConfig() RateLimitConfig {
return RateLimitConfig{
General: RateLimit{Requests: 100, Window: time.Minute},
Auth: RateLimit{Requests: 5, Window: time.Minute},
API: RateLimit{Requests: 60, Window: time.Minute},
Admin: RateLimit{Requests: 30, Window: time.Minute},
}
}
func RateLimitMiddleware(redis *redis.Client, config RateLimitConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
clientIP := c.IP()
userID := c.Locals("user_id")
path := c.Path()
// Determine rate limit based on endpoint
var limit RateLimit
switch {
case strings.HasPrefix(path, "/api/v1/auth"):
limit = config.Auth
case strings.HasPrefix(path, "/api/v1/admin"):
limit = config.Admin
case strings.HasPrefix(path, "/api/v1"):
limit = config.API
default:
limit = config.General
}
// Create rate limit key
key := fmt.Sprintf("rate_limit:%s:%s", clientIP, path)
if userID != nil {
key = fmt.Sprintf("rate_limit:user:%s:%s", userID, path)
}
// Check rate limit
current, err := redis.Incr(context.Background(), key).Result()
if err != nil {
// Log error but don't block request
return c.Next()
}
if current == 1 {
redis.Expire(context.Background(), key, limit.Window)
}
if current > int64(limit.Requests) {
return c.Status(429).JSON(fiber.Map{
"error": true,
"message": "Rate limit exceeded",
"retry_after": int(limit.Window.Seconds()),
})
}
// Add rate limit headers
c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit.Requests))
c.Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Requests-int(current)))
c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(limit.Window).Unix()))
return c.Next()
}
}
Game Integrity & Anti-Cheating
Server-Side Validation
// Game integrity and anti-cheating measures
package integrity
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"math"
"time"
)
type GameIntegrityService struct {
secretKey []byte
auditLogger *AuditLogger
anomalyDetector *AnomalyDetector
}
type GameState struct {
SessionID string `json:"session_id"`
PlayerID string `json:"player_id"`
QuestionID string `json:"question_id"`
StartTime time.Time `json:"start_time"`
AttemptsUsed int `json:"attempts_used"`
HintsUsed int `json:"hints_used"`
Score int `json:"score"`
ServerHash string `json:"server_hash"`
}
type AnswerSubmission struct {
SessionID string `json:"session_id"`
QuestionID string `json:"question_id"`
Answer string `json:"answer"`
AttemptNum int `json:"attempt_num"`
TimeTaken time.Duration `json:"time_taken"`
ClientHash string `json:"client_hash,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
type IntegrityCheck struct {
IsValid bool `json:"is_valid"`
Violations []string `json:"violations"`
RiskScore float64 `json:"risk_score"`
Action string `json:"action"` // allow, warn, block
}
func NewGameIntegrityService(secretKey []byte, auditLogger *AuditLogger) *GameIntegrityService {
return &GameIntegrityService{
secretKey: secretKey,
auditLogger: auditLogger,
anomalyDetector: NewAnomalyDetector(),
}
}
func (gis *GameIntegrityService) ValidateAnswerSubmission(ctx context.Context, submission *AnswerSubmission, gameState *GameState) (*IntegrityCheck, error) {
check := &IntegrityCheck{
IsValid: true,
Violations: []string{},
RiskScore: 0.0,
}
// 1. Validate timing
if err := gis.validateTiming(submission, gameState, check); err != nil {
return check, err
}
// 2. Validate sequence
if err := gis.validateSequence(submission, gameState, check); err != nil {
return check, err
}
// 3. Validate game state hash
if err := gis.validateGameStateHash(submission, gameState, check); err != nil {
return check, err
}
// 4. Check for behavioral anomalies
if err := gis.checkBehavioralAnomalies(ctx, submission, check); err != nil {
return check, err
}
// 5. Calculate final risk score and action
gis.calculateRiskAndAction(check)
// 6. Log integrity check results
gis.auditLogger.LogSecurityEvent("game_integrity_check", submission.SessionID, "", "info", map[string]interface{}{
"session_id": submission.SessionID,
"question_id": submission.QuestionID,
"risk_score": check.RiskScore,
"violations": check.Violations,
"action": check.Action,
})
return check, nil
}
func (gis *GameIntegrityService) validateTiming(submission *AnswerSubmission, gameState *GameState, check *IntegrityCheck) error {
// Minimum time check (prevent instant answers)
minTime := 2 * time.Second
if submission.TimeTaken < minTime {
check.Violations = append(check.Violations, "answer_too_fast")
check.RiskScore += 0.3
}
// Maximum time check (session timeout)
maxTime := 30 * time.Minute
totalTime := time.Since(gameState.StartTime)
if totalTime > maxTime {
check.Violations = append(check.Violations, "session_expired")
check.RiskScore += 0.8
check.IsValid = false
}
// Check for time manipulation
if submission.Timestamp.Before(gameState.StartTime) {
check.Violations = append(check.Violations, "timestamp_manipulation")
check.RiskScore += 0.5
}
return nil
}
func (gis *GameIntegrityService) validateSequence(submission *AnswerSubmission, gameState *GameState, check *IntegrityCheck) error {
// Validate attempt number
if submission.AttemptNum != gameState.AttemptsUsed+1 {
check.Violations = append(check.Violations, "invalid_attempt_sequence")
check.RiskScore += 0.4
}
// Validate attempt limits
if submission.AttemptNum > 3 {
check.Violations = append(check.Violations, "exceeded_attempt_limit")
check.RiskScore += 0.6
check.IsValid = false
}
// Validate question sequence
if submission.QuestionID != gameState.QuestionID {
check.Violations = append(check.Violations, "question_mismatch")
check.RiskScore += 0.5
check.IsValid = false
}
return nil
}
func (gis *GameIntegrityService) validateGameStateHash(submission *AnswerSubmission, gameState *GameState, check *IntegrityCheck) error {
// Calculate expected hash
expectedHash := gis.calculateGameStateHash(gameState)
if gameState.ServerHash != expectedHash {
check.Violations = append(check.Violations, "game_state_tampering")
check.RiskScore += 0.7
check.IsValid = false
}
return nil
}
func (gis *GameIntegrityService) checkBehavioralAnomalies(ctx context.Context, submission *AnswerSubmission, check *IntegrityCheck) error {
// Check answer patterns
anomalies := gis.anomalyDetector.DetectAnomalies(ctx, submission.SessionID, map[string]interface{}{
"time_taken": submission.TimeTaken.Seconds(),
"answer_length": len(submission.Answer),
"attempt_num": submission.AttemptNum,
})
for _, anomaly := range anomalies {
check.Violations = append(check.Violations, fmt.Sprintf("anomaly_%s", anomaly.Type))
check.RiskScore += anomaly.Severity
}
return nil
}
func (gis *GameIntegrityService) calculateRiskAndAction(check *IntegrityCheck) {
// Apply penalties for multiple violations
if len(check.Violations) > 1 {
check.RiskScore += float64(len(check.Violations)) * 0.1
}
// Determine action based on risk score
switch {
case check.RiskScore >= 0.8:
check.Action = "block"
check.IsValid = false
case check.RiskScore >= 0.5:
check.Action = "warn"
default:
check.Action = "allow"
}
}
func (gis *GameIntegrityService) calculateGameStateHash(gameState *GameState) string {
data := fmt.Sprintf("%s:%s:%s:%d:%d:%d:%d",
gameState.SessionID,
gameState.PlayerID,
gameState.QuestionID,
gameState.StartTime.Unix(),
gameState.AttemptsUsed,
gameState.HintsUsed,
gameState.Score,
)
h := hmac.New(sha256.New, gis.secretKey)
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
func (gis *GameIntegrityService) UpdateGameState(gameState *GameState) {
gameState.ServerHash = gis.calculateGameStateHash(gameState)
}
// Anomaly detection for behavioral patterns
type AnomalyDetector struct {
patterns map[string]*PatternTracker
}
type PatternTracker struct {
Samples []float64
Mean float64
StdDev float64
SampleCount int
}
type Anomaly struct {
Type string
Severity float64
Details string
}
func NewAnomalyDetector() *AnomalyDetector {
return &AnomalyDetector{
patterns: make(map[string]*PatternTracker),
}
}
func (ad *AnomalyDetector) DetectAnomalies(ctx context.Context, sessionID string, metrics map[string]interface{}) []Anomaly {
var anomalies []Anomaly
for metricName, value := range metrics {
if floatValue, ok := value.(float64); ok {
if anomaly := ad.checkMetricAnomaly(sessionID, metricName, floatValue); anomaly != nil {
anomalies = append(anomalies, *anomaly)
}
}
}
return anomalies
}
func (ad *AnomalyDetector) checkMetricAnomaly(sessionID, metricName string, value float64) *Anomaly {
key := fmt.Sprintf("%s:%s", sessionID, metricName)
tracker, exists := ad.patterns[key]
if !exists {
tracker = &PatternTracker{
Samples: []float64{},
}
ad.patterns[key] = tracker
}
// Add sample
tracker.Samples = append(tracker.Samples, value)
tracker.SampleCount++
// Need at least 3 samples to detect anomalies
if tracker.SampleCount < 3 {
return nil
}
// Calculate statistics
ad.updateStatistics(tracker)
// Check for anomaly (z-score > 2)
if tracker.StdDev > 0 {
zScore := math.Abs(value-tracker.Mean) / tracker.StdDev
if zScore > 2.0 {
severity := math.Min(zScore/5.0, 0.3) // Cap at 0.3
return &Anomaly{
Type: metricName,
Severity: severity,
Details: fmt.Sprintf("z-score: %.2f", zScore),
}
}
}
return nil
}
func (ad *AnomalyDetector) updateStatistics(tracker *PatternTracker) {
n := float64(len(tracker.Samples))
// Calculate mean
sum := 0.0
for _, sample := range tracker.Samples {
sum += sample
}
tracker.Mean = sum / n
// Calculate standard deviation
sumSquares := 0.0
for _, sample := range tracker.Samples {
diff := sample - tracker.Mean
sumSquares += diff * diff
}
tracker.StdDev = math.Sqrt(sumSquares / n)
}
Security Monitoring & Incident Response
Security Event Monitoring
// Security monitoring and incident response
package security
import (
"context"
"encoding/json"
"fmt"
"time"
)
type SecurityMonitor struct {
alertManager *AlertManager
incidentManager *IncidentManager
auditLogger *AuditLogger
redis *redis.Client
}
type SecurityEvent struct {
ID string `json:"id"`
Type string `json:"type"`
Severity string `json:"severity"`
UserID string `json:"user_id,omitempty"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent,omitempty"`
Timestamp time.Time `json:"timestamp"`
Details map[string]interface{} `json:"details"`
Context string `json:"context"`
}
type IncidentResponse struct {
IncidentID string `json:"incident_id"`
Action string `json:"action"`
AutomatedActions []string `json:"automated_actions"`
RequiresManualReview bool `json:"requires_manual_review"`
Timestamp time.Time `json:"timestamp"`
}
func NewSecurityMonitor(alertManager *AlertManager, incidentManager *IncidentManager, auditLogger *AuditLogger, redis *redis.Client) *SecurityMonitor {
return &SecurityMonitor{
alertManager: alertManager,
incidentManager: incidentManager,
auditLogger: auditLogger,
redis: redis,
}
}
func (sm *SecurityMonitor) ProcessSecurityEvent(ctx context.Context, event *SecurityEvent) (*IncidentResponse, error) {
// Enrich event with additional context
if err := sm.enrichSecurityEvent(ctx, event); err != nil {
return nil, fmt.Errorf("failed to enrich security event: %w", err)
}
// Analyze threat level
threatLevel := sm.analyzeThreatLevel(ctx, event)
// Create incident response
response := &IncidentResponse{
IncidentID: generateIncidentID(),
Timestamp: time.Now(),
AutomatedActions: []string{},
}
// Execute automated response based on threat level
switch threatLevel {
case "critical":
response.Action = "block_immediately"
response.RequiresManualReview = true
response.AutomatedActions = append(response.AutomatedActions,
"block_user", "invalidate_sessions", "notify_security_team")
sm.executeAutomatedResponse(ctx, event, response)
case "high":
response.Action = "temporary_restriction"
response.RequiresManualReview = true
response.AutomatedActions = append(response.AutomatedActions,
"rate_limit_user", "require_additional_auth", "alert_security_team")
sm.executeAutomatedResponse(ctx, event, response)
case "medium":
response.Action = "monitor_closely"
response.RequiresManualReview = false
response.AutomatedActions = append(response.AutomatedActions,
"increase_logging", "flag_for_review")
sm.executeAutomatedResponse(ctx, event, response)
case "low":
response.Action = "log_and_continue"
response.RequiresManualReview = false
response.AutomatedActions = append(response.AutomatedActions, "log_event")
default:
response.Action = "no_action"
}
// Log incident response
sm.auditLogger.LogSecurityEvent("incident_response", event.UserID, event.IPAddress, "info", map[string]interface{}{
"incident_id": response.IncidentID,
"original_event": event.Type,
"threat_level": threatLevel,
"response_action": response.Action,
"automated_actions": response.AutomatedActions,
})
return response, nil
}
func (sm *SecurityMonitor) enrichSecurityEvent(ctx context.Context, event *SecurityEvent) error {
// Add geolocation data
if geoData, err := sm.getGeolocation(event.IPAddress); err == nil {
event.Details["geolocation"] = geoData
}
// Add user behavior history
if event.UserID != "" {
if userHistory, err := sm.getUserBehaviorHistory(ctx, event.UserID); err == nil {
event.Details["user_history"] = userHistory
}
}
// Add IP reputation data
if reputation, err := sm.getIPReputation(event.IPAddress); err == nil {
event.Details["ip_reputation"] = reputation
}
return nil
}
func (sm *SecurityMonitor) analyzeThreatLevel(ctx context.Context, event *SecurityEvent) string {
score := 0
// Base severity score
switch event.Severity {
case "critical":
score += 40
case "high":
score += 30
case "medium":
score += 20
case "low":
score += 10
}
// Event type scoring
threatScores := map[string]int{
"brute_force_attack": 35,
"sql_injection_attempt": 40,
"xss_attempt": 30,
"session_hijack": 35,
"privilege_escalation": 40,
"data_exfiltration": 45,
"unusual_access_pattern": 20,
"rate_limit_exceeded": 15,
"authentication_failure": 10,
"suspicious_user_agent": 15,
"geo_anomaly": 25,
}
if typeScore, exists := threatScores[event.Type]; exists {
score += typeScore
}
// User history factor
if userHistory, exists := event.Details["user_history"].(map[string]interface{}); exists {
if previousIncidents, ok := userHistory["incident_count"].(int); ok {
score += previousIncidents * 5
}
}
// IP reputation factor
if reputation, exists := event.Details["ip_reputation"].(map[string]interface{}); exists {
if malicious, ok := reputation["is_malicious"].(bool); ok && malicious {
score += 25
}
if proxy, ok := reputation["is_proxy"].(bool); ok && proxy {
score += 10
}
}
// Geographic anomaly
if geoData, exists := event.Details["geolocation"].(map[string]interface{}); exists {
if country, ok := geoData["country"].(string); ok {
if sm.isHighRiskCountry(country) {
score += 15
}
}
}
// Convert score to threat level
switch {
case score >= 70:
return "critical"
case score >= 50:
return "high"
case score >= 30:
return "medium"
case score >= 15:
return "low"
default:
return "minimal"
}
}
func (sm *SecurityMonitor) executeAutomatedResponse(ctx context.Context, event *SecurityEvent, response *IncidentResponse) error {
for _, action := range response.AutomatedActions {
switch action {
case "block_user":
if event.UserID != "" {
sm.blockUser(ctx, event.UserID, response.IncidentID)
}
case "invalidate_sessions":
if event.UserID != "" {
sm.invalidateUserSessions(ctx, event.UserID)
}
case "block_ip":
sm.blockIP(ctx, event.IPAddress, time.Hour*24) // 24-hour block
case "rate_limit_user":
if event.UserID != "" {
sm.applyRateLimit(ctx, event.UserID, time.Hour) // 1-hour rate limit
}
case "require_additional_auth":
if event.UserID != "" {
sm.requireAdditionalAuth(ctx, event.UserID)
}
case "notify_security_team":
sm.notifySecurityTeam(event, response)
case "alert_security_team":
sm.alertSecurityTeam(event, response)
case "increase_logging":
sm.increaseLoggingLevel(event.UserID, event.IPAddress)
case "flag_for_review":
sm.flagForManualReview(event, response)
}
}
return nil
}
func (sm *SecurityMonitor) blockUser(ctx context.Context, userID, incidentID string) error {
// Add user to blocked list
key := fmt.Sprintf("blocked_users:%s", userID)
blockData := map[string]interface{}{
"blocked_at": time.Now().Unix(),
"incident_id": incidentID,
"reason": "automated_security_response",
}
blockJSON, _ := json.Marshal(blockData)
return sm.redis.Set(ctx, key, blockJSON, 24*time.Hour).Err()
}
func (sm *SecurityMonitor) blockIP(ctx context.Context, ipAddress string, duration time.Duration) error {
key := fmt.Sprintf("blocked_ips:%s", ipAddress)
blockData := map[string]interface{}{
"blocked_at": time.Now().Unix(),
"expires_at": time.Now().Add(duration).Unix(),
"reason": "automated_security_response",
}
blockJSON, _ := json.Marshal(blockData)
return sm.redis.Set(ctx, key, blockJSON, duration).Err()
}
func (sm *SecurityMonitor) applyRateLimit(ctx context.Context, userID string, duration time.Duration) error {
key := fmt.Sprintf("rate_limit_strict:%s", userID)
limitData := map[string]interface{}{
"limit": 5, // 5 requests per minute
"window": 60, // 1 minute window
"applied_at": time.Now().Unix(),
}
limitJSON, _ := json.Marshal(limitData)
return sm.redis.Set(ctx, key, limitJSON, duration).Err()
}
func (sm *SecurityMonitor) notifySecurityTeam(event *SecurityEvent, response *IncidentResponse) {
notification := map[string]interface{}{
"type": "security_incident",
"severity": "critical",
"incident_id": response.IncidentID,
"event_type": event.Type,
"user_id": event.UserID,
"ip_address": event.IPAddress,
"timestamp": event.Timestamp,
"details": event.Details,
}
// Send to security team via multiple channels
sm.alertManager.SendSlackAlert("#security-incidents", notification)
sm.alertManager.SendPagerDutyAlert(notification)
sm.alertManager.SendEmailAlert("security@knowfoolery.com", notification)
}
// Security health checks
func (sm *SecurityMonitor) PerformSecurityHealthCheck(ctx context.Context) (*SecurityHealthStatus, error) {
status := &SecurityHealthStatus{
Timestamp: time.Now(),
Checks: make(map[string]CheckResult),
}
// Check authentication system
status.Checks["auth_system"] = sm.checkAuthSystemHealth(ctx)
// Check rate limiting
status.Checks["rate_limiting"] = sm.checkRateLimitingHealth(ctx)
// Check encryption services
status.Checks["encryption"] = sm.checkEncryptionHealth(ctx)
// Check security monitoring
status.Checks["monitoring"] = sm.checkMonitoringHealth(ctx)
// Check blocked IPs/users
status.Checks["blocked_entities"] = sm.checkBlockedEntitiesHealth(ctx)
// Calculate overall health
status.OverallHealth = sm.calculateOverallHealth(status.Checks)
return status, nil
}
type SecurityHealthStatus struct {
Timestamp time.Time `json:"timestamp"`
OverallHealth string `json:"overall_health"`
Checks map[string]CheckResult `json:"checks"`
}
type CheckResult struct {
Status string `json:"status"` // healthy, degraded, unhealthy
Message string `json:"message"`
Details map[string]interface{} `json:"details,omitempty"`
}
This comprehensive security implementation ensures Know Foolery has robust protection against common threats while maintaining usability and performance. The security measures are layered and include proactive monitoring, automated incident response, and detailed audit trails for compliance and forensic analysis.