Source code for auth_middleware.providers.authn.cognito_provider

from time import time, time_ns

import httpx
from jose import jwk
from jose.utils import base64url_decode

from auth_middleware.logging import logger
from auth_middleware.providers.authn.cognito_authz_provider_settings import (
    CognitoAuthzProviderSettings,
)
from auth_middleware.providers.authn.jwt_provider import JWTProvider
from auth_middleware.providers.authz.groups_provider import GroupsProvider
from auth_middleware.providers.authz.permissions_provider import PermissionsProvider
from auth_middleware.providers.exceptions.aws_exception import AWSException
from auth_middleware.services.m2m_detector import M2MTokenDetector
from auth_middleware.types.jwt import JWK, JWKS, JWTAuthorizationCredentials
from auth_middleware.types.user import User


[docs] class CognitoProvider(JWTProvider): _instances = {} # Dict to store separate instances per class def __new__( cls, settings: CognitoAuthzProviderSettings | None = None, permissions_provider: type[PermissionsProvider] | PermissionsProvider | None = None, groups_provider: type[GroupsProvider] | GroupsProvider | None = None, ) -> CognitoProvider: logger.debug("Creating CognitoProvider instance") if cls not in cls._instances: cls._instances[cls] = super().__new__(cls) return cls._instances[cls]
[docs] def __init__( self, settings: CognitoAuthzProviderSettings | None = None, permissions_provider: type[PermissionsProvider] | PermissionsProvider | None = None, groups_provider: type[GroupsProvider] | GroupsProvider | None = None, ) -> None: logger.debug("Initializing CognitoProvider instance") if not getattr(self.__class__, "_initialized", False): # Avoid reinitialization if not settings: raise ValueError("Settings must be provided") # TODO: Refactor this # Lazy initialization for PermissionsProvider final_permissions_provider: PermissionsProvider | None = None if permissions_provider: if isinstance(permissions_provider, type) and issubclass( permissions_provider, PermissionsProvider ): logger.debug("Initializing PermissionsProvider") final_permissions_provider = permissions_provider() elif isinstance(permissions_provider, PermissionsProvider): logger.debug("Setting PermissionsProvider") final_permissions_provider = permissions_provider else: raise ValueError( "permissions_provider must be a PermissionsProvider " "or a subclass thereof" ) # TODO: Refactor this # Lazy initialization for GroupsProvider final_groups_provider: GroupsProvider | None = None if groups_provider: if isinstance(groups_provider, type) and issubclass( groups_provider, GroupsProvider ): logger.debug("Initializing GroupsProvider") final_groups_provider = groups_provider() elif isinstance(groups_provider, GroupsProvider): logger.debug("Setting GroupsProvider") final_groups_provider = groups_provider else: raise ValueError( "groups_provider must be a GroupsProvider or a subclass thereof" ) super().__init__( settings=settings, permissions_provider=final_permissions_provider, groups_provider=final_groups_provider, ) self._initialized = True
[docs] async def get_keys(self) -> list[JWK]: """Get keys from AWS Cognito Returns: List[JWK]: a list of JWK keys """ # TODO: Control errors async with httpx.AsyncClient() as client: if not isinstance(self._settings, CognitoAuthzProviderSettings): raise ValueError( "CognitoProvider requires CognitoAuthzProviderSettings" ) if not self._settings.jwks_url_template: raise ValueError( "jwks_url_template is required in CognitoAuthzProviderSettings" ) response = await client.get( self._settings.jwks_url_template.format( self._settings.user_pool_region, self._settings.user_pool_id, ) ) keys: list[JWK] = response.json()["keys"] return keys
[docs] async def load_jwks( self, ) -> JWKS: """Load JWKS credentials from remote Identity Provider Returns: JWKS: _description_ """ # TODO: Control errors keys: list[JWK] = await self.get_keys() if not isinstance(self._settings, CognitoAuthzProviderSettings): raise ValueError("CognitoProvider requires CognitoAuthzProviderSettings") timestamp: int = ( time_ns() + (self._settings.jwks_cache_interval or 20) * 60 * 1000000000 ) usage_counter: int = self._settings.jwks_cache_usages or 1000 jks: JWKS = JWKS(keys=keys, timestamp=timestamp, usage_counter=usage_counter) return jks
[docs] async def verify_token(self, token: JWTAuthorizationCredentials) -> bool: if ( self._settings and hasattr(self._settings, "jwt_token_verification_disabled") and self._settings.jwt_token_verification_disabled ): return True logger.debug("Verifying token through signature") hmac_key_candidate = await self._get_hmac_key(token) if not hmac_key_candidate: # TODO: Custom exception logger.error( "No public key found that matches the one present in the TOKEN!" ) raise AWSException("No public key found!") hmac_key = jwk.construct(hmac_key_candidate) decoded_signature = base64url_decode(token.signature.encode()) # if crypto is OK, then check expiry date if hmac_key.verify(token.message.encode(), decoded_signature): return bool(token.claims["exp"] > time()) return False
[docs] async def create_user_from_token(self, token: JWTAuthorizationCredentials) -> User: """Initializes a domain User object with data recovered from a JWT TOKEN. Args: token (JWTAuthorizationCredentials): Defaults to Depends(oauth2_scheme). Returns: User: Domain object. """ # Detect if this is an M2M token is_m2m = M2MTokenDetector.is_m2m_token(token) client_id = M2MTokenDetector.get_client_id(token) if is_m2m else None name_property: str = ( "username" if "username" in token.claims else "cognito:username" ) # Get groups directly using the groups provider groups: list[str] = [] if self._groups_provider and not is_m2m: # M2M tokens typically don't have groups groups = await self._groups_provider.fetch_groups(token) return User( token=str(token), jwt_credentials=token, groups_provider=self._groups_provider, permissions_provider=self._permissions_provider, id=token.claims["sub"], name=( token.claims[name_property] if name_property in token.claims else token.claims["sub"] ), email=token.claims["email"] if "email" in token.claims else None, groups=groups, is_m2m=is_m2m, client_id=client_id, )