// Package zitadel provides Zitadel authentication client for the KnowFoolery application. package zitadel import ( "context" "encoding/json" "errors" "fmt" "net/http" "net/url" "strings" "sync" "time" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" ) // Config holds the configuration for the Zitadel client. type Config struct { BaseURL string ProjectID string AdminToken string ClientID string ClientSecret string Issuer string Audience string Timeout time.Duration } // DefaultConfig returns a default configuration. func DefaultConfig() Config { return Config{ Timeout: 10 * time.Second, } } // Client provides access to Zitadel authentication services. type Client struct { config Config httpClient *http.Client jwksCache *JWKSCache discoveryMu sync.Mutex discoveryCached discoveryDocument discoveryFetched time.Time jwksMu sync.Mutex jwksURL string jwks keyfunc.Keyfunc } // JWKSCache caches the JSON Web Key Set for token validation. type JWKSCache struct { //mu sync.RWMutex keys map[string]interface{} //expiry time.Time duration time.Duration } // NewJWKSCache creates a new JWKS cache. func NewJWKSCache(cacheDuration time.Duration) *JWKSCache { return &JWKSCache{ keys: make(map[string]interface{}), duration: cacheDuration, } } // NewClient creates a new Zitadel client. func NewClient(config Config) *Client { return &Client{ config: config, httpClient: &http.Client{ Timeout: config.Timeout, }, jwksCache: NewJWKSCache(5 * time.Minute), } } // ValidationOptions controls token validation checks. type ValidationOptions struct { Issuer string Audience string RequiredClaims []string } // AuthClaims represents the claims extracted from a validated JWT. type AuthClaims struct { Subject string `json:"sub"` Email string `json:"email"` Name string `json:"name"` Roles []string `json:"urn:zitadel:iam:org:project:roles"` Audience []string `json:"aud"` Issuer string `json:"iss"` IssuedAt int64 `json:"iat"` ExpiresAt int64 `json:"exp"` AMR []string `json:"amr"` MFAVerified bool `json:"mfa_verified"` } // TokenResponse represents a token response from Zitadel. type TokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` } // UserInfo represents user information from Zitadel. type UserInfo struct { ID string `json:"sub"` Email string `json:"email"` Name string `json:"name"` Verified bool `json:"email_verified"` Roles []string `json:"roles"` } type discoveryDocument struct { Issuer string `json:"issuer"` JWKSURI string `json:"jwks_uri"` TokenEndpoint string `json:"token_endpoint"` RevocationEndpoint string `json:"revocation_endpoint"` } // ValidateToken validates a JWT token and returns the claims. func (c *Client) ValidateToken(ctx context.Context, token string, opts ValidationOptions) (*AuthClaims, error) { discovery, err := c.getDiscovery(ctx) if err != nil { return nil, err } jwks, err := c.getJWKS(discovery.JWKSURI) if err != nil { return nil, err } parsed, err := jwt.ParseWithClaims(token, jwt.MapClaims{}, jwks.KeyfuncCtx(ctx)) if err != nil { return nil, fmt.Errorf("token parse failed: %w", err) } if !parsed.Valid { return nil, errors.New("token is invalid") } claims, ok := parsed.Claims.(jwt.MapClaims) if !ok { return nil, errors.New("invalid token claims") } issuer := opts.Issuer if issuer == "" { issuer = c.config.Issuer } if issuer == "" { issuer = discovery.Issuer } if issuer != "" { if claimIssuer, _ := claims["iss"].(string); claimIssuer != issuer { return nil, fmt.Errorf("invalid issuer: %s", claimIssuer) } } audience := opts.Audience if audience == "" { audience = c.config.Audience } if audience != "" { if !audienceContains(claims["aud"], audience) { return nil, fmt.Errorf("invalid audience: %s", audience) } } for _, claimName := range opts.RequiredClaims { if !hasNonEmptyClaim(claims, claimName) { return nil, fmt.Errorf("missing required claim: %s", claimName) } } authClaims := &AuthClaims{ Subject: getStringClaim(claims, "sub"), Email: getStringClaim(claims, "email"), Name: getStringClaim(claims, "name"), Roles: parseRoles(claims["urn:zitadel:iam:org:project:roles"]), Audience: parseStringSlice(claims["aud"]), Issuer: getStringClaim(claims, "iss"), IssuedAt: getInt64Claim(claims, "iat"), ExpiresAt: getInt64Claim(claims, "exp"), AMR: parseStringSlice(claims["amr"]), } authClaims.MFAVerified = containsAny(authClaims.AMR, "mfa", "otp") return authClaims, nil } // RefreshToken refreshes an access token using a refresh token. func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { discovery, err := c.getDiscovery(ctx) if err != nil { return nil, err } if discovery.TokenEndpoint == "" { return nil, errors.New("token endpoint not available") } form := url.Values{} form.Set("grant_type", "refresh_token") form.Set("refresh_token", refreshToken) form.Set("client_id", c.config.ClientID) form.Set("client_secret", c.config.ClientSecret) req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.TokenEndpoint, strings.NewReader(form.Encode())) if err != nil { return nil, fmt.Errorf("failed to create token refresh request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to refresh token: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to refresh token: status %d", resp.StatusCode) } var tokenResponse TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { return nil, fmt.Errorf("failed to decode token response: %w", err) } return &tokenResponse, nil } // GetUserInfo retrieves user information using an access token. func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) { url := fmt.Sprintf("%s/oidc/v1/userinfo", c.config.BaseURL) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to get user info: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode) } var userInfo UserInfo if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { return nil, fmt.Errorf("failed to decode user info: %w", err) } return &userInfo, nil } // RevokeToken revokes a token. func (c *Client) RevokeToken(ctx context.Context, token string) error { discovery, err := c.getDiscovery(ctx) if err != nil { return err } if discovery.RevocationEndpoint == "" { return errors.New("revocation endpoint not available") } form := url.Values{} form.Set("token", token) form.Set("client_id", c.config.ClientID) form.Set("client_secret", c.config.ClientSecret) req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.RevocationEndpoint, strings.NewReader(form.Encode())) if err != nil { return fmt.Errorf("failed to create revocation request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to revoke token: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("failed to revoke token: status %d", resp.StatusCode) } return nil } func (c *Client) getDiscovery(ctx context.Context) (discoveryDocument, error) { c.discoveryMu.Lock() defer c.discoveryMu.Unlock() if c.discoveryCached.Issuer != "" && time.Since(c.discoveryFetched) < 10*time.Minute { return c.discoveryCached, nil } url := fmt.Sprintf("%s/.well-known/openid-configuration", c.config.BaseURL) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return discoveryDocument{}, fmt.Errorf("failed to create discovery request: %w", err) } resp, err := c.httpClient.Do(req) if err != nil { return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: status %d", resp.StatusCode) } var doc discoveryDocument if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { return discoveryDocument{}, fmt.Errorf("failed to decode discovery document: %w", err) } c.discoveryCached = doc c.discoveryFetched = time.Now() return doc, nil } func (c *Client) getJWKS(jwksURL string) (keyfunc.Keyfunc, error) { c.jwksMu.Lock() defer c.jwksMu.Unlock() if c.jwks != nil && c.jwksURL == jwksURL { return c.jwks, nil } jwks, err := keyfunc.NewDefaultCtx(context.Background(), []string{jwksURL}) if err != nil { return nil, fmt.Errorf("failed to create jwks: %w", err) } c.jwks = jwks c.jwksURL = jwksURL return jwks, nil } func getStringClaim(claims jwt.MapClaims, key string) string { if value, ok := claims[key].(string); ok { return value } return "" } func getInt64Claim(claims jwt.MapClaims, key string) int64 { switch value := claims[key].(type) { case float64: return int64(value) case int64: return value case json.Number: parsed, _ := value.Int64() return parsed default: return 0 } } func hasNonEmptyClaim(claims jwt.MapClaims, key string) bool { value, ok := claims[key] if !ok || value == nil { return false } switch typed := value.(type) { case string: return typed != "" case []interface{}: return len(typed) > 0 case []string: return len(typed) > 0 case map[string]interface{}: return len(typed) > 0 default: return true } } func parseStringSlice(value interface{}) []string { if value == nil { return nil } switch typed := value.(type) { case []string: return typed case []interface{}: result := make([]string, 0, len(typed)) for _, item := range typed { if s, ok := item.(string); ok { result = append(result, s) } } return result case string: if typed == "" { return nil } return []string{typed} default: return nil } } func parseRoles(value interface{}) []string { switch typed := value.(type) { case map[string]interface{}: roles := make([]string, 0, len(typed)) for role := range typed { roles = append(roles, role) } return roles default: return parseStringSlice(value) } } func audienceContains(value interface{}, expected string) bool { if expected == "" { return true } for _, aud := range parseStringSlice(value) { if aud == expected { return true } } return false } func containsAny(values []string, targets ...string) bool { for _, value := range values { for _, target := range targets { if value == target { return true } } } return false }