You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

430 lines
13 KiB
Go

package middleware
// middleware_test.go contains backend tests for package behavior, error paths, and regressions.
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/gofiber/fiber/v3"
redisv9 "github.com/redis/go-redis/v9"
gconfig "knowfoolery/backend/services/gateway-service/internal/infra/config"
"knowfoolery/backend/shared/infra/auth/zitadel"
"knowfoolery/backend/shared/infra/observability/logging"
)
// TestCORSMiddleware ensures cors middleware behavior is handled correctly.
func TestCORSMiddleware(t *testing.T) {
app := fiber.New()
app.Use(CORS(gconfig.CORSConfig{
AllowedOrigins: []string{"http://localhost:5173"},
AllowedMethods: "GET,POST,OPTIONS",
AllowedHeaders: "Content-Type,Authorization",
AllowCredentials: true,
MaxAgeSeconds: 120,
}))
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.Header.Get("Access-Control-Allow-Origin") != "" {
t.Fatalf("expected no CORS header without origin")
}
req = httptest.NewRequest(http.MethodOptions, "/ok", nil)
req.Header.Set("Origin", "http://evil.local")
res, err = app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusForbidden {
t.Fatalf("expected forbidden preflight for disallowed origin, got %d", res.StatusCode)
}
req = httptest.NewRequest(http.MethodOptions, "/ok", nil)
req.Header.Set("Origin", "http://localhost:5173")
res, err = app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
t.Fatalf("expected no content preflight, got %d", res.StatusCode)
}
if res.Header.Get("Access-Control-Allow-Origin") != "http://localhost:5173" {
t.Fatalf("unexpected allow origin: %s", res.Header.Get("Access-Control-Allow-Origin"))
}
if res.Header.Get("Access-Control-Allow-Credentials") != "true" {
t.Fatalf("missing allow credentials")
}
}
// TestCORSAllowAll ensures cors allow all behavior is handled correctly.
func TestCORSAllowAll(t *testing.T) {
app := fiber.New()
app.Use(CORS(gconfig.CORSConfig{AllowedOrigins: []string{"*"}}))
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
req.Header.Set("Origin", "http://anything.local")
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.Header.Get("Access-Control-Allow-Origin") != "*" {
t.Fatalf("expected wildcard origin header")
}
}
// TestSecurityHeadersMiddleware ensures security headers middleware behavior is handled correctly.
func TestSecurityHeadersMiddleware(t *testing.T) {
app := fiber.New()
app.Use(SecurityHeaders(gconfig.SecurityHeadersConfig{
ContentSecurityPolicy: "default-src 'self'",
EnableHSTS: true,
HSTSMaxAge: 31536000,
FrameOptions: "DENY",
ContentTypeOptions: true,
ReferrerPolicy: "strict-origin-when-cross-origin",
PermissionsPolicy: "geolocation=()",
}))
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
req.Header.Set("X-Forwarded-Proto", "https")
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.Header.Get("Content-Security-Policy") == "" {
t.Fatalf("missing CSP header")
}
if res.Header.Get("X-Frame-Options") != "DENY" {
t.Fatalf("missing frame options")
}
if res.Header.Get("X-Content-Type-Options") != "nosniff" {
t.Fatalf("missing nosniff")
}
if res.Header.Get("Strict-Transport-Security") == "" {
t.Fatalf("expected HSTS header for https request")
}
}
// TestSecurityHeadersNoHSTSOnHTTP ensures security headers no hsts on http behavior is handled correctly.
func TestSecurityHeadersNoHSTSOnHTTP(t *testing.T) {
app := fiber.New()
app.Use(SecurityHeaders(gconfig.SecurityHeadersConfig{
EnableHSTS: true,
HSTSMaxAge: 31536000,
}))
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.Header.Get("Strict-Transport-Security") != "" {
t.Fatalf("did not expect HSTS header on non-https request")
}
}
// TestRequestContextMiddlewareAndRequestID ensures request context middleware and request id behavior is handled correctly.
func TestRequestContextMiddlewareAndRequestID(t *testing.T) {
app := fiber.New()
app.Use(RequestContext(nil))
app.Get("/id", func(c fiber.Ctx) error {
if RequestID(c) == "" {
return c.SendStatus(http.StatusInternalServerError)
}
return c.SendStatus(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/id", nil)
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", res.StatusCode)
}
if res.Header.Get("X-Request-ID") == "" {
t.Fatalf("expected generated request id")
}
req = httptest.NewRequest(http.MethodGet, "/id", nil)
req.Header.Set("X-Request-ID", "req-123")
res, err = app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.Header.Get("X-Request-ID") != "req-123" {
t.Fatalf("expected propagated request id")
}
}
// TestRequestContextWithLoggerAndInvalidRequestIDLocal ensures request context with logger and invalid request id local behavior is handled correctly.
func TestRequestContextWithLoggerAndInvalidRequestIDLocal(t *testing.T) {
logger := logging.NewLogger(logging.DefaultConfig())
app := fiber.New()
app.Use(RequestContext(logger))
app.Get("/id", func(c fiber.Ctx) error {
c.Locals(requestIDKey, 123)
if RequestID(c) != "" {
return c.SendStatus(http.StatusInternalServerError)
}
return c.SendStatus(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/id", nil)
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", res.StatusCode)
}
}
// TestRateLimitMiddlewareDegradedModeAndHelpers ensures rate limit middleware degraded mode and helpers behavior is handled correctly.
func TestRateLimitMiddlewareDegradedModeAndHelpers(t *testing.T) {
app := fiber.New()
mw := RateLimitMiddleware(nil, gconfig.RateLimitConfig{
GeneralRequests: 7,
APIRequests: 7,
}, "/api/v1", nil)
app.Use(mw)
app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", res.StatusCode)
}
if res.Header.Get("X-RateLimit-Policy") != "degraded" {
t.Fatalf("expected degraded policy")
}
if res.Header.Get("X-RateLimit-Limit") != "7" {
t.Fatalf("expected limit header")
}
app = fiber.New()
app.Use(mw)
app.Options("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusNoContent) })
req = httptest.NewRequest(http.MethodOptions, "/api/v1/x", nil)
res, err = app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
t.Fatalf("expected options pass-through, got %d", res.StatusCode)
}
tiers := map[string]rateTier{
"general": {Name: "general"},
"auth": {Name: "auth"},
"api": {Name: "api"},
"admin": {Name: "admin"},
}
if got := selectTier("/api/v1/admin/auth", "/api/v1", tiers); got.Name != "auth" {
t.Fatalf("expected auth tier, got %s", got.Name)
}
if got := selectTier("/api/v1/admin/x", "/api/v1", tiers); got.Name != "admin" {
t.Fatalf("expected admin tier, got %s", got.Name)
}
if got := selectTier("/api/v1/x", "/api/v1", tiers); got.Name != "api" {
t.Fatalf("expected api tier, got %s", got.Name)
}
if got := selectTier("/x", "/api/v1", tiers); got.Name != "general" {
t.Fatalf("expected general tier, got %s", got.Name)
}
if v, ok := toInt64(int64(1)); !ok || v != 1 {
t.Fatalf("toInt64 int64 failed")
}
if v, ok := toInt64(int(2)); !ok || v != 2 {
t.Fatalf("toInt64 int failed")
}
if v, ok := toInt64("3"); !ok || v != 3 {
t.Fatalf("toInt64 string failed")
}
if _, ok := toInt64("bad"); ok {
t.Fatalf("expected invalid string parse failure")
}
if _, ok := toInt64(struct{}{}); ok {
t.Fatalf("expected unsupported type failure")
}
if allow, count, reset, ok := parseLimiterResult([]interface{}{int64(1), int64(2),
int64(3)}); !ok || !allow || count != 2 || reset != 3 {
t.Fatalf("unexpected parseLimiterResult output")
}
if _, _, _, ok := parseLimiterResult("bad"); ok {
t.Fatalf("expected parse failure for non-slice")
}
if _, _, _, ok := parseLimiterResult([]interface{}{1, 2}); ok {
t.Fatalf("expected parse failure for short slice")
}
if _, _, _, ok := parseLimiterResult([]interface{}{1, 2, struct{}{}}); ok {
t.Fatalf("expected parse failure for invalid reset")
}
if d := positiveDuration(0, time.Second); d != time.Second {
t.Fatalf("expected fallback duration")
}
if d := positiveDuration(2*time.Second, time.Second); d != 2*time.Second {
t.Fatalf("expected provided duration")
}
if m := maxInt(0, 9); m != 9 {
t.Fatalf("expected fallback int")
}
if m := maxInt(10, 9); m != 10 {
t.Fatalf("expected provided int")
}
}
// TestRateLimitMiddlewareRedisAllowedAndBlocked ensures rate limit middleware redis allowed and blocked behavior is handled correctly.
func TestRateLimitMiddlewareRedisAllowedAndBlocked(t *testing.T) {
mr, err := miniredis.Run()
if err != nil {
t.Skipf("miniredis unavailable in this environment: %v", err)
}
t.Cleanup(func() { mr.Close() })
client := redisv9.NewClient(&redisv9.Options{Addr: mr.Addr()})
t.Cleanup(func() { _ = client.Close() })
app := fiber.New()
app.Use(func(c fiber.Ctx) error {
c.Locals(string(zitadel.ContextKeyUserID), "user-1")
return c.Next()
})
app.Use(RateLimitMiddleware(
client,
gconfig.RateLimitConfig{
APIRequests: 1,
Window: time.Minute,
},
"/api/v1",
nil,
))
app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test first request: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Fatalf("expected first request to pass, got %d", res.StatusCode)
}
if res.Header.Get("X-RateLimit-Limit") != "1" {
t.Fatalf("expected rate-limit headers")
}
req = httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
res, err = app.Test(req)
if err != nil {
t.Fatalf("app.Test second request: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusTooManyRequests {
t.Fatalf("expected second request to be throttled, got %d", res.StatusCode)
}
if res.Header.Get("X-RateLimit-Remaining") != "0" {
t.Fatalf("expected remaining=0 when throttled")
}
}
// TestRateLimitMiddlewareRedisErrorDegrades ensures rate limit middleware redis error degrades behavior is handled correctly.
func TestRateLimitMiddlewareRedisErrorDegrades(t *testing.T) {
// Use a client pointing to an unreachable endpoint to force script.Run error.
client := redisv9.NewClient(&redisv9.Options{Addr: "127.0.0.1:1"})
t.Cleanup(func() { _ = client.Close() })
app := fiber.New()
app.Use(RateLimitMiddleware(
client,
gconfig.RateLimitConfig{
APIRequests: 5,
Window: time.Minute,
},
"/api/v1",
nil,
))
app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Fatalf("expected degraded mode to fail-open with 200, got %d", res.StatusCode)
}
if res.Header.Get("X-RateLimit-Policy") != "degraded" {
t.Fatalf("expected degraded policy header")
}
}
// TestIdentifyRequester ensures identify requester behavior is handled correctly.
func TestIdentifyRequester(t *testing.T) {
app := fiber.New()
app.Get("/id", func(c fiber.Ctx) error {
if c.Query("u") == "1" {
c.Locals(string(zitadel.ContextKeyUserID), "user-1")
}
id := identifyRequester(c)
if c.Query("u") == "1" && id != "user:user-1" {
return c.Status(http.StatusInternalServerError).SendString(id)
}
if c.Query("u") != "1" && !strings.HasPrefix(id, "ip:") {
return c.Status(http.StatusInternalServerError).SendString(id)
}
return c.SendStatus(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/id?u=1", nil)
res, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Fatalf("expected 200 for user requester")
}
req = httptest.NewRequest(http.MethodGet, "/id", nil)
res, err = app.Test(req)
if err != nil {
t.Fatalf("app.Test: %v", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Fatalf("expected 200 for ip requester")
}
}