You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
4.7 KiB
Go

package zitadel
// Tests for JWT middleware authentication, admin checks, and bypass paths.
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
type fakeValidator struct {
claims *AuthClaims
err error
called int
}
// ValidateToken simulates authentication token validation outcomes.
func (f *fakeValidator) ValidateToken(ctx context.Context, token string, _ ValidationOptions) (*AuthClaims, error) {
f.called++
return f.claims, f.err
}
// TestJWTMiddleware_Success ensures jwt middleware success behavior is handled correctly.
func TestJWTMiddleware_Success(t *testing.T) {
validator := &fakeValidator{claims: &AuthClaims{
Subject: "user-1",
Email: "a@b.com",
Name: "Alice",
Roles: []string{"player"},
MFAVerified: true,
}}
app := fiber.New()
app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator}))
app.Get("/", func(c fiber.Ctx) error {
return c.JSON(fiber.Map{
"user_id": GetUserID(c),
"email": GetUserEmail(c),
"roles": GetUserRoles(c),
})
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
var body map[string]interface{}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, "user-1", body["user_id"])
}
// TestJWTMiddleware_MissingHeader ensures jwt middleware missing header behavior is handled correctly.
func TestJWTMiddleware_MissingHeader(t *testing.T) {
validator := &fakeValidator{}
app := fiber.New()
app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator}))
app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/", nil)
resp, err := app.Test(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
// TestJWTMiddleware_InvalidHeaderFormat ensures jwt middleware invalid header format behavior is handled correctly.
func TestJWTMiddleware_InvalidHeaderFormat(t *testing.T) {
validator := &fakeValidator{}
app := fiber.New()
app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator}))
app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Token abc")
resp, err := app.Test(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
// TestJWTMiddleware_AdminRoleRequired ensures jwt middleware admin role required behavior is handled correctly.
func TestJWTMiddleware_AdminRoleRequired(t *testing.T) {
validator := &fakeValidator{claims: &AuthClaims{
Subject: "user-1",
Roles: []string{"player"},
MFAVerified: true,
}}
app := fiber.New()
app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator, AdminEndpoints: []string{"/admin"}}))
app.Get("/admin/stats", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/admin/stats", nil)
req.Header.Set("Authorization", "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusForbidden, resp.StatusCode)
}
// TestJWTMiddleware_MFARequiredForAdmin ensures jwt middleware mfa required for admin behavior is handled correctly.
func TestJWTMiddleware_MFARequiredForAdmin(t *testing.T) {
validator := &fakeValidator{claims: &AuthClaims{
Subject: "user-1",
Roles: []string{"admin"},
MFAVerified: false,
}}
app := fiber.New()
app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator, AdminEndpoints: []string{"/admin"}}))
app.Get("/admin/stats", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/admin/stats", nil)
req.Header.Set("Authorization", "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusForbidden, resp.StatusCode)
}
// TestJWTMiddleware_SkipPath ensures jwt middleware skip path behavior is handled correctly.
func TestJWTMiddleware_SkipPath(t *testing.T) {
validator := &fakeValidator{err: fiber.ErrUnauthorized}
app := fiber.New()
app.Use(JWTMiddleware(JWTMiddlewareConfig{Client: validator, SkipPaths: []string{"/public"}}))
app.Get("/public/health", func(c fiber.Ctx) error { return c.SendStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/public/health", nil)
resp, err := app.Test(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 0, validator.called)
}