You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

296 lines
6.5 KiB
Go

package middleware
import (
"fmt"
"strings"
"time"
"github.com/gofiber/fiber/v3"
"github.com/golang-jwt/jwt/v5"
"github.com/knowfoolery/backend/services/gateway-service/config"
)
type AuthMiddleware struct {
config config.AuthConfig
signingKey []byte
adminRoles map[string]bool
requiredScopes map[string]bool
}
type Claims struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Roles []string `json:"roles"`
Scopes []string `json:"scopes"`
jwt.RegisteredClaims
}
type UserContext struct {
UserID string
Email string
Roles []string
Scopes []string
Token string
}
func NewAuthMiddleware(cfg config.AuthConfig) *AuthMiddleware {
adminRoles := make(map[string]bool)
for _, role := range cfg.AdminRoles {
adminRoles[role] = true
}
requiredScopes := make(map[string]bool)
for _, scope := range cfg.RequiredScopes {
requiredScopes[scope] = true
}
return &AuthMiddleware{
config: cfg,
signingKey: []byte(cfg.JWTSigningKey),
adminRoles: adminRoles,
requiredScopes: requiredScopes,
}
}
func (am *AuthMiddleware) Handler() fiber.Handler {
return func(c fiber.Ctx) error {
path := c.Path()
if am.isPublicPath(path) {
return c.Next()
}
token := am.extractToken(c)
if token == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Missing or invalid authorization token",
})
}
claims, err := am.validateToken(token)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": fmt.Sprintf("Invalid token: %v", err),
})
}
if err := am.validateScopes(claims.Scopes); err != nil {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": fmt.Sprintf("Insufficient scopes: %v", err),
})
}
userCtx := &UserContext{
UserID: claims.UserID,
Email: claims.Email,
Roles: claims.Roles,
Scopes: claims.Scopes,
Token: token,
}
c.Locals("user", userCtx)
c.Set("X-User-ID", userCtx.UserID)
c.Set("X-User-Email", userCtx.Email)
c.Set("X-User-Roles", strings.Join(userCtx.Roles, ","))
return c.Next()
}
}
func (am *AuthMiddleware) RequireAdminRole() fiber.Handler {
return func(c fiber.Ctx) error {
userCtx := am.getUserContext(c)
if userCtx == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Authentication required",
})
}
if !am.hasAdminRole(userCtx.Roles) {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Admin role required",
})
}
return c.Next()
}
}
func (am *AuthMiddleware) RequireRole(requiredRole string) fiber.Handler {
return func(c fiber.Ctx) error {
userCtx := am.getUserContext(c)
if userCtx == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Authentication required",
})
}
if !am.hasRole(userCtx.Roles, requiredRole) {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": fmt.Sprintf("Role '%s' required", requiredRole),
})
}
return c.Next()
}
}
func (am *AuthMiddleware) RequireScopes(requiredScopes ...string) fiber.Handler {
return func(c fiber.Ctx) error {
userCtx := am.getUserContext(c)
if userCtx == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Authentication required",
})
}
if !am.hasScopes(userCtx.Scopes, requiredScopes) {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": fmt.Sprintf("Required scopes: %s", strings.Join(requiredScopes, ", ")),
})
}
return c.Next()
}
}
func (am *AuthMiddleware) extractToken(c fiber.Ctx) string {
authHeader := c.Get("Authorization")
if authHeader == "" {
return ""
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
return ""
}
return parts[1]
}
func (am *AuthMiddleware) validateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return am.signingKey, nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
if time.Now().After(claims.ExpiresAt.Time) {
return nil, fmt.Errorf("token has expired")
}
if am.shouldRefreshToken(claims) {
// TODO: Implement token refresh logic
}
return claims, nil
}
return nil, fmt.Errorf("invalid token")
}
func (am *AuthMiddleware) validateScopes(userScopes []string) error {
userScopeMap := make(map[string]bool)
for _, scope := range userScopes {
userScopeMap[scope] = true
}
for requiredScope := range am.requiredScopes {
if !userScopeMap[requiredScope] {
return fmt.Errorf("missing required scope: %s", requiredScope)
}
}
return nil
}
func (am *AuthMiddleware) shouldRefreshToken(claims *Claims) bool {
return time.Until(claims.ExpiresAt.Time) < am.config.RefreshThreshold
}
func (am *AuthMiddleware) hasAdminRole(userRoles []string) bool {
for _, role := range userRoles {
if am.adminRoles[role] {
return true
}
}
return false
}
func (am *AuthMiddleware) hasRole(userRoles []string, requiredRole string) bool {
for _, role := range userRoles {
if role == requiredRole {
return true
}
}
return false
}
func (am *AuthMiddleware) hasScopes(userScopes []string, requiredScopes []string) bool {
userScopeMap := make(map[string]bool)
for _, scope := range userScopes {
userScopeMap[scope] = true
}
for _, requiredScope := range requiredScopes {
if !userScopeMap[requiredScope] {
return false
}
}
return true
}
func (am *AuthMiddleware) getUserContext(c fiber.Ctx) *UserContext {
if user := c.Locals("user"); user != nil {
if userCtx, ok := user.(*UserContext); ok {
return userCtx
}
}
return nil
}
func (am *AuthMiddleware) isPublicPath(path string) bool {
publicPaths := []string{
"/health",
"/metrics",
"/api/v1/auth/login",
"/api/v1/auth/register",
"/api/v1/auth/refresh",
}
for _, publicPath := range publicPaths {
if strings.HasPrefix(path, publicPath) {
return true
}
}
return false
}
func GetUserFromContext(c fiber.Ctx) *UserContext {
if user := c.Locals("user"); user != nil {
if userCtx, ok := user.(*UserContext); ok {
return userCtx
}
}
return nil
}
func RequireAuthentication(c fiber.Ctx) error {
userCtx := GetUserFromContext(c)
if userCtx == nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Authentication required",
})
}
return nil
}