You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
450 lines
11 KiB
Go
450 lines
11 KiB
Go
// Package zitadel provides Zitadel authentication client for the KnowFoolery application.
|
|
package zitadel
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/MicahParks/keyfunc/v3"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
// Config holds the configuration for the Zitadel client.
|
|
type Config struct {
|
|
BaseURL string
|
|
ProjectID string
|
|
AdminToken string
|
|
ClientID string
|
|
ClientSecret string
|
|
Issuer string
|
|
Audience string
|
|
Timeout time.Duration
|
|
}
|
|
|
|
// DefaultConfig returns a default configuration.
|
|
func DefaultConfig() Config {
|
|
return Config{
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
}
|
|
|
|
// Client provides access to Zitadel authentication services.
|
|
type Client struct {
|
|
config Config
|
|
httpClient *http.Client
|
|
jwksCache *JWKSCache
|
|
|
|
discoveryMu sync.Mutex
|
|
discoveryCached discoveryDocument
|
|
discoveryFetched time.Time
|
|
jwksMu sync.Mutex
|
|
jwksURL string
|
|
jwks keyfunc.Keyfunc
|
|
}
|
|
|
|
// JWKSCache caches the JSON Web Key Set for token validation.
|
|
type JWKSCache struct {
|
|
//mu sync.RWMutex
|
|
keys map[string]interface{}
|
|
//expiry time.Time
|
|
duration time.Duration
|
|
}
|
|
|
|
// NewJWKSCache creates a new JWKS cache.
|
|
func NewJWKSCache(cacheDuration time.Duration) *JWKSCache {
|
|
return &JWKSCache{
|
|
keys: make(map[string]interface{}),
|
|
duration: cacheDuration,
|
|
}
|
|
}
|
|
|
|
// NewClient creates a new Zitadel client.
|
|
func NewClient(config Config) *Client {
|
|
return &Client{
|
|
config: config,
|
|
httpClient: &http.Client{
|
|
Timeout: config.Timeout,
|
|
},
|
|
jwksCache: NewJWKSCache(5 * time.Minute),
|
|
}
|
|
}
|
|
|
|
// ValidationOptions controls token validation checks.
|
|
type ValidationOptions struct {
|
|
Issuer string
|
|
Audience string
|
|
RequiredClaims []string
|
|
}
|
|
|
|
// AuthClaims represents the claims extracted from a validated JWT.
|
|
type AuthClaims struct {
|
|
Subject string `json:"sub"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
Roles []string `json:"urn:zitadel:iam:org:project:roles"`
|
|
Audience []string `json:"aud"`
|
|
Issuer string `json:"iss"`
|
|
IssuedAt int64 `json:"iat"`
|
|
ExpiresAt int64 `json:"exp"`
|
|
AMR []string `json:"amr"`
|
|
MFAVerified bool `json:"mfa_verified"`
|
|
}
|
|
|
|
// TokenResponse represents a token response from Zitadel.
|
|
type TokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
}
|
|
|
|
// UserInfo represents user information from Zitadel.
|
|
type UserInfo struct {
|
|
ID string `json:"sub"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
Verified bool `json:"email_verified"`
|
|
Roles []string `json:"roles"`
|
|
}
|
|
|
|
type discoveryDocument struct {
|
|
Issuer string `json:"issuer"`
|
|
JWKSURI string `json:"jwks_uri"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
RevocationEndpoint string `json:"revocation_endpoint"`
|
|
}
|
|
|
|
// ValidateToken validates a JWT token and returns the claims.
|
|
func (c *Client) ValidateToken(ctx context.Context, token string, opts ValidationOptions) (*AuthClaims, error) {
|
|
discovery, err := c.getDiscovery(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
jwks, err := c.getJWKS(discovery.JWKSURI)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
parsed, err := jwt.ParseWithClaims(token, jwt.MapClaims{}, jwks.KeyfuncCtx(ctx))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("token parse failed: %w", err)
|
|
}
|
|
|
|
if !parsed.Valid {
|
|
return nil, errors.New("token is invalid")
|
|
}
|
|
|
|
claims, ok := parsed.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, errors.New("invalid token claims")
|
|
}
|
|
|
|
issuer := opts.Issuer
|
|
if issuer == "" {
|
|
issuer = c.config.Issuer
|
|
}
|
|
if issuer == "" {
|
|
issuer = discovery.Issuer
|
|
}
|
|
if issuer != "" {
|
|
if claimIssuer, _ := claims["iss"].(string); claimIssuer != issuer {
|
|
return nil, fmt.Errorf("invalid issuer: %s", claimIssuer)
|
|
}
|
|
}
|
|
|
|
audience := opts.Audience
|
|
if audience == "" {
|
|
audience = c.config.Audience
|
|
}
|
|
if audience != "" {
|
|
if !audienceContains(claims["aud"], audience) {
|
|
return nil, fmt.Errorf("invalid audience: %s", audience)
|
|
}
|
|
}
|
|
|
|
for _, claimName := range opts.RequiredClaims {
|
|
if !hasNonEmptyClaim(claims, claimName) {
|
|
return nil, fmt.Errorf("missing required claim: %s", claimName)
|
|
}
|
|
}
|
|
|
|
authClaims := &AuthClaims{
|
|
Subject: getStringClaim(claims, "sub"),
|
|
Email: getStringClaim(claims, "email"),
|
|
Name: getStringClaim(claims, "name"),
|
|
Roles: parseRoles(claims["urn:zitadel:iam:org:project:roles"]),
|
|
Audience: parseStringSlice(claims["aud"]),
|
|
Issuer: getStringClaim(claims, "iss"),
|
|
IssuedAt: getInt64Claim(claims, "iat"),
|
|
ExpiresAt: getInt64Claim(claims, "exp"),
|
|
AMR: parseStringSlice(claims["amr"]),
|
|
}
|
|
authClaims.MFAVerified = containsAny(authClaims.AMR, "mfa", "otp")
|
|
|
|
return authClaims, nil
|
|
}
|
|
|
|
// RefreshToken refreshes an access token using a refresh token.
|
|
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
|
discovery, err := c.getDiscovery(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if discovery.TokenEndpoint == "" {
|
|
return nil, errors.New("token endpoint not available")
|
|
}
|
|
|
|
form := url.Values{}
|
|
form.Set("grant_type", "refresh_token")
|
|
form.Set("refresh_token", refreshToken)
|
|
form.Set("client_id", c.config.ClientID)
|
|
form.Set("client_secret", c.config.ClientSecret)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.TokenEndpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create token refresh request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("failed to refresh token: status %d", resp.StatusCode)
|
|
}
|
|
|
|
var tokenResponse TokenResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
|
|
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
|
}
|
|
|
|
return &tokenResponse, nil
|
|
}
|
|
|
|
// GetUserInfo retrieves user information using an access token.
|
|
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
|
|
url := fmt.Sprintf("%s/oidc/v1/userinfo", c.config.BaseURL)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user info: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode)
|
|
}
|
|
|
|
var userInfo UserInfo
|
|
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
|
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
|
}
|
|
|
|
return &userInfo, nil
|
|
}
|
|
|
|
// RevokeToken revokes a token.
|
|
func (c *Client) RevokeToken(ctx context.Context, token string) error {
|
|
discovery, err := c.getDiscovery(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if discovery.RevocationEndpoint == "" {
|
|
return errors.New("revocation endpoint not available")
|
|
}
|
|
|
|
form := url.Values{}
|
|
form.Set("token", token)
|
|
form.Set("client_id", c.config.ClientID)
|
|
form.Set("client_secret", c.config.ClientSecret)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.RevocationEndpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create revocation request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to revoke token: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
|
return fmt.Errorf("failed to revoke token: status %d", resp.StatusCode)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) getDiscovery(ctx context.Context) (discoveryDocument, error) {
|
|
c.discoveryMu.Lock()
|
|
defer c.discoveryMu.Unlock()
|
|
|
|
if c.discoveryCached.Issuer != "" && time.Since(c.discoveryFetched) < 10*time.Minute {
|
|
return c.discoveryCached, nil
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/.well-known/openid-configuration", c.config.BaseURL)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return discoveryDocument{}, fmt.Errorf("failed to create discovery request: %w", err)
|
|
}
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: status %d", resp.StatusCode)
|
|
}
|
|
|
|
var doc discoveryDocument
|
|
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
|
|
return discoveryDocument{}, fmt.Errorf("failed to decode discovery document: %w", err)
|
|
}
|
|
|
|
c.discoveryCached = doc
|
|
c.discoveryFetched = time.Now()
|
|
return doc, nil
|
|
}
|
|
|
|
func (c *Client) getJWKS(jwksURL string) (keyfunc.Keyfunc, error) {
|
|
c.jwksMu.Lock()
|
|
defer c.jwksMu.Unlock()
|
|
|
|
if c.jwks != nil && c.jwksURL == jwksURL {
|
|
return c.jwks, nil
|
|
}
|
|
|
|
jwks, err := keyfunc.NewDefaultCtx(context.Background(), []string{jwksURL})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create jwks: %w", err)
|
|
}
|
|
|
|
c.jwks = jwks
|
|
c.jwksURL = jwksURL
|
|
return jwks, nil
|
|
}
|
|
|
|
func getStringClaim(claims jwt.MapClaims, key string) string {
|
|
if value, ok := claims[key].(string); ok {
|
|
return value
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func getInt64Claim(claims jwt.MapClaims, key string) int64 {
|
|
switch value := claims[key].(type) {
|
|
case float64:
|
|
return int64(value)
|
|
case int64:
|
|
return value
|
|
case json.Number:
|
|
parsed, _ := value.Int64()
|
|
return parsed
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func hasNonEmptyClaim(claims jwt.MapClaims, key string) bool {
|
|
value, ok := claims[key]
|
|
if !ok || value == nil {
|
|
return false
|
|
}
|
|
switch typed := value.(type) {
|
|
case string:
|
|
return typed != ""
|
|
case []interface{}:
|
|
return len(typed) > 0
|
|
case []string:
|
|
return len(typed) > 0
|
|
case map[string]interface{}:
|
|
return len(typed) > 0
|
|
default:
|
|
return true
|
|
}
|
|
}
|
|
|
|
func parseStringSlice(value interface{}) []string {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
switch typed := value.(type) {
|
|
case []string:
|
|
return typed
|
|
case []interface{}:
|
|
result := make([]string, 0, len(typed))
|
|
for _, item := range typed {
|
|
if s, ok := item.(string); ok {
|
|
result = append(result, s)
|
|
}
|
|
}
|
|
return result
|
|
case string:
|
|
if typed == "" {
|
|
return nil
|
|
}
|
|
return []string{typed}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func parseRoles(value interface{}) []string {
|
|
switch typed := value.(type) {
|
|
case map[string]interface{}:
|
|
roles := make([]string, 0, len(typed))
|
|
for role := range typed {
|
|
roles = append(roles, role)
|
|
}
|
|
return roles
|
|
default:
|
|
return parseStringSlice(value)
|
|
}
|
|
}
|
|
|
|
func audienceContains(value interface{}, expected string) bool {
|
|
if expected == "" {
|
|
return true
|
|
}
|
|
for _, aud := range parseStringSlice(value) {
|
|
if aud == expected {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func containsAny(values []string, targets ...string) bool {
|
|
for _, value := range values {
|
|
for _, target := range targets {
|
|
if value == target {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|