Source code for luna_handlers.db.context

from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import uuid4

import ujson
from sqlalchemy import and_, delete, func, insert, select, update
from sqlalchemy.engine import RowProxy
from sqlalchemy.orm import aliased
from sqlalchemy.sql.base import ImmutableColumnCollection
from vlutils.helpers import convertTimeToString

from app.global_vars.enums import HandlerTarget
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.enums.handlers import HandlerType
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils.db_functions import dbExceptionWrap as exceptionWrap
from db.handlers_db_tools.models import handlers_models as models


[docs] class DBContext(BaseDBContext): """Handlers DB context."""
[docs] def loadHandlerFromRow(self, handlerRow: RowProxy, selectColumns: ImmutableColumnCollection) -> dict[str, Any]: """ Load handler as dict from raw row. Args: handlerRow: full row from db selectColumns: columns from sql query Returns: handler as dict """ result = {} for column in selectColumns: snakeCaseTarget = column.key target = HandlerTarget(snakeCaseTarget) if target == HandlerTarget.description: result[snakeCaseTarget] = handlerRow[snakeCaseTarget] or "" elif target == HandlerTarget.policies: result[snakeCaseTarget] = ( ujson.loads(handlerRow[snakeCaseTarget]) if handlerRow[snakeCaseTarget] is not None else None ) elif target in (HandlerTarget.createTime, HandlerTarget.lastUpdateTime): result[snakeCaseTarget] = convertTimeToString(handlerRow[snakeCaseTarget], self.storageTime == "UTC") elif target == HandlerTarget.lambdaId: if handlerRow[snakeCaseTarget] is None: continue result[snakeCaseTarget] = handlerRow[snakeCaseTarget] else: result[snakeCaseTarget] = handlerRow[snakeCaseTarget] return result
[docs] def loadVerifierFromRow(self, verifierRow: Tuple) -> Dict[str, Any]: """ Load verifier as dict from raw row. Args: verifierRow: full row from db Returns: verifier as dict """ verifier = dict(zip(models.Verifier.getColumnNames(), verifierRow)) if verifier["description"] is None: verifier["description"] = "" verifier["policies"] = ujson.loads(verifier["policies"]) if verifier["policies"] is not None else None verifier["create_time"] = convertTimeToString(verifier["create_time"], self.storageTime == "UTC") verifier["last_update_time"] = convertTimeToString(verifier["last_update_time"], self.storageTime == "UTC") return verifier
[docs] @exceptionWrap async def createVerifier(self, policies: dict, accountId: str, description: str = "") -> str: """ Create new verifier. Args: policies: verifier policies accountId: account id description: user verifier description Returns: verifier id """ async with DBContext.adaptor.connection(self.logger) as connection: verifierId = str(uuid4()) insertSt = insert(models.Verifier).values( description=description, account_id=accountId, policies=ujson.dumps(policies), verifier_id=verifierId ) await connection.execute(insertSt) return verifierId
[docs] @exceptionWrap async def getVerifier(self, verifierId: str, accountId: Optional[str] = None) -> Dict: """ Get verifier by id Args: verifierId: verifier id accountId: verifier account id Returns: deserialized dict with verifier Raises: VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False): if verifier not found """ async with DBContext.adaptor.connection(self.logger) as connection: query = select([models.Verifier]).where( and_( models.Verifier.verifier_id == verifierId, models.Verifier.account_id == accountId if accountId is not None else True, ) ) verifierRow = await connection.fetchone(query) if verifierRow is None: raise VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False) return self.loadVerifierFromRow(verifierRow)
[docs] @exceptionWrap async def putVerifier( self, verifierId: str, policies: dict, accountId: str, description: str = "" ) -> Optional[int]: """ Replace verifier by id. Args: verifierId: verifier id policies: verifier policies accountId: account id description: user verifier description Warnings: function does not create verifier! Return: verifier current version if verifier replaced """ async with DBContext.adaptor.connection(self.logger) as connection: updateSt = ( update(models.Verifier) .where(and_(models.Verifier.verifier_id == verifierId, models.Verifier.account_id == accountId)) .values(description=description, policies=ujson.dumps(policies), version=models.Verifier.version + 1) ).returning(models.Verifier.version) return await connection.scalar(updateSt)
[docs] @exceptionWrap async def deleteVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool: """ Delete verifier by id Args: verifierId: verifier id accountId: account id Return: True if verifier deleted, otherwise False """ async with DBContext.adaptor.connection(self.logger) as connection: deleteSt = delete(models.Verifier).where( and_( models.Verifier.verifier_id == verifierId, models.Verifier.account_id == accountId if accountId is not None else True, ) ) return bool(await connection.execute(deleteSt))
[docs] @exceptionWrap async def checkVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool: """ Check verifier existence Args: verifierId: verifier id accountId: account id Returns: True - if verifier exists, otherwise False """ async with DBContext.adaptor.connection(self.logger) as connection: query = select([models.Verifier.verifier_id]).where( and_( models.Verifier.verifier_id == verifierId, models.Verifier.account_id == accountId if accountId is not None else True, ) ) verifierRow = await connection.fetchone(query) return False if verifierRow is None else True
[docs] @exceptionWrap async def getVerifierCount(self, accountId: Optional[str] = None, description: Optional[str] = None): """ Get verifier count Args: accountId: account id description: verifier description Returns: verifier count """ async with DBContext.adaptor.connection(self.logger) as connection: selectSt = select([models.Verifier]).where( and_( models.Verifier.description.like("%{}%".format(description)) if description is not None else True, models.Verifier.account_id == accountId if accountId is not None else True, ) ) selectSt = select([func.count()]).select_from(aliased(selectSt)) count = await connection.scalar(selectSt) return count
[docs] @exceptionWrap async def getVerifiers( self, accountId: Optional[str] = None, description: Optional[str] = None, page: int = 1, pageSize: int = 100, ) -> List[dict]: """ Get verifiers by filters Args: accountId: account id description: verifier description page: page pageSize: page size Returns: list of verifiers """ async with DBContext.adaptor.connection(self.logger) as connection: selectSt = select([models.Verifier]).where( and_( models.Verifier.description.like("%{}%".format(description)) if description is not None else True, models.Verifier.account_id == accountId if accountId is not None else True, ) ) selectSt = ( selectSt.order_by(models.Verifier.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize) ) verifierRows = await connection.fetchall(selectSt) return [self.loadVerifierFromRow(handlerRow) for handlerRow in verifierRows]
[docs] @exceptionWrap async def createHandler( self, policies: dict, accountId: str, description: str = "", handlerType: HandlerType = HandlerType.static, lambdaId: str | None = None, ) -> str: """ Create new handler. Args: policies: set handler policies accountId: account id description: user handler description handlerType: handler type lambdaId: id of lambda handler Returns: handler id """ if handlerType in (HandlerType.dynamic, HandlerType.lambdaHandler): rulesStr = None else: rulesStr = ujson.dumps(policies) async with DBContext.adaptor.connection(self.logger) as connection: handlerId = str(uuid4()) insertSt = insert(models.Handler).values( description=description, account_id=accountId, policies=rulesStr, handler_id=handlerId, handler_type=handlerType.value, lambda_id=str(lambdaId) if lambdaId is not None else None, ) await connection.execute(insertSt) return handlerId
[docs] @exceptionWrap async def getHandler( self, handlerId: str, targets: Union[Iterable[HandlerTarget]] = tuple(HandlerTarget), accountId: Optional[str] = None, ) -> dict: """ Get handler by id Args: handlerId: handler id accountId: handler account id targets: handler targets Returns: deserialize dict with handler Raises: VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False): if handler not found """ selectSt = self._genGetHandlersQuery(targets=targets, handlerIds=[handlerId], accountId=accountId) async with DBContext.adaptor.connection(self.logger) as connection: handlerRow = await connection.fetchone(selectSt) if handlerRow is None: raise VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False) return self.loadHandlerFromRow(handlerRow, selectSt.columns)
[docs] @exceptionWrap async def getAbsentHandlersIds(self, handlerIds: List[str]) -> List[str]: """Given list of handlerIds, return those handlerIds that were removed from database. Used for removal of items from cache by cache invalidator. Args: handlerIds: handler ids Returns: List with handlersIds that were removed. """ selectSt = self._genGetHandlersQuery(targets=[HandlerTarget.handlerId], handlerIds=handlerIds) async with DBContext.adaptor.connection(self.logger) as connection: result = await connection.fetchall(selectSt) return list(set(handlerIds) - set([row["handler_id"] for row in result]))
[docs] @exceptionWrap async def getUpdatedHandlers( self, handlerIds: List[str], lastUpdateTimeGte: datetime, targets: Optional[list[HandlerTarget]] = None ) -> List[Dict]: """Get updated handlers by ids. Used for cache invalidation. Args: handlerIds: handler ids lastUpdateTimeGte: lower bound of handler update time targets: handler targets Returns: List with handlers """ selectSt = self._genGetHandlersQuery( handlerIds=handlerIds, lastUpdateTimeGte=lastUpdateTimeGte, targets=targets ) async with DBContext.adaptor.connection(self.logger) as connection: handlerRows = await connection.fetchall(selectSt) return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows]
[docs] @exceptionWrap async def putHandler( self, handlerId: str, policies: dict, accountId: str, description: str = "", handlerType: int = HandlerType.static.value, lambdaId: str | None = None, ) -> bool: """ Put handler instead old handler. Args: handlerId: handler policies: set handler policies accountId: account id description: user handler description handlerType: type of handler lambdaId: id of lambda handler Warnings: function does not create handler! Returns: True if handler exist otherwise false """ if handlerType in (HandlerType.dynamic.value, HandlerType.lambdaHandler.value): rulesStr = None else: rulesStr = ujson.dumps(policies) async with DBContext.adaptor.connection(self.logger) as connection: updateSt = ( update(models.Handler) .where(and_(models.Handler.handler_id == handlerId, models.Handler.account_id == accountId)) .values( description=description, policies=rulesStr, handler_type=handlerType, lambda_id=lambdaId, last_update_time=self.currentDBTimestamp, ) ) return bool(await connection.execute(updateSt))
[docs] @exceptionWrap async def deleteHandler(self, handlerId: str, accountId: Optional[str] = None) -> bool: """ Delete handler by id Args: handlerId: handler id accountId: account id of the handler Returns: True if handler exist otherwise false """ async with DBContext.adaptor.connection(self.logger) as connection: deleteSt = delete(models.Handler).where( and_( models.Handler.handler_id == handlerId, models.Handler.account_id == accountId if accountId is not None else True, ) ) return bool(await connection.execute(deleteSt))
[docs] @exceptionWrap async def doesHandlerExist(self, handlerId: str, accountId: Optional[str] = None) -> bool: """ Check a account handler with id=handlerId existence Args: handlerId: handler id accountId: handler account id Returns: true - if handler is exist otherwise false """ async with DBContext.adaptor.connection(self.logger) as connection: query = select([models.Handler.handler_id]).where( and_( models.Handler.handler_id == handlerId, models.Handler.account_id == accountId if accountId is not None else True, ) ) handlerRow = await connection.fetchone(query) return False if handlerRow is None else True
@staticmethod def _genGetHandlersQuery( targets: Iterable[HandlerTarget] = tuple(HandlerTarget), accountId: Optional[str] = None, description: Optional[str] = None, handlerType: int | None = None, handlerIds: Optional[list[str]] = None, lastUpdateTimeGte: Optional[datetime] = None, ) -> select: """ Get handlers by filters Args: targets: handler targets accountId: handler account id description: handler description handlerType: type of handler handlerIds: handlers ids lastUpdateTimeGte: lower bound of handler update time Returns: generated "select" statement """ selectSt = select([getattr(models.Handler, target.value) for target in targets]).where( and_( models.Handler.handler_id.in_(handlerIds) if handlerIds is not None else True, models.Handler.description.like("%{}%".format(description)) if description is not None else True, models.Handler.account_id == accountId if accountId is not None else True, models.Handler.handler_type == handlerType if handlerType is not None else True, models.Handler.last_update_time >= lastUpdateTimeGte if lastUpdateTimeGte is not None else True, ) ) return selectSt
[docs] @exceptionWrap async def getHandlers( self, accountId: Optional[str] = None, description: Optional[str] = None, handlerType: int | None = None, page: int = 1, pageSize: int = 100, ) -> List[dict]: """ Get handlers by filters Args: accountId: handler account id description: handler description handlerType: handler type page: page pageSize: page size Returns: list of deserialize handlers """ selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, handlerType=handlerType) selectSt = selectSt.order_by(models.Handler.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize) async with DBContext.adaptor.connection(self.logger) as connection: handlerRows = await connection.fetchall(selectSt) return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows]
[docs] @exceptionWrap async def getHandlerCount( self, accountId: str | None = None, description: str | None = None, handlerType: int | None = None ): """ Get handler count Args: accountId: handler account id description: handler description handlerType: handler type Returns: handler count """ async with DBContext.adaptor.connection(self.logger) as connection: selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, handlerType=handlerType) selectSt = select([func.count()]).select_from(aliased(selectSt)) count = (await connection.fetchone(selectSt))[0] return count