Source code for luna_accounts.db.context

from typing import Optional
from uuid import uuid4

import jwt
from asyncpg import ForeignKeyViolationError, UniqueViolationError
from passlib.hash import pbkdf2_sha256
from sqlalchemy import and_, delete, func, insert, select, sql, update
from sqlalchemy.exc import IntegrityError
from vlutils.helpers import convertTimeToString

from classes.enums import AccountType
from classes.functions import getCurrentDatetime
from classes.schemas.account import Account, AccountForPatch
from classes.schemas.token import Permissions, Token
from configs.config import JWT_SECRET_STRING
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils.db_functions import dbExceptionWrap
from db.accounts_db_tools.models import accounts_models as models


[docs] class DBContext(BaseDBContext): """Accounts DB context."""
[docs] @dbExceptionWrap async def createAccount(self, account: Account, accountId: Optional[str] = None) -> str: """ Create new account Args: account: account to create accountId: account id to create Returns: unique account id Raises: VLException(Error.AccountAlreadyExist.format(account.login), 409, False) if account with the same login already exists """ accountIdToCreate = accountId or str(uuid4()) async with DBContext.adaptor.connection(self.logger) as connection: checkDuplicateLoginSt = select([models.Account.account_id]).where(models.Account.login == account.login) if await connection.scalar(checkDuplicateLoginSt): raise VLException(Error.AccountWithLoginAlreadyExists.format(account.login), 409, False) if accountId is not None: selectSt = select([models.Account.account_id]).where(models.Account.account_id == accountIdToCreate) isAccountExist = await connection.scalar(selectSt) if isAccountExist: raise VLException(Error.AccountWithIdAlreadyExists.format(accountIdToCreate), 409, False) insertSt = insert(models.Account).values( account_id=accountIdToCreate, login=account.login, password=pbkdf2_sha256.hash(account.password), account_type=account.accountType.value, description=account.description, ) await connection.execute(insertSt) return accountIdToCreate
[docs] @dbExceptionWrap async def getAccountsCount(self) -> int: """ Get accounts count Returns: accounts count """ async with DBContext.adaptor.connection(self.logger) as connection: countSt = select([func.count()]).select_from(models.Account) return await connection.scalar(countSt)
[docs] @dbExceptionWrap async def getAccounts( self, page: int, pageSize: int, login: Optional[str] = None, accountType: Optional[AccountType] = None ) -> tuple[list[dict], int]: """ Get accounts with pagination Args: page: page pageSize: page size login: account login accountType: account type Returns: accounts """ async with DBContext.adaptor.connection(self.logger) as connection: filters = and_( models.Account.account_type == accountType.value if accountType is not None else True, models.Account.login == login if login is not None else True, ) selectSt = ( select( [ models.Account.account_id, models.Account.login, models.Account.account_type, models.Account.description, ] ) .where(filters) .offset((page - 1) * pageSize) .limit(pageSize) ) countSt = select([func.count()]).select_from(models.Account) dbReply, count = await connection.fetchall(selectSt), await connection.scalar(countSt) return [ dict(account_id=row[0], login=row[1], account_type=row[2], description=row[3] or "") for row in dbReply ], count
[docs] @dbExceptionWrap async def getAccount(self, accountId: str) -> Optional[dict]: """ Get account by account id Args: accountId: id of account Returns: account if exists otherwise None """ async with DBContext.adaptor.connection(self.logger) as connection: selectSt = select([models.Account.login, models.Account.account_type, models.Account.description]).where( models.Account.account_id == accountId ) dbReply = await connection.fetchone(selectSt) if not dbReply: return None return dict(account_id=accountId, login=dbReply[0], account_type=dbReply[1], description=dbReply[2] or "")
[docs] @dbExceptionWrap async def patchAccount(self, accountId: str, accountOverride: AccountForPatch) -> bool: """ Patch account by account id Args: accountId: id of account accountOverride: account for patch model Returns: True if account was updated otherwise False """ async with DBContext.adaptor.connection(self.logger) as connection: checkDuplicateLoginSt = select([models.Account.account_id]).where( and_(models.Account.login == accountOverride.login, models.Account.account_id != accountId) ) if await connection.scalar(checkDuplicateLoginSt): raise VLException(Error.AccountWithLoginAlreadyExists.format(accountOverride.login), 409, False) if accountOverride.accountType is not None and accountOverride.accountType != AccountType.admin: selectSt = select([models.Account.account_id]).where( and_(models.Account.account_id == accountId, models.Account.account_type == AccountType.admin.value) ) if await connection.scalar(selectSt): raise VLException(Error.AccountTypeChangeForbidden.format(accountId), 403, False) if accountOverride.password is not None: accountOverride.password = pbkdf2_sha256.hash(accountOverride.password) updateSt = ( update(models.Account).where(models.Account.account_id == accountId).values(**accountOverride.asDict()) ) return bool(await connection.execute(updateSt))
[docs] @dbExceptionWrap async def deleteAccount(self, accountId: str) -> bool: """ Delete account by account id Args: accountId: id of account Returns: True if account was deleted otherwise False """ async with DBContext.adaptor.connection(self.logger) as connection: deleteSt = delete(models.Account).where(models.Account.account_id == accountId) return bool(await connection.execute(deleteSt))
[docs] @dbExceptionWrap async def verifyAccountByLoginPassword(self, login: str, password: str) -> tuple[str, str]: """ Verify account by login and password Args: login: login password: password Returns: account type and account id Raises: VLException(Error.AccountLoginPasswordIncorrect, 400, isCriticalError=False) if account not found or password doesn't match """ exception = VLException(Error.AccountLoginPasswordIncorrect, 401, isCriticalError=False) async with DBContext.adaptor.connection(self.logger) as connection: selectSt = select([models.Account.password, models.Account.account_type, models.Account.account_id]).where( models.Account.login == login ) dbReply = await connection.fetchone(selectSt) if not dbReply: raise exception hashedPassword, accountType, accountId = dbReply if not pbkdf2_sha256.verify(password, hashedPassword): raise exception return accountType, accountId
[docs] @dbExceptionWrap async def verifyAccountByAccountId(self, accountId: str) -> str: """ Verify account by account id Args: accountId: accountId Returns: account type if account found otherwise None """ async with DBContext.adaptor.connection(self.logger) as connection: selectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId) dbReply = await connection.scalar(selectSt) if not dbReply: raise VLException(Error.AccountNotFound.format(accountId), 401, isCriticalError=False) return dbReply
[docs] @dbExceptionWrap async def verifyToken(self, token: str) -> Optional[tuple[str, dict]]: """ Verify account by JWT token Args: token: JWT token Returns: tuple with account type and permissions if account found otherwise None Raises: VLException(Error.JWTTokenNotFound, 400, isCriticalError=False) if token not found or unexpected format VLException(Error.TokenExpired, 400, False) if token expired """ try: tokenData = jwt.decode(token, key=JWT_SECRET_STRING, algorithms=["HS256"]) except jwt.PyJWTError: raise VLException(Error.CorruptedToken, 401, False) if sorted(list(tokenData)) != ["accountId", "expirationTime", "tokenId", "visibilityArea"]: raise VLException(Error.CorruptedToken, 401, False) async with DBContext.adaptor.connection(self.logger) as connection: selectTokenSt = select( [ models.Account.account_type, models.Token.expiration_time, models.Token.permissions, ] ).where( and_( models.Token.token_id == tokenData["tokenId"], models.Token.account_id == tokenData["accountId"], models.Account.account_id == models.Token.account_id, ) ) dbReply = await connection.fetchone(selectTokenSt) if not dbReply: raise VLException(Error.JWTTokenNotFound, 401, isCriticalError=False) accountType, expirationTime, tokenPermissions = dbReply if expirationTime and expirationTime <= getCurrentDatetime(self.storageTime == "UTC"): raise VLException(Error.TokenExpired, 401, False) return accountType, Permissions.fromStr(tokenPermissions).asDict()
[docs] @dbExceptionWrap async def createToken(self, token: Token, accountId: str) -> tuple[str, str]: """ Create new token Args: token: token to create accountId: account id Returns: tuple with unique token and its' id Raises: VLException(Error.AccountNotFound.format(accountId), 400, False) if specified account id not found """ async with DBContext.adaptor.connection(self.logger) as connection: accountTypeSelectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId) accountType = await connection.scalar(accountTypeSelectSt) if accountType == AccountType.user.value and token.visibilityArea == "all": raise VLException( Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False ) tokenId = str(uuid4()) insertSt = insert(models.Token).values( token_id=tokenId, account_id=accountId, permissions=token.permissions.asStr(), expiration_time=token.expirationTime, description=token.description, visibility_area=token.visibilityArea, ) try: await connection.execute(insertSt) except (IntegrityError, UniqueViolationError, ForeignKeyViolationError): raise VLException(Error.AccountNotFound.format(accountId), 400, False) expirationTime = ( convertTimeToString(token.expirationTime, self.storageTime == "UTC") if token.expirationTime is not None else None ) return tokenId, jwt.encode( dict( tokenId=tokenId, expirationTime=expirationTime, accountId=accountId, visibilityArea=token.visibilityArea, ), key=JWT_SECRET_STRING, )
[docs] def makeOutputToken(self, rowFromDB: tuple) -> dict: """ Make dict with token from row from db """ permissions = Permissions.fromStr(rowFromDB[1]).asDict() expirationTime = ( convertTimeToString(rowFromDB[3], self.storageTime == "UTC") if rowFromDB[3] is not None else None ) return { "token_id": rowFromDB[0], "token": jwt.encode( dict( tokenId=rowFromDB[0], expirationTime=expirationTime, accountId=rowFromDB[2], visibilityArea=rowFromDB[5], ), key=JWT_SECRET_STRING, ), "permissions": permissions, "expiration_time": expirationTime, "description": rowFromDB[4] or "", "account_id": rowFromDB[2], "visibility_area": rowFromDB[5], }
[docs] @dbExceptionWrap async def getTokens(self, page: int, pageSize: int, accountId: Optional[str] = None) -> list[dict]: """ Get tokens Args: accountId: account id page: page pageSize: page size Returns: tokens list """ async with DBContext.adaptor.connection(self.logger) as connection: selectSt = ( select( [ models.Token.token_id, models.Token.permissions, models.Token.account_id, models.Token.expiration_time, models.Token.description, models.Token.visibility_area, ] ) .where(models.Token.account_id == accountId if accountId is not None else True) .offset((page - 1) * pageSize) .limit(pageSize) ) dbReply = await connection.fetchall(selectSt) return [self.makeOutputToken(row) for row in dbReply]
@staticmethod def _getTokenFilters(tokenId: str, accountId: Optional[str] = None) -> sql.elements.BinaryExpression: """Get filters for token request""" return and_( models.Token.account_id == accountId if accountId is not None else True, models.Token.token_id == tokenId, )
[docs] @dbExceptionWrap async def getToken(self, tokenId: str, accountId: Optional[str] = None) -> Optional[dict]: """ Get token Args: tokenId: token id accountId: account id Returns: token if exist otherwise None """ async with DBContext.adaptor.connection(self.logger) as connection: selectSt = select( [ models.Token.token_id, models.Token.permissions, models.Token.account_id, models.Token.expiration_time, models.Token.description, models.Token.visibility_area, ] ).where(self._getTokenFilters(tokenId=tokenId, accountId=accountId)) dbReply = await connection.fetchone(selectSt) if not dbReply: raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False) return self.makeOutputToken(dbReply)
[docs] @dbExceptionWrap async def deleteToken(self, tokenId: str, accountId: str) -> bool: """ Delete token by id Args: tokenId: token id accountId: account id Returns: True if token was removed, False if token was not found """ async with DBContext.adaptor.connection(self.logger) as connection: filters = self._getTokenFilters(tokenId=tokenId, accountId=accountId) # required because oracle does not support multiple-table criteria within delete if self.dbType == "oracle": selectSt = select([models.Token.token_id]).where(filters) if not (await connection.scalar(selectSt)): return False deleteSt = delete(models.Token).where(models.Token.token_id == tokenId) else: deleteSt = delete(models.Token).where(filters) return bool(await connection.execute(deleteSt))
[docs] @dbExceptionWrap async def replaceToken(self, tokenId: str, token: Token, accountId: Optional[str] = None) -> str: """ Replace existing token using specified token id Args: tokenId: token id accountId: account id token: token Returns: token Raises: VLException(Error.TokenNotFoundById.format(tokenId), 404, False) if token not found """ async with DBContext.adaptor.connection(self.logger) as connection: if accountId is not None: accountTypeSelectSt = select([models.Account.account_type]).where( models.Account.account_id == accountId ) accountType = await connection.scalar(accountTypeSelectSt) else: accountTypeSelectSt = select([models.Account.account_id, models.Account.account_type]).where( and_(models.Account.account_id == models.Token.account_id, models.Token.token_id == tokenId) ) dbReply = await connection.fetchone(accountTypeSelectSt) if not dbReply: raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False) accountId, accountType = dbReply if accountType == AccountType.user.value and token.visibilityArea == "all": raise VLException( Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False ) values = dict( token_id=tokenId, permissions=token.permissions.asStr(), expiration_time=token.expirationTime, description=token.description, visibility_area=token.visibilityArea, ) updateSt = ( update(models.Token) .values(**values) .where(and_(models.Token.token_id == tokenId, models.Token.account_id == accountId)) ) dbReply = await connection.execute(updateSt) if not dbReply: raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False) return jwt.encode( dict( tokenId=tokenId, expirationTime=( convertTimeToString(token.expirationTime, self.storageTime == "UTC") if token.expirationTime is not None else None ), accountId=accountId, visibilityArea=token.visibilityArea, ), key=JWT_SECRET_STRING, )