package middleware // middleware_test.go contains backend tests for package behavior, error paths, and regressions. import ( "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/alicebob/miniredis/v2" "github.com/gofiber/fiber/v3" redisv9 "github.com/redis/go-redis/v9" gconfig "knowfoolery/backend/services/gateway-service/internal/infra/config" "knowfoolery/backend/shared/infra/auth/zitadel" "knowfoolery/backend/shared/infra/observability/logging" ) // TestCORSMiddleware ensures cors middleware behavior is handled correctly. func TestCORSMiddleware(t *testing.T) { app := fiber.New() app.Use(CORS(gconfig.CORSConfig{ AllowedOrigins: []string{"http://localhost:5173"}, AllowedMethods: "GET,POST,OPTIONS", AllowedHeaders: "Content-Type,Authorization", AllowCredentials: true, MaxAgeSeconds: 120, })) app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/ok", nil) res, err := app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.Header.Get("Access-Control-Allow-Origin") != "" { t.Fatalf("expected no CORS header without origin") } req = httptest.NewRequest(http.MethodOptions, "/ok", nil) req.Header.Set("Origin", "http://evil.local") res, err = app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusForbidden { t.Fatalf("expected forbidden preflight for disallowed origin, got %d", res.StatusCode) } req = httptest.NewRequest(http.MethodOptions, "/ok", nil) req.Header.Set("Origin", "http://localhost:5173") res, err = app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusNoContent { t.Fatalf("expected no content preflight, got %d", res.StatusCode) } if res.Header.Get("Access-Control-Allow-Origin") != "http://localhost:5173" { t.Fatalf("unexpected allow origin: %s", res.Header.Get("Access-Control-Allow-Origin")) } if res.Header.Get("Access-Control-Allow-Credentials") != "true" { t.Fatalf("missing allow credentials") } } // TestCORSAllowAll ensures cors allow all behavior is handled correctly. func TestCORSAllowAll(t *testing.T) { app := fiber.New() app.Use(CORS(gconfig.CORSConfig{AllowedOrigins: []string{"*"}})) app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/ok", nil) req.Header.Set("Origin", "http://anything.local") res, err := app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.Header.Get("Access-Control-Allow-Origin") != "*" { t.Fatalf("expected wildcard origin header") } } // TestSecurityHeadersMiddleware ensures security headers middleware behavior is handled correctly. func TestSecurityHeadersMiddleware(t *testing.T) { app := fiber.New() app.Use(SecurityHeaders(gconfig.SecurityHeadersConfig{ ContentSecurityPolicy: "default-src 'self'", EnableHSTS: true, HSTSMaxAge: 31536000, FrameOptions: "DENY", ContentTypeOptions: true, ReferrerPolicy: "strict-origin-when-cross-origin", PermissionsPolicy: "geolocation=()", })) app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/ok", nil) req.Header.Set("X-Forwarded-Proto", "https") res, err := app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.Header.Get("Content-Security-Policy") == "" { t.Fatalf("missing CSP header") } if res.Header.Get("X-Frame-Options") != "DENY" { t.Fatalf("missing frame options") } if res.Header.Get("X-Content-Type-Options") != "nosniff" { t.Fatalf("missing nosniff") } if res.Header.Get("Strict-Transport-Security") == "" { t.Fatalf("expected HSTS header for https request") } } // TestSecurityHeadersNoHSTSOnHTTP ensures security headers no hsts on http behavior is handled correctly. func TestSecurityHeadersNoHSTSOnHTTP(t *testing.T) { app := fiber.New() app.Use(SecurityHeaders(gconfig.SecurityHeadersConfig{ EnableHSTS: true, HSTSMaxAge: 31536000, })) app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/ok", nil) res, err := app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.Header.Get("Strict-Transport-Security") != "" { t.Fatalf("did not expect HSTS header on non-https request") } } // TestRequestContextMiddlewareAndRequestID ensures request context middleware and // request id behavior is handled correctly. func TestRequestContextMiddlewareAndRequestID(t *testing.T) { app := fiber.New() app.Use(RequestContext(nil)) app.Get("/id", func(c fiber.Ctx) error { if RequestID(c) == "" { return c.SendStatus(http.StatusInternalServerError) } return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/id", nil) res, err := app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", res.StatusCode) } if res.Header.Get("X-Request-ID") == "" { t.Fatalf("expected generated request id") } req = httptest.NewRequest(http.MethodGet, "/id", nil) req.Header.Set("X-Request-ID", "req-123") res, err = app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.Header.Get("X-Request-ID") != "req-123" { t.Fatalf("expected propagated request id") } } // TestRequestContextWithLoggerAndInvalidRequestIDLocal ensures request context with // logger and invalid request id local behavior is handled correctly. func TestRequestContextWithLoggerAndInvalidRequestIDLocal(t *testing.T) { logger := logging.NewLogger(logging.DefaultConfig()) app := fiber.New() app.Use(RequestContext(logger)) app.Get("/id", func(c fiber.Ctx) error { c.Locals(requestIDKey, 123) if RequestID(c) != "" { return c.SendStatus(http.StatusInternalServerError) } return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/id", nil) res, err := app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", res.StatusCode) } } // TestRateLimitMiddlewareDegradedModeAndHelpers ensures rate limit middleware degraded // mode and helpers behavior is handled correctly. func TestRateLimitMiddlewareDegradedModeAndHelpers(t *testing.T) { app := fiber.New() mw := RateLimitMiddleware(nil, gconfig.RateLimitConfig{ GeneralRequests: 7, APIRequests: 7, }, "/api/v1", nil) app.Use(mw) app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil) res, err := app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", res.StatusCode) } if res.Header.Get("X-RateLimit-Policy") != "degraded" { t.Fatalf("expected degraded policy") } if res.Header.Get("X-RateLimit-Limit") != "7" { t.Fatalf("expected limit header") } app = fiber.New() app.Use(mw) app.Options("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusNoContent) }) req = httptest.NewRequest(http.MethodOptions, "/api/v1/x", nil) res, err = app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusNoContent { t.Fatalf("expected options pass-through, got %d", res.StatusCode) } tiers := map[string]rateTier{ "general": {Name: "general"}, "auth": {Name: "auth"}, "api": {Name: "api"}, "admin": {Name: "admin"}, } if got := selectTier("/api/v1/admin/auth", "/api/v1", tiers); got.Name != "auth" { t.Fatalf("expected auth tier, got %s", got.Name) } if got := selectTier("/api/v1/admin/x", "/api/v1", tiers); got.Name != "admin" { t.Fatalf("expected admin tier, got %s", got.Name) } if got := selectTier("/api/v1/x", "/api/v1", tiers); got.Name != "api" { t.Fatalf("expected api tier, got %s", got.Name) } if got := selectTier("/x", "/api/v1", tiers); got.Name != "general" { t.Fatalf("expected general tier, got %s", got.Name) } if v, ok := toInt64(int64(1)); !ok || v != 1 { t.Fatalf("toInt64 int64 failed") } if v, ok := toInt64(int(2)); !ok || v != 2 { t.Fatalf("toInt64 int failed") } if v, ok := toInt64("3"); !ok || v != 3 { t.Fatalf("toInt64 string failed") } if _, ok := toInt64("bad"); ok { t.Fatalf("expected invalid string parse failure") } if _, ok := toInt64(struct{}{}); ok { t.Fatalf("expected unsupported type failure") } if allow, count, reset, ok := parseLimiterResult([]interface{}{int64(1), int64(2), int64(3)}); !ok || !allow || count != 2 || reset != 3 { t.Fatalf("unexpected parseLimiterResult output") } if _, _, _, ok := parseLimiterResult("bad"); ok { t.Fatalf("expected parse failure for non-slice") } if _, _, _, ok := parseLimiterResult([]interface{}{1, 2}); ok { t.Fatalf("expected parse failure for short slice") } if _, _, _, ok := parseLimiterResult([]interface{}{1, 2, struct{}{}}); ok { t.Fatalf("expected parse failure for invalid reset") } if d := positiveDuration(0, time.Second); d != time.Second { t.Fatalf("expected fallback duration") } if d := positiveDuration(2*time.Second, time.Second); d != 2*time.Second { t.Fatalf("expected provided duration") } if m := maxInt(0, 9); m != 9 { t.Fatalf("expected fallback int") } if m := maxInt(10, 9); m != 10 { t.Fatalf("expected provided int") } } // TestRateLimitMiddlewareRedisAllowedAndBlocked ensures rate limit middleware redis // allowed and blocked behavior is handled correctly. func TestRateLimitMiddlewareRedisAllowedAndBlocked(t *testing.T) { mr, err := miniredis.Run() if err != nil { t.Skipf("miniredis unavailable in this environment: %v", err) } t.Cleanup(func() { mr.Close() }) client := redisv9.NewClient(&redisv9.Options{Addr: mr.Addr()}) t.Cleanup(func() { _ = client.Close() }) app := fiber.New() app.Use(func(c fiber.Ctx) error { c.Locals(string(zitadel.ContextKeyUserID), "user-1") return c.Next() }) app.Use(RateLimitMiddleware( client, gconfig.RateLimitConfig{ APIRequests: 1, Window: time.Minute, }, "/api/v1", nil, )) app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil) res, err := app.Test(req) if err != nil { t.Fatalf("app.Test first request: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK { t.Fatalf("expected first request to pass, got %d", res.StatusCode) } if res.Header.Get("X-RateLimit-Limit") != "1" { t.Fatalf("expected rate-limit headers") } req = httptest.NewRequest(http.MethodGet, "/api/v1/x", nil) res, err = app.Test(req) if err != nil { t.Fatalf("app.Test second request: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusTooManyRequests { t.Fatalf("expected second request to be throttled, got %d", res.StatusCode) } if res.Header.Get("X-RateLimit-Remaining") != "0" { t.Fatalf("expected remaining=0 when throttled") } } // TestRateLimitMiddlewareRedisErrorDegrades ensures rate limit middleware redis error // degrades behavior is handled correctly. func TestRateLimitMiddlewareRedisErrorDegrades(t *testing.T) { // Use a client pointing to an unreachable endpoint to force script.Run error. client := redisv9.NewClient(&redisv9.Options{Addr: "127.0.0.1:1"}) t.Cleanup(func() { _ = client.Close() }) app := fiber.New() app.Use(RateLimitMiddleware( client, gconfig.RateLimitConfig{ APIRequests: 5, Window: time.Minute, }, "/api/v1", nil, )) app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil) res, err := app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK { t.Fatalf("expected degraded mode to fail-open with 200, got %d", res.StatusCode) } if res.Header.Get("X-RateLimit-Policy") != "degraded" { t.Fatalf("expected degraded policy header") } } // TestIdentifyRequester ensures identify requester behavior is handled correctly. func TestIdentifyRequester(t *testing.T) { app := fiber.New() app.Get("/id", func(c fiber.Ctx) error { if c.Query("u") == "1" { c.Locals(string(zitadel.ContextKeyUserID), "user-1") } id := identifyRequester(c) if c.Query("u") == "1" && id != "user:user-1" { return c.Status(http.StatusInternalServerError).SendString(id) } if c.Query("u") != "1" && !strings.HasPrefix(id, "ip:") { return c.Status(http.StatusInternalServerError).SendString(id) } return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/id?u=1", nil) res, err := app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK { t.Fatalf("expected 200 for user requester") } req = httptest.NewRequest(http.MethodGet, "/id", nil) res, err = app.Test(req) if err != nil { t.Fatalf("app.Test: %v", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK { t.Fatalf("expected 200 for ip requester") } }