62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
import jwt
|
|
from fastapi import FastAPI, Depends, HTTPException, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from jwt import PyJWKClient
|
|
|
|
app = FastAPI()
|
|
security = HTTPBearer()
|
|
|
|
# Entra ID Configuration
|
|
TENANT_ID = "your-tenant-id"
|
|
CLIENT_ID = "your-client-id"
|
|
JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys"
|
|
|
|
jwks_client = PyJWKClient(JWKS_URL)
|
|
|
|
class EntraGroupChecker:
|
|
def __init__(self, required_group_id: str):
|
|
"""
|
|
required_group_id: The Object ID of the group from Azure Portal
|
|
"""
|
|
self.required_group_id = required_group_id
|
|
|
|
def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
token = credentials.credentials
|
|
try:
|
|
# Fetch the signing key and validate the JWT
|
|
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
|
payload = jwt.decode(
|
|
token,
|
|
signing_key.key,
|
|
algorithms=["RS256"],
|
|
audience=CLIENT_ID,
|
|
issuer=f"https://sts.windows.net/{TENANT_ID}/"
|
|
)
|
|
|
|
# Extract groups from the token
|
|
user_groups = payload.get("groups", [])
|
|
|
|
if self.required_group_id not in user_groups:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="User is not a member of the required AD group."
|
|
)
|
|
|
|
return payload # Return the full user context
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
raise HTTPException(status_code=401, detail="Token expired")
|
|
except jwt.InvalidTokenError as e:
|
|
raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}")
|
|
|
|
# Example Usage
|
|
# Replace with your actual Azure Group Object ID
|
|
ADMIN_GROUP_ID = "b82a3f...-....-..."
|
|
|
|
@app.get("/secure-data")
|
|
def get_secure_data(user=Depends(EntraGroupChecker(ADMIN_GROUP_ID))):
|
|
return {
|
|
"message": "Access granted via Entra ID Group membership",
|
|
"user_id": user.get("oid")
|
|
}
|