From 3d22f73613908300adca4ea5cc96c3a0e0c63974 Mon Sep 17 00:00:00 2001 From: oabrivard Date: Thu, 5 Feb 2026 23:07:36 +0100 Subject: [PATCH] Finished 'Task 5: Auth (Zitadel + RBAC)' --- backend/go.work.sum | 12 +- backend/shared/go.mod | 4 + backend/shared/go.sum | 12 + backend/shared/infra/auth/zitadel/client.go | 343 +++++++++++++++++- .../shared/infra/auth/zitadel/client_test.go | 254 ++++++++++++- .../shared/infra/auth/zitadel/middleware.go | 8 +- .../infra/auth/zitadel/middleware_test.go | 2 +- 7 files changed, 603 insertions(+), 32 deletions(-) diff --git a/backend/go.work.sum b/backend/go.work.sum index 2662ebc..ecbb44e 100644 --- a/backend/go.work.sum +++ b/backend/go.work.sum @@ -1,18 +1,22 @@ +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.7.0 h1:pdafUNyq+p3ZlvjJX1HWFP7MA3+cLpDtg69U3kITJGM= +github.com/MicahParks/keyfunc/v3 v3.7.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/alecthomas/kingpin/v2 v2.4.0/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE= github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU= @@ -21,8 +25,8 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/backend/shared/go.mod b/backend/shared/go.mod index a72e62c..372060d 100644 --- a/backend/shared/go.mod +++ b/backend/shared/go.mod @@ -3,8 +3,10 @@ module knowfoolery/backend/shared go 1.25.5 require ( + github.com/MicahParks/keyfunc/v3 v3.7.0 github.com/go-playground/validator/v10 v10.25.0 github.com/gofiber/fiber/v3 v3.0.0-beta.3 + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 github.com/prometheus/client_golang v1.20.5 github.com/rs/zerolog v1.33.0 @@ -12,6 +14,7 @@ require ( ) require ( + github.com/MicahParks/jwkset v0.11.0 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -37,6 +40,7 @@ require ( golang.org/x/net v0.34.0 // indirect golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect + golang.org/x/time v0.9.0 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/backend/shared/go.sum b/backend/shared/go.sum index 47d6059..7269615 100644 --- a/backend/shared/go.sum +++ b/backend/shared/go.sum @@ -1,3 +1,7 @@ +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.7.0 h1:pdafUNyq+p3ZlvjJX1HWFP7MA3+cLpDtg69U3kITJGM= +github.com/MicahParks/keyfunc/v3 v3.7.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -23,6 +27,8 @@ github.com/gofiber/fiber/v3 v3.0.0-beta.3 h1:7Q2I+HsIqnIEEDB+9oe7Gadpakh6ZLhXpTY github.com/gofiber/fiber/v3 v3.0.0-beta.3/go.mod h1:kcMur0Dxqk91R7p4vxEpJfDWZ9u5IfvrtQc8Bvv/JmY= github.com/gofiber/utils/v2 v2.0.0-beta.4 h1:1gjbVFFwVwUb9arPcqiB6iEjHBwo7cHsyS41NeIW3co= github.com/gofiber/utils/v2 v2.0.0-beta.4/go.mod h1:sdRsPU1FXX6YiDGGxd+q2aPJRMzpsxdzCXo9dz+xtOY= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -30,6 +36,7 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= @@ -54,6 +61,7 @@ github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65 github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= @@ -76,8 +84,12 @@ golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/backend/shared/infra/auth/zitadel/client.go b/backend/shared/infra/auth/zitadel/client.go index 45f24e2..1c6bd44 100644 --- a/backend/shared/infra/auth/zitadel/client.go +++ b/backend/shared/infra/auth/zitadel/client.go @@ -4,17 +4,28 @@ package zitadel import ( "context" "encoding/json" + "errors" "fmt" "net/http" + "net/url" + "strings" + "sync" "time" + + "github.com/MicahParks/keyfunc/v3" + "github.com/golang-jwt/jwt/v5" ) // Config holds the configuration for the Zitadel client. type Config struct { - BaseURL string - ProjectID string - AdminToken string - Timeout time.Duration + BaseURL string + ProjectID string + AdminToken string + ClientID string + ClientSecret string + Issuer string + Audience string + Timeout time.Duration } // DefaultConfig returns a default configuration. @@ -29,6 +40,13 @@ type Client struct { config Config httpClient *http.Client jwksCache *JWKSCache + + discoveryMu sync.Mutex + discoveryCached discoveryDocument + discoveryFetched time.Time + jwksMu sync.Mutex + jwksURL string + jwks keyfunc.Keyfunc } // JWKSCache caches the JSON Web Key Set for token validation. @@ -58,6 +76,13 @@ func NewClient(config Config) *Client { } } +// ValidationOptions controls token validation checks. +type ValidationOptions struct { + Issuer string + Audience string + RequiredClaims []string +} + // AuthClaims represents the claims extracted from a validated JWT. type AuthClaims struct { Subject string `json:"sub"` @@ -68,7 +93,8 @@ type AuthClaims struct { Issuer string `json:"iss"` IssuedAt int64 `json:"iat"` ExpiresAt int64 `json:"exp"` - MFAVerified bool `json:"amr"` + AMR []string `json:"amr"` + MFAVerified bool `json:"mfa_verified"` } // TokenResponse represents a token response from Zitadel. @@ -88,22 +114,122 @@ type UserInfo struct { Roles []string `json:"roles"` } +type discoveryDocument struct { + Issuer string `json:"issuer"` + JWKSURI string `json:"jwks_uri"` + TokenEndpoint string `json:"token_endpoint"` + RevocationEndpoint string `json:"revocation_endpoint"` +} + // ValidateToken validates a JWT token and returns the claims. -// This is a placeholder implementation that should be replaced with actual JWT validation. -func (c *Client) ValidateToken(ctx context.Context, token string) (*AuthClaims, error) { - // TODO: Implement actual JWT validation with JWKS - // This is a placeholder that should: - // 1. Fetch JWKS from Zitadel - // 2. Parse and validate the JWT - // 3. Verify signature, expiration, issuer, audience - // 4. Return the claims - return nil, fmt.Errorf("not implemented: token validation requires JWKS integration") +func (c *Client) ValidateToken(ctx context.Context, token string, opts ValidationOptions) (*AuthClaims, error) { + discovery, err := c.getDiscovery(ctx) + if err != nil { + return nil, err + } + + jwks, err := c.getJWKS(discovery.JWKSURI) + if err != nil { + return nil, err + } + + parsed, err := jwt.ParseWithClaims(token, jwt.MapClaims{}, jwks.KeyfuncCtx(ctx)) + if err != nil { + return nil, fmt.Errorf("token parse failed: %w", err) + } + + if !parsed.Valid { + return nil, errors.New("token is invalid") + } + + claims, ok := parsed.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("invalid token claims") + } + + issuer := opts.Issuer + if issuer == "" { + issuer = c.config.Issuer + } + if issuer == "" { + issuer = discovery.Issuer + } + if issuer != "" { + if claimIssuer, _ := claims["iss"].(string); claimIssuer != issuer { + return nil, fmt.Errorf("invalid issuer: %s", claimIssuer) + } + } + + audience := opts.Audience + if audience == "" { + audience = c.config.Audience + } + if audience != "" { + if !audienceContains(claims["aud"], audience) { + return nil, fmt.Errorf("invalid audience: %s", audience) + } + } + + for _, claimName := range opts.RequiredClaims { + if !hasNonEmptyClaim(claims, claimName) { + return nil, fmt.Errorf("missing required claim: %s", claimName) + } + } + + authClaims := &AuthClaims{ + Subject: getStringClaim(claims, "sub"), + Email: getStringClaim(claims, "email"), + Name: getStringClaim(claims, "name"), + Roles: parseRoles(claims["urn:zitadel:iam:org:project:roles"]), + Audience: parseStringSlice(claims["aud"]), + Issuer: getStringClaim(claims, "iss"), + IssuedAt: getInt64Claim(claims, "iat"), + ExpiresAt: getInt64Claim(claims, "exp"), + AMR: parseStringSlice(claims["amr"]), + } + authClaims.MFAVerified = containsAny(authClaims.AMR, "mfa", "otp") + + return authClaims, nil } // RefreshToken refreshes an access token using a refresh token. func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { - // TODO: Implement token refresh - return nil, fmt.Errorf("not implemented: token refresh") + discovery, err := c.getDiscovery(ctx) + if err != nil { + return nil, err + } + if discovery.TokenEndpoint == "" { + return nil, errors.New("token endpoint not available") + } + + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", refreshToken) + form.Set("client_id", c.config.ClientID) + form.Set("client_secret", c.config.ClientSecret) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.TokenEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create token refresh request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to refresh token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to refresh token: status %d", resp.StatusCode) + } + + var tokenResponse TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + return nil, fmt.Errorf("failed to decode token response: %w", err) + } + + return &tokenResponse, nil } // GetUserInfo retrieves user information using an access token. @@ -137,6 +263,187 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo // RevokeToken revokes a token. func (c *Client) RevokeToken(ctx context.Context, token string) error { - // TODO: Implement token revocation - return fmt.Errorf("not implemented: token revocation") + discovery, err := c.getDiscovery(ctx) + if err != nil { + return err + } + if discovery.RevocationEndpoint == "" { + return errors.New("revocation endpoint not available") + } + + form := url.Values{} + form.Set("token", token) + form.Set("client_id", c.config.ClientID) + form.Set("client_secret", c.config.ClientSecret) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.RevocationEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return fmt.Errorf("failed to create revocation request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to revoke token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return fmt.Errorf("failed to revoke token: status %d", resp.StatusCode) + } + + return nil +} + +func (c *Client) getDiscovery(ctx context.Context) (discoveryDocument, error) { + c.discoveryMu.Lock() + defer c.discoveryMu.Unlock() + + if c.discoveryCached.Issuer != "" && time.Since(c.discoveryFetched) < 10*time.Minute { + return c.discoveryCached, nil + } + + url := fmt.Sprintf("%s/.well-known/openid-configuration", c.config.BaseURL) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return discoveryDocument{}, fmt.Errorf("failed to create discovery request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: status %d", resp.StatusCode) + } + + var doc discoveryDocument + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + return discoveryDocument{}, fmt.Errorf("failed to decode discovery document: %w", err) + } + + c.discoveryCached = doc + c.discoveryFetched = time.Now() + return doc, nil +} + +func (c *Client) getJWKS(jwksURL string) (keyfunc.Keyfunc, error) { + c.jwksMu.Lock() + defer c.jwksMu.Unlock() + + if c.jwks != nil && c.jwksURL == jwksURL { + return c.jwks, nil + } + + jwks, err := keyfunc.NewDefaultCtx(context.Background(), []string{jwksURL}) + if err != nil { + return nil, fmt.Errorf("failed to create jwks: %w", err) + } + + c.jwks = jwks + c.jwksURL = jwksURL + return jwks, nil +} + +func getStringClaim(claims jwt.MapClaims, key string) string { + if value, ok := claims[key].(string); ok { + return value + } + return "" +} + +func getInt64Claim(claims jwt.MapClaims, key string) int64 { + switch value := claims[key].(type) { + case float64: + return int64(value) + case int64: + return value + case json.Number: + parsed, _ := value.Int64() + return parsed + default: + return 0 + } +} + +func hasNonEmptyClaim(claims jwt.MapClaims, key string) bool { + value, ok := claims[key] + if !ok || value == nil { + return false + } + switch typed := value.(type) { + case string: + return typed != "" + case []interface{}: + return len(typed) > 0 + case []string: + return len(typed) > 0 + case map[string]interface{}: + return len(typed) > 0 + default: + return true + } +} + +func parseStringSlice(value interface{}) []string { + if value == nil { + return nil + } + switch typed := value.(type) { + case []string: + return typed + case []interface{}: + result := make([]string, 0, len(typed)) + for _, item := range typed { + if s, ok := item.(string); ok { + result = append(result, s) + } + } + return result + case string: + if typed == "" { + return nil + } + return []string{typed} + default: + return nil + } +} + +func parseRoles(value interface{}) []string { + switch typed := value.(type) { + case map[string]interface{}: + roles := make([]string, 0, len(typed)) + for role := range typed { + roles = append(roles, role) + } + return roles + default: + return parseStringSlice(value) + } +} + +func audienceContains(value interface{}, expected string) bool { + if expected == "" { + return true + } + for _, aud := range parseStringSlice(value) { + if aud == expected { + return true + } + } + return false +} + +func containsAny(values []string, targets ...string) bool { + for _, value := range values { + for _, target := range targets { + if value == target { + return true + } + } + } + return false } diff --git a/backend/shared/infra/auth/zitadel/client_test.go b/backend/shared/infra/auth/zitadel/client_test.go index 673e829..033c85b 100644 --- a/backend/shared/infra/auth/zitadel/client_test.go +++ b/backend/shared/infra/auth/zitadel/client_test.go @@ -1,18 +1,112 @@ package zitadel -// Tests for Zitadel client user info calls and placeholder methods. +// Tests for Zitadel client user info calls, JWT validation, and token operations. import ( "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" "encoding/json" + "io" + "math/big" "net/http" "net/http/httptest" + "net/url" "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" ) +type jwksKey struct { + Kty string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"use"` + Alg string `json:"alg"` + N string `json:"n"` + E string `json:"e"` +} + +type jwks struct { + Keys []jwksKey `json:"keys"` +} + +func generateJWKS(t *testing.T) (*rsa.PrivateKey, jwks, string) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + kid := "kid-1" + n := base64.RawURLEncoding.EncodeToString(key.N.Bytes()) + e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(key.E)).Bytes()) + + return key, jwks{ + Keys: []jwksKey{ + { + Kty: "RSA", + Kid: kid, + Use: "sig", + Alg: "RS256", + N: n, + E: e, + }, + }, + }, kid +} + +func signToken(t *testing.T, key *rsa.PrivateKey, kid string, claims jwt.MapClaims) string { + t.Helper() + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = kid + signed, err := token.SignedString(key) + require.NoError(t, err) + return signed +} + +func newOIDCServer(t *testing.T, jwksDoc jwks) *httptest.Server { + t.Helper() + var baseURL string + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": baseURL, + "jwks_uri": baseURL + "/jwks", + "token_endpoint": baseURL + "/token", + "revocation_endpoint": baseURL + "/revoke", + }) + case "/jwks": + _ = json.NewEncoder(w).Encode(jwksDoc) + case "/token": + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + if r.FormValue("grant_type") != "refresh_token" { + w.WriteHeader(http.StatusBadRequest) + return + } + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + case "/revoke": + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + server := httptest.NewServer(handler) + baseURL = server.URL + return server +} + // TestGetUserInfo_Success verifies user info retrieval on a 200 response. func TestGetUserInfo_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -46,15 +140,161 @@ func TestGetUserInfo_NonOK(t *testing.T) { require.Error(t, err) } -// TestClient_NotImplementedMethods verifies placeholder methods return errors. -func TestClient_NotImplementedMethods(t *testing.T) { - client := NewClient(DefaultConfig()) +// TestValidateToken_Success verifies JWT validation using JWKS and discovery. +func TestValidateToken_Success(t *testing.T) { + key, jwksDoc, kid := generateJWKS(t) + server := newOIDCServer(t, jwksDoc) + defer server.Close() - _, err := client.ValidateToken(context.Background(), "token") + claims := jwt.MapClaims{ + "sub": "user-1", + "email": "a@b.com", + "name": "Alice", + "aud": "client-1", + "iss": server.URL, + "iat": time.Now().Unix(), + "exp": time.Now().Add(10 * time.Minute).Unix(), + "urn:zitadel:iam:org:project:roles": map[string]interface{}{ + "admin": map[string]interface{}{}, + }, + "amr": []string{"otp"}, + } + token := signToken(t, key, kid, claims) + + client := NewClient(Config{BaseURL: server.URL, Timeout: 2 * time.Second}) + parsed, err := client.ValidateToken(context.Background(), token, ValidationOptions{ + Issuer: server.URL, + Audience: "client-1", + }) + require.NoError(t, err) + require.Equal(t, "user-1", parsed.Subject) + require.True(t, parsed.MFAVerified) + require.Contains(t, parsed.Roles, "admin") +} + +// TestValidateToken_InvalidAudience verifies audience validation. +func TestValidateToken_InvalidAudience(t *testing.T) { + key, jwksDoc, kid := generateJWKS(t) + server := newOIDCServer(t, jwksDoc) + defer server.Close() + + token := signToken(t, key, kid, jwt.MapClaims{ + "sub": "user-1", + "aud": "other", + "iss": server.URL, + "exp": time.Now().Add(5 * time.Minute).Unix(), + }) + + client := NewClient(Config{BaseURL: server.URL, Timeout: 2 * time.Second}) + _, err := client.ValidateToken(context.Background(), token, ValidationOptions{ + Issuer: server.URL, + Audience: "client-1", + }) require.Error(t, err) +} + +// TestValidateToken_MissingRequiredClaim verifies required claim enforcement. +func TestValidateToken_MissingRequiredClaim(t *testing.T) { + key, jwksDoc, kid := generateJWKS(t) + server := newOIDCServer(t, jwksDoc) + defer server.Close() + + token := signToken(t, key, kid, jwt.MapClaims{ + "sub": "user-1", + "aud": "client-1", + "iss": server.URL, + "exp": time.Now().Add(5 * time.Minute).Unix(), + }) - _, err = client.RefreshToken(context.Background(), "refresh") + client := NewClient(Config{BaseURL: server.URL, Timeout: 2 * time.Second}) + _, err := client.ValidateToken(context.Background(), token, ValidationOptions{ + Issuer: server.URL, + Audience: "client-1", + RequiredClaims: []string{"email"}, + }) require.Error(t, err) +} + +// TestRefreshToken_Success verifies token refresh against discovery endpoint. +func TestRefreshToken_Success(t *testing.T) { + key, jwksDoc, _ := generateJWKS(t) + _ = key + server := newOIDCServer(t, jwksDoc) + defer server.Close() + + client := NewClient(Config{ + BaseURL: server.URL, + ClientID: "client", + ClientSecret: "secret", + Timeout: 2 * time.Second, + }) + + resp, err := client.RefreshToken(context.Background(), "refresh-token") + require.NoError(t, err) + require.Equal(t, "new-access", resp.AccessToken) +} + +// TestRevokeToken_Success verifies token revocation against discovery endpoint. +func TestRevokeToken_Success(t *testing.T) { + key, jwksDoc, _ := generateJWKS(t) + _ = key + server := newOIDCServer(t, jwksDoc) + defer server.Close() + + client := NewClient(Config{ + BaseURL: server.URL, + ClientID: "client", + ClientSecret: "secret", + Timeout: 2 * time.Second, + }) + + err := client.RevokeToken(context.Background(), "token") + require.NoError(t, err) +} + +// TestRefreshToken_UsesForm ensures refresh uses form-encoded requests. +func TestRefreshToken_UsesForm(t *testing.T) { + key, jwksDoc, _ := generateJWKS(t) + _ = key + + var captured string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": "test", + "jwks_uri": serverURL(r), + "token_endpoint": serverURL(r) + "/token", + }) + return + } + if r.URL.Path == "/jwks" { + _ = json.NewEncoder(w).Encode(jwksDoc) + return + } + if r.URL.Path == "/token" { + body, _ := io.ReadAll(r.Body) + captured = string(body) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "access", + }) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + client := NewClient(Config{ + BaseURL: server.URL, + ClientID: "client", + ClientSecret: "secret", + Timeout: 2 * time.Second, + }) + _, err := client.RefreshToken(context.Background(), "refresh-token") + require.NoError(t, err) + require.Contains(t, captured, url.QueryEscape("refresh-token")) + require.Contains(t, captured, "client_id=client") +} - require.Error(t, client.RevokeToken(context.Background(), "token")) +func serverURL(r *http.Request) string { + return "http://" + r.Host } diff --git a/backend/shared/infra/auth/zitadel/middleware.go b/backend/shared/infra/auth/zitadel/middleware.go index 67ed0ab..f160e12 100644 --- a/backend/shared/infra/auth/zitadel/middleware.go +++ b/backend/shared/infra/auth/zitadel/middleware.go @@ -36,7 +36,7 @@ type JWTMiddlewareConfig struct { // TokenValidator defines the interface for validating JWT tokens. type TokenValidator interface { - ValidateToken(ctx context.Context, token string) (*AuthClaims, error) + ValidateToken(ctx context.Context, token string, opts ValidationOptions) (*AuthClaims, error) } // JWTMiddleware creates a Fiber middleware for JWT validation. @@ -69,7 +69,11 @@ func JWTMiddleware(config JWTMiddlewareConfig) fiber.Handler { // Validate token ctx := c.Context() - claims, err := config.Client.ValidateToken(ctx, tokenString) + claims, err := config.Client.ValidateToken(ctx, tokenString, ValidationOptions{ + Issuer: config.Issuer, + Audience: config.Audience, + RequiredClaims: config.RequiredClaims, + }) if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "error": true, diff --git a/backend/shared/infra/auth/zitadel/middleware_test.go b/backend/shared/infra/auth/zitadel/middleware_test.go index f535272..cb519fe 100644 --- a/backend/shared/infra/auth/zitadel/middleware_test.go +++ b/backend/shared/infra/auth/zitadel/middleware_test.go @@ -19,7 +19,7 @@ type fakeValidator struct { called int } -func (f *fakeValidator) ValidateToken(ctx context.Context, token string) (*AuthClaims, error) { +func (f *fakeValidator) ValidateToken(ctx context.Context, token string, _ ValidationOptions) (*AuthClaims, error) { f.called++ return f.claims, f.err }