from typing import Optional, Dict, Any import os import requests from jose import jwt, JWTError from datetime import datetime, timedelta import json class OAuthHandler: def __init__(self): # OAuth Configuration (using Auth0 as example) self.auth0_domain = os.getenv("AUTH0_DOMAIN", "") self.auth0_client_id = os.getenv("AUTH0_CLIENT_ID", "") self.auth0_client_secret = os.getenv("AUTH0_CLIENT_SECRET", "") self.auth0_audience = os.getenv("AUTH0_AUDIENCE", "") # JWT Settings self.secret_key = os.getenv("JWT_SECRET_KEY", "change-this-in-production") self.algorithm = "HS256" # Admin user configuration self.admin_emails = set( email.strip() for email in os.getenv("ADMIN_EMAILS", "").split(",") if email.strip() ) # Cache for JWKS (JSON Web Key Set) self.jwks_cache = {} self.jwks_cache_expiry = None def verify_auth0_token(self, token: str) -> Optional[Dict[str, Any]]: """Verify Auth0 JWT token""" try: # Get the JWKS from Auth0 if not self._is_jwks_cache_valid(): self._refresh_jwks_cache() # Decode the JWT header to get the key ID header = jwt.get_unverified_header(token) key_id = header.get("kid") if not key_id or key_id not in self.jwks_cache: return None # Get the public key rsa_key = self.jwks_cache[key_id] # Verify and decode the token payload = jwt.decode( token, rsa_key, algorithms=["RS256"], audience=self.auth0_audience, issuer=f"https://{self.auth0_domain}/" ) return payload except JWTError as e: print(f"JWT verification error: {e}") return None except Exception as e: print(f"Token verification error: {e}") return None def _is_jwks_cache_valid(self) -> bool: """Check if JWKS cache is still valid""" if not self.jwks_cache or not self.jwks_cache_expiry: return False return datetime.utcnow() < self.jwks_cache_expiry def _refresh_jwks_cache(self): """Refresh the JWKS cache from Auth0""" try: jwks_url = f"https://{self.auth0_domain}/.well-known/jwks.json" response = requests.get(jwks_url, timeout=10) response.raise_for_status() jwks = response.json() # Process keys and store in cache self.jwks_cache = {} for key in jwks.get("keys", []): key_id = key.get("kid") if key_id: # Convert JWK to RSA key rsa_key = { "kty": key.get("kty"), "kid": key.get("kid"), "use": key.get("use"), "n": key.get("n"), "e": key.get("e") } self.jwks_cache[key_id] = rsa_key # Set cache expiry (1 hour) self.jwks_cache_expiry = datetime.utcnow() + timedelta(hours=1) except Exception as e: print(f"Error refreshing JWKS cache: {e}") def create_admin_token(self, user_data: Dict[str, Any]) -> str: """Create a local admin token""" payload = { "sub": user_data.get("sub") or user_data.get("email"), "email": user_data.get("email"), "name": user_data.get("name", "Unknown"), "role": "admin", "permissions": [ "read:questions", "write:questions", "delete:questions", "read:themes", "write:themes", "read:analytics", "read:players", "admin:system" ], "iat": datetime.utcnow(), "exp": datetime.utcnow() + timedelta(hours=8) # 8-hour session } return jwt.encode(payload, self.secret_key, algorithm=self.algorithm) def verify_admin_token(self, token: str) -> Optional[Dict[str, Any]]: """Verify admin token""" try: payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) # Check if token is expired exp = payload.get("exp") if exp and datetime.utcnow().timestamp() > exp: return None # Check if user has admin role if payload.get("role") != "admin": return None return payload except JWTError: return None def is_admin_user(self, email: str) -> bool: """Check if email is in admin list""" return email.lower() in {admin.lower() for admin in self.admin_emails} def authenticate_admin(self, auth0_token: str) -> Optional[str]: """ Complete admin authentication flow: 1. Verify Auth0 token 2. Check if user is admin 3. Create admin session token """ try: # Verify Auth0 token user_data = self.verify_auth0_token(auth0_token) if not user_data: return None # Check if user is admin user_email = user_data.get("email") if not user_email or not self.is_admin_user(user_email): raise PermissionError("User is not authorized as admin") # Create admin session token admin_token = self.create_admin_token(user_data) return admin_token except PermissionError: raise except Exception as e: print(f"Admin authentication error: {e}") return None def get_user_permissions(self, token: str) -> list: """Get user permissions from token""" payload = self.verify_admin_token(token) if payload: return payload.get("permissions", []) return [] def has_permission(self, token: str, permission: str) -> bool: """Check if user has specific permission""" permissions = self.get_user_permissions(token) return permission in permissions # Global OAuth handler instance oauth_handler = OAuthHandler()