from datetime import datetime
from typing import Callable, Iterable, Optional, TypeVar
from uuid import uuid4
import jwt
from asyncpg import ForeignKeyViolationError, UniqueViolationError
from passlib.hash import pbkdf2_sha256
from sqlalchemy import and_, delete, func, insert, select, sql, update
from sqlalchemy.exc import IntegrityError
from vlutils.helpers import convertTimeToString
from classes.enums import AccountType
from classes.functions import getCurrentDatetime
from classes.jwt import JWTProcessor
from classes.schemas.account import Account, AccountForPatch
from classes.schemas.token import PermissionsInDB, Token
from configs.configs.configs.settings.classes import DBSetting
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.enums.accounts import PermissionsTargets
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils.db_functions import dbExceptionWrap
from db.accounts_db_tools.models import accounts_models as models
_ACCOUNT_BY_TARGET_GETTER_COLUMNS_MAP = {
"account_id": models.Account.account_id,
"login": models.Account.login,
"account_type": models.Account.account_type,
"description": models.Account.description,
"create_time": models.Account.create_time,
"last_update_time": models.Account.last_update_time,
}
X = TypeVar("X")
Y = TypeVar("Y")
[docs]
class DBContext(BaseDBContext):
"""Accounts DB context."""
# cached value converters from db format to face api format
_accountProcessFunctions: dict[str, Callable[[Y], X]] = {}
[docs]
@classmethod
async def initDBContext(
cls,
dbSettings: DBSetting,
storageTime: str,
ecdsaKeyString: str | None = None,
ecdsaKeyPassword: str | None = None,
**kwargs,
) -> None:
"""
Initialize context
Args:
dbSettings: database settings
storageTime: storage time
ecdsaKeyString: rsa key string for RS256
ecdsaKeyPassword: rsa key password for RS256
"""
await super().initDBContext(dbSettings=dbSettings, storageTime=storageTime, **kwargs)
storageTimeIsUTC = cls.storageTime == "UTC"
cls._accountProcessFunctions = {
"account_id": lambda s: s if s is not None else "",
"login": lambda s: s if s is not None else "",
"account_type": lambda s: s if s is not None else "",
"description": lambda s: s if s is not None else "",
"create_time": lambda t: convertTimeToString(t, storageTimeIsUTC),
"last_update_time": lambda t: convertTimeToString(t, storageTimeIsUTC),
}
if ecdsaKeyString:
cls._jwtProcessor = JWTProcessor(
algorithm="ES256", ecdsaKeyString=ecdsaKeyString, ecdsaKeyPassword=ecdsaKeyPassword
)
else:
cls._jwtProcessor = JWTProcessor(algorithm="HS256")
[docs]
@classmethod
def makeOutputAccounts(cls, rows: list[tuple], columns: Iterable[str]) -> list[dict[str, any]]:
"""
Make result accounts (from the database reply) proper for user.
Args:
rows: list from db
columns: selected columns
Returns:
faces with changed fields
"""
if not rows:
return []
processableTargets = set(cls._accountProcessFunctions) & set(columns)
if processableTargets:
accounts = []
for row in rows:
account = dict(row)
for target in processableTargets:
account[target] = cls._accountProcessFunctions[target](account[target])
accounts.append(account)
else:
accounts = [dict(row) for row in rows]
return accounts
[docs]
@dbExceptionWrap
async def createAccount(self, account: Account, accountId: Optional[str] = None) -> str:
"""
Create new account
Args:
account: account to create
accountId: account id to create
Returns:
unique account id
Raises:
VLException(Error.AccountAlreadyExist.format(account.login), 409, False) if account with the same login
already exists
"""
accountIdToCreate = accountId or str(uuid4())
async with DBContext.adaptor.connection(self.logger) as connection:
checkDuplicateLoginSt = select([models.Account.account_id]).where(models.Account.login == account.login)
if await connection.scalar(checkDuplicateLoginSt):
raise VLException(Error.AccountWithLoginAlreadyExists.format(account.login), 409, False)
if accountId is not None:
selectSt = select([models.Account.account_id]).where(models.Account.account_id == accountIdToCreate)
isAccountExist = await connection.scalar(selectSt)
if isAccountExist:
raise VLException(Error.AccountWithIdAlreadyExists.format(accountIdToCreate), 409, False)
insertSt = insert(models.Account).values(
account_id=accountIdToCreate,
login=account.login,
password=pbkdf2_sha256.hash(account.password),
account_type=account.accountType.value,
description=account.description,
)
await connection.execute(insertSt)
return accountIdToCreate
[docs]
@dbExceptionWrap
async def getAccountsCount(self) -> int:
"""
Get accounts count
Returns:
accounts count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
countSt = select([func.count()]).select_from(models.Account)
return await connection.scalar(countSt)
[docs]
@dbExceptionWrap
async def getAccounts(
self,
targets: list[str],
page: int = 1,
pageSize: int = 1,
login: str | None = None,
accountType: AccountType | None = None,
createTimeLt: datetime | None = None,
createTimeGte: datetime | None = None,
accountId: str | None = None,
getCount: bool = True,
) -> tuple[list[dict], int] | list[dict]:
"""
Get accounts with pagination
Args:
targets: targets to get
page: page
pageSize: page size
login: account login
accountType: account type
createTimeLt: upper bound of account create time
createTimeGte: lower bound of account create time
accountId: account id
getCount: whether to get account count
Returns:
accounts or accounts and count
"""
columns2Get = [_ACCOUNT_BY_TARGET_GETTER_COLUMNS_MAP[target] for target in targets]
async with DBContext.adaptor.connection(self.logger) as connection:
filters = and_(
models.Account.account_type == accountType.value if accountType is not None else True,
models.Account.login == login if login is not None else True,
models.Account.create_time >= createTimeGte if createTimeGte is not None else True,
models.Account.create_time < createTimeLt if createTimeLt is not None else True,
models.Account.account_id == accountId if accountId is not None else True,
)
selectSt = (
select(columns2Get)
.where(filters)
.order_by(models.Account.create_time.desc())
.offset((page - 1) * pageSize)
.limit(pageSize)
)
accountsRows = await connection.fetchall(selectSt)
accountsRes = self.makeOutputAccounts(accountsRows, targets)
if getCount:
countSt = select([func.count()]).select_from(models.Account)
count = await connection.scalar(countSt)
return accountsRes, count
return accountsRes
[docs]
@dbExceptionWrap
async def patchAccount(self, accountId: str, accountOverride: AccountForPatch) -> bool:
"""
Patch account by account id
Args:
accountId: id of account
accountOverride: account for patch model
Returns:
True if account was updated otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
checkDuplicateLoginSt = select([models.Account.account_id]).where(
and_(models.Account.login == accountOverride.login, models.Account.account_id != accountId)
)
if await connection.scalar(checkDuplicateLoginSt):
raise VLException(Error.AccountWithLoginAlreadyExists.format(accountOverride.login), 409, False)
if accountOverride.accountType is not None and accountOverride.accountType != AccountType.admin:
selectSt = select([models.Account.account_id]).where(
and_(models.Account.account_id == accountId, models.Account.account_type == AccountType.admin.value)
)
if await connection.scalar(selectSt):
raise VLException(Error.AccountTypeChangeForbidden.format(accountId), 403, False)
if accountOverride.password is not None:
accountOverride.password = pbkdf2_sha256.hash(accountOverride.password)
updateSt = (
update(models.Account)
.where(models.Account.account_id == accountId)
.values(last_update_time=self.currentDBTimestamp, **accountOverride.asDict())
)
return bool(await connection.execute(updateSt))
[docs]
@dbExceptionWrap
async def deleteAccount(self, accountId: str) -> bool:
"""
Delete account by account id
Args:
accountId: id of account
Returns:
True if account was deleted otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
deleteSt = delete(models.Account).where(models.Account.account_id == accountId)
return bool(await connection.execute(deleteSt))
[docs]
@dbExceptionWrap
async def verifyAccountByLoginPassword(self, login: str, password: str) -> tuple[str, str]:
"""
Verify account by login and password
Args:
login: login
password: password
Returns:
account type and account id
Raises:
VLException(Error.AccountLoginPasswordIncorrect, 400, isCriticalError=False) if account not found or
password doesn't match
"""
exception = VLException(Error.AccountLoginPasswordIncorrect, 401, isCriticalError=False)
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Account.password, models.Account.account_type, models.Account.account_id]).where(
models.Account.login == login
)
dbReply = await connection.fetchone(selectSt)
if not dbReply:
raise exception
hashedPassword, accountType, accountId = dbReply
if not pbkdf2_sha256.verify(password, hashedPassword):
raise exception
return accountType, accountId
[docs]
@dbExceptionWrap
async def verifyAccountByAccountId(self, accountId: str) -> str:
"""
Verify account by account id
Args:
accountId: accountId
Returns:
account type if account found otherwise None
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId)
dbReply = await connection.scalar(selectSt)
if not dbReply:
raise VLException(Error.AccountNotFound.format(accountId), 401, isCriticalError=False)
return dbReply
@staticmethod
def _filterPermissions(permissionsString: str, permissionTargets: PermissionsTargets | None = None) -> dict:
"""Get permissions from string and filter by targets"""
permissions = PermissionsInDB.fromStr(permissionsString).asDict()
if permissionTargets is None or permissionTargets == PermissionsTargets.standard and "custom" in permissions:
permissions.pop("custom", None)
elif permissionTargets == PermissionsTargets.custom:
permissions = permissions.pop("custom", {})
elif permissionTargets == PermissionsTargets.all and "custom" in permissions:
permissions.update(permissions.pop("custom", {}))
return permissions
[docs]
@dbExceptionWrap
async def verifyToken(self, token: str, permissionTargets: int | None = None) -> Optional[tuple[str, dict]]:
"""
Verify account by JWT token
Args:
token: JWT token
permissionTargets: target for permissions to return in response
Returns:
tuple with account type and permissions if account found otherwise None
Raises:
VLException(Error.JWTTokenNotFound, 400, isCriticalError=False) if token not found or unexpected format
VLException(Error.TokenExpired, 400, False) if token expired
"""
try:
tokenData = self._jwtProcessor.decode(token)
except jwt.PyJWTError:
raise VLException(Error.CorruptedToken, 401, False)
if sorted(list(tokenData)) != ["accountId", "expirationTime", "tokenId", "visibilityArea"]:
raise VLException(Error.CorruptedToken, 401, False)
async with DBContext.adaptor.connection(self.logger) as connection:
selectTokenSt = select(
[
models.Account.account_type,
models.Token.expiration_time,
models.Token.permissions,
]
).where(
and_(
models.Token.token_id == tokenData["tokenId"],
models.Token.account_id == tokenData["accountId"],
models.Account.account_id == models.Token.account_id,
)
)
dbReply = await connection.fetchone(selectTokenSt)
if not dbReply:
raise VLException(Error.JWTTokenNotFound, 401, isCriticalError=False)
accountType, expirationTime, tokenPermissions = dbReply
if expirationTime and expirationTime <= getCurrentDatetime(True):
raise VLException(Error.TokenExpired, 401, False)
return accountType, self._filterPermissions(
permissionsString=tokenPermissions, permissionTargets=permissionTargets
)
[docs]
@dbExceptionWrap
async def createToken(self, token: Token, accountId: str) -> tuple[str, str]:
"""
Create new token
Args:
token: token to create
accountId: account id
Returns:
tuple with unique token and its' id
Raises:
VLException(Error.AccountNotFound.format(accountId), 400, False) if specified account id not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
accountTypeSelectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId)
accountType = await connection.scalar(accountTypeSelectSt)
if accountType == AccountType.user.value and token.visibilityArea == "all":
raise VLException(
Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False
)
tokenId = str(uuid4())
insertSt = insert(models.Token).values(
token_id=tokenId,
account_id=accountId,
permissions=token.permissions.asStr(),
expiration_time=token.expirationTime,
description=token.description,
visibility_area=token.visibilityArea,
)
try:
await connection.execute(insertSt)
except (IntegrityError, UniqueViolationError, ForeignKeyViolationError):
raise VLException(Error.AccountNotFound.format(accountId), 400, False)
expirationTime = (
convertTimeToString(token.expirationTime, self.storageTime == "UTC")
if token.expirationTime is not None
else None
)
return tokenId, self._jwtProcessor.encode(
dict(
tokenId=tokenId,
expirationTime=expirationTime,
accountId=accountId,
visibilityArea=token.visibilityArea,
)
)
[docs]
def makeOutputToken(self, rowFromDB: tuple, permissionTargets: PermissionsTargets | None = None) -> dict:
"""
Make dict with token from row from db
"""
permissions = self._filterPermissions(permissionsString=rowFromDB[1], permissionTargets=permissionTargets)
expirationTime = (
convertTimeToString(rowFromDB[3], self.storageTime == "UTC") if rowFromDB[3] is not None else None
)
return {
"token_id": rowFromDB[0],
"token": self._jwtProcessor.encode(
dict(
tokenId=rowFromDB[0],
expirationTime=expirationTime,
accountId=rowFromDB[2],
visibilityArea=rowFromDB[5],
),
),
"permissions": permissions,
"expiration_time": expirationTime,
"description": rowFromDB[4] or "",
"account_id": rowFromDB[2],
"visibility_area": rowFromDB[5],
"create_time": convertTimeToString(rowFromDB[6], DBContext.storageTime == "UTC"),
"last_update_time": convertTimeToString(rowFromDB[7], DBContext.storageTime == "UTC"),
}
[docs]
@dbExceptionWrap
async def getTokens(
self,
page: int,
pageSize: int,
permissionTargets: PermissionsTargets | None = None,
accountId: str | None = None,
createTimeLt: datetime | None = None,
createTimeGte: datetime | None = None,
) -> list[dict]:
"""
Get tokens
Args:
accountId: account id
page: page
pageSize: page size
permissionTargets: permission targets
createTimeLt: upper bound of token create time
createTimeGte: lower bound of token create time
Returns:
tokens list
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = and_(
models.Token.account_id == accountId if accountId is not None else True,
models.Token.create_time >= createTimeGte if createTimeGte is not None else True,
models.Token.create_time < createTimeLt if createTimeLt is not None else True,
)
selectSt = (
select(
[
models.Token.token_id,
models.Token.permissions,
models.Token.account_id,
models.Token.expiration_time,
models.Token.description,
models.Token.visibility_area,
models.Token.create_time,
models.Token.last_update_time,
]
)
.where(filters)
.order_by(models.Token.create_time.desc())
.offset((page - 1) * pageSize)
.limit(pageSize)
)
dbReply = await connection.fetchall(selectSt)
return [self.makeOutputToken(row, permissionTargets=permissionTargets) for row in dbReply]
@staticmethod
def _getTokenFilters(tokenId: str, accountId: Optional[str] = None) -> sql.elements.BinaryExpression:
"""Get filters for token request"""
return and_(
models.Token.account_id == accountId if accountId is not None else True,
models.Token.token_id == tokenId,
)
[docs]
@dbExceptionWrap
async def getToken(
self, tokenId: str, permissionTargets: int = PermissionsTargets.standard.value, accountId: str | None = None
) -> dict | None:
"""
Get token
Args:
tokenId: token id
permissionTargets: permission targets
accountId: account id
Returns:
token if exist otherwise None
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select(
[
models.Token.token_id,
models.Token.permissions,
models.Token.account_id,
models.Token.expiration_time,
models.Token.description,
models.Token.visibility_area,
models.Token.create_time,
models.Token.last_update_time,
]
).where(self._getTokenFilters(tokenId=tokenId, accountId=accountId))
dbReply = await connection.fetchone(selectSt)
if not dbReply:
raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False)
return self.makeOutputToken(dbReply, permissionTargets=permissionTargets)
[docs]
@dbExceptionWrap
async def deleteToken(self, tokenId: str, accountId: str) -> bool:
"""
Delete token by id
Args:
tokenId: token id
accountId: account id
Returns:
True if token was removed, False if token was not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = self._getTokenFilters(tokenId=tokenId, accountId=accountId)
# required because oracle does not support multiple-table criteria within delete
if self.dbType == "oracle":
selectSt = select([models.Token.token_id]).where(filters)
if not (await connection.scalar(selectSt)):
return False
deleteSt = delete(models.Token).where(models.Token.token_id == tokenId)
else:
deleteSt = delete(models.Token).where(filters)
return bool(await connection.execute(deleteSt))
[docs]
@dbExceptionWrap
async def replaceToken(self, tokenId: str, token: Token, accountId: Optional[str] = None) -> str:
"""
Replace existing token using specified token id
Args:
tokenId: token id
accountId: account id
token: token
Returns:
token
Raises:
VLException(Error.TokenNotFoundById.format(tokenId), 404, False) if token not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
if accountId is not None:
accountTypeSelectSt = select([models.Account.account_type]).where(
models.Account.account_id == accountId
)
accountType = await connection.scalar(accountTypeSelectSt)
else:
accountTypeSelectSt = select([models.Account.account_id, models.Account.account_type]).where(
and_(models.Account.account_id == models.Token.account_id, models.Token.token_id == tokenId)
)
dbReply = await connection.fetchone(accountTypeSelectSt)
if not dbReply:
raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False)
accountId, accountType = dbReply
if accountType == AccountType.user.value and token.visibilityArea == "all":
raise VLException(
Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False
)
values = dict(
token_id=tokenId,
permissions=token.permissions.asStr(),
expiration_time=token.expirationTime,
description=token.description,
visibility_area=token.visibilityArea,
last_update_time=self.currentDBTimestamp,
)
updateSt = (
update(models.Token)
.values(**values)
.where(and_(models.Token.token_id == tokenId, models.Token.account_id == accountId))
)
dbReply = await connection.execute(updateSt)
if not dbReply:
raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False)
return self._jwtProcessor.encode(
dict(
tokenId=tokenId,
expirationTime=(
convertTimeToString(token.expirationTime, self.storageTime == "UTC")
if token.expirationTime is not None
else None
),
accountId=accountId,
visibilityArea=token.visibilityArea,
),
)