from time import time, time_ns
import httpx
from jose import jwk
from jose.utils import base64url_decode
from auth_middleware.logging import logger
from auth_middleware.providers.authn.cognito_authz_provider_settings import CognitoAuthzProviderSettings
from auth_middleware.providers.authn.jwt_provider import JWTProvider
from auth_middleware.providers.authz.groups_provider import GroupsProvider
from auth_middleware.providers.authz.permissions_provider import PermissionsProvider
from auth_middleware.providers.exceptions.aws_exception import AWSException
from auth_middleware.types.jwt import JWK, JWKS, JWTAuthorizationCredentials
from auth_middleware.types.user import User
[docs]
class CognitoProvider(JWTProvider):
def __new__(
cls,
settings: "CognitoAuthzProviderSettings | None" = None,
permissions_provider: "type[PermissionsProvider] | PermissionsProvider | None" = None,
groups_provider: "type[GroupsProvider] | GroupsProvider | None" = None,
) -> "CognitoProvider":
logger.debug("Creating CognitoProvider instance")
if not hasattr(cls, "instance"):
cls.instance = super().__new__(cls)
return cls.instance
[docs]
def __init__(
self,
settings: "CognitoAuthzProviderSettings | None" = None,
permissions_provider: "type[PermissionsProvider] | PermissionsProvider | None" = None,
groups_provider: "type[GroupsProvider] | GroupsProvider | None" = None,
) -> None:
logger.debug("Initializing CognitoProvider instance")
if not getattr(self.__class__, "_initialized", False): # Avoid reinitialization
if not settings:
raise ValueError("Settings must be provided")
# TODO: Refactor this
# Lazy initialization for PermissionsProvider
final_permissions_provider: "PermissionsProvider | None" = None
if permissions_provider:
if isinstance(permissions_provider, type) and issubclass(
permissions_provider, PermissionsProvider
):
logger.debug("Initializing PermissionsProvider")
final_permissions_provider = permissions_provider()
elif isinstance(permissions_provider, PermissionsProvider):
logger.debug("Setting PermissionsProvider")
final_permissions_provider = permissions_provider
else:
raise ValueError(
"permissions_provider must be a PermissionsProvider "
"or a subclass thereof"
)
# TODO: Refactor this
# Lazy initialization for GroupsProvider
final_groups_provider: "GroupsProvider | None" = None
if groups_provider:
if isinstance(groups_provider, type) and issubclass(
groups_provider, GroupsProvider
):
logger.debug("Initializing GroupsProvider")
final_groups_provider = groups_provider()
elif isinstance(groups_provider, GroupsProvider):
logger.debug("Setting GroupsProvider")
final_groups_provider = groups_provider
else:
raise ValueError(
"groups_provider must be a GroupsProvider or a subclass thereof"
)
super().__init__(
settings=settings,
permissions_provider=final_permissions_provider,
groups_provider=final_groups_provider,
)
self._initialized = True
[docs]
async def get_keys(self) -> list[JWK]:
"""Get keys from AWS Cognito
Returns:
List[JWK]: a list of JWK keys
"""
# TODO: Control errors
async with httpx.AsyncClient() as client:
if not isinstance(self._settings, CognitoAuthzProviderSettings):
raise ValueError("CognitoProvider requires CognitoAuthzProviderSettings")
if not self._settings.jwks_url_template:
raise ValueError("jwks_url_template is required in CognitoAuthzProviderSettings")
response = await client.get(
self._settings.jwks_url_template.format(
self._settings.user_pool_region,
self._settings.user_pool_id,
)
)
keys: list[JWK] = response.json()["keys"]
return keys
[docs]
async def load_jwks(
self,
) -> JWKS:
"""Load JWKS credentials from remote Identity Provider
Returns:
JWKS: _description_
"""
# TODO: Control errors
keys: list[JWK] = await self.get_keys()
if not isinstance(self._settings, CognitoAuthzProviderSettings):
raise ValueError("CognitoProvider requires CognitoAuthzProviderSettings")
timestamp: int = (
time_ns() + (self._settings.jwks_cache_interval or 20) * 60 * 1000000000
)
usage_counter: int = self._settings.jwks_cache_usages or 1000
jks: JWKS = JWKS(keys=keys, timestamp=timestamp, usage_counter=usage_counter)
return jks
[docs]
async def verify_token(self, token: JWTAuthorizationCredentials) -> bool:
if self._settings and hasattr(self._settings, 'jwt_token_verification_disabled') and self._settings.jwt_token_verification_disabled:
return True
logger.debug("Verifying token through signature")
hmac_key_candidate = await self._get_hmac_key(token)
if not hmac_key_candidate:
# TODO: Custom exception
logger.error(
"No public key found that matches the one present in the TOKEN!"
)
raise AWSException("No public key found!")
hmac_key = jwk.construct(hmac_key_candidate)
decoded_signature = base64url_decode(token.signature.encode())
# if crypto is OK, then check expiry date
if hmac_key.verify(token.message.encode(), decoded_signature):
return bool(token.claims["exp"] > time())
return False
[docs]
async def create_user_from_token(self, token: JWTAuthorizationCredentials) -> User:
"""Initializes a domain User object with data recovered from a JWT TOKEN.
Args:
token (JWTAuthorizationCredentials): Defaults to Depends(oauth2_scheme).
Returns:
User: Domain object.
"""
name_property: str = (
"username" if "username" in token.claims else "cognito:username"
)
# Get groups directly using the groups provider
groups: list[str] = []
if self._groups_provider:
groups = await self._groups_provider.fetch_groups(token)
return User(
token=str(token),
groups_provider=self._groups_provider,
permissions_provider=self._permissions_provider,
id=token.claims["sub"],
name=(
token.claims[name_property]
if name_property in token.claims
else token.claims["sub"]
),
email=token.claims["email"] if "email" in token.claims else None,
groups=groups,
)