You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
6.5 KiB
Python
189 lines
6.5 KiB
Python
from typing import Optional, Dict, Any
|
|
import os
|
|
import requests
|
|
from jose import jwt, JWTError
|
|
from datetime import datetime, timedelta
|
|
import json
|
|
|
|
class OAuthHandler:
|
|
def __init__(self):
|
|
# OAuth Configuration (using Auth0 as example)
|
|
self.auth0_domain = os.getenv("AUTH0_DOMAIN", "")
|
|
self.auth0_client_id = os.getenv("AUTH0_CLIENT_ID", "")
|
|
self.auth0_client_secret = os.getenv("AUTH0_CLIENT_SECRET", "")
|
|
self.auth0_audience = os.getenv("AUTH0_AUDIENCE", "")
|
|
|
|
# JWT Settings
|
|
self.secret_key = os.getenv("JWT_SECRET_KEY", "change-this-in-production")
|
|
self.algorithm = "HS256"
|
|
|
|
# Admin user configuration
|
|
self.admin_emails = set(
|
|
email.strip()
|
|
for email in os.getenv("ADMIN_EMAILS", "").split(",")
|
|
if email.strip()
|
|
)
|
|
|
|
# Cache for JWKS (JSON Web Key Set)
|
|
self.jwks_cache = {}
|
|
self.jwks_cache_expiry = None
|
|
|
|
def verify_auth0_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Verify Auth0 JWT token"""
|
|
try:
|
|
# Get the JWKS from Auth0
|
|
if not self._is_jwks_cache_valid():
|
|
self._refresh_jwks_cache()
|
|
|
|
# Decode the JWT header to get the key ID
|
|
header = jwt.get_unverified_header(token)
|
|
key_id = header.get("kid")
|
|
|
|
if not key_id or key_id not in self.jwks_cache:
|
|
return None
|
|
|
|
# Get the public key
|
|
rsa_key = self.jwks_cache[key_id]
|
|
|
|
# Verify and decode the token
|
|
payload = jwt.decode(
|
|
token,
|
|
rsa_key,
|
|
algorithms=["RS256"],
|
|
audience=self.auth0_audience,
|
|
issuer=f"https://{self.auth0_domain}/"
|
|
)
|
|
|
|
return payload
|
|
|
|
except JWTError as e:
|
|
print(f"JWT verification error: {e}")
|
|
return None
|
|
except Exception as e:
|
|
print(f"Token verification error: {e}")
|
|
return None
|
|
|
|
def _is_jwks_cache_valid(self) -> bool:
|
|
"""Check if JWKS cache is still valid"""
|
|
if not self.jwks_cache or not self.jwks_cache_expiry:
|
|
return False
|
|
return datetime.utcnow() < self.jwks_cache_expiry
|
|
|
|
def _refresh_jwks_cache(self):
|
|
"""Refresh the JWKS cache from Auth0"""
|
|
try:
|
|
jwks_url = f"https://{self.auth0_domain}/.well-known/jwks.json"
|
|
response = requests.get(jwks_url, timeout=10)
|
|
response.raise_for_status()
|
|
|
|
jwks = response.json()
|
|
|
|
# Process keys and store in cache
|
|
self.jwks_cache = {}
|
|
for key in jwks.get("keys", []):
|
|
key_id = key.get("kid")
|
|
if key_id:
|
|
# Convert JWK to RSA key
|
|
rsa_key = {
|
|
"kty": key.get("kty"),
|
|
"kid": key.get("kid"),
|
|
"use": key.get("use"),
|
|
"n": key.get("n"),
|
|
"e": key.get("e")
|
|
}
|
|
self.jwks_cache[key_id] = rsa_key
|
|
|
|
# Set cache expiry (1 hour)
|
|
self.jwks_cache_expiry = datetime.utcnow() + timedelta(hours=1)
|
|
|
|
except Exception as e:
|
|
print(f"Error refreshing JWKS cache: {e}")
|
|
|
|
def create_admin_token(self, user_data: Dict[str, Any]) -> str:
|
|
"""Create a local admin token"""
|
|
payload = {
|
|
"sub": user_data.get("sub") or user_data.get("email"),
|
|
"email": user_data.get("email"),
|
|
"name": user_data.get("name", "Unknown"),
|
|
"role": "admin",
|
|
"permissions": [
|
|
"read:questions",
|
|
"write:questions",
|
|
"delete:questions",
|
|
"read:themes",
|
|
"write:themes",
|
|
"read:analytics",
|
|
"read:players",
|
|
"admin:system"
|
|
],
|
|
"iat": datetime.utcnow(),
|
|
"exp": datetime.utcnow() + timedelta(hours=8) # 8-hour session
|
|
}
|
|
|
|
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
|
|
|
def verify_admin_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Verify admin token"""
|
|
try:
|
|
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
|
|
|
# Check if token is expired
|
|
exp = payload.get("exp")
|
|
if exp and datetime.utcnow().timestamp() > exp:
|
|
return None
|
|
|
|
# Check if user has admin role
|
|
if payload.get("role") != "admin":
|
|
return None
|
|
|
|
return payload
|
|
|
|
except JWTError:
|
|
return None
|
|
|
|
def is_admin_user(self, email: str) -> bool:
|
|
"""Check if email is in admin list"""
|
|
return email.lower() in {admin.lower() for admin in self.admin_emails}
|
|
|
|
def authenticate_admin(self, auth0_token: str) -> Optional[str]:
|
|
"""
|
|
Complete admin authentication flow:
|
|
1. Verify Auth0 token
|
|
2. Check if user is admin
|
|
3. Create admin session token
|
|
"""
|
|
try:
|
|
# Verify Auth0 token
|
|
user_data = self.verify_auth0_token(auth0_token)
|
|
if not user_data:
|
|
return None
|
|
|
|
# Check if user is admin
|
|
user_email = user_data.get("email")
|
|
if not user_email or not self.is_admin_user(user_email):
|
|
raise PermissionError("User is not authorized as admin")
|
|
|
|
# Create admin session token
|
|
admin_token = self.create_admin_token(user_data)
|
|
return admin_token
|
|
|
|
except PermissionError:
|
|
raise
|
|
except Exception as e:
|
|
print(f"Admin authentication error: {e}")
|
|
return None
|
|
|
|
def get_user_permissions(self, token: str) -> list:
|
|
"""Get user permissions from token"""
|
|
payload = self.verify_admin_token(token)
|
|
if payload:
|
|
return payload.get("permissions", [])
|
|
return []
|
|
|
|
def has_permission(self, token: str, permission: str) -> bool:
|
|
"""Check if user has specific permission"""
|
|
permissions = self.get_user_permissions(token)
|
|
return permission in permissions
|
|
|
|
# Global OAuth handler instance
|
|
oauth_handler = OAuthHandler() |