Source code for luna_accounts.db.context

from datetime import datetime
from typing import Awaitable, Iterable, Optional, TypeVar
from uuid import uuid4

import jwt
from accounts_tools.classes.enums import AccountType
from accounts_tools.classes.jwt import JWTProcessor
from accounts_tools.classes.validators import validateTokenAgainstAccount
from asyncpg import ForeignKeyViolationError, UniqueViolationError
from passlib.hash import pbkdf2_sha256
from sqlalchemy import and_, delete, func, insert, select, sql, update
from sqlalchemy.exc import IntegrityError
from vlutils.helpers import convertTimeToString

from classes.functions import getCurrentDatetime
from classes.schemas.account import Account, AccountForPatch
from classes.schemas.token import PermissionsInDB, Token
from configs.config import DB_CONNECT_TIMEOUT
from configs.configs.configs.settings.classes import DBSetting
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.enums.accounts import PermissionsTargets
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils import mixins
from crutches_on_wheels.cow.utils.check_connection import checkConnectionToDB
from crutches_on_wheels.cow.utils.db_functions import dbExceptionWrap
from crutches_on_wheels.cow.utils.healthcheck import checkSql, checkSqlMigration
from db.accounts_db_tools.models import accounts_models as models

_ACCOUNT_BY_TARGET_GETTER_COLUMNS_MAP = {
    "account_id": models.Account.account_id,
    "login": models.Account.login,
    "account_type": models.Account.account_type,
    "description": models.Account.description,
    "create_time": models.Account.create_time,
    "last_update_time": models.Account.last_update_time,
}

X = TypeVar("X")
Y = TypeVar("Y")


