from typing import Optional
from uuid import uuid4
import jwt
from asyncpg import ForeignKeyViolationError, UniqueViolationError
from passlib.hash import pbkdf2_sha256
from sqlalchemy import and_, delete, func, insert, select, sql, update
from sqlalchemy.exc import IntegrityError
from vlutils.helpers import convertTimeToString
from classes.enums import AccountType
from classes.functions import getCurrentDatetime
from classes.schemas.account import Account, AccountForPatch
from classes.schemas.token import Permissions, Token
from configs.config import JWT_SECRET_STRING
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils.db_functions import dbExceptionWrap
from db.accounts_db_tools.models import accounts_models as models
[docs]
class DBContext(BaseDBContext):
"""Accounts DB context."""
[docs]
@dbExceptionWrap
async def createAccount(self, account: Account, accountId: Optional[str] = None) -> str:
"""
Create new account
Args:
account: account to create
accountId: account id to create
Returns:
unique account id
Raises:
VLException(Error.AccountAlreadyExist.format(account.login), 409, False) if account with the same login
already exists
"""
accountIdToCreate = accountId or str(uuid4())
async with DBContext.adaptor.connection(self.logger) as connection:
checkDuplicateLoginSt = select([models.Account.account_id]).where(models.Account.login == account.login)
if await connection.scalar(checkDuplicateLoginSt):
raise VLException(Error.AccountWithLoginAlreadyExists.format(account.login), 409, False)
if accountId is not None:
selectSt = select([models.Account.account_id]).where(models.Account.account_id == accountIdToCreate)
isAccountExist = await connection.scalar(selectSt)
if isAccountExist:
raise VLException(Error.AccountWithIdAlreadyExists.format(accountIdToCreate), 409, False)
insertSt = insert(models.Account).values(
account_id=accountIdToCreate,
login=account.login,
password=pbkdf2_sha256.hash(account.password),
account_type=account.accountType.value,
description=account.description,
)
await connection.execute(insertSt)
return accountIdToCreate
[docs]
@dbExceptionWrap
async def getAccountsCount(self) -> int:
"""
Get accounts count
Returns:
accounts count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
countSt = select([func.count()]).select_from(models.Account)
return await connection.scalar(countSt)
[docs]
@dbExceptionWrap
async def getAccounts(
self, page: int, pageSize: int, login: Optional[str] = None, accountType: Optional[AccountType] = None
) -> tuple[list[dict], int]:
"""
Get accounts with pagination
Args:
page: page
pageSize: page size
login: account login
accountType: account type
Returns:
accounts
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = and_(
models.Account.account_type == accountType.value if accountType is not None else True,
models.Account.login == login if login is not None else True,
)
selectSt = (
select(
[
models.Account.account_id,
models.Account.login,
models.Account.account_type,
models.Account.description,
]
)
.where(filters)
.offset((page - 1) * pageSize)
.limit(pageSize)
)
countSt = select([func.count()]).select_from(models.Account)
dbReply, count = await connection.fetchall(selectSt), await connection.scalar(countSt)
return [
dict(account_id=row[0], login=row[1], account_type=row[2], description=row[3] or "") for row in dbReply
], count
[docs]
@dbExceptionWrap
async def getAccount(self, accountId: str) -> Optional[dict]:
"""
Get account by account id
Args:
accountId: id of account
Returns:
account if exists otherwise None
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Account.login, models.Account.account_type, models.Account.description]).where(
models.Account.account_id == accountId
)
dbReply = await connection.fetchone(selectSt)
if not dbReply:
return None
return dict(account_id=accountId, login=dbReply[0], account_type=dbReply[1], description=dbReply[2] or "")
[docs]
@dbExceptionWrap
async def patchAccount(self, accountId: str, accountOverride: AccountForPatch) -> bool:
"""
Patch account by account id
Args:
accountId: id of account
accountOverride: account for patch model
Returns:
True if account was updated otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
checkDuplicateLoginSt = select([models.Account.account_id]).where(
and_(models.Account.login == accountOverride.login, models.Account.account_id != accountId)
)
if await connection.scalar(checkDuplicateLoginSt):
raise VLException(Error.AccountWithLoginAlreadyExists.format(accountOverride.login), 409, False)
if accountOverride.accountType is not None and accountOverride.accountType != AccountType.admin:
selectSt = select([models.Account.account_id]).where(
and_(models.Account.account_id == accountId, models.Account.account_type == AccountType.admin.value)
)
if await connection.scalar(selectSt):
raise VLException(Error.AccountTypeChangeForbidden.format(accountId), 403, False)
if accountOverride.password is not None:
accountOverride.password = pbkdf2_sha256.hash(accountOverride.password)
updateSt = (
update(models.Account).where(models.Account.account_id == accountId).values(**accountOverride.asDict())
)
return bool(await connection.execute(updateSt))
[docs]
@dbExceptionWrap
async def deleteAccount(self, accountId: str) -> bool:
"""
Delete account by account id
Args:
accountId: id of account
Returns:
True if account was deleted otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
deleteSt = delete(models.Account).where(models.Account.account_id == accountId)
return bool(await connection.execute(deleteSt))
[docs]
@dbExceptionWrap
async def verifyAccountByLoginPassword(self, login: str, password: str) -> tuple[str, str]:
"""
Verify account by login and password
Args:
login: login
password: password
Returns:
account type and account id
Raises:
VLException(Error.AccountLoginPasswordIncorrect, 400, isCriticalError=False) if account not found or
password doesn't match
"""
exception = VLException(Error.AccountLoginPasswordIncorrect, 401, isCriticalError=False)
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Account.password, models.Account.account_type, models.Account.account_id]).where(
models.Account.login == login
)
dbReply = await connection.fetchone(selectSt)
if not dbReply:
raise exception
hashedPassword, accountType, accountId = dbReply
if not pbkdf2_sha256.verify(password, hashedPassword):
raise exception
return accountType, accountId
[docs]
@dbExceptionWrap
async def verifyAccountByAccountId(self, accountId: str) -> str:
"""
Verify account by account id
Args:
accountId: accountId
Returns:
account type if account found otherwise None
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId)
dbReply = await connection.scalar(selectSt)
if not dbReply:
raise VLException(Error.AccountNotFound.format(accountId), 401, isCriticalError=False)
return dbReply
[docs]
@dbExceptionWrap
async def verifyToken(self, token: str) -> Optional[tuple[str, dict]]:
"""
Verify account by JWT token
Args:
token: JWT token
Returns:
tuple with account type and permissions if account found otherwise None
Raises:
VLException(Error.JWTTokenNotFound, 400, isCriticalError=False) if token not found or unexpected format
VLException(Error.TokenExpired, 400, False) if token expired
"""
try:
tokenData = jwt.decode(token, key=JWT_SECRET_STRING, algorithms=["HS256"])
except jwt.PyJWTError:
raise VLException(Error.CorruptedToken, 401, False)
if sorted(list(tokenData)) != ["accountId", "expirationTime", "tokenId", "visibilityArea"]:
raise VLException(Error.CorruptedToken, 401, False)
async with DBContext.adaptor.connection(self.logger) as connection:
selectTokenSt = select(
[
models.Account.account_type,
models.Token.expiration_time,
models.Token.permissions,
]
).where(
and_(
models.Token.token_id == tokenData["tokenId"],
models.Token.account_id == tokenData["accountId"],
models.Account.account_id == models.Token.account_id,
)
)
dbReply = await connection.fetchone(selectTokenSt)
if not dbReply:
raise VLException(Error.JWTTokenNotFound, 401, isCriticalError=False)
accountType, expirationTime, tokenPermissions = dbReply
if expirationTime and expirationTime <= getCurrentDatetime(self.storageTime == "UTC"):
raise VLException(Error.TokenExpired, 401, False)
return accountType, Permissions.fromStr(tokenPermissions).asDict()
[docs]
@dbExceptionWrap
async def createToken(self, token: Token, accountId: str) -> tuple[str, str]:
"""
Create new token
Args:
token: token to create
accountId: account id
Returns:
tuple with unique token and its' id
Raises:
VLException(Error.AccountNotFound.format(accountId), 400, False) if specified account id not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
accountTypeSelectSt = select([models.Account.account_type]).where(models.Account.account_id == accountId)
accountType = await connection.scalar(accountTypeSelectSt)
if accountType == AccountType.user.value and token.visibilityArea == "all":
raise VLException(
Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False
)
tokenId = str(uuid4())
insertSt = insert(models.Token).values(
token_id=tokenId,
account_id=accountId,
permissions=token.permissions.asStr(),
expiration_time=token.expirationTime,
description=token.description,
visibility_area=token.visibilityArea,
)
try:
await connection.execute(insertSt)
except (IntegrityError, UniqueViolationError, ForeignKeyViolationError):
raise VLException(Error.AccountNotFound.format(accountId), 400, False)
expirationTime = (
convertTimeToString(token.expirationTime, self.storageTime == "UTC")
if token.expirationTime is not None
else None
)
return tokenId, jwt.encode(
dict(
tokenId=tokenId,
expirationTime=expirationTime,
accountId=accountId,
visibilityArea=token.visibilityArea,
),
key=JWT_SECRET_STRING,
)
[docs]
def makeOutputToken(self, rowFromDB: tuple) -> dict:
"""
Make dict with token from row from db
"""
permissions = Permissions.fromStr(rowFromDB[1]).asDict()
expirationTime = (
convertTimeToString(rowFromDB[3], self.storageTime == "UTC") if rowFromDB[3] is not None else None
)
return {
"token_id": rowFromDB[0],
"token": jwt.encode(
dict(
tokenId=rowFromDB[0],
expirationTime=expirationTime,
accountId=rowFromDB[2],
visibilityArea=rowFromDB[5],
),
key=JWT_SECRET_STRING,
),
"permissions": permissions,
"expiration_time": expirationTime,
"description": rowFromDB[4] or "",
"account_id": rowFromDB[2],
"visibility_area": rowFromDB[5],
}
[docs]
@dbExceptionWrap
async def getTokens(self, page: int, pageSize: int, accountId: Optional[str] = None) -> list[dict]:
"""
Get tokens
Args:
accountId: account id
page: page
pageSize: page size
Returns:
tokens list
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = (
select(
[
models.Token.token_id,
models.Token.permissions,
models.Token.account_id,
models.Token.expiration_time,
models.Token.description,
models.Token.visibility_area,
]
)
.where(models.Token.account_id == accountId if accountId is not None else True)
.offset((page - 1) * pageSize)
.limit(pageSize)
)
dbReply = await connection.fetchall(selectSt)
return [self.makeOutputToken(row) for row in dbReply]
@staticmethod
def _getTokenFilters(tokenId: str, accountId: Optional[str] = None) -> sql.elements.BinaryExpression:
"""Get filters for token request"""
return and_(
models.Token.account_id == accountId if accountId is not None else True,
models.Token.token_id == tokenId,
)
[docs]
@dbExceptionWrap
async def getToken(self, tokenId: str, accountId: Optional[str] = None) -> Optional[dict]:
"""
Get token
Args:
tokenId: token id
accountId: account id
Returns:
token if exist otherwise None
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select(
[
models.Token.token_id,
models.Token.permissions,
models.Token.account_id,
models.Token.expiration_time,
models.Token.description,
models.Token.visibility_area,
]
).where(self._getTokenFilters(tokenId=tokenId, accountId=accountId))
dbReply = await connection.fetchone(selectSt)
if not dbReply:
raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False)
return self.makeOutputToken(dbReply)
[docs]
@dbExceptionWrap
async def deleteToken(self, tokenId: str, accountId: str) -> bool:
"""
Delete token by id
Args:
tokenId: token id
accountId: account id
Returns:
True if token was removed, False if token was not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = self._getTokenFilters(tokenId=tokenId, accountId=accountId)
# required because oracle does not support multiple-table criteria within delete
if self.dbType == "oracle":
selectSt = select([models.Token.token_id]).where(filters)
if not (await connection.scalar(selectSt)):
return False
deleteSt = delete(models.Token).where(models.Token.token_id == tokenId)
else:
deleteSt = delete(models.Token).where(filters)
return bool(await connection.execute(deleteSt))
[docs]
@dbExceptionWrap
async def replaceToken(self, tokenId: str, token: Token, accountId: Optional[str] = None) -> str:
"""
Replace existing token using specified token id
Args:
tokenId: token id
accountId: account id
token: token
Returns:
token
Raises:
VLException(Error.TokenNotFoundById.format(tokenId), 404, False) if token not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
if accountId is not None:
accountTypeSelectSt = select([models.Account.account_type]).where(
models.Account.account_id == accountId
)
accountType = await connection.scalar(accountTypeSelectSt)
else:
accountTypeSelectSt = select([models.Account.account_id, models.Account.account_type]).where(
and_(models.Account.account_id == models.Token.account_id, models.Token.token_id == tokenId)
)
dbReply = await connection.fetchone(accountTypeSelectSt)
if not dbReply:
raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False)
accountId, accountType = dbReply
if accountType == AccountType.user.value and token.visibilityArea == "all":
raise VLException(
Error.IncorrectAccountTokenPermissions.format("visibility_area", AccountType.user.value), 400, False
)
values = dict(
token_id=tokenId,
permissions=token.permissions.asStr(),
expiration_time=token.expirationTime,
description=token.description,
visibility_area=token.visibilityArea,
)
updateSt = (
update(models.Token)
.values(**values)
.where(and_(models.Token.token_id == tokenId, models.Token.account_id == accountId))
)
dbReply = await connection.execute(updateSt)
if not dbReply:
raise VLException(Error.TokenNotFoundById.format(tokenId), 404, False)
return jwt.encode(
dict(
tokenId=tokenId,
expirationTime=(
convertTimeToString(token.expirationTime, self.storageTime == "UTC")
if token.expirationTime is not None
else None
),
accountId=accountId,
visibilityArea=token.visibilityArea,
),
key=JWT_SECRET_STRING,
)