import ujson
from datetime import datetime
from sqlalchemy.engine import RowProxy
from sqlalchemy.sql.base import ImmutableColumnCollection
from typing import Tuple, Dict, Any, Optional, List, Iterable, Union
from uuid import uuid4
from sqlalchemy import insert, and_, update, delete, func, select
from sqlalchemy.orm import aliased
from app.global_vars.enums import HandlerTarget
from vlutils.helpers import convertTimeToString
from crutches_on_wheels.db.base_context import BaseDBContext
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.utils.db_functions import dbExceptionWrap as exceptionWrap
from db import models
[docs]class DBContext(BaseDBContext):
"""Handlers DB context."""
[docs] def loadHandlerFromRow(self, handlerRow: RowProxy, selectColumns: ImmutableColumnCollection) -> dict[str, Any]:
"""
Load handler as dict from raw row.
Args:
handlerRow: full row from db
selectColumns: columns from sql query
Returns:
handler as dict
"""
result = {}
for column in selectColumns:
snakeCaseTarget = column.key
target = HandlerTarget(snakeCaseTarget)
if target == HandlerTarget.description:
result[snakeCaseTarget] = handlerRow[snakeCaseTarget] or ""
elif target == HandlerTarget.policies:
result[snakeCaseTarget] = (
ujson.loads(handlerRow[snakeCaseTarget]) if handlerRow[snakeCaseTarget] is not None else None
)
elif target in (HandlerTarget.createTime, HandlerTarget.lastUpdateTime):
result[snakeCaseTarget] = convertTimeToString(handlerRow[snakeCaseTarget], self.storageTime == "UTC")
else:
result[snakeCaseTarget] = handlerRow[snakeCaseTarget]
return result
[docs] def loadVerifierFromRow(self, verifierRow: Tuple) -> Dict[str, Any]:
"""
Load verifier as dict from raw row.
Args:
verifierRow: full row from db
Returns:
verifier as dict
"""
verifier = dict(zip(models.Verifier.getColumnNames(), verifierRow))
if verifier["description"] is None:
verifier["description"] = ""
verifier["policies"] = ujson.loads(verifier["policies"]) if verifier["policies"] is not None else None
verifier["create_time"] = convertTimeToString(verifier["create_time"], self.storageTime == "UTC")
verifier["last_update_time"] = convertTimeToString(verifier["last_update_time"], self.storageTime == "UTC")
return verifier
[docs] @exceptionWrap
async def createVerifier(self, policies: dict, accountId: str, description: str = "") -> str:
"""
Create new verifier.
Args:
policies: verifier policies
accountId: account id
description: user verifier description
Returns:
verifier id
"""
async with DBContext.adaptor.connection(self.logger) as connection:
verifierId = str(uuid4())
insertSt = insert(models.Verifier).values(
description=description, account_id=accountId, policies=ujson.dumps(policies), verifier_id=verifierId
)
await connection.execute(insertSt)
return verifierId
[docs] @exceptionWrap
async def getVerifier(self, verifierId: str, accountId: Optional[str] = None) -> Dict:
"""
Get verifier by id
Args:
verifierId: verifier id
accountId: verifier account id
Returns:
deserialized dict with verifier
Raises:
VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False): if verifier not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
query = select([models.Verifier]).where(
and_(
models.Verifier.verifier_id == verifierId,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
verifierRow = await connection.fetchone(query)
if verifierRow is None:
raise VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False)
return self.loadVerifierFromRow(verifierRow)
[docs] @exceptionWrap
async def putVerifier(
self, verifierId: str, policies: dict, accountId: str, description: str = ""
) -> Optional[int]:
"""
Replace verifier by id.
Args:
verifierId: verifier id
policies: verifier policies
accountId: account id
description: user verifier description
Warnings:
function does not create verifier!
Return:
verifier current version if verifier replaced
"""
async with DBContext.adaptor.connection(self.logger) as connection:
updateSt = (
update(models.Verifier)
.where(and_(models.Verifier.verifier_id == verifierId, models.Verifier.account_id == accountId))
.values(description=description, policies=ujson.dumps(policies), version=models.Verifier.version + 1)
).returning(models.Verifier.version)
return await connection.scalar(updateSt)
[docs] @exceptionWrap
async def deleteVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool:
"""
Delete verifier by id
Args:
verifierId: verifier id
accountId: account id
Return:
True if verifier deleted, otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
deleteSt = delete(models.Verifier).where(
and_(
models.Verifier.verifier_id == verifierId,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
return bool(await connection.execute(deleteSt))
[docs] @exceptionWrap
async def checkVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool:
"""
Check verifier existence
Args:
verifierId: verifier id
accountId: account id
Returns:
True - if verifier exists, otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
query = select([models.Verifier.verifier_id]).where(
and_(
models.Verifier.verifier_id == verifierId,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
verifierRow = await connection.fetchone(query)
return False if verifierRow is None else True
[docs] @exceptionWrap
async def getVerifierCount(self, accountId: Optional[str] = None, description: Optional[str] = None):
"""
Get verifier count
Args:
accountId: account id
description: verifier description
Returns:
verifier count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Verifier]).where(
and_(
models.Verifier.description.like("%{}%".format(description)) if description is not None else True,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
selectSt = select([func.count()]).select_from(aliased(selectSt))
count = await connection.scalar(selectSt)
return count
[docs] @exceptionWrap
async def getVerifiers(
self, accountId: Optional[str] = None, description: Optional[str] = None, page: int = 1, pageSize: int = 100,
) -> List[dict]:
"""
Get verifiers by filters
Args:
accountId: account id
description: verifier description
page: page
pageSize: page size
Returns:
list of verifiers
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Verifier]).where(
and_(
models.Verifier.description.like("%{}%".format(description)) if description is not None else True,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
selectSt = (
selectSt.order_by(models.Verifier.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize)
)
verifierRows = await connection.fetchall(selectSt)
return [self.loadVerifierFromRow(handlerRow) for handlerRow in verifierRows]
[docs] @exceptionWrap
async def createHandler(
self, policies: dict, accountId: str, description: str = "", isDynamic: bool = False
) -> str:
"""
Create new handler.
Args:
policies: set handler policies
accountId: account id
description: user handler description
isDynamic: dynamic handler flag
Returns:
handler id
"""
if isDynamic:
rulesStr = None
else:
rulesStr = ujson.dumps(policies)
async with DBContext.adaptor.connection(self.logger) as connection:
handlerId = str(uuid4())
insertSt = insert(models.Handler).values(
description=description,
account_id=accountId,
policies=rulesStr,
handler_id=handlerId,
is_dynamic=isDynamic,
)
await connection.execute(insertSt)
return handlerId
[docs] @exceptionWrap
async def getHandler(
self,
handlerId: str,
targets: Union[Iterable[HandlerTarget]] = tuple(HandlerTarget),
accountId: Optional[str] = None,
) -> dict:
"""
Get handler by id
Args:
handlerId: handler id
accountId: handler account id
targets: handler targets
Returns:
deserialize dict with handler
Raises:
VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False): if handler not found
"""
selectSt = self._genGetHandlersQuery(targets=targets, handlerIds=[handlerId], accountId=accountId)
async with DBContext.adaptor.connection(self.logger) as connection:
handlerRow = await connection.fetchone(selectSt)
if handlerRow is None:
raise VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False)
return self.loadHandlerFromRow(handlerRow, selectSt.columns)
[docs] @exceptionWrap
async def getAbsentHandlersIds(self, handlerIds: List[str]) -> List[str]:
"""Given list of handlerIds, return those handlerIds that were removed from database.
Used for removal of items from cache by cache invalidator.
Args:
handlerIds: handler ids
Returns:
List with handlersIds that were removed.
"""
selectSt = self._genGetHandlersQuery(targets=[HandlerTarget.handlerId], handlerIds=handlerIds)
async with DBContext.adaptor.connection(self.logger) as connection:
result = await connection.fetchall(selectSt)
return list(set(handlerIds) - set([row["handler_id"] for row in result]))
[docs] @exceptionWrap
async def getUpdatedHandlers(
self, handlerIds: List[str], lastUpdateTimeGte: datetime, targets: Optional[list[HandlerTarget]] = None
) -> List[Dict]:
"""Get updated handlers by ids. Used for cache invalidation.
Args:
handlerIds: handler ids
lastUpdateTimeGte: lower bound of handler update time
targets: handler targets
Returns:
List with handlers
"""
selectSt = self._genGetHandlersQuery(
handlerIds=handlerIds, lastUpdateTimeGte=lastUpdateTimeGte, targets=targets
)
async with DBContext.adaptor.connection(self.logger) as connection:
handlerRows = await connection.fetchall(selectSt)
return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows]
[docs] @exceptionWrap
async def putHandler(
self, handlerId: str, policies: dict, accountId: str, description: str = "", isDynamic: bool = False
) -> bool:
"""
Put handler instead old handler.
Args:
handlerId: handler
policies: set handler policies
accountId: account id
description: user handler description
isDynamic: dynamic handler flag
Warnings:
function does not create handler!
Return:
True if handler exist otherwise false
"""
if isDynamic:
rulesStr = None
else:
rulesStr = ujson.dumps(policies)
async with DBContext.adaptor.connection(self.logger) as connection:
updateSt = (
update(models.Handler)
.where(and_(models.Handler.handler_id == handlerId, models.Handler.account_id == accountId))
.values(
description=description,
policies=rulesStr,
is_dynamic=isDynamic,
last_update_time=self.currentDBTimestamp,
)
)
return bool(await connection.execute(updateSt))
[docs] @exceptionWrap
async def deleteHandler(self, handlerId: str, accountId: Optional[str] = None) -> bool:
"""
Delete handler by id
Args:
handlerId: handler id
accountId: account id of the handler
Return:
True if handler exist otherwise false
"""
async with DBContext.adaptor.connection(self.logger) as connection:
deleteSt = delete(models.Handler).where(
and_(
models.Handler.handler_id == handlerId,
models.Handler.account_id == accountId if accountId is not None else True,
)
)
return bool(await connection.execute(deleteSt))
[docs] @exceptionWrap
async def doesHandlerExist(self, handlerId: str, accountId: Optional[str] = None) -> bool:
"""
Check a account handler with id=handlerId existence
Args:
handlerId: handler id
accountId: handler account id
Returns:
true - if handler is exist otherwise false
"""
async with DBContext.adaptor.connection(self.logger) as connection:
query = select([models.Handler.handler_id]).where(
and_(
models.Handler.handler_id == handlerId,
models.Handler.account_id == accountId if accountId is not None else True,
)
)
handlerRow = await connection.fetchone(query)
return False if handlerRow is None else True
@staticmethod
def _genGetHandlersQuery(
targets: Iterable[HandlerTarget] = tuple(HandlerTarget),
accountId: Optional[str] = None,
description: Optional[str] = None,
isDynamic: Optional[bool] = None,
handlerIds: Optional[list[str]] = None,
lastUpdateTimeGte: Optional[datetime] = None,
) -> select:
"""
Get handlers by filters
Args:
targets: handler targets
accountId: handler account id
description: handler description
isDynamic: whether to get only dynamic (non-dynamic) handlers
handlerIds: handlers ids
lastUpdateTimeGte: lower bound of handler update time
Returns:
generated "select" statement
"""
selectSt = select([getattr(models.Handler, target.value) for target in targets]).where(
and_(
models.Handler.handler_id.in_(handlerIds) if handlerIds is not None else True,
models.Handler.description.like("%{}%".format(description)) if description is not None else True,
models.Handler.account_id == accountId if accountId is not None else True,
models.Handler.is_dynamic == isDynamic if isDynamic is not None else True,
models.Handler.last_update_time >= lastUpdateTimeGte if lastUpdateTimeGte is not None else True,
)
)
return selectSt
[docs] @exceptionWrap
async def getHandlers(
self,
accountId: Optional[str] = None,
description: Optional[str] = None,
isDynamic: Optional[bool] = None,
page: int = 1,
pageSize: int = 100,
) -> List[dict]:
"""
Get handlers by filters
Args:
accountId: handler account id
description: handler description
isDynamic: whether to get only dynamic (non-dynamic) handlers
page: page
pageSize: page size
Returns:
list of deserialize handlers
"""
selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, isDynamic=isDynamic)
selectSt = selectSt.order_by(models.Handler.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize)
async with DBContext.adaptor.connection(self.logger) as connection:
handlerRows = await connection.fetchall(selectSt)
return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows]
[docs] @exceptionWrap
async def getHandlerCount(
self, accountId: Optional[str] = None, description: Optional[str] = None, isDynamic: Optional[bool] = None
):
"""
Get handler count
Args:
accountId: handler account id
description: handler description
isDynamic: whether to get only dynamic (non-dynamic) handlers
Returns:
handler count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, isDynamic=isDynamic)
selectSt = select([func.count()]).select_from(aliased(selectSt))
count = (await connection.fetchone(selectSt))[0]
return count