Finished 'Task 5: Auth (Zitadel + RBAC)'

master
oabrivard 1 month ago
parent 689be9228c
commit 3d22f73613

@ -1,18 +1,22 @@
github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ=
github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0=
github.com/MicahParks/keyfunc/v3 v3.7.0 h1:pdafUNyq+p3ZlvjJX1HWFP7MA3+cLpDtg69U3kITJGM=
github.com/MicahParks/keyfunc/v3 v3.7.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0=
github.com/alecthomas/kingpin/v2 v2.4.0/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE=
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE=
github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0=
github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw=
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
@ -21,8 +25,8 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=

@ -3,8 +3,10 @@ module knowfoolery/backend/shared
go 1.25.5
require (
github.com/MicahParks/keyfunc/v3 v3.7.0
github.com/go-playground/validator/v10 v10.25.0
github.com/gofiber/fiber/v3 v3.0.0-beta.3
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/google/uuid v1.6.0
github.com/prometheus/client_golang v1.20.5
github.com/rs/zerolog v1.33.0
@ -12,6 +14,7 @@ require (
)
require (
github.com/MicahParks/jwkset v0.11.0 // indirect
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@ -37,6 +40,7 @@ require (
golang.org/x/net v0.34.0 // indirect
golang.org/x/sys v0.29.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.9.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

@ -1,3 +1,7 @@
github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ=
github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0=
github.com/MicahParks/keyfunc/v3 v3.7.0 h1:pdafUNyq+p3ZlvjJX1HWFP7MA3+cLpDtg69U3kITJGM=
github.com/MicahParks/keyfunc/v3 v3.7.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0=
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@ -23,6 +27,8 @@ github.com/gofiber/fiber/v3 v3.0.0-beta.3 h1:7Q2I+HsIqnIEEDB+9oe7Gadpakh6ZLhXpTY
github.com/gofiber/fiber/v3 v3.0.0-beta.3/go.mod h1:kcMur0Dxqk91R7p4vxEpJfDWZ9u5IfvrtQc8Bvv/JmY=
github.com/gofiber/utils/v2 v2.0.0-beta.4 h1:1gjbVFFwVwUb9arPcqiB6iEjHBwo7cHsyS41NeIW3co=
github.com/gofiber/utils/v2 v2.0.0-beta.4/go.mod h1:sdRsPU1FXX6YiDGGxd+q2aPJRMzpsxdzCXo9dz+xtOY=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@ -30,6 +36,7 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
@ -54,6 +61,7 @@ github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
@ -76,8 +84,12 @@ golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

@ -4,17 +4,28 @@ package zitadel
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/MicahParks/keyfunc/v3"
"github.com/golang-jwt/jwt/v5"
)
// Config holds the configuration for the Zitadel client.
type Config struct {
BaseURL string
ProjectID string
AdminToken string
Timeout time.Duration
BaseURL string
ProjectID string
AdminToken string
ClientID string
ClientSecret string
Issuer string
Audience string
Timeout time.Duration
}
// DefaultConfig returns a default configuration.
@ -29,6 +40,13 @@ type Client struct {
config Config
httpClient *http.Client
jwksCache *JWKSCache
discoveryMu sync.Mutex
discoveryCached discoveryDocument
discoveryFetched time.Time
jwksMu sync.Mutex
jwksURL string
jwks keyfunc.Keyfunc
}
// JWKSCache caches the JSON Web Key Set for token validation.
@ -58,6 +76,13 @@ func NewClient(config Config) *Client {
}
}
// ValidationOptions controls token validation checks.
type ValidationOptions struct {
Issuer string
Audience string
RequiredClaims []string
}
// AuthClaims represents the claims extracted from a validated JWT.
type AuthClaims struct {
Subject string `json:"sub"`
@ -68,7 +93,8 @@ type AuthClaims struct {
Issuer string `json:"iss"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp"`
MFAVerified bool `json:"amr"`
AMR []string `json:"amr"`
MFAVerified bool `json:"mfa_verified"`
}
// TokenResponse represents a token response from Zitadel.
@ -88,22 +114,122 @@ type UserInfo struct {
Roles []string `json:"roles"`
}
type discoveryDocument struct {
Issuer string `json:"issuer"`
JWKSURI string `json:"jwks_uri"`
TokenEndpoint string `json:"token_endpoint"`
RevocationEndpoint string `json:"revocation_endpoint"`
}
// ValidateToken validates a JWT token and returns the claims.
// This is a placeholder implementation that should be replaced with actual JWT validation.
func (c *Client) ValidateToken(ctx context.Context, token string) (*AuthClaims, error) {
// TODO: Implement actual JWT validation with JWKS
// This is a placeholder that should:
// 1. Fetch JWKS from Zitadel
// 2. Parse and validate the JWT
// 3. Verify signature, expiration, issuer, audience
// 4. Return the claims
return nil, fmt.Errorf("not implemented: token validation requires JWKS integration")
func (c *Client) ValidateToken(ctx context.Context, token string, opts ValidationOptions) (*AuthClaims, error) {
discovery, err := c.getDiscovery(ctx)
if err != nil {
return nil, err
}
jwks, err := c.getJWKS(discovery.JWKSURI)
if err != nil {
return nil, err
}
parsed, err := jwt.ParseWithClaims(token, jwt.MapClaims{}, jwks.KeyfuncCtx(ctx))
if err != nil {
return nil, fmt.Errorf("token parse failed: %w", err)
}
if !parsed.Valid {
return nil, errors.New("token is invalid")
}
claims, ok := parsed.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid token claims")
}
issuer := opts.Issuer
if issuer == "" {
issuer = c.config.Issuer
}
if issuer == "" {
issuer = discovery.Issuer
}
if issuer != "" {
if claimIssuer, _ := claims["iss"].(string); claimIssuer != issuer {
return nil, fmt.Errorf("invalid issuer: %s", claimIssuer)
}
}
audience := opts.Audience
if audience == "" {
audience = c.config.Audience
}
if audience != "" {
if !audienceContains(claims["aud"], audience) {
return nil, fmt.Errorf("invalid audience: %s", audience)
}
}
for _, claimName := range opts.RequiredClaims {
if !hasNonEmptyClaim(claims, claimName) {
return nil, fmt.Errorf("missing required claim: %s", claimName)
}
}
authClaims := &AuthClaims{
Subject: getStringClaim(claims, "sub"),
Email: getStringClaim(claims, "email"),
Name: getStringClaim(claims, "name"),
Roles: parseRoles(claims["urn:zitadel:iam:org:project:roles"]),
Audience: parseStringSlice(claims["aud"]),
Issuer: getStringClaim(claims, "iss"),
IssuedAt: getInt64Claim(claims, "iat"),
ExpiresAt: getInt64Claim(claims, "exp"),
AMR: parseStringSlice(claims["amr"]),
}
authClaims.MFAVerified = containsAny(authClaims.AMR, "mfa", "otp")
return authClaims, nil
}
// RefreshToken refreshes an access token using a refresh token.
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
// TODO: Implement token refresh
return nil, fmt.Errorf("not implemented: token refresh")
discovery, err := c.getDiscovery(ctx)
if err != nil {
return nil, err
}
if discovery.TokenEndpoint == "" {
return nil, errors.New("token endpoint not available")
}
form := url.Values{}
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
form.Set("client_id", c.config.ClientID)
form.Set("client_secret", c.config.ClientSecret)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to refresh token: status %d", resp.StatusCode)
}
var tokenResponse TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
return nil, fmt.Errorf("failed to decode token response: %w", err)
}
return &tokenResponse, nil
}
// GetUserInfo retrieves user information using an access token.
@ -137,6 +263,187 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
// RevokeToken revokes a token.
func (c *Client) RevokeToken(ctx context.Context, token string) error {
// TODO: Implement token revocation
return fmt.Errorf("not implemented: token revocation")
discovery, err := c.getDiscovery(ctx)
if err != nil {
return err
}
if discovery.RevocationEndpoint == "" {
return errors.New("revocation endpoint not available")
}
form := url.Values{}
form.Set("token", token)
form.Set("client_id", c.config.ClientID)
form.Set("client_secret", c.config.ClientSecret)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.RevocationEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return fmt.Errorf("failed to create revocation request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to revoke token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("failed to revoke token: status %d", resp.StatusCode)
}
return nil
}
func (c *Client) getDiscovery(ctx context.Context) (discoveryDocument, error) {
c.discoveryMu.Lock()
defer c.discoveryMu.Unlock()
if c.discoveryCached.Issuer != "" && time.Since(c.discoveryFetched) < 10*time.Minute {
return c.discoveryCached, nil
}
url := fmt.Sprintf("%s/.well-known/openid-configuration", c.config.BaseURL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return discoveryDocument{}, fmt.Errorf("failed to create discovery request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: status %d", resp.StatusCode)
}
var doc discoveryDocument
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
return discoveryDocument{}, fmt.Errorf("failed to decode discovery document: %w", err)
}
c.discoveryCached = doc
c.discoveryFetched = time.Now()
return doc, nil
}
func (c *Client) getJWKS(jwksURL string) (keyfunc.Keyfunc, error) {
c.jwksMu.Lock()
defer c.jwksMu.Unlock()
if c.jwks != nil && c.jwksURL == jwksURL {
return c.jwks, nil
}
jwks, err := keyfunc.NewDefaultCtx(context.Background(), []string{jwksURL})
if err != nil {
return nil, fmt.Errorf("failed to create jwks: %w", err)
}
c.jwks = jwks
c.jwksURL = jwksURL
return jwks, nil
}
func getStringClaim(claims jwt.MapClaims, key string) string {
if value, ok := claims[key].(string); ok {
return value
}
return ""
}
func getInt64Claim(claims jwt.MapClaims, key string) int64 {
switch value := claims[key].(type) {
case float64:
return int64(value)
case int64:
return value
case json.Number:
parsed, _ := value.Int64()
return parsed
default:
return 0
}
}
func hasNonEmptyClaim(claims jwt.MapClaims, key string) bool {
value, ok := claims[key]
if !ok || value == nil {
return false
}
switch typed := value.(type) {
case string:
return typed != ""
case []interface{}:
return len(typed) > 0
case []string:
return len(typed) > 0
case map[string]interface{}:
return len(typed) > 0
default:
return true
}
}
func parseStringSlice(value interface{}) []string {
if value == nil {
return nil
}
switch typed := value.(type) {
case []string:
return typed
case []interface{}:
result := make([]string, 0, len(typed))
for _, item := range typed {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
case string:
if typed == "" {
return nil
}
return []string{typed}
default:
return nil
}
}
func parseRoles(value interface{}) []string {
switch typed := value.(type) {
case map[string]interface{}:
roles := make([]string, 0, len(typed))
for role := range typed {
roles = append(roles, role)
}
return roles
default:
return parseStringSlice(value)
}
}
func audienceContains(value interface{}, expected string) bool {
if expected == "" {
return true
}
for _, aud := range parseStringSlice(value) {
if aud == expected {
return true
}
}
return false
}
func containsAny(values []string, targets ...string) bool {
for _, value := range values {
for _, target := range targets {
if value == target {
return true
}
}
}
return false
}

@ -1,18 +1,112 @@
package zitadel
// Tests for Zitadel client user info calls and placeholder methods.
// Tests for Zitadel client user info calls, JWT validation, and token operations.
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
type jwksKey struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
}
type jwks struct {
Keys []jwksKey `json:"keys"`
}
func generateJWKS(t *testing.T) (*rsa.PrivateKey, jwks, string) {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
kid := "kid-1"
n := base64.RawURLEncoding.EncodeToString(key.N.Bytes())
e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(key.E)).Bytes())
return key, jwks{
Keys: []jwksKey{
{
Kty: "RSA",
Kid: kid,
Use: "sig",
Alg: "RS256",
N: n,
E: e,
},
},
}, kid
}
func signToken(t *testing.T, key *rsa.PrivateKey, kid string, claims jwt.MapClaims) string {
t.Helper()
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = kid
signed, err := token.SignedString(key)
require.NoError(t, err)
return signed
}
func newOIDCServer(t *testing.T, jwksDoc jwks) *httptest.Server {
t.Helper()
var baseURL string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
_ = json.NewEncoder(w).Encode(map[string]string{
"issuer": baseURL,
"jwks_uri": baseURL + "/jwks",
"token_endpoint": baseURL + "/token",
"revocation_endpoint": baseURL + "/revoke",
})
case "/jwks":
_ = json.NewEncoder(w).Encode(jwksDoc)
case "/token":
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
if r.FormValue("grant_type") != "refresh_token" {
w.WriteHeader(http.StatusBadRequest)
return
}
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "new-access",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
})
case "/revoke":
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
server := httptest.NewServer(handler)
baseURL = server.URL
return server
}
// TestGetUserInfo_Success verifies user info retrieval on a 200 response.
func TestGetUserInfo_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -46,15 +140,161 @@ func TestGetUserInfo_NonOK(t *testing.T) {
require.Error(t, err)
}
// TestClient_NotImplementedMethods verifies placeholder methods return errors.
func TestClient_NotImplementedMethods(t *testing.T) {
client := NewClient(DefaultConfig())
// TestValidateToken_Success verifies JWT validation using JWKS and discovery.
func TestValidateToken_Success(t *testing.T) {
key, jwksDoc, kid := generateJWKS(t)
server := newOIDCServer(t, jwksDoc)
defer server.Close()
_, err := client.ValidateToken(context.Background(), "token")
claims := jwt.MapClaims{
"sub": "user-1",
"email": "a@b.com",
"name": "Alice",
"aud": "client-1",
"iss": server.URL,
"iat": time.Now().Unix(),
"exp": time.Now().Add(10 * time.Minute).Unix(),
"urn:zitadel:iam:org:project:roles": map[string]interface{}{
"admin": map[string]interface{}{},
},
"amr": []string{"otp"},
}
token := signToken(t, key, kid, claims)
client := NewClient(Config{BaseURL: server.URL, Timeout: 2 * time.Second})
parsed, err := client.ValidateToken(context.Background(), token, ValidationOptions{
Issuer: server.URL,
Audience: "client-1",
})
require.NoError(t, err)
require.Equal(t, "user-1", parsed.Subject)
require.True(t, parsed.MFAVerified)
require.Contains(t, parsed.Roles, "admin")
}
// TestValidateToken_InvalidAudience verifies audience validation.
func TestValidateToken_InvalidAudience(t *testing.T) {
key, jwksDoc, kid := generateJWKS(t)
server := newOIDCServer(t, jwksDoc)
defer server.Close()
token := signToken(t, key, kid, jwt.MapClaims{
"sub": "user-1",
"aud": "other",
"iss": server.URL,
"exp": time.Now().Add(5 * time.Minute).Unix(),
})
client := NewClient(Config{BaseURL: server.URL, Timeout: 2 * time.Second})
_, err := client.ValidateToken(context.Background(), token, ValidationOptions{
Issuer: server.URL,
Audience: "client-1",
})
require.Error(t, err)
}
// TestValidateToken_MissingRequiredClaim verifies required claim enforcement.
func TestValidateToken_MissingRequiredClaim(t *testing.T) {
key, jwksDoc, kid := generateJWKS(t)
server := newOIDCServer(t, jwksDoc)
defer server.Close()
token := signToken(t, key, kid, jwt.MapClaims{
"sub": "user-1",
"aud": "client-1",
"iss": server.URL,
"exp": time.Now().Add(5 * time.Minute).Unix(),
})
_, err = client.RefreshToken(context.Background(), "refresh")
client := NewClient(Config{BaseURL: server.URL, Timeout: 2 * time.Second})
_, err := client.ValidateToken(context.Background(), token, ValidationOptions{
Issuer: server.URL,
Audience: "client-1",
RequiredClaims: []string{"email"},
})
require.Error(t, err)
}
// TestRefreshToken_Success verifies token refresh against discovery endpoint.
func TestRefreshToken_Success(t *testing.T) {
key, jwksDoc, _ := generateJWKS(t)
_ = key
server := newOIDCServer(t, jwksDoc)
defer server.Close()
client := NewClient(Config{
BaseURL: server.URL,
ClientID: "client",
ClientSecret: "secret",
Timeout: 2 * time.Second,
})
resp, err := client.RefreshToken(context.Background(), "refresh-token")
require.NoError(t, err)
require.Equal(t, "new-access", resp.AccessToken)
}
// TestRevokeToken_Success verifies token revocation against discovery endpoint.
func TestRevokeToken_Success(t *testing.T) {
key, jwksDoc, _ := generateJWKS(t)
_ = key
server := newOIDCServer(t, jwksDoc)
defer server.Close()
client := NewClient(Config{
BaseURL: server.URL,
ClientID: "client",
ClientSecret: "secret",
Timeout: 2 * time.Second,
})
err := client.RevokeToken(context.Background(), "token")
require.NoError(t, err)
}
// TestRefreshToken_UsesForm ensures refresh uses form-encoded requests.
func TestRefreshToken_UsesForm(t *testing.T) {
key, jwksDoc, _ := generateJWKS(t)
_ = key
var captured string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/openid-configuration" {
_ = json.NewEncoder(w).Encode(map[string]string{
"issuer": "test",
"jwks_uri": serverURL(r),
"token_endpoint": serverURL(r) + "/token",
})
return
}
if r.URL.Path == "/jwks" {
_ = json.NewEncoder(w).Encode(jwksDoc)
return
}
if r.URL.Path == "/token" {
body, _ := io.ReadAll(r.Body)
captured = string(body)
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "access",
})
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
client := NewClient(Config{
BaseURL: server.URL,
ClientID: "client",
ClientSecret: "secret",
Timeout: 2 * time.Second,
})
_, err := client.RefreshToken(context.Background(), "refresh-token")
require.NoError(t, err)
require.Contains(t, captured, url.QueryEscape("refresh-token"))
require.Contains(t, captured, "client_id=client")
}
require.Error(t, client.RevokeToken(context.Background(), "token"))
func serverURL(r *http.Request) string {
return "http://" + r.Host
}

@ -36,7 +36,7 @@ type JWTMiddlewareConfig struct {
// TokenValidator defines the interface for validating JWT tokens.
type TokenValidator interface {
ValidateToken(ctx context.Context, token string) (*AuthClaims, error)
ValidateToken(ctx context.Context, token string, opts ValidationOptions) (*AuthClaims, error)
}
// JWTMiddleware creates a Fiber middleware for JWT validation.
@ -69,7 +69,11 @@ func JWTMiddleware(config JWTMiddlewareConfig) fiber.Handler {
// Validate token
ctx := c.Context()
claims, err := config.Client.ValidateToken(ctx, tokenString)
claims, err := config.Client.ValidateToken(ctx, tokenString, ValidationOptions{
Issuer: config.Issuer,
Audience: config.Audience,
RequiredClaims: config.RequiredClaims,
})
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": true,

@ -19,7 +19,7 @@ type fakeValidator struct {
called int
}
func (f *fakeValidator) ValidateToken(ctx context.Context, token string) (*AuthClaims, error) {
func (f *fakeValidator) ValidateToken(ctx context.Context, token string, _ ValidationOptions) (*AuthClaims, error) {
f.called++
return f.claims, f.err
}

Loading…
Cancel
Save