package middleware import ( "fmt" "strings" "time" "github.com/gofiber/fiber/v3" "github.com/golang-jwt/jwt/v5" "github.com/knowfoolery/backend/services/gateway-service/config" ) type AuthMiddleware struct { config config.AuthConfig signingKey []byte adminRoles map[string]bool requiredScopes map[string]bool } type Claims struct { UserID string `json:"user_id"` Email string `json:"email"` Roles []string `json:"roles"` Scopes []string `json:"scopes"` jwt.RegisteredClaims } type UserContext struct { UserID string Email string Roles []string Scopes []string Token string } func NewAuthMiddleware(cfg config.AuthConfig) *AuthMiddleware { adminRoles := make(map[string]bool) for _, role := range cfg.AdminRoles { adminRoles[role] = true } requiredScopes := make(map[string]bool) for _, scope := range cfg.RequiredScopes { requiredScopes[scope] = true } return &AuthMiddleware{ config: cfg, signingKey: []byte(cfg.JWTSigningKey), adminRoles: adminRoles, requiredScopes: requiredScopes, } } func (am *AuthMiddleware) Handler() fiber.Handler { return func(c fiber.Ctx) error { path := c.Path() if am.isPublicPath(path) { return c.Next() } token := am.extractToken(c) if token == "" { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "error": "Missing or invalid authorization token", }) } claims, err := am.validateToken(token) if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "error": fmt.Sprintf("Invalid token: %v", err), }) } if err := am.validateScopes(claims.Scopes); err != nil { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": fmt.Sprintf("Insufficient scopes: %v", err), }) } userCtx := &UserContext{ UserID: claims.UserID, Email: claims.Email, Roles: claims.Roles, Scopes: claims.Scopes, Token: token, } c.Locals("user", userCtx) c.Set("X-User-ID", userCtx.UserID) c.Set("X-User-Email", userCtx.Email) c.Set("X-User-Roles", strings.Join(userCtx.Roles, ",")) return c.Next() } } func (am *AuthMiddleware) RequireAdminRole() fiber.Handler { return func(c fiber.Ctx) error { userCtx := am.getUserContext(c) if userCtx == nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "error": "Authentication required", }) } if !am.hasAdminRole(userCtx.Roles) { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "Admin role required", }) } return c.Next() } } func (am *AuthMiddleware) RequireRole(requiredRole string) fiber.Handler { return func(c fiber.Ctx) error { userCtx := am.getUserContext(c) if userCtx == nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "error": "Authentication required", }) } if !am.hasRole(userCtx.Roles, requiredRole) { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": fmt.Sprintf("Role '%s' required", requiredRole), }) } return c.Next() } } func (am *AuthMiddleware) RequireScopes(requiredScopes ...string) fiber.Handler { return func(c fiber.Ctx) error { userCtx := am.getUserContext(c) if userCtx == nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "error": "Authentication required", }) } if !am.hasScopes(userCtx.Scopes, requiredScopes) { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": fmt.Sprintf("Required scopes: %s", strings.Join(requiredScopes, ", ")), }) } return c.Next() } } func (am *AuthMiddleware) extractToken(c fiber.Ctx) string { authHeader := c.Get("Authorization") if authHeader == "" { return "" } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { return "" } return parts[1] } func (am *AuthMiddleware) validateToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return am.signingKey, nil }) if err != nil { return nil, err } if claims, ok := token.Claims.(*Claims); ok && token.Valid { if time.Now().After(claims.ExpiresAt.Time) { return nil, fmt.Errorf("token has expired") } if am.shouldRefreshToken(claims) { // TODO: Implement token refresh logic } return claims, nil } return nil, fmt.Errorf("invalid token") } func (am *AuthMiddleware) validateScopes(userScopes []string) error { userScopeMap := make(map[string]bool) for _, scope := range userScopes { userScopeMap[scope] = true } for requiredScope := range am.requiredScopes { if !userScopeMap[requiredScope] { return fmt.Errorf("missing required scope: %s", requiredScope) } } return nil } func (am *AuthMiddleware) shouldRefreshToken(claims *Claims) bool { return time.Until(claims.ExpiresAt.Time) < am.config.RefreshThreshold } func (am *AuthMiddleware) hasAdminRole(userRoles []string) bool { for _, role := range userRoles { if am.adminRoles[role] { return true } } return false } func (am *AuthMiddleware) hasRole(userRoles []string, requiredRole string) bool { for _, role := range userRoles { if role == requiredRole { return true } } return false } func (am *AuthMiddleware) hasScopes(userScopes []string, requiredScopes []string) bool { userScopeMap := make(map[string]bool) for _, scope := range userScopes { userScopeMap[scope] = true } for _, requiredScope := range requiredScopes { if !userScopeMap[requiredScope] { return false } } return true } func (am *AuthMiddleware) getUserContext(c fiber.Ctx) *UserContext { if user := c.Locals("user"); user != nil { if userCtx, ok := user.(*UserContext); ok { return userCtx } } return nil } func (am *AuthMiddleware) isPublicPath(path string) bool { publicPaths := []string{ "/health", "/metrics", "/api/v1/auth/login", "/api/v1/auth/register", "/api/v1/auth/refresh", } for _, publicPath := range publicPaths { if strings.HasPrefix(path, publicPath) { return true } } return false } func GetUserFromContext(c fiber.Ctx) *UserContext { if user := c.Locals("user"); user != nil { if userCtx, ok := user.(*UserContext); ok { return userCtx } } return nil } func RequireAuthentication(c fiber.Ctx) error { userCtx := GetUserFromContext(c) if userCtx == nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "error": "Authentication required", }) } return nil }