package zitadel // Tests for JWT middleware authentication, admin checks, and bypass paths. import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) type fakeValidator struct { claims *AuthClaims err error called int } func (f *fakeValidator) ValidateToken(ctx context.Context, token string) (*AuthClaims, error) { f.called++ return f.claims, f.err } // TestJWTMiddleware_Success verifies valid tokens populate context and allow requests. func TestJWTMiddleware_Success(t *testing.T) { validator := &fakeValidator{claims: &AuthClaims{ Subject: "user-1", Email: "a@b.com", Name: "Alice", Roles: []string{"player"}, MFAVerified: true, }} app := fiber.New() app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator})) app.Get("/", func(c fiber.Ctx) error { return c.JSON(fiber.Map{ "user_id": GetUserID(c), "email": GetUserEmail(c), "roles": GetUserRoles(c), }) }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer token") resp, err := app.Test(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) var body map[string]interface{} require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) require.Equal(t, "user-1", body["user_id"]) } // TestJWTMiddleware_MissingHeader verifies missing Authorization header returns 401. func TestJWTMiddleware_MissingHeader(t *testing.T) { validator := &fakeValidator{} app := fiber.New() app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator})) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/", nil) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, resp.StatusCode) } // TestJWTMiddleware_InvalidHeaderFormat verifies malformed Authorization header returns 401. func TestJWTMiddleware_InvalidHeaderFormat(t *testing.T) { validator := &fakeValidator{} app := fiber.New() app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator})) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Token abc") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, resp.StatusCode) } // TestJWTMiddleware_AdminRoleRequired ensures admin paths reject non-admin roles. func TestJWTMiddleware_AdminRoleRequired(t *testing.T) { validator := &fakeValidator{claims: &AuthClaims{ Subject: "user-1", Roles: []string{"player"}, MFAVerified: true, }} app := fiber.New() app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator, AdminEndpoints: []string{"/admin"}})) app.Get("/admin/stats", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/admin/stats", nil) req.Header.Set("Authorization", "Bearer token") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusForbidden, resp.StatusCode) } // TestJWTMiddleware_MFARequiredForAdmin ensures admin paths require MFA verification. func TestJWTMiddleware_MFARequiredForAdmin(t *testing.T) { validator := &fakeValidator{claims: &AuthClaims{ Subject: "user-1", Roles: []string{"admin"}, MFAVerified: false, }} app := fiber.New() app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator, AdminEndpoints: []string{"/admin"}})) app.Get("/admin/stats", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/admin/stats", nil) req.Header.Set("Authorization", "Bearer token") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusForbidden, resp.StatusCode) } // TestJWTMiddleware_SkipPath verifies skip paths bypass token validation. func TestJWTMiddleware_SkipPath(t *testing.T) { validator := &fakeValidator{err: fiber.ErrUnauthorized} app := fiber.New() app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator, SkipPaths: []string{"/public"}})) app.Get("/public/health", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/public/health", nil) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, 0, validator.called) }