from __future__ import annotations
from typing import Union
from ksuid import Ksuid
from sqlalchemy import String, select
from sqlalchemy.orm import Mapped, mapped_column
from auth_middleware.logging import logger
from auth_middleware.providers.authz.groups_provider import GroupsProvider
from auth_middleware.types.jwt import JWTAuthorizationCredentials
from .async_database import AsyncDatabase
from .sql_base_model import Base
class GroupsModel(Base): # type: ignore[misc]
"""Repository groups model
Args:
Base (_type_): SQLAlchemy base model
BaseModel (_type_): base entity model
"""
__tablename__ = "authz_groups"
id: Mapped[str] = mapped_column(
String(27),
primary_key=True,
default=lambda: str(Ksuid()),
index=True,
)
username: Mapped[str] = mapped_column(String(500), nullable=False)
group: Mapped[str] = mapped_column(String(100), nullable=False)
[docs]
class SqlGroupsProvider(GroupsProvider):
"""Recovers groups from AWS Cognito using the token provided
Args:
metaclass (_type_, optional): _description_. Defaults to ABCMeta.
"""
[docs]
async def fetch_groups(self, token: Union[str, JWTAuthorizationCredentials]) -> list[str]:
"""Get groups using the token provided
Args:
token (JWTAuthorizationCredentials): _description_
Raises:
NotImplementedError: _description_
Returns:
List[str]: _description_
"""
# 1. Get the username from the token
username: str = token.claims["username"] if isinstance(token, JWTAuthorizationCredentials) else token
# 2. Check if groups are in the cache
# 3. If not in cache, fetch from the database
groups: list[str] = await self.get_groups_from_db(username=username)
# 4. Return the groups
return groups
[docs]
async def get_groups_from_db(
self,
*,
username: str,
) -> list[str]:
"""Gets groups from the database
Args:
username (str): Username
Returns:
List[str]: List of groups
"""
logger.debug("Username: {}", username)
# TODO: exception capture on init
async with AsyncDatabase.get_session() as session:
try:
query = select(GroupsModel).filter(GroupsModel.username == username)
result = await session.execute(query)
scalars = result.scalars()
items: list[GroupsModel] = list(scalars.all())
return [item.group for item in items]
except Exception as ex:
logger.exception("AsyncDatabase error")
raise ex