Files
css-test/app/api/auth.py
T
2026-06-07 03:35:28 +00:00

102 lines
4.1 KiB
Python

import jwt
import httpx
from urllib.parse import quote
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import RedirectResponse
from jwt import PyJWKClient
app = FastAPI()
# Note: We use HTTPBearer(auto_error=False) so FastAPI doesn't automatically
# crash with a 403/401 when the token is missing, allowing us to redirect instead.
security = HTTPBearer(auto_error=False)
TENANT_ID = "your-tenant-id"
CLIENT_ID = "your-client-id"
JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys"
REDIRECT_URI = "https://your-app.com/callback" # Must match Entra ID registration
jwks_client = PyJWKClient(JWKS_URL)
class EntraGroupChecker:
def __init__(self, required_group_id: str):
self.required_group_id = required_group_id
def _get_redirect_to_login(self) -> RedirectResponse:
"""Constructs the Entra ID authorization URL and returns a redirect."""
scopes = quote("openid profile User.Read")
encoded_redirect = quote(REDIRECT_URI)
entra_login_url = (
f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/authorize"
f"?client_id={CLIENT_ID}"
f"&response_type=token" # Or 'code' if using Authorization Code Flow
f"&redirect_uri={encoded_redirect}"
f"&scope={scopes}"
f"&response_mode=fragment"
)
return RedirectResponse(url=entra_login_url)
async def _fetch_groups_from_graph(self, token: str) -> list[str]:
url = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id"
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code == 200:
return [item["id"] for item in response.json().get("value", []) if "id" in item]
return []
except httpx.RequestError:
return []
async def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)):
# 1. Handle Missing Token: Trigger redirect
if not credentials:
return self._get_redirect_to_login()
token = credentials.credentials
try:
signing_key = jwks_client.get_signing_key_from_jwt(token)
payload = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
audience=CLIENT_ID,
issuer=f"https://sts.windows.net/{TENANT_ID}/"
)
user_groups = payload.get("groups", [])
has_overage = "_claim_names" in payload and "groups" in payload["_claim_names"].get("groups", "")
if not user_groups and has_overage:
user_groups = await self._fetch_groups_from_graph(token)
# If token is valid but they lack permissions, we still raise a 403 Forbidden
# (Don't redirect here, or they will get stuck in an infinite login loop)
if self.required_group_id not in user_groups:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User is not a member of the required AD group."
)
payload["resolved_groups"] = user_groups
return payload
# 2. Handle Expired or Malformed Tokens: Trigger redirect
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
return self._get_redirect_to_login()
# Usage
ADMIN_GROUP_ID = "b82a3f...-....-..."
@app.get("/secure-dashboard")
async def get_secure_dashboard(user_or_redirect=Depends(EntraGroupChecker(ADMIN_GROUP_ID))):
# If the dependency returned a RedirectResponse, pass it back to the browser
if isinstance(user_or_redirect, RedirectResponse):
return user_or_redirect
return {
"message": "Access granted to dashboard",
"user_claims": user_or_redirect
}