Improved backend test coverage
parent
7106e22d19
commit
6cadc15448
@ -0,0 +1,134 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v3"
|
||||||
|
|
||||||
|
"knowfoolery/backend/shared/infra/utils/validation"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUnauthorizedBranches(t *testing.T) {
|
||||||
|
h := NewHandler(nil, validation.NewValidator(), nil, nil)
|
||||||
|
app := fiber.New()
|
||||||
|
app.Post("/sessions/start", h.StartSession)
|
||||||
|
app.Post("/sessions/end", h.EndSession)
|
||||||
|
app.Post("/sessions/:id/hint", h.RequestHint)
|
||||||
|
app.Get("/sessions/:id", h.GetSession)
|
||||||
|
app.Get("/sessions/:id/question", h.GetCurrentQuestion)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
}{
|
||||||
|
{http.MethodPost, "/sessions/start"},
|
||||||
|
{http.MethodPost, "/sessions/end"},
|
||||||
|
{http.MethodPost, "/sessions/s1/hint"},
|
||||||
|
{http.MethodGet, "/sessions/s1"},
|
||||||
|
{http.MethodGet, "/sessions/s1/question"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
req := httptest.NewRequest(tc.method, tc.path, nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test(%s %s): %v", tc.method, tc.path, err)
|
||||||
|
}
|
||||||
|
if res.StatusCode != http.StatusUnauthorized {
|
||||||
|
_ = res.Body.Close()
|
||||||
|
t.Fatalf("expected unauthorized for %s %s, got %d", tc.method, tc.path, res.StatusCode)
|
||||||
|
}
|
||||||
|
_ = res.Body.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartAndEndValidationBranches(t *testing.T) {
|
||||||
|
h := NewHandler(nil, validation.NewValidator(), nil, nil)
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(func(c fiber.Ctx) error {
|
||||||
|
c.Locals("user_id", "user-1")
|
||||||
|
c.Locals("user_roles", []string{"player"})
|
||||||
|
return c.Next()
|
||||||
|
})
|
||||||
|
app.Post("/sessions/start", h.StartSession)
|
||||||
|
app.Post("/sessions/end", h.EndSession)
|
||||||
|
app.Post("/sessions/:id/answer", h.SubmitAnswer)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/sessions/start", bytes.NewReader([]byte("{")))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test start malformed: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected bad request for malformed start body, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/sessions/start", bytes.NewReader([]byte(`{"preferred_theme":"a"}`)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test start validation: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected bad request for start validation, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/sessions/end", bytes.NewReader([]byte("{")))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test end malformed: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected bad request for malformed end body, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/sessions/end", bytes.NewReader([]byte(`{"session_id":""}`)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test end validation: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected bad request for end validation, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBearerTokenAndClaimsHelpers(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(func(c fiber.Ctx) error {
|
||||||
|
c.Locals("user_id", "admin-1")
|
||||||
|
c.Locals("user_roles", []string{"admin"})
|
||||||
|
return c.Next()
|
||||||
|
})
|
||||||
|
app.Get("/token", func(c fiber.Ctx) error {
|
||||||
|
if got := bearerToken(c); got != "abc" {
|
||||||
|
return c.SendStatus(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
claims := authClaimsFromContext(c)
|
||||||
|
if !claims.IsAdmin || claims.UserID != "admin-1" {
|
||||||
|
return c.SendStatus(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return c.SendStatus(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/token", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer abc")
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test helper route: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected helper route success, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,417 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/gofiber/fiber/v3"
|
||||||
|
redisv9 "github.com/redis/go-redis/v9"
|
||||||
|
|
||||||
|
gconfig "knowfoolery/backend/services/gateway-service/internal/infra/config"
|
||||||
|
"knowfoolery/backend/shared/infra/auth/zitadel"
|
||||||
|
"knowfoolery/backend/shared/infra/observability/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCORSMiddleware(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(CORS(gconfig.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"http://localhost:5173"},
|
||||||
|
AllowedMethods: "GET,POST,OPTIONS",
|
||||||
|
AllowedHeaders: "Content-Type,Authorization",
|
||||||
|
AllowCredentials: true,
|
||||||
|
MaxAgeSeconds: 120,
|
||||||
|
}))
|
||||||
|
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.Header.Get("Access-Control-Allow-Origin") != "" {
|
||||||
|
t.Fatalf("expected no CORS header without origin")
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodOptions, "/ok", nil)
|
||||||
|
req.Header.Set("Origin", "http://evil.local")
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected forbidden preflight for disallowed origin, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodOptions, "/ok", nil)
|
||||||
|
req.Header.Set("Origin", "http://localhost:5173")
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusNoContent {
|
||||||
|
t.Fatalf("expected no content preflight, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
if res.Header.Get("Access-Control-Allow-Origin") != "http://localhost:5173" {
|
||||||
|
t.Fatalf("unexpected allow origin: %s", res.Header.Get("Access-Control-Allow-Origin"))
|
||||||
|
}
|
||||||
|
if res.Header.Get("Access-Control-Allow-Credentials") != "true" {
|
||||||
|
t.Fatalf("missing allow credentials")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSAllowAll(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(CORS(gconfig.CORSConfig{AllowedOrigins: []string{"*"}}))
|
||||||
|
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
|
||||||
|
req.Header.Set("Origin", "http://anything.local")
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.Header.Get("Access-Control-Allow-Origin") != "*" {
|
||||||
|
t.Fatalf("expected wildcard origin header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityHeadersMiddleware(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(SecurityHeaders(gconfig.SecurityHeadersConfig{
|
||||||
|
ContentSecurityPolicy: "default-src 'self'",
|
||||||
|
EnableHSTS: true,
|
||||||
|
HSTSMaxAge: 31536000,
|
||||||
|
FrameOptions: "DENY",
|
||||||
|
ContentTypeOptions: true,
|
||||||
|
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||||
|
PermissionsPolicy: "geolocation=()",
|
||||||
|
}))
|
||||||
|
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
|
||||||
|
req.Header.Set("X-Forwarded-Proto", "https")
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.Header.Get("Content-Security-Policy") == "" {
|
||||||
|
t.Fatalf("missing CSP header")
|
||||||
|
}
|
||||||
|
if res.Header.Get("X-Frame-Options") != "DENY" {
|
||||||
|
t.Fatalf("missing frame options")
|
||||||
|
}
|
||||||
|
if res.Header.Get("X-Content-Type-Options") != "nosniff" {
|
||||||
|
t.Fatalf("missing nosniff")
|
||||||
|
}
|
||||||
|
if res.Header.Get("Strict-Transport-Security") == "" {
|
||||||
|
t.Fatalf("expected HSTS header for https request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityHeadersNoHSTSOnHTTP(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(SecurityHeaders(gconfig.SecurityHeadersConfig{
|
||||||
|
EnableHSTS: true,
|
||||||
|
HSTSMaxAge: 31536000,
|
||||||
|
}))
|
||||||
|
app.Get("/ok", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/ok", nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.Header.Get("Strict-Transport-Security") != "" {
|
||||||
|
t.Fatalf("did not expect HSTS header on non-https request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestContextMiddlewareAndRequestID(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(RequestContext(nil))
|
||||||
|
app.Get("/id", func(c fiber.Ctx) error {
|
||||||
|
if RequestID(c) == "" {
|
||||||
|
return c.SendStatus(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return c.SendStatus(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/id", nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
if res.Header.Get("X-Request-ID") == "" {
|
||||||
|
t.Fatalf("expected generated request id")
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/id", nil)
|
||||||
|
req.Header.Set("X-Request-ID", "req-123")
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.Header.Get("X-Request-ID") != "req-123" {
|
||||||
|
t.Fatalf("expected propagated request id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestContextWithLoggerAndInvalidRequestIDLocal(t *testing.T) {
|
||||||
|
logger := logging.NewLogger(logging.DefaultConfig())
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(RequestContext(logger))
|
||||||
|
app.Get("/id", func(c fiber.Ctx) error {
|
||||||
|
c.Locals(requestIDKey, 123)
|
||||||
|
if RequestID(c) != "" {
|
||||||
|
return c.SendStatus(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return c.SendStatus(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/id", nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitMiddlewareDegradedModeAndHelpers(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
mw := RateLimitMiddleware(nil, gconfig.RateLimitConfig{
|
||||||
|
GeneralRequests: 7,
|
||||||
|
APIRequests: 7,
|
||||||
|
}, "/api/v1", nil)
|
||||||
|
app.Use(mw)
|
||||||
|
app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
if res.Header.Get("X-RateLimit-Policy") != "degraded" {
|
||||||
|
t.Fatalf("expected degraded policy")
|
||||||
|
}
|
||||||
|
if res.Header.Get("X-RateLimit-Limit") != "7" {
|
||||||
|
t.Fatalf("expected limit header")
|
||||||
|
}
|
||||||
|
|
||||||
|
app = fiber.New()
|
||||||
|
app.Use(mw)
|
||||||
|
app.Options("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusNoContent) })
|
||||||
|
req = httptest.NewRequest(http.MethodOptions, "/api/v1/x", nil)
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusNoContent {
|
||||||
|
t.Fatalf("expected options pass-through, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
tiers := map[string]rateTier{
|
||||||
|
"general": {Name: "general"},
|
||||||
|
"auth": {Name: "auth"},
|
||||||
|
"api": {Name: "api"},
|
||||||
|
"admin": {Name: "admin"},
|
||||||
|
}
|
||||||
|
if got := selectTier("/api/v1/admin/auth", "/api/v1", tiers); got.Name != "auth" {
|
||||||
|
t.Fatalf("expected auth tier, got %s", got.Name)
|
||||||
|
}
|
||||||
|
if got := selectTier("/api/v1/admin/x", "/api/v1", tiers); got.Name != "admin" {
|
||||||
|
t.Fatalf("expected admin tier, got %s", got.Name)
|
||||||
|
}
|
||||||
|
if got := selectTier("/api/v1/x", "/api/v1", tiers); got.Name != "api" {
|
||||||
|
t.Fatalf("expected api tier, got %s", got.Name)
|
||||||
|
}
|
||||||
|
if got := selectTier("/x", "/api/v1", tiers); got.Name != "general" {
|
||||||
|
t.Fatalf("expected general tier, got %s", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := toInt64(int64(1)); !ok || v != 1 {
|
||||||
|
t.Fatalf("toInt64 int64 failed")
|
||||||
|
}
|
||||||
|
if v, ok := toInt64(int(2)); !ok || v != 2 {
|
||||||
|
t.Fatalf("toInt64 int failed")
|
||||||
|
}
|
||||||
|
if v, ok := toInt64("3"); !ok || v != 3 {
|
||||||
|
t.Fatalf("toInt64 string failed")
|
||||||
|
}
|
||||||
|
if _, ok := toInt64("bad"); ok {
|
||||||
|
t.Fatalf("expected invalid string parse failure")
|
||||||
|
}
|
||||||
|
if _, ok := toInt64(struct{}{}); ok {
|
||||||
|
t.Fatalf("expected unsupported type failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
if allow, count, reset, ok := parseLimiterResult([]interface{}{int64(1), int64(2),
|
||||||
|
int64(3)}); !ok || !allow || count != 2 || reset != 3 {
|
||||||
|
t.Fatalf("unexpected parseLimiterResult output")
|
||||||
|
}
|
||||||
|
if _, _, _, ok := parseLimiterResult("bad"); ok {
|
||||||
|
t.Fatalf("expected parse failure for non-slice")
|
||||||
|
}
|
||||||
|
if _, _, _, ok := parseLimiterResult([]interface{}{1, 2}); ok {
|
||||||
|
t.Fatalf("expected parse failure for short slice")
|
||||||
|
}
|
||||||
|
if _, _, _, ok := parseLimiterResult([]interface{}{1, 2, struct{}{}}); ok {
|
||||||
|
t.Fatalf("expected parse failure for invalid reset")
|
||||||
|
}
|
||||||
|
|
||||||
|
if d := positiveDuration(0, time.Second); d != time.Second {
|
||||||
|
t.Fatalf("expected fallback duration")
|
||||||
|
}
|
||||||
|
if d := positiveDuration(2*time.Second, time.Second); d != 2*time.Second {
|
||||||
|
t.Fatalf("expected provided duration")
|
||||||
|
}
|
||||||
|
if m := maxInt(0, 9); m != 9 {
|
||||||
|
t.Fatalf("expected fallback int")
|
||||||
|
}
|
||||||
|
if m := maxInt(10, 9); m != 10 {
|
||||||
|
t.Fatalf("expected provided int")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitMiddlewareRedisAllowedAndBlocked(t *testing.T) {
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("miniredis unavailable in this environment: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { mr.Close() })
|
||||||
|
client := redisv9.NewClient(&redisv9.Options{Addr: mr.Addr()})
|
||||||
|
t.Cleanup(func() { _ = client.Close() })
|
||||||
|
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(func(c fiber.Ctx) error {
|
||||||
|
c.Locals(string(zitadel.ContextKeyUserID), "user-1")
|
||||||
|
return c.Next()
|
||||||
|
})
|
||||||
|
app.Use(RateLimitMiddleware(
|
||||||
|
client,
|
||||||
|
gconfig.RateLimitConfig{
|
||||||
|
APIRequests: 1,
|
||||||
|
Window: time.Minute,
|
||||||
|
},
|
||||||
|
"/api/v1",
|
||||||
|
nil,
|
||||||
|
))
|
||||||
|
app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test first request: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected first request to pass, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
if res.Header.Get("X-RateLimit-Limit") != "1" {
|
||||||
|
t.Fatalf("expected rate-limit headers")
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test second request: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("expected second request to be throttled, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
if res.Header.Get("X-RateLimit-Remaining") != "0" {
|
||||||
|
t.Fatalf("expected remaining=0 when throttled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitMiddlewareRedisErrorDegrades(t *testing.T) {
|
||||||
|
// Use a client pointing to an unreachable endpoint to force script.Run error.
|
||||||
|
client := redisv9.NewClient(&redisv9.Options{Addr: "127.0.0.1:1"})
|
||||||
|
t.Cleanup(func() { _ = client.Close() })
|
||||||
|
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(RateLimitMiddleware(
|
||||||
|
client,
|
||||||
|
gconfig.RateLimitConfig{
|
||||||
|
APIRequests: 5,
|
||||||
|
Window: time.Minute,
|
||||||
|
},
|
||||||
|
"/api/v1",
|
||||||
|
nil,
|
||||||
|
))
|
||||||
|
app.Get("/api/v1/x", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/x", nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected degraded mode to fail-open with 200, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
if res.Header.Get("X-RateLimit-Policy") != "degraded" {
|
||||||
|
t.Fatalf("expected degraded policy header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIdentifyRequester(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
app.Get("/id", func(c fiber.Ctx) error {
|
||||||
|
if c.Query("u") == "1" {
|
||||||
|
c.Locals(string(zitadel.ContextKeyUserID), "user-1")
|
||||||
|
}
|
||||||
|
id := identifyRequester(c)
|
||||||
|
if c.Query("u") == "1" && id != "user:user-1" {
|
||||||
|
return c.Status(http.StatusInternalServerError).SendString(id)
|
||||||
|
}
|
||||||
|
if c.Query("u") != "1" && !strings.HasPrefix(id, "ip:") {
|
||||||
|
return c.Status(http.StatusInternalServerError).SendString(id)
|
||||||
|
}
|
||||||
|
return c.SendStatus(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/id?u=1", nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200 for user requester")
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/id", nil)
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200 for ip requester")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,120 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v3"
|
||||||
|
|
||||||
|
"knowfoolery/backend/shared/infra/utils/validation"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateAuthAndValidationBranches(t *testing.T) {
|
||||||
|
h := NewHandler(nil, validation.NewValidator(), nil, nil, true, 20, 100)
|
||||||
|
app := fiber.New()
|
||||||
|
app.Post("/leaderboard/update", h.Update)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/leaderboard/update", bytes.NewReader([]byte(`{}`)))
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected forbidden without auth claims, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
app = fiber.New()
|
||||||
|
app.Use(func(c fiber.Ctx) error {
|
||||||
|
c.Locals("user_id", "user-1")
|
||||||
|
c.Locals("user_roles", []string{"player"})
|
||||||
|
return c.Next()
|
||||||
|
})
|
||||||
|
app.Post("/leaderboard/update", h.Update)
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/leaderboard/update", bytes.NewReader([]byte(`{}`)))
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected forbidden for non-service/non-admin, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
app = fiber.New()
|
||||||
|
app.Use(func(c fiber.Ctx) error {
|
||||||
|
c.Locals("user_id", "svc-1")
|
||||||
|
c.Locals("user_roles", []string{"service"})
|
||||||
|
return c.Next()
|
||||||
|
})
|
||||||
|
app.Post("/leaderboard/update", h.Update)
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/leaderboard/update", bytes.NewReader([]byte("{")))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected bad request for malformed json, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
validShape := `{
|
||||||
|
"session_id":"s1",
|
||||||
|
"player_id":"u1",
|
||||||
|
"player_name":"P",
|
||||||
|
"total_score":1,
|
||||||
|
"questions_asked":1,
|
||||||
|
"questions_correct":1,
|
||||||
|
"hints_used":0,
|
||||||
|
"duration_seconds":10,
|
||||||
|
"completed_at":"not-rfc3339",
|
||||||
|
"completion_type":"completed"
|
||||||
|
}`
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/leaderboard/update", bytes.NewReader([]byte(validShape)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
res, err = app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected bad request for completed_at format, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPlayerRankingForbiddenBranch(t *testing.T) {
|
||||||
|
h := NewHandler(nil, validation.NewValidator(), nil, nil, false, 20, 100)
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(func(c fiber.Ctx) error {
|
||||||
|
c.Locals("user_id", "user-2")
|
||||||
|
c.Locals("user_roles", []string{"player"})
|
||||||
|
return c.Next()
|
||||||
|
})
|
||||||
|
app.Get("/leaderboard/players/:id", h.GetPlayerRanking)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/leaderboard/players/user-1?page=oops&page_size=-1", nil)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected forbidden for non-owner non-admin, got %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHelperFunctions(t *testing.T) {
|
||||||
|
if got := atoiWithDefault("", 3); got != 3 {
|
||||||
|
t.Fatalf("expected default for empty input")
|
||||||
|
}
|
||||||
|
if got := atoiWithDefault("bad", 4); got != 4 {
|
||||||
|
t.Fatalf("expected default for invalid input")
|
||||||
|
}
|
||||||
|
if got := atoiWithDefault("7", 4); got != 7 {
|
||||||
|
t.Fatalf("expected parsed int")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue