Source code for auth_middleware.providers.authn.jwt_provider

from abc import ABCMeta, abstractmethod
from time import time_ns
from typing import TYPE_CHECKING

from auth_middleware.logging import logger
from auth_middleware.types.jwt import JWK, JWKS, JWTAuthorizationCredentials
from auth_middleware.types.user import User

if TYPE_CHECKING:
    from auth_middleware.providers.authn.jwt_provider_settings import JWTProviderSettings
    from auth_middleware.providers.authz.groups_provider import GroupsProvider
    from auth_middleware.providers.authz.permissions_provider import PermissionsProvider


[docs] class JWTProvider(metaclass=ABCMeta): """Basic interface for a JWT authentication provider Args: metaclass (_type_, optional): _description_. Defaults to ABCMeta. """ _settings: "JWTProviderSettings | None" _permissions_provider: "PermissionsProvider | None" _groups_provider: "GroupsProvider | None"
[docs] def __init__( self, settings: "JWTProviderSettings | None" = None, permissions_provider: "PermissionsProvider | None" = None, groups_provider: "GroupsProvider | None" = None, ) -> None: self._settings = settings self._permissions_provider = permissions_provider self._groups_provider = groups_provider
async def _get_jwks(self) -> JWKS | None: """ Returns a structure that caches the public keys used by the auth provider to sign its JWT tokens. Cache is refreshed after a settable time or number of reads (usages) """ reload_cache = False try: if ( not hasattr(self, "jks") or self.jks.timestamp is None or self.jks.timestamp < time_ns() or self.jks.usage_counter is None or self.jks.usage_counter <= 0 ): reload_cache = True except AttributeError: # the first time after application startup, self.jks is NOT defined reload_cache = True try: if reload_cache: self.jks: JWKS = await self.load_jwks() logger.debug("JWKS loaded") # Always decrement usage counter after accessing JWKS if hasattr(self, "jks") and self.jks.usage_counter is not None: self.jks.usage_counter -= 1 except KeyError: return None return self.jks async def _get_hmac_key(self, token: JWTAuthorizationCredentials) -> JWK | None: jwks: JWKS | None = await self._get_jwks() if jwks is not None and jwks.keys is not None: for key in jwks.keys: if key["kid"] == token.header["kid"]: return key return None
[docs] @abstractmethod async def load_jwks(
self, ) -> JWKS: ...
[docs] @abstractmethod async def verify_token(
self, token: JWTAuthorizationCredentials, ) -> bool: ...
[docs] @abstractmethod async def create_user_from_token(
self, token: JWTAuthorizationCredentials, ) -> User: ...