You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

451 lines
11 KiB
Go

// Package zitadel provides Zitadel authentication client for the KnowFoolery application.
package zitadel
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/MicahParks/keyfunc/v3"
"github.com/golang-jwt/jwt/v5"
)
// Config holds the configuration for the Zitadel client.
type Config struct {
BaseURL string
ProjectID string
AdminToken string
ClientID string
ClientSecret string
Issuer string
Audience string
Timeout time.Duration
}
// DefaultConfig returns a default configuration.
func DefaultConfig() Config {
return Config{
Timeout: 10 * time.Second,
}
}
// Client provides access to Zitadel authentication services.
type Client struct {
config Config
httpClient *http.Client
jwksCache *JWKSCache
discoveryMu sync.Mutex
discoveryCached discoveryDocument
discoveryFetched time.Time
jwksMu sync.Mutex
jwksURL string
jwks keyfunc.Keyfunc
}
// JWKSCache caches the JSON Web Key Set for token validation.
type JWKSCache struct {
//mu sync.RWMutex
keys map[string]interface{}
//expiry time.Time
duration time.Duration
}
// NewJWKSCache creates a new JWKS cache.
func NewJWKSCache(cacheDuration time.Duration) *JWKSCache {
return &JWKSCache{
keys: make(map[string]interface{}),
duration: cacheDuration,
}
}
// NewClient creates a new Zitadel client.
func NewClient(config Config) *Client {
return &Client{
config: config,
httpClient: &http.Client{
Timeout: config.Timeout,
},
jwksCache: NewJWKSCache(5 * time.Minute),
}
}
// ValidationOptions controls token validation checks.
type ValidationOptions struct {
Issuer string
Audience string
RequiredClaims []string
}
// AuthClaims represents the claims extracted from a validated JWT.
type AuthClaims struct {
Subject string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
Roles []string `json:"urn:zitadel:iam:org:project:roles"`
Audience []string `json:"aud"`
Issuer string `json:"iss"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp"`
AMR []string `json:"amr"`
MFAVerified bool `json:"mfa_verified"`
}
// TokenResponse represents a token response from Zitadel.
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
// UserInfo represents user information from Zitadel.
type UserInfo struct {
ID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
Verified bool `json:"email_verified"`
Roles []string `json:"roles"`
}
type discoveryDocument struct {
Issuer string `json:"issuer"`
JWKSURI string `json:"jwks_uri"`
TokenEndpoint string `json:"token_endpoint"`
RevocationEndpoint string `json:"revocation_endpoint"`
}
// ValidateToken validates a JWT token and returns the claims.
func (c *Client) ValidateToken(ctx context.Context, token string, opts ValidationOptions) (*AuthClaims, error) {
discovery, err := c.getDiscovery(ctx)
if err != nil {
return nil, err
}
jwks, err := c.getJWKS(discovery.JWKSURI)
if err != nil {
return nil, err
}
parsed, err := jwt.ParseWithClaims(token, jwt.MapClaims{}, jwks.KeyfuncCtx(ctx))
if err != nil {
return nil, fmt.Errorf("token parse failed: %w", err)
}
if !parsed.Valid {
return nil, errors.New("token is invalid")
}
claims, ok := parsed.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid token claims")
}
issuer := opts.Issuer
if issuer == "" {
issuer = c.config.Issuer
}
if issuer == "" {
issuer = discovery.Issuer
}
if issuer != "" {
if claimIssuer, _ := claims["iss"].(string); claimIssuer != issuer {
return nil, fmt.Errorf("invalid issuer: %s", claimIssuer)
}
}
audience := opts.Audience
if audience == "" {
audience = c.config.Audience
}
if audience != "" {
if !audienceContains(claims["aud"], audience) {
return nil, fmt.Errorf("invalid audience: %s", audience)
}
}
for _, claimName := range opts.RequiredClaims {
if !hasNonEmptyClaim(claims, claimName) {
return nil, fmt.Errorf("missing required claim: %s", claimName)
}
}
authClaims := &AuthClaims{
Subject: getStringClaim(claims, "sub"),
Email: getStringClaim(claims, "email"),
Name: getStringClaim(claims, "name"),
Roles: parseRoles(claims["urn:zitadel:iam:org:project:roles"]),
Audience: parseStringSlice(claims["aud"]),
Issuer: getStringClaim(claims, "iss"),
IssuedAt: getInt64Claim(claims, "iat"),
ExpiresAt: getInt64Claim(claims, "exp"),
AMR: parseStringSlice(claims["amr"]),
}
authClaims.MFAVerified = containsAny(authClaims.AMR, "mfa", "otp")
return authClaims, nil
}
// RefreshToken refreshes an access token using a refresh token.
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
discovery, err := c.getDiscovery(ctx)
if err != nil {
return nil, err
}
if discovery.TokenEndpoint == "" {
return nil, errors.New("token endpoint not available")
}
form := url.Values{}
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
form.Set("client_id", c.config.ClientID)
form.Set("client_secret", c.config.ClientSecret)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to refresh token: status %d", resp.StatusCode)
}
var tokenResponse TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
return nil, fmt.Errorf("failed to decode token response: %w", err)
}
return &tokenResponse, nil
}
// GetUserInfo retrieves user information using an access token.
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
url := fmt.Sprintf("%s/oidc/v1/userinfo", c.config.BaseURL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode)
}
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
return &userInfo, nil
}
// RevokeToken revokes a token.
func (c *Client) RevokeToken(ctx context.Context, token string) error {
discovery, err := c.getDiscovery(ctx)
if err != nil {
return err
}
if discovery.RevocationEndpoint == "" {
return errors.New("revocation endpoint not available")
}
form := url.Values{}
form.Set("token", token)
form.Set("client_id", c.config.ClientID)
form.Set("client_secret", c.config.ClientSecret)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.RevocationEndpoint,
strings.NewReader(form.Encode()))
if err != nil {
return fmt.Errorf("failed to create revocation request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to revoke token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("failed to revoke token: status %d", resp.StatusCode)
}
return nil
}
func (c *Client) getDiscovery(ctx context.Context) (discoveryDocument, error) {
c.discoveryMu.Lock()
defer c.discoveryMu.Unlock()
if c.discoveryCached.Issuer != "" && time.Since(c.discoveryFetched) < 10*time.Minute {
return c.discoveryCached, nil
}
url := fmt.Sprintf("%s/.well-known/openid-configuration", c.config.BaseURL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return discoveryDocument{}, fmt.Errorf("failed to create discovery request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return discoveryDocument{}, fmt.Errorf("failed to fetch discovery document: status %d", resp.StatusCode)
}
var doc discoveryDocument
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
return discoveryDocument{}, fmt.Errorf("failed to decode discovery document: %w", err)
}
c.discoveryCached = doc
c.discoveryFetched = time.Now()
return doc, nil
}
func (c *Client) getJWKS(jwksURL string) (keyfunc.Keyfunc, error) {
c.jwksMu.Lock()
defer c.jwksMu.Unlock()
if c.jwks != nil && c.jwksURL == jwksURL {
return c.jwks, nil
}
jwks, err := keyfunc.NewDefaultCtx(context.Background(), []string{jwksURL})
if err != nil {
return nil, fmt.Errorf("failed to create jwks: %w", err)
}
c.jwks = jwks
c.jwksURL = jwksURL
return jwks, nil
}
func getStringClaim(claims jwt.MapClaims, key string) string {
if value, ok := claims[key].(string); ok {
return value
}
return ""
}
func getInt64Claim(claims jwt.MapClaims, key string) int64 {
switch value := claims[key].(type) {
case float64:
return int64(value)
case int64:
return value
case json.Number:
parsed, _ := value.Int64()
return parsed
default:
return 0
}
}
func hasNonEmptyClaim(claims jwt.MapClaims, key string) bool {
value, ok := claims[key]
if !ok || value == nil {
return false
}
switch typed := value.(type) {
case string:
return typed != ""
case []interface{}:
return len(typed) > 0
case []string:
return len(typed) > 0
case map[string]interface{}:
return len(typed) > 0
default:
return true
}
}
func parseStringSlice(value interface{}) []string {
if value == nil {
return nil
}
switch typed := value.(type) {
case []string:
return typed
case []interface{}:
result := make([]string, 0, len(typed))
for _, item := range typed {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
case string:
if typed == "" {
return nil
}
return []string{typed}
default:
return nil
}
}
func parseRoles(value interface{}) []string {
switch typed := value.(type) {
case map[string]interface{}:
roles := make([]string, 0, len(typed))
for role := range typed {
roles = append(roles, role)
}
return roles
default:
return parseStringSlice(value)
}
}
func audienceContains(value interface{}, expected string) bool {
if expected == "" {
return true
}
for _, aud := range parseStringSlice(value) {
if aud == expected {
return true
}
}
return false
}
func containsAny(values []string, targets ...string) bool {
for _, value := range values {
for _, target := range targets {
if value == target {
return true
}
}
}
return false
}