Update app/api/auth.py
This commit is contained in:
+60
-19
@@ -1,29 +1,62 @@
|
|||||||
import jwt
|
import jwt
|
||||||
|
import httpx
|
||||||
|
from urllib.parse import quote
|
||||||
from fastapi import FastAPI, Depends, HTTPException, status
|
from fastapi import FastAPI, Depends, HTTPException, status
|
||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
from jwt import PyJWKClient
|
from jwt import PyJWKClient
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
security = HTTPBearer()
|
|
||||||
|
|
||||||
# Entra ID Configuration
|
# Note: We use HTTPBearer(auto_error=False) so FastAPI doesn't automatically
|
||||||
|
# crash with a 403/401 when the token is missing, allowing us to redirect instead.
|
||||||
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
TENANT_ID = "your-tenant-id"
|
TENANT_ID = "your-tenant-id"
|
||||||
CLIENT_ID = "your-client-id"
|
CLIENT_ID = "your-client-id"
|
||||||
JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys"
|
JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys"
|
||||||
|
REDIRECT_URI = "https://your-app.com/callback" # Must match Entra ID registration
|
||||||
|
|
||||||
jwks_client = PyJWKClient(JWKS_URL)
|
jwks_client = PyJWKClient(JWKS_URL)
|
||||||
|
|
||||||
class EntraGroupChecker:
|
class EntraGroupChecker:
|
||||||
def __init__(self, required_group_id: str):
|
def __init__(self, required_group_id: str):
|
||||||
"""
|
|
||||||
required_group_id: The Object ID of the group from Azure Portal
|
|
||||||
"""
|
|
||||||
self.required_group_id = required_group_id
|
self.required_group_id = required_group_id
|
||||||
|
|
||||||
def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
def _get_redirect_to_login(self) -> RedirectResponse:
|
||||||
|
"""Constructs the Entra ID authorization URL and returns a redirect."""
|
||||||
|
scopes = quote("openid profile User.Read")
|
||||||
|
encoded_redirect = quote(REDIRECT_URI)
|
||||||
|
|
||||||
|
entra_login_url = (
|
||||||
|
f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/authorize"
|
||||||
|
f"?client_id={CLIENT_ID}"
|
||||||
|
f"&response_type=token" # Or 'code' if using Authorization Code Flow
|
||||||
|
f"&redirect_uri={encoded_redirect}"
|
||||||
|
f"&scope={scopes}"
|
||||||
|
f"&response_mode=fragment"
|
||||||
|
)
|
||||||
|
return RedirectResponse(url=entra_login_url)
|
||||||
|
|
||||||
|
async def _fetch_groups_from_graph(self, token: str) -> list[str]:
|
||||||
|
url = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id"
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.get(url, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return [item["id"] for item in response.json().get("value", []) if "id" in item]
|
||||||
|
return []
|
||||||
|
except httpx.RequestError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||||
|
# 1. Handle Missing Token: Trigger redirect
|
||||||
|
if not credentials:
|
||||||
|
return self._get_redirect_to_login()
|
||||||
|
|
||||||
token = credentials.credentials
|
token = credentials.credentials
|
||||||
try:
|
try:
|
||||||
# Fetch the signing key and validate the JWT
|
|
||||||
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token,
|
token,
|
||||||
@@ -33,29 +66,37 @@ class EntraGroupChecker:
|
|||||||
issuer=f"https://sts.windows.net/{TENANT_ID}/"
|
issuer=f"https://sts.windows.net/{TENANT_ID}/"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract groups from the token
|
|
||||||
user_groups = payload.get("groups", [])
|
user_groups = payload.get("groups", [])
|
||||||
|
has_overage = "_claim_names" in payload and "groups" in payload["_claim_names"].get("groups", "")
|
||||||
|
|
||||||
|
if not user_groups and has_overage:
|
||||||
|
user_groups = await self._fetch_groups_from_graph(token)
|
||||||
|
|
||||||
|
# If token is valid but they lack permissions, we still raise a 403 Forbidden
|
||||||
|
# (Don't redirect here, or they will get stuck in an infinite login loop)
|
||||||
if self.required_group_id not in user_groups:
|
if self.required_group_id not in user_groups:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="User is not a member of the required AD group."
|
detail="User is not a member of the required AD group."
|
||||||
)
|
)
|
||||||
|
|
||||||
return payload # Return the full user context
|
payload["resolved_groups"] = user_groups
|
||||||
|
return payload
|
||||||
|
|
||||||
except jwt.ExpiredSignatureError:
|
# 2. Handle Expired or Malformed Tokens: Trigger redirect
|
||||||
raise HTTPException(status_code=401, detail="Token expired")
|
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||||
except jwt.InvalidTokenError as e:
|
return self._get_redirect_to_login()
|
||||||
raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}")
|
|
||||||
|
|
||||||
# Example Usage
|
# Usage
|
||||||
# Replace with your actual Azure Group Object ID
|
|
||||||
ADMIN_GROUP_ID = "b82a3f...-....-..."
|
ADMIN_GROUP_ID = "b82a3f...-....-..."
|
||||||
|
|
||||||
@app.get("/secure-data")
|
@app.get("/secure-dashboard")
|
||||||
def get_secure_data(user=Depends(EntraGroupChecker(ADMIN_GROUP_ID))):
|
async def get_secure_dashboard(user_or_redirect=Depends(EntraGroupChecker(ADMIN_GROUP_ID))):
|
||||||
|
# If the dependency returned a RedirectResponse, pass it back to the browser
|
||||||
|
if isinstance(user_or_redirect, RedirectResponse):
|
||||||
|
return user_or_redirect
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": "Access granted via Entra ID Group membership",
|
"message": "Access granted to dashboard",
|
||||||
"user_id": user.get("oid")
|
"user_claims": user_or_redirect
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user