Source code for auth_middleware.providers.azure.entra_id_provider

from time import time_ns
from typing import Any

import httpx
from fastapi import status
from joserfc import jwt as joserfc_jwt
from joserfc.errors import JoseError
from joserfc.jwk import import_key
from joserfc.jwt import JWTClaimsRegistry

from auth_middleware.contracts.groups_provider import GroupsProvider
from auth_middleware.contracts.jwt_provider import JWTProvider
from auth_middleware.contracts.permissions_provider import PermissionsProvider
from auth_middleware.exceptions.invalid_token_exception import InvalidTokenException
from auth_middleware.logging import logger
from auth_middleware.providers.azure.azure_exception import AzureException
from auth_middleware.providers.azure.settings import settings
from auth_middleware.types.jwt import JWKS, JWTAuthorizationCredentials
from auth_middleware.types.user import User


[docs] class EntraIDProvider(JWTProvider): def __new__( cls, permissions_provider: PermissionsProvider | None = None, groups_provider: GroupsProvider | None = None, ) -> EntraIDProvider: if not hasattr(cls, "instance"): cls.instance = super().__new__(cls) return cls.instance
[docs] def __init__( self, permissions_provider: PermissionsProvider | None = None, groups_provider: GroupsProvider | None = None, ) -> None: if not hasattr(self, "_initialized"): # Avoid reinitialization super().__init__( permissions_provider=permissions_provider, groups_provider=groups_provider, ) self._initialized = True
# TODO: implement correct types
[docs] async def get_keys(self, jwks_uri: str) -> Any: """Get keys Returns: TODO: List[JWK]: a list of JWK """ # TODO: Control errors async with httpx.AsyncClient() as client: response = await client.get(jwks_uri) keys_data: dict[str, str] = response.json()["keys"] return keys_data
[docs] async def get_openid_config(self) -> dict[str, str]: """Get openid config from entradid Returns: List[JWK]: a list of JWK """ # TODO: Control errors async with httpx.AsyncClient() as client: try: response = await client.get( settings.AUTH_PROVIDER_AZURE_ENTRA_ID_JWKS_URL_TEMPLATE.format( settings.AUTH_PROVIDER_AZURE_ENTRA_ID_TENANT_ID, ) ) config_data: dict[str, str] = response.json() except Exception as e: logger.error("Error in get_openid_config: {}", str(e)) return {} return config_data
[docs] async def load_jwks( self, ) -> JWKS: """Load JWKS credentials from remote Identity Provider Returns: JWKS: _description_ """ # TODO: Control errors openid_config = await self.get_openid_config() jwks_uri = openid_config["jwks_uri"] keys = await self.get_keys(jwks_uri) # Convert 'x5c' field in each key from list to string for key in keys: if "x5c" in key and isinstance(key["x5c"], list): key["x5c"] = "".join(key["x5c"]) timestamp: int = ( time_ns() + getattr(settings, "AUTH_MIDDLEWARE_JWKS_CACHE_INTERVAL_MINUTES", 20) * 60 * 1000000000 ) usage_counter: int = getattr( settings, "AUTH_MIDDLEWARE_JWKS_CACHE_USAGES", 1000 ) jks: JWKS = JWKS(keys=keys, timestamp=timestamp, usage_counter=usage_counter) return jks
[docs] async def verify_token(self, token: JWTAuthorizationCredentials) -> bool: """Verifiy token signature Args: token (JWTAuthorizationCredentials): _description_ Raises: AzureException: _description_ Returns: bool: _description_ """ hmac_key_candidate = await self._get_hmac_key(token) if not hmac_key_candidate: logger.error( "No public key found that matches the one present in the TOKEN!" ) raise InvalidTokenException( status_code=status.HTTP_401_UNAUTHORIZED, detail="No public key found!", ) try: rsa_key_dict = { "kty": hmac_key_candidate["kty"], "kid": hmac_key_candidate["kid"], "use": hmac_key_candidate["use"], "n": hmac_key_candidate["n"], "e": hmac_key_candidate["e"], } key = import_key(rsa_key_dict) audience = settings.AUTH_PROVIDER_AZURE_ENTRA_ID_AUDIENCE_ID token_obj = joserfc_jwt.decode( token.jwt_token, key, algorithms=["RS256"], ) if audience: claims_registry = JWTClaimsRegistry( aud={"essential": True, "value": audience} ) claims_registry.validate(token_obj.claims) return bool(token_obj.claims.get("sub")) except JoseError as je: logger.error("Error in EntraIDClient: {}", str(je)) return False except Exception as e: logger.error("Error in JWTBearerManager: {}", str(e)) raise AzureException("Error in JWTBearerManager") from e
[docs] async def create_user_from_token(self, token: JWTAuthorizationCredentials) -> User: """Initializes a domain User object with data recovered from a JWT TOKEN. Args: token (JWTAuthorizationCredentials): Defaults to Depends(oauth2_scheme). Returns: User: Domain object. """ name_property: str = ( "username" if "username" in token.claims else "preferred_username" ) groups: list[str] = ( await self._groups_provider.fetch_groups(token) if self._groups_provider else [] ) return User( id=token.claims["sub"], name=( token.claims[name_property] if name_property in token.claims else token.claims["sub"] ), groups=groups, email=token.claims["email"] if "email" in token.claims else None, )