You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
6.5 KiB
Python

from typing import Optional, Dict, Any
import os
import requests
from jose import jwt, JWTError
from datetime import datetime, timedelta
import json
class OAuthHandler:
def __init__(self):
# OAuth Configuration (using Auth0 as example)
self.auth0_domain = os.getenv("AUTH0_DOMAIN", "")
self.auth0_client_id = os.getenv("AUTH0_CLIENT_ID", "")
self.auth0_client_secret = os.getenv("AUTH0_CLIENT_SECRET", "")
self.auth0_audience = os.getenv("AUTH0_AUDIENCE", "")
# JWT Settings
self.secret_key = os.getenv("JWT_SECRET_KEY", "change-this-in-production")
self.algorithm = "HS256"
# Admin user configuration
self.admin_emails = set(
email.strip()
for email in os.getenv("ADMIN_EMAILS", "").split(",")
if email.strip()
)
# Cache for JWKS (JSON Web Key Set)
self.jwks_cache = {}
self.jwks_cache_expiry = None
def verify_auth0_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Verify Auth0 JWT token"""
try:
# Get the JWKS from Auth0
if not self._is_jwks_cache_valid():
self._refresh_jwks_cache()
# Decode the JWT header to get the key ID
header = jwt.get_unverified_header(token)
key_id = header.get("kid")
if not key_id or key_id not in self.jwks_cache:
return None
# Get the public key
rsa_key = self.jwks_cache[key_id]
# Verify and decode the token
payload = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=self.auth0_audience,
issuer=f"https://{self.auth0_domain}/"
)
return payload
except JWTError as e:
print(f"JWT verification error: {e}")
return None
except Exception as e:
print(f"Token verification error: {e}")
return None
def _is_jwks_cache_valid(self) -> bool:
"""Check if JWKS cache is still valid"""
if not self.jwks_cache or not self.jwks_cache_expiry:
return False
return datetime.utcnow() < self.jwks_cache_expiry
def _refresh_jwks_cache(self):
"""Refresh the JWKS cache from Auth0"""
try:
jwks_url = f"https://{self.auth0_domain}/.well-known/jwks.json"
response = requests.get(jwks_url, timeout=10)
response.raise_for_status()
jwks = response.json()
# Process keys and store in cache
self.jwks_cache = {}
for key in jwks.get("keys", []):
key_id = key.get("kid")
if key_id:
# Convert JWK to RSA key
rsa_key = {
"kty": key.get("kty"),
"kid": key.get("kid"),
"use": key.get("use"),
"n": key.get("n"),
"e": key.get("e")
}
self.jwks_cache[key_id] = rsa_key
# Set cache expiry (1 hour)
self.jwks_cache_expiry = datetime.utcnow() + timedelta(hours=1)
except Exception as e:
print(f"Error refreshing JWKS cache: {e}")
def create_admin_token(self, user_data: Dict[str, Any]) -> str:
"""Create a local admin token"""
payload = {
"sub": user_data.get("sub") or user_data.get("email"),
"email": user_data.get("email"),
"name": user_data.get("name", "Unknown"),
"role": "admin",
"permissions": [
"read:questions",
"write:questions",
"delete:questions",
"read:themes",
"write:themes",
"read:analytics",
"read:players",
"admin:system"
],
"iat": datetime.utcnow(),
"exp": datetime.utcnow() + timedelta(hours=8) # 8-hour session
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def verify_admin_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Verify admin token"""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
# Check if token is expired
exp = payload.get("exp")
if exp and datetime.utcnow().timestamp() > exp:
return None
# Check if user has admin role
if payload.get("role") != "admin":
return None
return payload
except JWTError:
return None
def is_admin_user(self, email: str) -> bool:
"""Check if email is in admin list"""
return email.lower() in {admin.lower() for admin in self.admin_emails}
def authenticate_admin(self, auth0_token: str) -> Optional[str]:
"""
Complete admin authentication flow:
1. Verify Auth0 token
2. Check if user is admin
3. Create admin session token
"""
try:
# Verify Auth0 token
user_data = self.verify_auth0_token(auth0_token)
if not user_data:
return None
# Check if user is admin
user_email = user_data.get("email")
if not user_email or not self.is_admin_user(user_email):
raise PermissionError("User is not authorized as admin")
# Create admin session token
admin_token = self.create_admin_token(user_data)
return admin_token
except PermissionError:
raise
except Exception as e:
print(f"Admin authentication error: {e}")
return None
def get_user_permissions(self, token: str) -> list:
"""Get user permissions from token"""
payload = self.verify_admin_token(token)
if payload:
return payload.get("permissions", [])
return []
def has_permission(self, token: str, permission: str) -> bool:
"""Check if user has specific permission"""
permissions = self.get_user_permissions(token)
return permission in permissions
# Global OAuth handler instance
oauth_handler = OAuthHandler()