from datetime import datetime
from typing import Awaitable, Iterable, Optional, TypeVar
from uuid import uuid4
import jwt
from accounts_tools.classes.enums import AccountType
from accounts_tools.classes.jwt import JWTProcessor
from accounts_tools.classes.validators import validateTokenAgainstAccount
from asyncpg import ForeignKeyViolationError, UniqueViolationError
from passlib.hash import pbkdf2_sha256
from sqlalchemy import and_, delete, func, insert, select, sql, update
from sqlalchemy.exc import IntegrityError
from vlutils.helpers import convertTimeToString
from classes.functions import getCurrentDatetime
from classes.schemas.account import Account, AccountForPatch
from classes.schemas.token import PermissionsInDB, Token
from configs.config import DB_CONNECT_TIMEOUT
from configs.configs.configs.settings.classes import DBSetting
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.enums.accounts import PermissionsTargets
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils import mixins
from crutches_on_wheels.cow.utils.check_connection import checkConnectionToDB
from crutches_on_wheels.cow.utils.db_functions import dbExceptionWrap
from crutches_on_wheels.cow.utils.healthcheck import checkSql, checkSqlMigration
from db.accounts_db_tools.models import accounts_models as models
_ACCOUNT_BY_TARGET_GETTER_COLUMNS_MAP = {
"account_id": models.Account.account_id,
"login": models.Account.login,
"account_type": models.Account.account_type,
"description": models.Account.description,
"create_time": models.Account.create_time,
"last_update_time": models.Account.last_update_time,
}
X = TypeVar("X")
Y = TypeVar("Y")
[docs]
class DBContext(BaseDBContext, mixins.Initializable):
"""Accounts DB context."""
def __init__(
self,
dbSettings: DBSetting,
storageTime: str,
ecdsaKeyString: str | None = None,
ecdsaKeyPassword: str | None = None,
):
super().__init__()
self.dbSettings = dbSettings
self.storageTime = storageTime
self.ecdsaKeyString = ecdsaKeyString
self.ecdsaKeyPassword = ecdsaKeyPassword
storageTimeIsUTC = self.storageTime == "UTC"
# cached value converters from db format to face api format
self._accountProcessFunctions = {
"account_id": lambda s: s if s is not None else "",
"login": lambda s: s if s is not None else "",
"account_type": lambda s: s if s is not None else "",
"description": lambda s: s if s is not None else "",
"create_time": lambda t: convertTimeToString(t, storageTimeIsUTC),
"last_update_time": lambda t: convertTimeToString(t, storageTimeIsUTC),
}
if ecdsaKeyString:
self._jwtProcessor = JWTProcessor(
algorithm="ES256", ecdsaKeyString=ecdsaKeyString, ecdsaKeyPassword=ecdsaKeyPassword
)
else:
self._jwtProcessor = JWTProcessor(algorithm="HS256")
[docs]
async def initialize(self):
await self.__class__.initDBContext(
dbSettings=self.dbSettings, storageTime=self.storageTime, connectTimeout=DB_CONNECT_TIMEOUT
)
[docs]
def makeOutputAccounts(self, rows: list[tuple], columns: Iterable[str]) -> list[dict[str, any]]:
"""
Make result accounts (from the database reply) proper for user.
Args:
rows: list from db
columns: selected columns
Returns:
faces with changed fields
"""
if not rows:
return []
processableTargets = set(self._accountProcessFunctions) & set(columns)
if processableTargets:
accounts = []
for row in rows:
account = dict(row)
for target in processableTargets:
account[target] = self._accountProcessFunctions[target](account[target])
accounts.append(account)
else:
accounts = [dict(row) for row in rows]
return accounts
[docs]
@dbExceptionWrap
async def createAccount(self, account: Account, accountId: Optional[str] = None) -> str:
"""
Create new account
Args:
account: account to create
accountId: account id to create
Returns:
unique account id
Raises:
VLException(Error.AccountAlreadyExist.format(account.login), 409, False) if account with the same login
already exists
"""
accountIdToCreate = accountId or str(uuid4())
async with DBContext.adaptor.connection(self.logger) as connection:
checkDuplicateLoginSt = select([models.Account.account_id]).where(models.Account.login == account.login)
if await connection.scalar(checkDuplicateLoginSt):
raise VLException(Error.AccountWithLoginAlreadyExists.format(account.login), 409, False)
if accountId is not None:
selectSt = select([models.Account.account_id]).where(models.Account.account_id == accountIdToCreate)
isAccountExist = await connection.scalar(selectSt)
if isAccountExist:
raise VLException(Error.AccountWithIdAlreadyExists.format(accountIdToCreate), 409, False)
insertSt = insert(models.Account).values(
account_id=accountIdToCreate,
login=account.login,
password=pbkdf2_sha256.hash(account.password),
account_type=account.accountType.value,
description=account.description,
)
await connection.execute(insertSt)
return accountIdToCreate
[docs]
@dbExceptionWrap
async def getAccountsCount(self) -> int:
"""
Get accounts count
Returns:
accounts count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
countSt = select([func.count()]).select_from(models.Account)
return await connection.scalar(countSt)
[docs]
@dbExceptionWrap
async def getAccounts(
self,
targets: list[str],
page: int = 1,
pageSize: int = 1,
login: str | None = None,
accountType: AccountType | None = None,
createTimeLt: datetime | None = None,
createTimeGte: datetime | None = None,
accountId: str | None = None,
getCount: bool = True,
) -> tuple[list[dict], int] | list[dict]:
"""
Get accounts with pagination
Args:
targets: targets to get
page: page
pageSize: page size
login: account login
accountType: account type
createTimeLt: upper bound of account create time
createTimeGte: lower bound of account create time
accountId: account id
getCount: whether to get account count
Returns:
accounts or accounts and count
"""
columns2Get = [_ACCOUNT_BY_TARGET_GETTER_COLUMNS_MAP[target] for target in targets]
async with DBContext.adaptor.connection(self.logger) as connection:
filters = and_(
models.Account.account_type == accountType.value if accountType is not None else True,
models.Account.login == login if login is not None else True,
models.Account.create_time >= createTimeGte if createTimeGte is not None else True,
models.Account.create_time < createTimeLt if createTimeLt is not None else True,
models.Account.account_id == accountId if accountId is not None else True,
)
selectSt = (
select(columns2Get)
.where(filters)
.order_by(models.Account.create_time.desc())
.offset((page - 1) * pageSize)
.limit(pageSize)
)
accountsRows = await connection.fetchall(selectSt)
accountsRes = self.makeOutputAccounts(accountsRows, targets)
if getCount:
countSt = select([func.count()]).select_from(models.Account)
count = await connection.scalar(countSt)
return accountsRes, count
return accountsRes
[docs]
@dbExceptionWrap
async def patchAccount(self, accountId: str, accountOverride: AccountForPatch) -> bool:
"""
Patch account by account id
Args:
accountId: id of account
accountOverride: account for patch model
Returns:
True if account was updated otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
checkDuplicateLoginSt = select([models.Account.account_id]).where(
and_(models.Account.login == accountOverride.login, models.Account.account_id != accountId)
)
if await connection.scalar(checkDuplicateLoginSt):
raise VLException(Error.AccountWithLoginAlreadyExists.format(accountOverride.login), 409, False)
if accountOverride.accountType is not None and accountOverride.accountType != AccountType.admin:
selectSt = select([models.Account.account_id]).where(
and_(models.Account.account_id == accountId, models.Account.account_type == AccountType.admin.value)
)
if await connection.scalar(selectSt):
raise VLException(Error.AccountTypeChangeForbidden.format(accountId), 403, False)
if accountOverride.password is not None:
accountOverride.password = pbkdf2_sha256.hash(accountOverride.password)
updateSt = (
update(models.Account)
.where(models.Account.account_id == accountId)
.values(last_update_time=self.currentDBTimestamp, **accountOverride.asDict())
)
return bool(await connection.execute(updateSt))
[docs]
@dbExceptionWrap
async def deleteAccount(self, accountId: str) -> bool:
"""
Delete account by account id
Args:
accountId: id of account
Returns:
True if account was deleted otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
deleteSt = delete(models.Account).where(models.Account.account_id == accountId)
return bool(await connection.execute(deleteSt))
[docs]
@dbExceptionWrap
async def verifyAccountByLoginPassword(self, login: str, password: str) -> tuple[str, str]:
"""
Verify account by login and password
Args:
login: login
password: password
Returns:
account type and account id
Raises:
VLException(Error.AccountLoginPasswordIncorrect, 400, isCriticalError=False) if account not found or
password doesn't match
"""
exception = VLException(Error.AccountLoginPasswordIncorrect, 401, isCriticalError=False)
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Account.password, models.Account.account_type, models.Account.account_id]).where(
models.Account.login == login
)
dbReply = await connection.fetchone(selectSt)
if not dbReply:
raise exception
hashedPassword, accountType, accountId = dbReply
if not pbkdf2_sha256.verify(password, hashedPassword):
raise exception
return accountType, accountId
[docs]
@dbExceptionWrap
async def verifyAccountByAccountId(self, accountId: str) -> str:
"""
Verify account by account id
Args:
accountId: accountId
Returns:
account type if account found otherwise None
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId)
dbReply = await connection.scalar(selectSt)
if not dbReply:
raise VLException(Error.AccountNotFound.format(accountId), 401, isCriticalError=False)
return dbReply
@staticmethod
def _filterPermissions(permissionsString: str, permissionTargets: PermissionsTargets | None = None) -> dict:
"""Get permissions from string and filter by targets"""
permissions = PermissionsInDB.fromStr(permissionsString).asDict()
if permissionTargets is None or permissionTargets == PermissionsTargets.standard and "custom" in permissions:
permissions.pop("custom", None)
elif permissionTargets == PermissionsTargets.custom:
permissions = permissions.pop("custom", {})
elif permissionTargets == PermissionsTargets.all and "custom" in permissions:
permissions.update(permissions.pop("custom", {}))
return permissions
[docs]
@dbExceptionWrap
async def verifyToken(self, token: str, permissionTargets: int | None = None) -> Optional[tuple[str, dict]]:
"""
Verify account by JWT token
Args:
token: JWT token
permissionTargets: target for permissions to return in response
Returns:
tuple with account type and permissions if account found otherwise None
Raises:
VLException(Error.JWTTokenNotFound, 400, isCriticalError=False) if token not found or unexpected format
VLException(Error.TokenExpired, 400, False) if token expired
"""
try:
tokenData = self._jwtProcessor.decode(token)
except jwt.PyJWTError:
raise VLException(Error.CorruptedToken, 401, False)
if sorted(list(tokenData)) != ["accountId", "expirationTime", "tokenId", "visibilityArea"]:
raise VLException(Error.CorruptedToken, 401, False)
async with DBContext.adaptor.connection(self.logger) as connection:
selectTokenSt = select(
[
models.Account.account_type,
models.Token.expiration_time,
models.Token.permissions,
]
).where(
and_(
models.Token.token_id == tokenData["tokenId"],
models.Token.account_id == tokenData["accountId"],
models.Account.account_id == models.Token.account_id,
)
)
dbReply = await connection.fetchone(selectTokenSt)
if not dbReply:
raise VLException(Error.JWTTokenNotFound, 401, isCriticalError=False)
accountType, expirationTime, tokenPermissions = dbReply
if expirationTime and expirationTime <= getCurrentDatetime(True):
raise VLException(Error.TokenExpired, 401, False)
return accountType, self._filterPermissions(
permissionsString=tokenPermissions, permissionTargets=permissionTargets
)
[docs]
@dbExceptionWrap
async def createToken(self, token: Token, accountId: str) -> tuple[str, str]:
"""
Create new token
Args:
token: token to create
accountId: account id
Returns:
tuple with unique token and its' id
Raises:
VLException(Error.AccountNotFound.format(accountId), 400, False) if specified account id not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
accountTypeSelectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId)
accountType = await connection.scalar(accountTypeSelectSt)
if not validateTokenAgainstAccount(accountType, token.visibilityArea.value):
raise VLException(
Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False
)
tokenId = str(uuid4())
insertSt = insert(models.Token).values(
token_id=tokenId,
account_id=accountId,
permissions=token.permissions.asStr(),
expiration_time=token.expirationTime,
description=token.description,
visibility_area=token.visibilityArea.value,
)
try:
await connection.execute(insertSt)
except (IntegrityError, UniqueViolationError, ForeignKeyViolationError):
raise VLException(Error.AccountNotFound.format(accountId), 400, False)
expirationTime = (
convertTimeToString(token.expirationTime, self.storageTime == "UTC")
if token.expirationTime is not None
else None
)
return tokenId, self._jwtProcessor.encode(
tokenId=tokenId,
expirationTime=expirationTime,
accountId=accountId,
visibilityArea=token.visibilityArea.value,
)
[docs]
def makeOutputToken(self, rowFromDB: tuple, permissionTargets: PermissionsTargets | None = None) -> dict:
"""
Make dict with token from row from db
"""
permissions = self._filterPermissions(permissionsString=rowFromDB[1], permissionTargets=permissionTargets)
expirationTime = (
convertTimeToString(rowFromDB[3], self.storageTime == "UTC") if rowFromDB[3] is not None else None
)
return {
"token_id": rowFromDB[0],
"token": self._jwtProcessor.encode(
tokenId=rowFromDB[0],
expirationTime=expirationTime,
accountId=rowFromDB[2],
visibilityArea=rowFromDB[5],
),
"permissions": permissions,
"expiration_time": expirationTime,
"description": rowFromDB[4] or "",
"account_id": rowFromDB[2],
"visibility_area": rowFromDB[5],
"create_time": convertTimeToString(rowFromDB[6], DBContext.storageTime == "UTC"),
"last_update_time": convertTimeToString(rowFromDB[7], DBContext.storageTime == "UTC"),
}
[docs]
@dbExceptionWrap
async def getTokens(
self,
page: int,
pageSize: int,
permissionTargets: PermissionsTargets | None = None,
accountId: str | None = None,
createTimeLt: datetime | None = None,
createTimeGte: datetime | None = None,
) -> list[dict]:
"""
Get tokens
Args:
accountId: account id
page: page
pageSize: page size
permissionTargets: permission targets
createTimeLt: upper bound of token create time
createTimeGte: lower bound of token create time
Returns:
tokens list
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = and_(
models.Token.account_id == accountId if accountId is not None else True,
models.Token.create_time >= createTimeGte if createTimeGte is not None else True,
models.Token.create_time < createTimeLt if createTimeLt is not None else True,
)
selectSt = (
select(
[
models.Token.token_id,
models.Token.permissions,
models.Token.account_id,
models.Token.expiration_time,
models.Token.description,
models.Token.visibility_area,
models.Token.create_time,
models.Token.last_update_time,
]
)
.where(filters)
.order_by(models.Token.create_time.desc())
.offset((page - 1) * pageSize)
.limit(pageSize)
)
dbReply = await connection.fetchall(selectSt)
return [self.makeOutputToken(row, permissionTargets=permissionTargets) for row in dbReply]
@staticmethod
def _getTokenFilters(tokenId: str, accountId: Optional[str] = None) -> sql.elements.BinaryExpression:
"""Get filters for token request"""
return and_(
models.Token.account_id == accountId if accountId is not None else True,
models.Token.token_id == tokenId,
)
[docs]
@dbExceptionWrap
async def getToken(
self, tokenId: str, permissionTargets: int = PermissionsTargets.standard.value, accountId: str | None = None
) -> dict | None:
"""
Get token
Args:
tokenId: token id
permissionTargets: permission targets
accountId: account id
Returns:
token if exist otherwise None
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select(
[
models.Token.token_id,
models.Token.permissions,
models.Token.account_id,
models.Token.expiration_time,
models.Token.description,
models.Token.visibility_area,
models.Token.create_time,
models.Token.last_update_time,
]
).where(self._getTokenFilters(tokenId=tokenId, accountId=accountId))
dbReply = await connection.fetchone(selectSt)
if not dbReply:
raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False)
return self.makeOutputToken(dbReply, permissionTargets=permissionTargets)
[docs]
@dbExceptionWrap
async def deleteToken(self, tokenId: str, accountId: str) -> bool:
"""
Delete token by id
Args:
tokenId: token id
accountId: account id
Returns:
True if token was removed, False if token was not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = self._getTokenFilters(tokenId=tokenId, accountId=accountId)
# required because oracle does not support multiple-table criteria within delete
if self.dbType == "oracle":
selectSt = select([models.Token.token_id]).where(filters)
if not (await connection.scalar(selectSt)):
return False
deleteSt = delete(models.Token).where(models.Token.token_id == tokenId)
else:
deleteSt = delete(models.Token).where(filters)
return bool(await connection.execute(deleteSt))
[docs]
@dbExceptionWrap
async def replaceToken(self, tokenId: str, token: Token, accountId: Optional[str] = None) -> str:
"""
Replace existing token using specified token id
Args:
tokenId: token id
accountId: account id
token: token
Returns:
token
Raises:
VLException(Error.TokenNotFoundById.format(tokenId), 404, False) if token not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
if accountId is not None:
accountTypeSelectSt = select([models.Account.account_type]).where(
models.Account.account_id == accountId
)
accountType = await connection.scalar(accountTypeSelectSt)
else:
accountTypeSelectSt = select([models.Account.account_id, models.Account.account_type]).where(
and_(models.Account.account_id == models.Token.account_id, models.Token.token_id == tokenId)
)
dbReply = await connection.fetchone(accountTypeSelectSt)
if not dbReply:
raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False)
accountId, accountType = dbReply
if accountType == AccountType.user.value and token.visibilityArea.value == "all":
raise VLException(
Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False
)
values = dict(
token_id=tokenId,
permissions=token.permissions.asStr(),
expiration_time=token.expirationTime,
description=token.description,
visibility_area=token.visibilityArea.value,
last_update_time=self.currentDBTimestamp,
)
updateSt = (
update(models.Token)
.values(**values)
.where(and_(models.Token.token_id == tokenId, models.Token.account_id == accountId))
)
dbReply = await connection.execute(updateSt)
if not dbReply:
raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False)
return self._jwtProcessor.encode(
tokenId=tokenId,
expirationTime=(
convertTimeToString(token.expirationTime, self.storageTime == "UTC")
if token.expirationTime is not None
else None
),
accountId=accountId,
visibilityArea=token.visibilityArea.value,
)
[docs]
async def probe(self) -> bool:
"""Ensure provided config is valid. Create new connection. Can be used without initialization"""
checkState = await checkConnectionToDB(dbSetting=self.dbSettings, postfix="accounts", asyncCheck=True)
return checkState
[docs]
def getRuntimeChecks(self, _) -> list[tuple[str, Awaitable]]:
"""Checks for healthcheck."""
return [("accounts_db", checkSql(self.adaptor)), ("accounts_db_migration", checkSqlMigration(self.adaptor))]