From c71d929b45b0999e43ceacab0e7316ba101e0d98 Mon Sep 17 00:00:00 2001 From: Paul Atkin Date: Sun, 7 Jun 2026 03:35:28 +0000 Subject: [PATCH] Update app/api/auth.py --- app/api/auth.py | 81 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 20 deletions(-) diff --git a/app/api/auth.py b/app/api/auth.py index 2b71853..08f1d38 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -1,29 +1,62 @@ import jwt +import httpx +from urllib.parse import quote from fastapi import FastAPI, Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.responses import RedirectResponse from jwt import PyJWKClient app = FastAPI() -security = HTTPBearer() -# Entra ID Configuration +# Note: We use HTTPBearer(auto_error=False) so FastAPI doesn't automatically +# crash with a 403/401 when the token is missing, allowing us to redirect instead. +security = HTTPBearer(auto_error=False) + TENANT_ID = "your-tenant-id" CLIENT_ID = "your-client-id" JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys" +REDIRECT_URI = "https://your-app.com/callback" # Must match Entra ID registration jwks_client = PyJWKClient(JWKS_URL) class EntraGroupChecker: def __init__(self, required_group_id: str): - """ - required_group_id: The Object ID of the group from Azure Portal - """ self.required_group_id = required_group_id - def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)): + def _get_redirect_to_login(self) -> RedirectResponse: + """Constructs the Entra ID authorization URL and returns a redirect.""" + scopes = quote("openid profile User.Read") + encoded_redirect = quote(REDIRECT_URI) + + entra_login_url = ( + f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/authorize" + f"?client_id={CLIENT_ID}" + f"&response_type=token" # Or 'code' if using Authorization Code Flow + f"&redirect_uri={encoded_redirect}" + f"&scope={scopes}" + f"&response_mode=fragment" + ) + return RedirectResponse(url=entra_login_url) + + async def _fetch_groups_from_graph(self, token: str) -> list[str]: + url = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id" + headers = {"Authorization": f"Bearer {token}"} + async with httpx.AsyncClient() as client: + try: + response = await client.get(url, headers=headers) + if response.status_code == 200: + return [item["id"] for item in response.json().get("value", []) if "id" in item] + return [] + except httpx.RequestError: + return [] + + async def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)): + # 1. Handle Missing Token: Trigger redirect + if not credentials: + return self._get_redirect_to_login() + token = credentials.credentials try: - # Fetch the signing key and validate the JWT signing_key = jwks_client.get_signing_key_from_jwt(token) payload = jwt.decode( token, @@ -33,29 +66,37 @@ class EntraGroupChecker: issuer=f"https://sts.windows.net/{TENANT_ID}/" ) - # Extract groups from the token user_groups = payload.get("groups", []) + has_overage = "_claim_names" in payload and "groups" in payload["_claim_names"].get("groups", "") + if not user_groups and has_overage: + user_groups = await self._fetch_groups_from_graph(token) + + # If token is valid but they lack permissions, we still raise a 403 Forbidden + # (Don't redirect here, or they will get stuck in an infinite login loop) if self.required_group_id not in user_groups: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User is not a member of the required AD group." ) - return payload # Return the full user context + payload["resolved_groups"] = user_groups + return payload - except jwt.ExpiredSignatureError: - raise HTTPException(status_code=401, detail="Token expired") - except jwt.InvalidTokenError as e: - raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}") + # 2. Handle Expired or Malformed Tokens: Trigger redirect + except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): + return self._get_redirect_to_login() -# Example Usage -# Replace with your actual Azure Group Object ID +# Usage ADMIN_GROUP_ID = "b82a3f...-....-..." -@app.get("/secure-data") -def get_secure_data(user=Depends(EntraGroupChecker(ADMIN_GROUP_ID))): +@app.get("/secure-dashboard") +async def get_secure_dashboard(user_or_redirect=Depends(EntraGroupChecker(ADMIN_GROUP_ID))): + # If the dependency returned a RedirectResponse, pass it back to the browser + if isinstance(user_or_redirect, RedirectResponse): + return user_or_redirect + return { - "message": "Access granted via Entra ID Group membership", - "user_id": user.get("oid") - } + "message": "Access granted to dashboard", + "user_claims": user_or_redirect + } \ No newline at end of file