Update app/api/auth.py

This commit is contained in:
2026-06-07 03:35:28 +00:00
parent b9414f074b
commit c71d929b45
+61 -20
View File
@@ -1,29 +1,62 @@
import jwt import jwt
import httpx
from urllib.parse import quote
from fastapi import FastAPI, Depends, HTTPException, status from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import RedirectResponse
from jwt import PyJWKClient from jwt import PyJWKClient
app = FastAPI() app = FastAPI()
security = HTTPBearer()
# Entra ID Configuration # Note: We use HTTPBearer(auto_error=False) so FastAPI doesn't automatically
# crash with a 403/401 when the token is missing, allowing us to redirect instead.
security = HTTPBearer(auto_error=False)
TENANT_ID = "your-tenant-id" TENANT_ID = "your-tenant-id"
CLIENT_ID = "your-client-id" CLIENT_ID = "your-client-id"
JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys" JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys"
REDIRECT_URI = "https://your-app.com/callback" # Must match Entra ID registration
jwks_client = PyJWKClient(JWKS_URL) jwks_client = PyJWKClient(JWKS_URL)
class EntraGroupChecker: class EntraGroupChecker:
def __init__(self, required_group_id: str): def __init__(self, required_group_id: str):
"""
required_group_id: The Object ID of the group from Azure Portal
"""
self.required_group_id = required_group_id self.required_group_id = required_group_id
def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)): def _get_redirect_to_login(self) -> RedirectResponse:
"""Constructs the Entra ID authorization URL and returns a redirect."""
scopes = quote("openid profile User.Read")
encoded_redirect = quote(REDIRECT_URI)
entra_login_url = (
f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/authorize"
f"?client_id={CLIENT_ID}"
f"&response_type=token" # Or 'code' if using Authorization Code Flow
f"&redirect_uri={encoded_redirect}"
f"&scope={scopes}"
f"&response_mode=fragment"
)
return RedirectResponse(url=entra_login_url)
async def _fetch_groups_from_graph(self, token: str) -> list[str]:
url = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id"
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code == 200:
return [item["id"] for item in response.json().get("value", []) if "id" in item]
return []
except httpx.RequestError:
return []
async def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)):
# 1. Handle Missing Token: Trigger redirect
if not credentials:
return self._get_redirect_to_login()
token = credentials.credentials token = credentials.credentials
try: try:
# Fetch the signing key and validate the JWT
signing_key = jwks_client.get_signing_key_from_jwt(token) signing_key = jwks_client.get_signing_key_from_jwt(token)
payload = jwt.decode( payload = jwt.decode(
token, token,
@@ -33,29 +66,37 @@ class EntraGroupChecker:
issuer=f"https://sts.windows.net/{TENANT_ID}/" issuer=f"https://sts.windows.net/{TENANT_ID}/"
) )
# Extract groups from the token
user_groups = payload.get("groups", []) user_groups = payload.get("groups", [])
has_overage = "_claim_names" in payload and "groups" in payload["_claim_names"].get("groups", "")
if not user_groups and has_overage:
user_groups = await self._fetch_groups_from_graph(token)
# If token is valid but they lack permissions, we still raise a 403 Forbidden
# (Don't redirect here, or they will get stuck in an infinite login loop)
if self.required_group_id not in user_groups: if self.required_group_id not in user_groups:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="User is not a member of the required AD group." detail="User is not a member of the required AD group."
) )
return payload # Return the full user context payload["resolved_groups"] = user_groups
return payload
except jwt.ExpiredSignatureError: # 2. Handle Expired or Malformed Tokens: Trigger redirect
raise HTTPException(status_code=401, detail="Token expired") except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
except jwt.InvalidTokenError as e: return self._get_redirect_to_login()
raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}")
# Example Usage # Usage
# Replace with your actual Azure Group Object ID
ADMIN_GROUP_ID = "b82a3f...-....-..." ADMIN_GROUP_ID = "b82a3f...-....-..."
@app.get("/secure-data") @app.get("/secure-dashboard")
def get_secure_data(user=Depends(EntraGroupChecker(ADMIN_GROUP_ID))): async def get_secure_dashboard(user_or_redirect=Depends(EntraGroupChecker(ADMIN_GROUP_ID))):
# If the dependency returned a RedirectResponse, pass it back to the browser
if isinstance(user_or_redirect, RedirectResponse):
return user_or_redirect
return { return {
"message": "Access granted via Entra ID Group membership", "message": "Access granted to dashboard",
"user_id": user.get("oid") "user_claims": user_or_redirect
} }