[docs] class DBContext(BaseDBContext, mixins.Initializable): """Accounts DB context.""" def __init__( self, dbSettings: DBSetting, storageTime: str, ecdsaKeyString: str | None = None, ecdsaKeyPassword: str | None = None, ): super().__init__() self.dbSettings = dbSettings self.storageTime = storageTime self.ecdsaKeyString = ecdsaKeyString self.ecdsaKeyPassword = ecdsaKeyPassword storageTimeIsUTC = self.storageTime == "UTC" # cached value converters from db format to face api format self._accountProcessFunctions = { "account_id": lambda s: s if s is not None else "", "login": lambda s: s if s is not None else "", "account_type": lambda s: s if s is not None else "", "description": lambda s: s if s is not None else "", "create_time": lambda t: convertTimeToString(t, storageTimeIsUTC), "last_update_time": lambda t: convertTimeToString(t, storageTimeIsUTC), } if ecdsaKeyString: self._jwtProcessor = JWTProcessor( algorithm="ES256", ecdsaKeyString=ecdsaKeyString, ecdsaKeyPassword=ecdsaKeyPassword ) else: self._jwtProcessor = JWTProcessor(algorithm="HS256")
[docs] async def initialize(self): await self.__class__.initDBContext( dbSettings=self.dbSettings, storageTime=self.storageTime, connectTimeout=DB_CONNECT_TIMEOUT )
[docs] def makeOutputAccounts(self, rows: list[tuple], columns: Iterable[str]) -> list[dict[str, any]]: """ Make result accounts (from the database reply) proper for user. Args: rows: list from db columns: selected columns Returns: faces with changed fields """ if not rows: return [] processableTargets = set(self._accountProcessFunctions) & set(columns) if processableTargets: accounts = [] for row in rows: account = dict(row) for target in processableTargets: account[target] = self._accountProcessFunctions[target](account[target]) accounts.append(account) else: accounts = [dict(row) for row in rows] return accounts
[docs] @dbExceptionWrap async def createAccount(self, account: Account, accountId: Optional[str] = None) -> str: """ Create new account Args: account: account to create accountId: account id to create Returns: unique account id Raises: VLException(Error.AccountAlreadyExist.format(account.login), 409, False) if account with the same login already exists """ accountIdToCreate = accountId or str(uuid4()) async with DBContext.adaptor.connection(self.logger) as connection: checkDuplicateLoginSt = select([models.Account.account_id]).where(models.Account.login == account.login) if await connection.scalar(checkDuplicateLoginSt): raise VLException(Error.AccountWithLoginAlreadyExists.format(account.login), 409, False) if accountId is not None: selectSt = select([models.Account.account_id]).where(models.Account.account_id == accountIdToCreate) isAccountExist = await connection.scalar(selectSt) if isAccountExist: raise VLException(Error.AccountWithIdAlreadyExists.format(accountIdToCreate), 409, False) insertSt = insert(models.Account).values( account_id=accountIdToCreate, login=account.login, password=pbkdf2_sha256.hash(account.password), account_type=account.accountType.value, description=account.description, ) await connection.execute(insertSt) return accountIdToCreate
[docs] @dbExceptionWrap async def getAccountsCount(self) -> int: """ Get accounts count Returns: accounts count """ async with DBContext.adaptor.connection(self.logger) as connection: countSt = select([func.count()]).select_from(models.Account) return await connection.scalar(countSt)
[docs] @dbExceptionWrap async def getAccounts( self, targets: list[str], page: int = 1, pageSize: int = 1, login: str | None = None, accountType: AccountType | None = None, createTimeLt: datetime | None = None, createTimeGte: datetime | None = None, accountId: str | None = None, getCount: bool = True, ) -> tuple[list[dict], int] | list[dict]: """ Get accounts with pagination Args: targets: targets to get page: page pageSize: page size login: account login accountType: account type createTimeLt: upper bound of account create time createTimeGte: lower bound of account create time accountId: account id getCount: whether to get account count Returns: accounts or accounts and count """ columns2Get = [_ACCOUNT_BY_TARGET_GETTER_COLUMNS_MAP[target] for target in targets] async with DBContext.adaptor.connection(self.logger) as connection: filters = and_( models.Account.account_type == accountType.value if accountType is not None else True, models.Account.login == login if login is not None else True, models.Account.create_time >= createTimeGte if createTimeGte is not None else True, models.Account.create_time < createTimeLt if createTimeLt is not None else True, models.Account.account_id == accountId if accountId is not None else True, ) selectSt = ( select(columns2Get) .where(filters) .order_by(models.Account.create_time.desc()) .offset((page - 1) * pageSize) .limit(pageSize) ) accountsRows = await connection.fetchall(selectSt) accountsRes = self.makeOutputAccounts(accountsRows, targets) if getCount: countSt = select([func.count()]).select_from(models.Account) count = await connection.scalar(countSt) return accountsRes, count return accountsRes
[docs] @dbExceptionWrap async def patchAccount(self, accountId: str, accountOverride: AccountForPatch) -> bool: """ Patch account by account id Args: accountId: id of account accountOverride: account for patch model Returns: True if account was updated otherwise False """ async with DBContext.adaptor.connection(self.logger) as connection: checkDuplicateLoginSt = select([models.Account.account_id]).where( and_(models.Account.login == accountOverride.login, models.Account.account_id != accountId) ) if await connection.scalar(checkDuplicateLoginSt): raise VLException(Error.AccountWithLoginAlreadyExists.format(accountOverride.login), 409, False) if accountOverride.accountType is not None and accountOverride.accountType != AccountType.admin: selectSt = select([models.Account.account_id]).where( and_(models.Account.account_id == accountId, models.Account.account_type == AccountType.admin.value) ) if await connection.scalar(selectSt): raise VLException(Error.AccountTypeChangeForbidden.format(accountId), 403, False) if accountOverride.password is not None: accountOverride.password = pbkdf2_sha256.hash(accountOverride.password) updateSt = ( update(models.Account) .where(models.Account.account_id == accountId) .values(last_update_time=self.currentDBTimestamp, **accountOverride.asDict()) ) return bool(await connection.execute(updateSt))
[docs] @dbExceptionWrap async def deleteAccount(self, accountId: str) -> bool: """ Delete account by account id Args: accountId: id of account Returns: True if account was deleted otherwise False """ async with DBContext.adaptor.connection(self.logger) as connection: deleteSt = delete(models.Account).where(models.Account.account_id == accountId) return bool(await connection.execute(deleteSt))
[docs] @dbExceptionWrap async def verifyAccountByLoginPassword(self, login: str, password: str) -> tuple[str, str]: """ Verify account by login and password Args: login: login password: password Returns: account type and account id Raises: VLException(Error.AccountLoginPasswordIncorrect, 400, isCriticalError=False) if account not found or password doesn't match """ exception = VLException(Error.AccountLoginPasswordIncorrect, 401, isCriticalError=False) async with DBContext.adaptor.connection(self.logger) as connection: selectSt = select([models.Account.password, models.Account.account_type, models.Account.account_id]).where( models.Account.login == login ) dbReply = await connection.fetchone(selectSt) if not dbReply: raise exception hashedPassword, accountType, accountId = dbReply if not pbkdf2_sha256.verify(password, hashedPassword): raise exception return accountType, accountId
[docs] @dbExceptionWrap async def verifyAccountByAccountId(self, accountId: str) -> str: """ Verify account by account id Args: accountId: accountId Returns: account type if account found otherwise None """ async with DBContext.adaptor.connection(self.logger) as connection: selectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId) dbReply = await connection.scalar(selectSt) if not dbReply: raise VLException(Error.AccountNotFound.format(accountId), 401, isCriticalError=False) return dbReply
@staticmethod def _filterPermissions(permissionsString: str, permissionTargets: PermissionsTargets | None = None) -> dict: """Get permissions from string and filter by targets""" permissions = PermissionsInDB.fromStr(permissionsString).asDict() if permissionTargets is None or permissionTargets == PermissionsTargets.standard and "custom" in permissions: permissions.pop("custom", None) elif permissionTargets == PermissionsTargets.custom: permissions = permissions.pop("custom", {}) elif permissionTargets == PermissionsTargets.all and "custom" in permissions: permissions.update(permissions.pop("custom", {})) return permissions
[docs] @dbExceptionWrap async def verifyToken(self, token: str, permissionTargets: int | None = None) -> Optional[tuple[str, dict]]: """ Verify account by JWT token Args: token: JWT token permissionTargets: target for permissions to return in response Returns: tuple with account type and permissions if account found otherwise None Raises: VLException(Error.JWTTokenNotFound, 400, isCriticalError=False) if token not found or unexpected format VLException(Error.TokenExpired, 400, False) if token expired """ try: tokenData = self._jwtProcessor.decode(token) except jwt.PyJWTError: raise VLException(Error.CorruptedToken, 401, False) if sorted(list(tokenData)) != ["accountId", "expirationTime", "tokenId", "visibilityArea"]: raise VLException(Error.CorruptedToken, 401, False) async with DBContext.adaptor.connection(self.logger) as connection: selectTokenSt = select( [ models.Account.account_type, models.Token.expiration_time, models.Token.permissions, ] ).where( and_( models.Token.token_id == tokenData["tokenId"], models.Token.account_id == tokenData["accountId"], models.Account.account_id == models.Token.account_id, ) ) dbReply = await connection.fetchone(selectTokenSt) if not dbReply: raise VLException(Error.JWTTokenNotFound, 401, isCriticalError=False) accountType, expirationTime, tokenPermissions = dbReply if expirationTime and expirationTime <= getCurrentDatetime(True): raise VLException(Error.TokenExpired, 401, False) return accountType, self._filterPermissions( permissionsString=tokenPermissions, permissionTargets=permissionTargets )
[docs] @dbExceptionWrap async def createToken(self, token: Token, accountId: str) -> tuple[str, str]: """ Create new token Args: token: token to create accountId: account id Returns: tuple with unique token and its' id Raises: VLException(Error.AccountNotFound.format(accountId), 400, False) if specified account id not found """ async with DBContext.adaptor.connection(self.logger) as connection: accountTypeSelectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId) accountType = await connection.scalar(accountTypeSelectSt) if not validateTokenAgainstAccount(accountType, token.visibilityArea.value): raise VLException( Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False ) tokenId = str(uuid4()) insertSt = insert(models.Token).values( token_id=tokenId, account_id=accountId, permissions=token.permissions.asStr(), expiration_time=token.expirationTime, description=token.description, visibility_area=token.visibilityArea.value, ) try: await connection.execute(insertSt) except (IntegrityError, UniqueViolationError, ForeignKeyViolationError): raise VLException(Error.AccountNotFound.format(accountId), 400, False) expirationTime = ( convertTimeToString(token.expirationTime, self.storageTime == "UTC") if token.expirationTime is not None else None ) return tokenId, self._jwtProcessor.encode( tokenId=tokenId, expirationTime=expirationTime, accountId=accountId, visibilityArea=token.visibilityArea.value, )
[docs] def makeOutputToken(self, rowFromDB: tuple, permissionTargets: PermissionsTargets | None = None) -> dict: """ Make dict with token from row from db """ permissions = self._filterPermissions(permissionsString=rowFromDB[1], permissionTargets=permissionTargets) expirationTime = ( convertTimeToString(rowFromDB[3], self.storageTime == "UTC") if rowFromDB[3] is not None else None ) return { "token_id": rowFromDB[0], "token": self._jwtProcessor.encode( tokenId=rowFromDB[0], expirationTime=expirationTime, accountId=rowFromDB[2], visibilityArea=rowFromDB[5], ), "permissions": permissions, "expiration_time": expirationTime, "description": rowFromDB[4] or "", "account_id": rowFromDB[2], "visibility_area": rowFromDB[5], "create_time": convertTimeToString(rowFromDB[6], DBContext.storageTime == "UTC"), "last_update_time": convertTimeToString(rowFromDB[7], DBContext.storageTime == "UTC"), }
[docs] @dbExceptionWrap async def getTokens( self, page: int, pageSize: int, permissionTargets: PermissionsTargets | None = None, accountId: str | None = None, createTimeLt: datetime | None = None, createTimeGte: datetime | None = None, ) -> list[dict]: """ Get tokens Args: accountId: account id page: page pageSize: page size permissionTargets: permission targets createTimeLt: upper bound of token create time createTimeGte: lower bound of token create time Returns: tokens list """ async with DBContext.adaptor.connection(self.logger) as connection: filters = and_( models.Token.account_id == accountId if accountId is not None else True, models.Token.create_time >= createTimeGte if createTimeGte is not None else True, models.Token.create_time < createTimeLt if createTimeLt is not None else True, ) selectSt = ( select( [ models.Token.token_id, models.Token.permissions, models.Token.account_id, models.Token.expiration_time, models.Token.description, models.Token.visibility_area, models.Token.create_time, models.Token.last_update_time, ] ) .where(filters) .order_by(models.Token.create_time.desc()) .offset((page - 1) * pageSize) .limit(pageSize) ) dbReply = await connection.fetchall(selectSt) return [self.makeOutputToken(row, permissionTargets=permissionTargets) for row in dbReply]
@staticmethod def _getTokenFilters(tokenId: str, accountId: Optional[str] = None) -> sql.elements.BinaryExpression: """Get filters for token request""" return and_( models.Token.account_id == accountId if accountId is not None else True, models.Token.token_id == tokenId, )
[docs] @dbExceptionWrap async def getToken( self, tokenId: str, permissionTargets: int = PermissionsTargets.standard.value, accountId: str | None = None ) -> dict | None: """ Get token Args: tokenId: token id permissionTargets: permission targets accountId: account id Returns: token if exist otherwise None """ async with DBContext.adaptor.connection(self.logger) as connection: selectSt = select( [ models.Token.token_id, models.Token.permissions, models.Token.account_id, models.Token.expiration_time, models.Token.description, models.Token.visibility_area, models.Token.create_time, models.Token.last_update_time, ] ).where(self._getTokenFilters(tokenId=tokenId, accountId=accountId)) dbReply = await connection.fetchone(selectSt) if not dbReply: raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False) return self.makeOutputToken(dbReply, permissionTargets=permissionTargets)
[docs] @dbExceptionWrap async def deleteToken(self, tokenId: str, accountId: str) -> bool: """ Delete token by id Args: tokenId: token id accountId: account id Returns: True if token was removed, False if token was not found """ async with DBContext.adaptor.connection(self.logger) as connection: filters = self._getTokenFilters(tokenId=tokenId, accountId=accountId) # required because oracle does not support multiple-table criteria within delete if self.dbType == "oracle": selectSt = select([models.Token.token_id]).where(filters) if not (await connection.scalar(selectSt)): return False deleteSt = delete(models.Token).where(models.Token.token_id == tokenId) else: deleteSt = delete(models.Token).where(filters) return bool(await connection.execute(deleteSt))
[docs] @dbExceptionWrap async def replaceToken(self, tokenId: str, token: Token, accountId: Optional[str] = None) -> str: """ Replace existing token using specified token id Args: tokenId: token id accountId: account id token: token Returns: token Raises: VLException(Error.TokenNotFoundById.format(tokenId), 404, False) if token not found """ async with DBContext.adaptor.connection(self.logger) as connection: if accountId is not None: accountTypeSelectSt = select([models.Account.account_type]).where( models.Account.account_id == accountId ) accountType = await connection.scalar(accountTypeSelectSt) else: accountTypeSelectSt = select([models.Account.account_id, models.Account.account_type]).where( and_(models.Account.account_id == models.Token.account_id, models.Token.token_id == tokenId) ) dbReply = await connection.fetchone(accountTypeSelectSt) if not dbReply: raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False) accountId, accountType = dbReply if accountType == AccountType.user.value and token.visibilityArea.value == "all": raise VLException( Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False ) values = dict( token_id=tokenId, permissions=token.permissions.asStr(), expiration_time=token.expirationTime, description=token.description, visibility_area=token.visibilityArea.value, last_update_time=self.currentDBTimestamp, ) updateSt = ( update(models.Token) .values(**values) .where(and_(models.Token.token_id == tokenId, models.Token.account_id == accountId)) ) dbReply = await connection.execute(updateSt) if not dbReply: raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False) return self._jwtProcessor.encode( tokenId=tokenId, expirationTime=( convertTimeToString(token.expirationTime, self.storageTime == "UTC") if token.expirationTime is not None else None ), accountId=accountId, visibilityArea=token.visibilityArea.value, )
[docs] async def probe(self) -> bool: """Ensure provided config is valid. Create new connection. Can be used without initialization""" checkState = await checkConnectionToDB(dbSetting=self.dbSettings, postfix="accounts", asyncCheck=True) return checkState
[docs] def getRuntimeChecks(self, _) -> list[tuple[str, Awaitable]]: """Checks for healthcheck.""" return [("accounts_db", checkSql(self.adaptor)), ("accounts_db_migration", checkSqlMigration(self.adaptor))]