Add app/api/auth.py
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
import jwt
|
||||
from fastapi import FastAPI, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from jwt import PyJWKClient
|
||||
|
||||
app = FastAPI()
|
||||
security = HTTPBearer()
|
||||
|
||||
# Entra ID Configuration
|
||||
TENANT_ID = "your-tenant-id"
|
||||
CLIENT_ID = "your-client-id"
|
||||
JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys"
|
||||
|
||||
jwks_client = PyJWKClient(JWKS_URL)
|
||||
|
||||
class EntraGroupChecker:
|
||||
def __init__(self, required_group_id: str):
|
||||
"""
|
||||
required_group_id: The Object ID of the group from Azure Portal
|
||||
"""
|
||||
self.required_group_id = required_group_id
|
||||
|
||||
def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
token = credentials.credentials
|
||||
try:
|
||||
# Fetch the signing key and validate the JWT
|
||||
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256"],
|
||||
audience=CLIENT_ID,
|
||||
issuer=f"https://sts.windows.net/{TENANT_ID}/"
|
||||
)
|
||||
|
||||
# Extract groups from the token
|
||||
user_groups = payload.get("groups", [])
|
||||
|
||||
if self.required_group_id not in user_groups:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User is not a member of the required AD group."
|
||||
)
|
||||
|
||||
return payload # Return the full user context
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}")
|
||||
|
||||
# Example Usage
|
||||
# Replace with your actual Azure Group Object ID
|
||||
ADMIN_GROUP_ID = "b82a3f...-....-..."
|
||||
|
||||
@app.get("/secure-data")
|
||||
def get_secure_data(user=Depends(EntraGroupChecker(ADMIN_GROUP_ID))):
|
||||
return {
|
||||
"message": "Access granted via Entra ID Group membership",
|
||||
"user_id": user.get("oid")
|
||||
}
|
||||
Reference in New Issue
Block a user