from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import uuid4
import ujson
from sqlalchemy import and_, delete, func, insert, select, update
from sqlalchemy.engine import RowProxy
from sqlalchemy.orm import aliased
from sqlalchemy.sql.base import ImmutableColumnCollection
from vlutils.helpers import convertTimeToString
from app.global_vars.enums import HandlerTarget
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.enums.handlers import HandlerType
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils.db_functions import dbExceptionWrap as exceptionWrap
from db.handlers_db_tools.models import handlers_models as models
[docs]
class DBContext(BaseDBContext):
"""Handlers DB context."""
[docs]
def loadHandlerFromRow(self, handlerRow: RowProxy, selectColumns: ImmutableColumnCollection) -> dict[str, Any]:
"""
Load handler as dict from raw row.
Args:
handlerRow: full row from db
selectColumns: columns from sql query
Returns:
handler as dict
"""
result = {}
for column in selectColumns:
snakeCaseTarget = column.key
target = HandlerTarget(snakeCaseTarget)
if target == HandlerTarget.description:
result[snakeCaseTarget] = handlerRow[snakeCaseTarget] or ""
elif target == HandlerTarget.policies:
result[snakeCaseTarget] = (
ujson.loads(handlerRow[snakeCaseTarget]) if handlerRow[snakeCaseTarget] is not None else None
)
elif target in (HandlerTarget.createTime, HandlerTarget.lastUpdateTime):
result[snakeCaseTarget] = convertTimeToString(handlerRow[snakeCaseTarget], self.storageTime == "UTC")
elif target == HandlerTarget.lambdaId:
if handlerRow[snakeCaseTarget] is None:
continue
result[snakeCaseTarget] = handlerRow[snakeCaseTarget]
else:
result[snakeCaseTarget] = handlerRow[snakeCaseTarget]
return result
[docs]
def loadVerifierFromRow(self, verifierRow: Tuple) -> Dict[str, Any]:
"""
Load verifier as dict from raw row.
Args:
verifierRow: full row from db
Returns:
verifier as dict
"""
verifier = dict(zip(models.Verifier.getColumnNames(), verifierRow))
if verifier["description"] is None:
verifier["description"] = ""
verifier["policies"] = ujson.loads(verifier["policies"]) if verifier["policies"] is not None else None
verifier["create_time"] = convertTimeToString(verifier["create_time"], self.storageTime == "UTC")
verifier["last_update_time"] = convertTimeToString(verifier["last_update_time"], self.storageTime == "UTC")
return verifier
[docs]
@exceptionWrap
async def createVerifier(self, policies: dict, accountId: str, description: str = "") -> str:
"""
Create new verifier.
Args:
policies: verifier policies
accountId: account id
description: user verifier description
Returns:
verifier id
"""
async with DBContext.adaptor.connection(self.logger) as connection:
verifierId = str(uuid4())
insertSt = insert(models.Verifier).values(
description=description, account_id=accountId, policies=ujson.dumps(policies), verifier_id=verifierId
)
await connection.execute(insertSt)
return verifierId
[docs]
@exceptionWrap
async def getVerifier(self, verifierId: str, accountId: Optional[str] = None) -> Dict:
"""
Get verifier by id
Args:
verifierId: verifier id
accountId: verifier account id
Returns:
deserialized dict with verifier
Raises:
VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False): if verifier not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
query = select([models.Verifier]).where(
and_(
models.Verifier.verifier_id == verifierId,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
verifierRow = await connection.fetchone(query)
if verifierRow is None:
raise VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False)
return self.loadVerifierFromRow(verifierRow)
[docs]
@exceptionWrap
async def putVerifier(
self, verifierId: str, policies: dict, accountId: str, description: str = ""
) -> Optional[int]:
"""
Replace verifier by id.
Args:
verifierId: verifier id
policies: verifier policies
accountId: account id
description: user verifier description
Warnings:
function does not create verifier!
Return:
verifier current version if verifier replaced
"""
async with DBContext.adaptor.connection(self.logger) as connection:
updateSt = (
update(models.Verifier)
.where(and_(models.Verifier.verifier_id == verifierId, models.Verifier.account_id == accountId))
.values(description=description, policies=ujson.dumps(policies), version=models.Verifier.version + 1)
).returning(models.Verifier.version)
return await connection.scalar(updateSt)
[docs]
@exceptionWrap
async def deleteVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool:
"""
Delete verifier by id
Args:
verifierId: verifier id
accountId: account id
Return:
True if verifier deleted, otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
deleteSt = delete(models.Verifier).where(
and_(
models.Verifier.verifier_id == verifierId,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
return bool(await connection.execute(deleteSt))
[docs]
@exceptionWrap
async def checkVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool:
"""
Check verifier existence
Args:
verifierId: verifier id
accountId: account id
Returns:
True - if verifier exists, otherwise False
"""
async with DBContext.adaptor.connection(self.logger) as connection:
query = select([models.Verifier.verifier_id]).where(
and_(
models.Verifier.verifier_id == verifierId,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
verifierRow = await connection.fetchone(query)
return False if verifierRow is None else True
[docs]
@exceptionWrap
async def getVerifierCount(self, accountId: Optional[str] = None, description: Optional[str] = None):
"""
Get verifier count
Args:
accountId: account id
description: verifier description
Returns:
verifier count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Verifier]).where(
and_(
models.Verifier.description.like("%{}%".format(description)) if description is not None else True,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
selectSt = select([func.count()]).select_from(aliased(selectSt))
count = await connection.scalar(selectSt)
return count
[docs]
@exceptionWrap
async def getVerifiers(
self,
accountId: Optional[str] = None,
description: Optional[str] = None,
page: int = 1,
pageSize: int = 100,
) -> List[dict]:
"""
Get verifiers by filters
Args:
accountId: account id
description: verifier description
page: page
pageSize: page size
Returns:
list of verifiers
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Verifier]).where(
and_(
models.Verifier.description.like("%{}%".format(description)) if description is not None else True,
models.Verifier.account_id == accountId if accountId is not None else True,
)
)
selectSt = (
selectSt.order_by(models.Verifier.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize)
)
verifierRows = await connection.fetchall(selectSt)
return [self.loadVerifierFromRow(handlerRow) for handlerRow in verifierRows]
[docs]
@exceptionWrap
async def createHandler(
self,
policies: dict,
accountId: str,
description: str = "",
handlerType: HandlerType = HandlerType.static,
lambdaId: str | None = None,
) -> str:
"""
Create new handler.
Args:
policies: set handler policies
accountId: account id
description: user handler description
handlerType: handler type
lambdaId: id of lambda handler
Returns:
handler id
"""
if handlerType in (HandlerType.dynamic, HandlerType.lambdaHandler):
rulesStr = None
else:
rulesStr = ujson.dumps(policies)
async with DBContext.adaptor.connection(self.logger) as connection:
handlerId = str(uuid4())
insertSt = insert(models.Handler).values(
description=description,
account_id=accountId,
policies=rulesStr,
handler_id=handlerId,
handler_type=handlerType.value,
lambda_id=str(lambdaId) if lambdaId is not None else None,
)
await connection.execute(insertSt)
return handlerId
[docs]
@exceptionWrap
async def getHandler(
self,
handlerId: str,
targets: Union[Iterable[HandlerTarget]] = tuple(HandlerTarget),
accountId: Optional[str] = None,
) -> dict:
"""
Get handler by id
Args:
handlerId: handler id
accountId: handler account id
targets: handler targets
Returns:
deserialize dict with handler
Raises:
VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False): if handler not found
"""
selectSt = self._genGetHandlersQuery(targets=targets, handlerIds=[handlerId], accountId=accountId)
async with DBContext.adaptor.connection(self.logger) as connection:
handlerRow = await connection.fetchone(selectSt)
if handlerRow is None:
raise VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False)
return self.loadHandlerFromRow(handlerRow, selectSt.columns)
[docs]
@exceptionWrap
async def getAbsentHandlersIds(self, handlerIds: List[str]) -> List[str]:
"""Given list of handlerIds, return those handlerIds that were removed from database.
Used for removal of items from cache by cache invalidator.
Args:
handlerIds: handler ids
Returns:
List with handlersIds that were removed.
"""
selectSt = self._genGetHandlersQuery(targets=[HandlerTarget.handlerId], handlerIds=handlerIds)
async with DBContext.adaptor.connection(self.logger) as connection:
result = await connection.fetchall(selectSt)
return list(set(handlerIds) - set([row["handler_id"] for row in result]))
[docs]
@exceptionWrap
async def getUpdatedHandlers(
self, handlerIds: List[str], lastUpdateTimeGte: datetime, targets: Optional[list[HandlerTarget]] = None
) -> List[Dict]:
"""Get updated handlers by ids. Used for cache invalidation.
Args:
handlerIds: handler ids
lastUpdateTimeGte: lower bound of handler update time
targets: handler targets
Returns:
List with handlers
"""
selectSt = self._genGetHandlersQuery(
handlerIds=handlerIds, lastUpdateTimeGte=lastUpdateTimeGte, targets=targets
)
async with DBContext.adaptor.connection(self.logger) as connection:
handlerRows = await connection.fetchall(selectSt)
return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows]
[docs]
@exceptionWrap
async def putHandler(
self,
handlerId: str,
policies: dict,
accountId: str,
description: str = "",
handlerType: int = HandlerType.static.value,
lambdaId: str | None = None,
) -> bool:
"""
Put handler instead old handler.
Args:
handlerId: handler
policies: set handler policies
accountId: account id
description: user handler description
handlerType: type of handler
lambdaId: id of lambda handler
Warnings:
function does not create handler!
Returns:
True if handler exist otherwise false
"""
if handlerType in (HandlerType.dynamic.value, HandlerType.lambdaHandler.value):
rulesStr = None
else:
rulesStr = ujson.dumps(policies)
async with DBContext.adaptor.connection(self.logger) as connection:
updateSt = (
update(models.Handler)
.where(and_(models.Handler.handler_id == handlerId, models.Handler.account_id == accountId))
.values(
description=description,
policies=rulesStr,
handler_type=handlerType,
lambda_id=lambdaId,
last_update_time=self.currentDBTimestamp,
)
)
return bool(await connection.execute(updateSt))
[docs]
@exceptionWrap
async def deleteHandler(self, handlerId: str, accountId: Optional[str] = None) -> bool:
"""
Delete handler by id
Args:
handlerId: handler id
accountId: account id of the handler
Returns:
True if handler exist otherwise false
"""
async with DBContext.adaptor.connection(self.logger) as connection:
deleteSt = delete(models.Handler).where(
and_(
models.Handler.handler_id == handlerId,
models.Handler.account_id == accountId if accountId is not None else True,
)
)
return bool(await connection.execute(deleteSt))
[docs]
@exceptionWrap
async def doesHandlerExist(self, handlerId: str, accountId: Optional[str] = None) -> bool:
"""
Check a account handler with id=handlerId existence
Args:
handlerId: handler id
accountId: handler account id
Returns:
true - if handler is exist otherwise false
"""
async with DBContext.adaptor.connection(self.logger) as connection:
query = select([models.Handler.handler_id]).where(
and_(
models.Handler.handler_id == handlerId,
models.Handler.account_id == accountId if accountId is not None else True,
)
)
handlerRow = await connection.fetchone(query)
return False if handlerRow is None else True
@staticmethod
def _genGetHandlersQuery(
targets: Iterable[HandlerTarget] = tuple(HandlerTarget),
accountId: Optional[str] = None,
description: Optional[str] = None,
handlerType: int | None = None,
handlerIds: Optional[list[str]] = None,
lastUpdateTimeGte: Optional[datetime] = None,
) -> select:
"""
Get handlers by filters
Args:
targets: handler targets
accountId: handler account id
description: handler description
handlerType: type of handler
handlerIds: handlers ids
lastUpdateTimeGte: lower bound of handler update time
Returns:
generated "select" statement
"""
selectSt = select([getattr(models.Handler, target.value) for target in targets]).where(
and_(
models.Handler.handler_id.in_(handlerIds) if handlerIds is not None else True,
models.Handler.description.like("%{}%".format(description)) if description is not None else True,
models.Handler.account_id == accountId if accountId is not None else True,
models.Handler.handler_type == handlerType if handlerType is not None else True,
models.Handler.last_update_time >= lastUpdateTimeGte if lastUpdateTimeGte is not None else True,
)
)
return selectSt
[docs]
@exceptionWrap
async def getHandlers(
self,
accountId: Optional[str] = None,
description: Optional[str] = None,
handlerType: int | None = None,
page: int = 1,
pageSize: int = 100,
) -> List[dict]:
"""
Get handlers by filters
Args:
accountId: handler account id
description: handler description
handlerType: handler type
page: page
pageSize: page size
Returns:
list of deserialize handlers
"""
selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, handlerType=handlerType)
selectSt = selectSt.order_by(models.Handler.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize)
async with DBContext.adaptor.connection(self.logger) as connection:
handlerRows = await connection.fetchall(selectSt)
return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows]
[docs]
@exceptionWrap
async def getHandlerCount(
self, accountId: str | None = None, description: str | None = None, handlerType: int | None = None
):
"""
Get handler count
Args:
accountId: handler account id
description: handler description
handlerType: handler type
Returns:
handler count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, handlerType=handlerType)
selectSt = select([func.count()]).select_from(aliased(selectSt))
count = (await connection.fetchone(selectSt))[0]
return count