You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
296 lines
6.5 KiB
Go
296 lines
6.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
"github.com/knowfoolery/backend/services/gateway-service/config"
|
|
)
|
|
|
|
type AuthMiddleware struct {
|
|
config config.AuthConfig
|
|
signingKey []byte
|
|
adminRoles map[string]bool
|
|
requiredScopes map[string]bool
|
|
}
|
|
|
|
type Claims struct {
|
|
UserID string `json:"user_id"`
|
|
Email string `json:"email"`
|
|
Roles []string `json:"roles"`
|
|
Scopes []string `json:"scopes"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
type UserContext struct {
|
|
UserID string
|
|
Email string
|
|
Roles []string
|
|
Scopes []string
|
|
Token string
|
|
}
|
|
|
|
func NewAuthMiddleware(cfg config.AuthConfig) *AuthMiddleware {
|
|
adminRoles := make(map[string]bool)
|
|
for _, role := range cfg.AdminRoles {
|
|
adminRoles[role] = true
|
|
}
|
|
|
|
requiredScopes := make(map[string]bool)
|
|
for _, scope := range cfg.RequiredScopes {
|
|
requiredScopes[scope] = true
|
|
}
|
|
|
|
return &AuthMiddleware{
|
|
config: cfg,
|
|
signingKey: []byte(cfg.JWTSigningKey),
|
|
adminRoles: adminRoles,
|
|
requiredScopes: requiredScopes,
|
|
}
|
|
}
|
|
|
|
func (am *AuthMiddleware) Handler() fiber.Handler {
|
|
return func(c fiber.Ctx) error {
|
|
path := c.Path()
|
|
|
|
if am.isPublicPath(path) {
|
|
return c.Next()
|
|
}
|
|
|
|
token := am.extractToken(c)
|
|
if token == "" {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
|
"error": "Missing or invalid authorization token",
|
|
})
|
|
}
|
|
|
|
claims, err := am.validateToken(token)
|
|
if err != nil {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
|
"error": fmt.Sprintf("Invalid token: %v", err),
|
|
})
|
|
}
|
|
|
|
if err := am.validateScopes(claims.Scopes); err != nil {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": fmt.Sprintf("Insufficient scopes: %v", err),
|
|
})
|
|
}
|
|
|
|
userCtx := &UserContext{
|
|
UserID: claims.UserID,
|
|
Email: claims.Email,
|
|
Roles: claims.Roles,
|
|
Scopes: claims.Scopes,
|
|
Token: token,
|
|
}
|
|
|
|
c.Locals("user", userCtx)
|
|
|
|
c.Set("X-User-ID", userCtx.UserID)
|
|
c.Set("X-User-Email", userCtx.Email)
|
|
c.Set("X-User-Roles", strings.Join(userCtx.Roles, ","))
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func (am *AuthMiddleware) RequireAdminRole() fiber.Handler {
|
|
return func(c fiber.Ctx) error {
|
|
userCtx := am.getUserContext(c)
|
|
if userCtx == nil {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
|
"error": "Authentication required",
|
|
})
|
|
}
|
|
|
|
if !am.hasAdminRole(userCtx.Roles) {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": "Admin role required",
|
|
})
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func (am *AuthMiddleware) RequireRole(requiredRole string) fiber.Handler {
|
|
return func(c fiber.Ctx) error {
|
|
userCtx := am.getUserContext(c)
|
|
if userCtx == nil {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
|
"error": "Authentication required",
|
|
})
|
|
}
|
|
|
|
if !am.hasRole(userCtx.Roles, requiredRole) {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": fmt.Sprintf("Role '%s' required", requiredRole),
|
|
})
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func (am *AuthMiddleware) RequireScopes(requiredScopes ...string) fiber.Handler {
|
|
return func(c fiber.Ctx) error {
|
|
userCtx := am.getUserContext(c)
|
|
if userCtx == nil {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
|
"error": "Authentication required",
|
|
})
|
|
}
|
|
|
|
if !am.hasScopes(userCtx.Scopes, requiredScopes) {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": fmt.Sprintf("Required scopes: %s", strings.Join(requiredScopes, ", ")),
|
|
})
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func (am *AuthMiddleware) extractToken(c fiber.Ctx) string {
|
|
authHeader := c.Get("Authorization")
|
|
if authHeader == "" {
|
|
return ""
|
|
}
|
|
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
return ""
|
|
}
|
|
|
|
return parts[1]
|
|
}
|
|
|
|
func (am *AuthMiddleware) validateToken(tokenString string) (*Claims, error) {
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return am.signingKey, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
|
if time.Now().After(claims.ExpiresAt.Time) {
|
|
return nil, fmt.Errorf("token has expired")
|
|
}
|
|
|
|
if am.shouldRefreshToken(claims) {
|
|
// TODO: Implement token refresh logic
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("invalid token")
|
|
}
|
|
|
|
func (am *AuthMiddleware) validateScopes(userScopes []string) error {
|
|
userScopeMap := make(map[string]bool)
|
|
for _, scope := range userScopes {
|
|
userScopeMap[scope] = true
|
|
}
|
|
|
|
for requiredScope := range am.requiredScopes {
|
|
if !userScopeMap[requiredScope] {
|
|
return fmt.Errorf("missing required scope: %s", requiredScope)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (am *AuthMiddleware) shouldRefreshToken(claims *Claims) bool {
|
|
return time.Until(claims.ExpiresAt.Time) < am.config.RefreshThreshold
|
|
}
|
|
|
|
func (am *AuthMiddleware) hasAdminRole(userRoles []string) bool {
|
|
for _, role := range userRoles {
|
|
if am.adminRoles[role] {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (am *AuthMiddleware) hasRole(userRoles []string, requiredRole string) bool {
|
|
for _, role := range userRoles {
|
|
if role == requiredRole {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (am *AuthMiddleware) hasScopes(userScopes []string, requiredScopes []string) bool {
|
|
userScopeMap := make(map[string]bool)
|
|
for _, scope := range userScopes {
|
|
userScopeMap[scope] = true
|
|
}
|
|
|
|
for _, requiredScope := range requiredScopes {
|
|
if !userScopeMap[requiredScope] {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (am *AuthMiddleware) getUserContext(c fiber.Ctx) *UserContext {
|
|
if user := c.Locals("user"); user != nil {
|
|
if userCtx, ok := user.(*UserContext); ok {
|
|
return userCtx
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (am *AuthMiddleware) isPublicPath(path string) bool {
|
|
publicPaths := []string{
|
|
"/health",
|
|
"/metrics",
|
|
"/api/v1/auth/login",
|
|
"/api/v1/auth/register",
|
|
"/api/v1/auth/refresh",
|
|
}
|
|
|
|
for _, publicPath := range publicPaths {
|
|
if strings.HasPrefix(path, publicPath) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func GetUserFromContext(c fiber.Ctx) *UserContext {
|
|
if user := c.Locals("user"); user != nil {
|
|
if userCtx, ok := user.(*UserContext); ok {
|
|
return userCtx
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func RequireAuthentication(c fiber.Ctx) error {
|
|
userCtx := GetUserFromContext(c)
|
|
if userCtx == nil {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
|
"error": "Authentication required",
|
|
})
|
|
}
|
|
return nil
|
|
} |