You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
430 lines
13 KiB
Go
430 lines
13 KiB
Go
package middleware
|
|
|
|
// middleware_test.go contains tests for backend behavior.
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/gofiber/fiber/v3"
|
|
redisv9 "github.com/redis/go-redis/v9"
|
|
|
|
gconfig "knowfoolery/backend/services/gateway-service/internal/infra/config"
|
|
"knowfoolery/backend/shared/infra/auth/zitadel"
|
|
"knowfoolery/backend/shared/infra/observability/logging"
|
|
)
|
|
|
|
// TestCORSMiddleware verifies expected behavior.
|
|
func TestCORSMiddleware(t *testing.T) {
|
|
app := fiber.New()
|
|
app.Use(CORS(gconfig.CORSConfig{
|
|
AllowedOrigins: []string{"http://localhost:5173"},
|
|
AllowedMethods: "GET,POST,OPTIONS",
|
|
AllowedHeaders: "Content-Type,Authorization",
|
|
AllowCredentials: true,
|
|
MaxAgeSeconds: 120,
|
|
}))
|
|
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.Header.Get("Access-Control-Allow-Origin") != "" {
|
|
t.Fatalf("expected no CORS header without origin")
|
|
}
|
|
|
|
req = httptest.NewRequest(http.MethodOptions, "/ok", nil)
|
|
req.Header.Set("Origin", "http://evil.local")
|
|
res, err = app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("expected forbidden preflight for disallowed origin, got %d", res.StatusCode)
|
|
}
|
|
|
|
req = httptest.NewRequest(http.MethodOptions, "/ok", nil)
|
|
req.Header.Set("Origin", "http://localhost:5173")
|
|
res, err = app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusNoContent {
|
|
t.Fatalf("expected no content preflight, got %d", res.StatusCode)
|
|
}
|
|
if res.Header.Get("Access-Control-Allow-Origin") != "http://localhost:5173" {
|
|
t.Fatalf("unexpected allow origin: %s", res.Header.Get("Access-Control-Allow-Origin"))
|
|
}
|
|
if res.Header.Get("Access-Control-Allow-Credentials") != "true" {
|
|
t.Fatalf("missing allow credentials")
|
|
}
|
|
}
|
|
|
|
// TestCORSAllowAll verifies expected behavior.
|
|
func TestCORSAllowAll(t *testing.T) {
|
|
app := fiber.New()
|
|
app.Use(CORS(gconfig.CORSConfig{AllowedOrigins: []string{"*"}}))
|
|
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
|
|
req.Header.Set("Origin", "http://anything.local")
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.Header.Get("Access-Control-Allow-Origin") != "*" {
|
|
t.Fatalf("expected wildcard origin header")
|
|
}
|
|
}
|
|
|
|
// TestSecurityHeadersMiddleware verifies expected behavior.
|
|
func TestSecurityHeadersMiddleware(t *testing.T) {
|
|
app := fiber.New()
|
|
app.Use(SecurityHeaders(gconfig.SecurityHeadersConfig{
|
|
ContentSecurityPolicy: "default-src 'self'",
|
|
EnableHSTS: true,
|
|
HSTSMaxAge: 31536000,
|
|
FrameOptions: "DENY",
|
|
ContentTypeOptions: true,
|
|
ReferrerPolicy: "strict-origin-when-cross-origin",
|
|
PermissionsPolicy: "geolocation=()",
|
|
}))
|
|
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
if res.Header.Get("Content-Security-Policy") == "" {
|
|
t.Fatalf("missing CSP header")
|
|
}
|
|
if res.Header.Get("X-Frame-Options") != "DENY" {
|
|
t.Fatalf("missing frame options")
|
|
}
|
|
if res.Header.Get("X-Content-Type-Options") != "nosniff" {
|
|
t.Fatalf("missing nosniff")
|
|
}
|
|
if res.Header.Get("Strict-Transport-Security") == "" {
|
|
t.Fatalf("expected HSTS header for https request")
|
|
}
|
|
}
|
|
|
|
// TestSecurityHeadersNoHSTSOnHTTP verifies expected behavior.
|
|
func TestSecurityHeadersNoHSTSOnHTTP(t *testing.T) {
|
|
app := fiber.New()
|
|
app.Use(SecurityHeaders(gconfig.SecurityHeadersConfig{
|
|
EnableHSTS: true,
|
|
HSTSMaxAge: 31536000,
|
|
}))
|
|
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.Header.Get("Strict-Transport-Security") != "" {
|
|
t.Fatalf("did not expect HSTS header on non-https request")
|
|
}
|
|
}
|
|
|
|
// TestRequestContextMiddlewareAndRequestID verifies expected behavior.
|
|
func TestRequestContextMiddlewareAndRequestID(t *testing.T) {
|
|
app := fiber.New()
|
|
app.Use(RequestContext(nil))
|
|
app.Get("/id", func(c fiber.Ctx) error {
|
|
if RequestID(c) == "" {
|
|
return c.SendStatus(http.StatusInternalServerError)
|
|
}
|
|
return c.SendStatus(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/id", nil)
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", res.StatusCode)
|
|
}
|
|
if res.Header.Get("X-Request-ID") == "" {
|
|
t.Fatalf("expected generated request id")
|
|
}
|
|
|
|
req = httptest.NewRequest(http.MethodGet, "/id", nil)
|
|
req.Header.Set("X-Request-ID", "req-123")
|
|
res, err = app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.Header.Get("X-Request-ID") != "req-123" {
|
|
t.Fatalf("expected propagated request id")
|
|
}
|
|
}
|
|
|
|
// TestRequestContextWithLoggerAndInvalidRequestIDLocal verifies expected behavior.
|
|
func TestRequestContextWithLoggerAndInvalidRequestIDLocal(t *testing.T) {
|
|
logger := logging.NewLogger(logging.DefaultConfig())
|
|
app := fiber.New()
|
|
app.Use(RequestContext(logger))
|
|
app.Get("/id", func(c fiber.Ctx) error {
|
|
c.Locals(requestIDKey, 123)
|
|
if RequestID(c) != "" {
|
|
return c.SendStatus(http.StatusInternalServerError)
|
|
}
|
|
return c.SendStatus(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/id", nil)
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", res.StatusCode)
|
|
}
|
|
}
|
|
|
|
// TestRateLimitMiddlewareDegradedModeAndHelpers verifies expected behavior.
|
|
func TestRateLimitMiddlewareDegradedModeAndHelpers(t *testing.T) {
|
|
app := fiber.New()
|
|
mw := RateLimitMiddleware(nil, gconfig.RateLimitConfig{
|
|
GeneralRequests: 7,
|
|
APIRequests: 7,
|
|
}, "/api/v1", nil)
|
|
app.Use(mw)
|
|
app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", res.StatusCode)
|
|
}
|
|
if res.Header.Get("X-RateLimit-Policy") != "degraded" {
|
|
t.Fatalf("expected degraded policy")
|
|
}
|
|
if res.Header.Get("X-RateLimit-Limit") != "7" {
|
|
t.Fatalf("expected limit header")
|
|
}
|
|
|
|
app = fiber.New()
|
|
app.Use(mw)
|
|
app.Options("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusNoContent) })
|
|
req = httptest.NewRequest(http.MethodOptions, "/api/v1/x", nil)
|
|
res, err = app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusNoContent {
|
|
t.Fatalf("expected options pass-through, got %d", res.StatusCode)
|
|
}
|
|
|
|
tiers := map[string]rateTier{
|
|
"general": {Name: "general"},
|
|
"auth": {Name: "auth"},
|
|
"api": {Name: "api"},
|
|
"admin": {Name: "admin"},
|
|
}
|
|
if got := selectTier("/api/v1/admin/auth", "/api/v1", tiers); got.Name != "auth" {
|
|
t.Fatalf("expected auth tier, got %s", got.Name)
|
|
}
|
|
if got := selectTier("/api/v1/admin/x", "/api/v1", tiers); got.Name != "admin" {
|
|
t.Fatalf("expected admin tier, got %s", got.Name)
|
|
}
|
|
if got := selectTier("/api/v1/x", "/api/v1", tiers); got.Name != "api" {
|
|
t.Fatalf("expected api tier, got %s", got.Name)
|
|
}
|
|
if got := selectTier("/x", "/api/v1", tiers); got.Name != "general" {
|
|
t.Fatalf("expected general tier, got %s", got.Name)
|
|
}
|
|
|
|
if v, ok := toInt64(int64(1)); !ok || v != 1 {
|
|
t.Fatalf("toInt64 int64 failed")
|
|
}
|
|
if v, ok := toInt64(int(2)); !ok || v != 2 {
|
|
t.Fatalf("toInt64 int failed")
|
|
}
|
|
if v, ok := toInt64("3"); !ok || v != 3 {
|
|
t.Fatalf("toInt64 string failed")
|
|
}
|
|
if _, ok := toInt64("bad"); ok {
|
|
t.Fatalf("expected invalid string parse failure")
|
|
}
|
|
if _, ok := toInt64(struct{}{}); ok {
|
|
t.Fatalf("expected unsupported type failure")
|
|
}
|
|
|
|
if allow, count, reset, ok := parseLimiterResult([]interface{}{int64(1), int64(2),
|
|
int64(3)}); !ok || !allow || count != 2 || reset != 3 {
|
|
t.Fatalf("unexpected parseLimiterResult output")
|
|
}
|
|
if _, _, _, ok := parseLimiterResult("bad"); ok {
|
|
t.Fatalf("expected parse failure for non-slice")
|
|
}
|
|
if _, _, _, ok := parseLimiterResult([]interface{}{1, 2}); ok {
|
|
t.Fatalf("expected parse failure for short slice")
|
|
}
|
|
if _, _, _, ok := parseLimiterResult([]interface{}{1, 2, struct{}{}}); ok {
|
|
t.Fatalf("expected parse failure for invalid reset")
|
|
}
|
|
|
|
if d := positiveDuration(0, time.Second); d != time.Second {
|
|
t.Fatalf("expected fallback duration")
|
|
}
|
|
if d := positiveDuration(2*time.Second, time.Second); d != 2*time.Second {
|
|
t.Fatalf("expected provided duration")
|
|
}
|
|
if m := maxInt(0, 9); m != 9 {
|
|
t.Fatalf("expected fallback int")
|
|
}
|
|
if m := maxInt(10, 9); m != 10 {
|
|
t.Fatalf("expected provided int")
|
|
}
|
|
}
|
|
|
|
// TestRateLimitMiddlewareRedisAllowedAndBlocked verifies expected behavior.
|
|
func TestRateLimitMiddlewareRedisAllowedAndBlocked(t *testing.T) {
|
|
mr, err := miniredis.Run()
|
|
if err != nil {
|
|
t.Skipf("miniredis unavailable in this environment: %v", err)
|
|
}
|
|
t.Cleanup(func() { mr.Close() })
|
|
client := redisv9.NewClient(&redisv9.Options{Addr: mr.Addr()})
|
|
t.Cleanup(func() { _ = client.Close() })
|
|
|
|
app := fiber.New()
|
|
app.Use(func(c fiber.Ctx) error {
|
|
c.Locals(string(zitadel.ContextKeyUserID), "user-1")
|
|
return c.Next()
|
|
})
|
|
app.Use(RateLimitMiddleware(
|
|
client,
|
|
gconfig.RateLimitConfig{
|
|
APIRequests: 1,
|
|
Window: time.Minute,
|
|
},
|
|
"/api/v1",
|
|
nil,
|
|
))
|
|
app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test first request: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected first request to pass, got %d", res.StatusCode)
|
|
}
|
|
if res.Header.Get("X-RateLimit-Limit") != "1" {
|
|
t.Fatalf("expected rate-limit headers")
|
|
}
|
|
|
|
req = httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
|
|
res, err = app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test second request: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusTooManyRequests {
|
|
t.Fatalf("expected second request to be throttled, got %d", res.StatusCode)
|
|
}
|
|
if res.Header.Get("X-RateLimit-Remaining") != "0" {
|
|
t.Fatalf("expected remaining=0 when throttled")
|
|
}
|
|
}
|
|
|
|
// TestRateLimitMiddlewareRedisErrorDegrades verifies expected behavior.
|
|
func TestRateLimitMiddlewareRedisErrorDegrades(t *testing.T) {
|
|
// Use a client pointing to an unreachable endpoint to force script.Run error.
|
|
client := redisv9.NewClient(&redisv9.Options{Addr: "127.0.0.1:1"})
|
|
t.Cleanup(func() { _ = client.Close() })
|
|
|
|
app := fiber.New()
|
|
app.Use(RateLimitMiddleware(
|
|
client,
|
|
gconfig.RateLimitConfig{
|
|
APIRequests: 5,
|
|
Window: time.Minute,
|
|
},
|
|
"/api/v1",
|
|
nil,
|
|
))
|
|
app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected degraded mode to fail-open with 200, got %d", res.StatusCode)
|
|
}
|
|
if res.Header.Get("X-RateLimit-Policy") != "degraded" {
|
|
t.Fatalf("expected degraded policy header")
|
|
}
|
|
}
|
|
|
|
// TestIdentifyRequester verifies expected behavior.
|
|
func TestIdentifyRequester(t *testing.T) {
|
|
app := fiber.New()
|
|
app.Get("/id", func(c fiber.Ctx) error {
|
|
if c.Query("u") == "1" {
|
|
c.Locals(string(zitadel.ContextKeyUserID), "user-1")
|
|
}
|
|
id := identifyRequester(c)
|
|
if c.Query("u") == "1" && id != "user:user-1" {
|
|
return c.Status(http.StatusInternalServerError).SendString(id)
|
|
}
|
|
if c.Query("u") != "1" && !strings.HasPrefix(id, "ip:") {
|
|
return c.Status(http.StatusInternalServerError).SendString(id)
|
|
}
|
|
return c.SendStatus(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/id?u=1", nil)
|
|
res, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200 for user requester")
|
|
}
|
|
|
|
req = httptest.NewRequest(http.MethodGet, "/id", nil)
|
|
res, err = app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("app.Test: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200 for ip requester")
|
|
}
|
|
}
|