diff --git a/app/api/auth.py b/app/api/auth.py new file mode 100644 index 0000000..2b71853 --- /dev/null +++ b/app/api/auth.py @@ -0,0 +1,61 @@ +import jwt +from fastapi import FastAPI, Depends, HTTPException, status +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from jwt import PyJWKClient + +app = FastAPI() +security = HTTPBearer() + +# Entra ID Configuration +TENANT_ID = "your-tenant-id" +CLIENT_ID = "your-client-id" +JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys" + +jwks_client = PyJWKClient(JWKS_URL) + +class EntraGroupChecker: + def __init__(self, required_group_id: str): + """ + required_group_id: The Object ID of the group from Azure Portal + """ + self.required_group_id = required_group_id + + def __call__(self, credentials: HTTPAuthorizationCredentials = Depends(security)): + token = credentials.credentials + try: + # Fetch the signing key and validate the JWT + signing_key = jwks_client.get_signing_key_from_jwt(token) + payload = jwt.decode( + token, + signing_key.key, + algorithms=["RS256"], + audience=CLIENT_ID, + issuer=f"https://sts.windows.net/{TENANT_ID}/" + ) + + # Extract groups from the token + user_groups = payload.get("groups", []) + + if self.required_group_id not in user_groups: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User is not a member of the required AD group." + ) + + return payload # Return the full user context + + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token expired") + except jwt.InvalidTokenError as e: + raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}") + +# Example Usage +# Replace with your actual Azure Group Object ID +ADMIN_GROUP_ID = "b82a3f...-....-..." + +@app.get("/secure-data") +def get_secure_data(user=Depends(EntraGroupChecker(ADMIN_GROUP_ID))): + return { + "message": "Access granted via Entra ID Group membership", + "user_id": user.get("oid") + }