import jwt import httpx from urllib.parse import quote from fastapi import FastAPI, Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.responses import RedirectResponse from jwt import PyJWKClient app = FastAPI() # Note: We use HTTPBearer(auto_error=False) so FastAPI doesn't automatically # crash with a 403/401 when the token is missing, allowing us to redirect instead. security = HTTPBearer(auto_error=False) TENANT_ID = "your-tenant-id" CLIENT_ID = "your-client-id" JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys" REDIRECT_URI = "https://your-app.com/callback" # Must match Entra ID registration jwks_client = PyJWKClient(JWKS_URL) class EntraGroupChecker: def __init__(self, required_group_id: str): self.required_group_id = required_group_id def _get_redirect_to_login(self) -> RedirectResponse: """Constructs the Entra ID authorization URL and returns a redirect.""" scopes = quote("openid profile User.Read") encoded_redirect = quote(REDIRECT_URI) entra_login_url = ( f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/authorize" f"?client_id={CLIENT_ID}" f"&response_type=token" # Or 'code' if using Authorization Code Flow f"&redirect_uri={encoded_redirect}" f"&scope={scopes}" f"&response_mode=fragment" ) return RedirectResponse(url=entra_login_url) async def _fetch_groups_from_graph(self, token: str) -> list[str]: url = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id" headers = {"Authorization": f"Bearer {token}"} async with httpx.AsyncClient() as client: try: response = await client.get(url, headers=headers) if response.status_code == 200: return [item["id"] for item in response.json().get("value", []) if "id" in item] return [] except httpx.RequestError: return [] async def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)): # 1. Handle Missing Token: Trigger redirect if not credentials: return self._get_redirect_to_login() token = credentials.credentials try: signing_key = jwks_client.get_signing_key_from_jwt(token) payload = jwt.decode( token, signing_key.key, algorithms=["RS256"], audience=CLIENT_ID, issuer=f"https://sts.windows.net/{TENANT_ID}/" ) user_groups = payload.get("groups", []) has_overage = "_claim_names" in payload and "groups" in payload["_claim_names"].get("groups", "") if not user_groups and has_overage: user_groups = await self._fetch_groups_from_graph(token) # If token is valid but they lack permissions, we still raise a 403 Forbidden # (Don't redirect here, or they will get stuck in an infinite login loop) if self.required_group_id not in user_groups: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User is not a member of the required AD group." ) payload["resolved_groups"] = user_groups return payload # 2. Handle Expired or Malformed Tokens: Trigger redirect except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): return self._get_redirect_to_login() # Usage ADMIN_GROUP_ID = "b82a3f...-....-..." @app.get("/secure-dashboard") async def get_secure_dashboard(user_or_redirect=Depends(EntraGroupChecker(ADMIN_GROUP_ID))): # If the dependency returned a RedirectResponse, pass it back to the browser if isinstance(user_or_redirect, RedirectResponse): return user_or_redirect return { "message": "Access granted to dashboard", "user_claims": user_or_redirect }