You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
4.6 KiB
Go

// Package zitadel provides Zitadel authentication client for the KnowFoolery application.
package zitadel
import (
"context"
"strings"
"github.com/gofiber/fiber/v3"
)
// ContextKey represents a context key for authentication values.
type ContextKey string
const (
// ContextKeyUserID is the context key for the user ID.
ContextKeyUserID ContextKey = "user_id"
// ContextKeyUserEmail is the context key for the user email.
ContextKeyUserEmail ContextKey = "user_email"
// ContextKeyUserName is the context key for the user name.
ContextKeyUserName ContextKey = "user_name"
// ContextKeyUserRoles is the context key for the user roles.
ContextKeyUserRoles ContextKey = "user_roles"
// ContextKeyMFAVerified is the context key for MFA verification status.
ContextKeyMFAVerified ContextKey = "mfa_verified"
)
// JWTMiddlewareConfig holds configuration for the JWT middleware.
type JWTMiddlewareConfig struct {
Client *Client
Issuer string
Audience string
RequiredClaims []string
AdminEndpoints []string
SkipPaths []string
}
// JWTMiddleware creates a Fiber middleware for JWT validation.
func JWTMiddleware(config JWTMiddlewareConfig) fiber.Handler {
return func(c fiber.Ctx) error {
// Check if path should be skipped
path := c.Path()
for _, skipPath := range config.SkipPaths {
if strings.HasPrefix(path, skipPath) {
return c.Next()
}
}
// Extract token from Authorization header
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": true,
"message": "Authorization header required",
})
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": true,
"message": "Invalid authorization header format",
})
}
// Validate token
ctx := c.Context()
claims, err := config.Client.ValidateToken(ctx, tokenString)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": true,
"message": "Invalid token",
"details": err.Error(),
})
}
// Set user context
c.Locals(string(ContextKeyUserID), claims.Subject)
c.Locals(string(ContextKeyUserEmail), claims.Email)
c.Locals(string(ContextKeyUserName), claims.Name)
c.Locals(string(ContextKeyUserRoles), claims.Roles)
c.Locals(string(ContextKeyMFAVerified), claims.MFAVerified)
// Check admin access for admin endpoints
if isAdminEndpoint(path, config.AdminEndpoints) {
if !hasAdminRole(claims.Roles) {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": true,
"message": "Admin access required",
})
}
if !claims.MFAVerified {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": true,
"message": "MFA verification required for admin access",
})
}
}
return c.Next()
}
}
// isAdminEndpoint checks if the given path is an admin endpoint.
func isAdminEndpoint(path string, adminEndpoints []string) bool {
for _, adminPath := range adminEndpoints {
if strings.HasPrefix(path, adminPath) {
return true
}
}
return false
}
// hasAdminRole checks if the user has the admin role.
func hasAdminRole(roles []string) bool {
for _, role := range roles {
if role == "admin" {
return true
}
}
return false
}
// GetUserID extracts the user ID from the Fiber context.
func GetUserID(c fiber.Ctx) string {
if userID := c.Locals(string(ContextKeyUserID)); userID != nil {
return userID.(string)
}
return ""
}
// GetUserEmail extracts the user email from the Fiber context.
func GetUserEmail(c fiber.Ctx) string {
if email := c.Locals(string(ContextKeyUserEmail)); email != nil {
return email.(string)
}
return ""
}
// GetUserRoles extracts the user roles from the Fiber context.
func GetUserRoles(c fiber.Ctx) []string {
if roles := c.Locals(string(ContextKeyUserRoles)); roles != nil {
return roles.([]string)
}
return nil
}
// IsMFAVerified checks if MFA has been verified for the current user.
func IsMFAVerified(c fiber.Ctx) bool {
if verified := c.Locals(string(ContextKeyMFAVerified)); verified != nil {
return verified.(bool)
}
return false
}
// GetUserFromContext retrieves user information from a standard context.
func GetUserFromContext(ctx context.Context) (userID, email, name string, roles []string, ok bool) {
userID, _ = ctx.Value(ContextKeyUserID).(string)
email, _ = ctx.Value(ContextKeyUserEmail).(string)
name, _ = ctx.Value(ContextKeyUserName).(string)
roles, _ = ctx.Value(ContextKeyUserRoles).([]string)
ok = userID != ""
return
}