import jwt from fastapi import FastAPI, Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from jwt import PyJWKClient app = FastAPI() security = HTTPBearer() # Entra ID Configuration TENANT_ID = "your-tenant-id" CLIENT_ID = "your-client-id" JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys" jwks_client = PyJWKClient(JWKS_URL) class EntraGroupChecker: def __init__(self, required_group_id: str): """ required_group_id: The Object ID of the group from Azure Portal """ self.required_group_id = required_group_id def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)): token = credentials.credentials try: # Fetch the signing key and validate the JWT signing_key = jwks_client.get_signing_key_from_jwt(token) payload = jwt.decode( token, signing_key.key, algorithms=["RS256"], audience=CLIENT_ID, issuer=f"https://sts.windows.net/{TENANT_ID}/" ) # Extract groups from the token user_groups = payload.get("groups", []) if self.required_group_id not in user_groups: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User is not a member of the required AD group." ) return payload # Return the full user context except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError as e: raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}") # Example Usage # Replace with your actual Azure Group Object ID ADMIN_GROUP_ID = "b82a3f...-....-..." @app.get("/secure-data") def get_secure_data(user=Depends(EntraGroupChecker(ADMIN_GROUP_ID))): return { "message": "Access granted via Entra ID Group membership", "user_id": user.get("oid") }