from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import uuid4
import ujson
from sqlalchemy import and_, delete, func, insert, select, update
from sqlalchemy.engine import RowProxy
from sqlalchemy.orm import aliased
from sqlalchemy.sql.base import ImmutableColumnCollection
from vlutils.helpers import convertTimeToString
from app.global_vars.enums import HandlerTarget
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.enums.handlers import HandlerType
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils.db_functions import dbExceptionWrap as exceptionWrap
from db.handlers_db_tools.models import handlers_models as models
[docs]
class DBContext(BaseDBContext):
    """Handlers DB context."""
[docs]
    def loadHandlerFromRow(self, handlerRow: RowProxy, selectColumns: ImmutableColumnCollection) -> dict[str, Any]:
        """
        Load handler as dict from raw row.
        Args:
            handlerRow: full row from db
            selectColumns: columns from sql query
        Returns:
            handler as dict
        """
        result = {}
        for column in selectColumns:
            snakeCaseTarget = column.key
            target = HandlerTarget(snakeCaseTarget)
            if target == HandlerTarget.description:
                result[snakeCaseTarget] = handlerRow[snakeCaseTarget] or ""
            elif target == HandlerTarget.policies:
                result[snakeCaseTarget] = (
                    ujson.loads(handlerRow[snakeCaseTarget]) if handlerRow[snakeCaseTarget] is not None else None
                )
            elif target in (HandlerTarget.createTime, HandlerTarget.lastUpdateTime):
                result[snakeCaseTarget] = convertTimeToString(handlerRow[snakeCaseTarget], self.storageTime == "UTC")
            elif target == HandlerTarget.lambdaId:
                if handlerRow[snakeCaseTarget] is None:
                    continue
                result[snakeCaseTarget] = handlerRow[snakeCaseTarget]
            else:
                result[snakeCaseTarget] = handlerRow[snakeCaseTarget]
        return result 
[docs]
    def loadVerifierFromRow(self, verifierRow: Tuple) -> Dict[str, Any]:
        """
        Load verifier as dict from raw row.
        Args:
            verifierRow: full row from db
        Returns:
            verifier as dict
        """
        verifier = dict(zip(models.Verifier.getColumnNames(), verifierRow))
        if verifier["description"] is None:
            verifier["description"] = ""
        verifier["policies"] = ujson.loads(verifier["policies"]) if verifier["policies"] is not None else None
        verifier["create_time"] = convertTimeToString(verifier["create_time"], self.storageTime == "UTC")
        verifier["last_update_time"] = convertTimeToString(verifier["last_update_time"], self.storageTime == "UTC")
        return verifier 
[docs]
    @exceptionWrap
    async def createVerifier(self, policies: dict, accountId: str, description: str = "") -> str:
        """
        Create new verifier.
        Args:
            policies: verifier policies
            accountId: account id
            description: user verifier description
        Returns:
            verifier id
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            verifierId = str(uuid4())
            insertSt = insert(models.Verifier).values(
                description=description, account_id=accountId, policies=ujson.dumps(policies), verifier_id=verifierId
            )
            await connection.execute(insertSt)
            return verifierId 
[docs]
    @exceptionWrap
    async def getVerifier(self, verifierId: str, accountId: Optional[str] = None) -> Dict:
        """
        Get verifier by id
        Args:
            verifierId: verifier id
            accountId: verifier account id
        Returns:
            deserialized dict with verifier
        Raises:
            VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False): if verifier not found
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            query = select([models.Verifier]).where(
                and_(
                    models.Verifier.verifier_id == verifierId,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            verifierRow = await connection.fetchone(query)
            if verifierRow is None:
                raise VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False)
            return self.loadVerifierFromRow(verifierRow) 
[docs]
    @exceptionWrap
    async def putVerifier(
        self, verifierId: str, policies: dict, accountId: str, description: str = ""
    ) -> Optional[int]:
        """
        Replace verifier by id.
        Args:
            verifierId: verifier id
            policies: verifier policies
            accountId: account id
            description: user verifier description
        Warnings:
            function does not create verifier!
        Return:
            verifier current version if verifier replaced
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            updateSt = (
                update(models.Verifier)
                .where(and_(models.Verifier.verifier_id == verifierId, models.Verifier.account_id == accountId))
                .values(description=description, policies=ujson.dumps(policies), version=models.Verifier.version + 1)
            ).returning(models.Verifier.version)
            return await connection.scalar(updateSt) 
[docs]
    @exceptionWrap
    async def deleteVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool:
        """
        Delete verifier by id
        Args:
            verifierId: verifier id
            accountId: account id
        Return:
            True if verifier deleted, otherwise False
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            deleteSt = delete(models.Verifier).where(
                and_(
                    models.Verifier.verifier_id == verifierId,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            return bool(await connection.execute(deleteSt)) 
[docs]
    @exceptionWrap
    async def checkVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool:
        """
        Check verifier existence
        Args:
            verifierId: verifier id
            accountId: account id
        Returns:
            True - if verifier exists, otherwise False
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            query = select([models.Verifier.verifier_id]).where(
                and_(
                    models.Verifier.verifier_id == verifierId,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            verifierRow = await connection.fetchone(query)
            return False if verifierRow is None else True 
[docs]
    @exceptionWrap
    async def getVerifierCount(self, accountId: Optional[str] = None, description: Optional[str] = None):
        """
        Get verifier count
        Args:
            accountId: account id
            description: verifier description
        Returns:
            verifier count
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            selectSt = select([models.Verifier]).where(
                and_(
                    models.Verifier.description.like("%{}%".format(description)) if description is not None else True,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            selectSt = select([func.count()]).select_from(aliased(selectSt))
            count = await connection.scalar(selectSt)
            return count 
[docs]
    @exceptionWrap
    async def getVerifiers(
        self,
        accountId: Optional[str] = None,
        description: Optional[str] = None,
        page: int = 1,
        pageSize: int = 100,
    ) -> List[dict]:
        """
        Get verifiers by filters
        Args:
            accountId: account id
            description: verifier description
            page: page
            pageSize: page size
        Returns:
            list of verifiers
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            selectSt = select([models.Verifier]).where(
                and_(
                    models.Verifier.description.like("%{}%".format(description)) if description is not None else True,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            selectSt = (
                selectSt.order_by(models.Verifier.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize)
            )
            verifierRows = await connection.fetchall(selectSt)
            return [self.loadVerifierFromRow(handlerRow) for handlerRow in verifierRows] 
[docs]
    @exceptionWrap
    async def createHandler(
        self,
        policies: dict,
        accountId: str,
        description: str = "",
        handlerType: HandlerType = HandlerType.static,
        lambdaId: str | None = None,
    ) -> str:
        """
        Create new handler.
        Args:
            policies: set handler policies
            accountId: account id
            description: user handler description
            handlerType: handler type
            lambdaId: id of lambda handler
        Returns:
            handler id
        """
        if handlerType in (HandlerType.dynamic, HandlerType.lambdaHandler):
            rulesStr = None
        else:
            rulesStr = ujson.dumps(policies)
        async with DBContext.adaptor.connection(self.logger) as connection:
            handlerId = str(uuid4())
            insertSt = insert(models.Handler).values(
                description=description,
                account_id=accountId,
                policies=rulesStr,
                handler_id=handlerId,
                handler_type=handlerType.value,
                lambda_id=str(lambdaId) if lambdaId is not None else None,
            )
            await connection.execute(insertSt)
            return handlerId 
[docs]
    @exceptionWrap
    async def getHandler(
        self,
        handlerId: str,
        targets: Union[Iterable[HandlerTarget]] = tuple(HandlerTarget),
        accountId: Optional[str] = None,
    ) -> dict:
        """
        Get handler by id
        Args:
            handlerId: handler id
            accountId: handler account id
            targets: handler targets
        Returns:
            deserialize dict with handler
        Raises:
            VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False): if handler not found
        """
        selectSt = self._genGetHandlersQuery(targets=targets, handlerIds=[handlerId], accountId=accountId)
        async with DBContext.adaptor.connection(self.logger) as connection:
            handlerRow = await connection.fetchone(selectSt)
            if handlerRow is None:
                raise VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False)
            return self.loadHandlerFromRow(handlerRow, selectSt.columns) 
[docs]
    @exceptionWrap
    async def getAbsentHandlersIds(self, handlerIds: List[str]) -> List[str]:
        """Given list of handlerIds, return those handlerIds that were removed from database.
        Used for removal of items from cache by cache invalidator.
        Args:
            handlerIds: handler ids
        Returns:
            List with handlersIds that were removed.
        """
        selectSt = self._genGetHandlersQuery(targets=[HandlerTarget.handlerId], handlerIds=handlerIds)
        async with DBContext.adaptor.connection(self.logger) as connection:
            result = await connection.fetchall(selectSt)
            return list(set(handlerIds) - set([row["handler_id"] for row in result])) 
[docs]
    @exceptionWrap
    async def getUpdatedHandlers(
        self, handlerIds: List[str], lastUpdateTimeGte: datetime, targets: Optional[list[HandlerTarget]] = None
    ) -> List[Dict]:
        """Get updated handlers by ids. Used for cache invalidation.
        Args:
            handlerIds: handler ids
            lastUpdateTimeGte: lower bound of handler update time
            targets: handler targets
        Returns:
            List with handlers
        """
        selectSt = self._genGetHandlersQuery(
            handlerIds=handlerIds, lastUpdateTimeGte=lastUpdateTimeGte, targets=targets
        )
        async with DBContext.adaptor.connection(self.logger) as connection:
            handlerRows = await connection.fetchall(selectSt)
            return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows] 
[docs]
    @exceptionWrap
    async def putHandler(
        self,
        handlerId: str,
        policies: dict,
        accountId: str,
        description: str = "",
        handlerType: int = HandlerType.static.value,
        lambdaId: str | None = None,
    ) -> bool:
        """
        Put handler  instead old handler.
        Args:
            handlerId: handler
            policies: set handler policies
            accountId: account id
            description: user handler description
            handlerType: type of handler
            lambdaId: id of lambda handler
        Warnings:
            function does not create handler!
        Returns:
            True if handler exist otherwise false
        """
        if handlerType in (HandlerType.dynamic.value, HandlerType.lambdaHandler.value):
            rulesStr = None
        else:
            rulesStr = ujson.dumps(policies)
        async with DBContext.adaptor.connection(self.logger) as connection:
            updateSt = (
                update(models.Handler)
                .where(and_(models.Handler.handler_id == handlerId, models.Handler.account_id == accountId))
                .values(
                    description=description,
                    policies=rulesStr,
                    handler_type=handlerType,
                    lambda_id=lambdaId,
                    last_update_time=self.currentDBTimestamp,
                )
            )
            return bool(await connection.execute(updateSt)) 
[docs]
    @exceptionWrap
    async def deleteHandler(self, handlerId: str, accountId: Optional[str] = None) -> bool:
        """
        Delete handler by id
        Args:
            handlerId: handler id
            accountId: account id of the handler
        Returns:
            True if handler exist otherwise false
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            deleteSt = delete(models.Handler).where(
                and_(
                    models.Handler.handler_id == handlerId,
                    models.Handler.account_id == accountId if accountId is not None else True,
                )
            )
            return bool(await connection.execute(deleteSt)) 
[docs]
    @exceptionWrap
    async def doesHandlerExist(self, handlerId: str, accountId: Optional[str] = None) -> bool:
        """
        Check a account handler with id=handlerId existence
        Args:
            handlerId: handler id
            accountId: handler account id
        Returns:
            true - if handler is exist otherwise false
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            query = select([models.Handler.handler_id]).where(
                and_(
                    models.Handler.handler_id == handlerId,
                    models.Handler.account_id == accountId if accountId is not None else True,
                )
            )
            handlerRow = await connection.fetchone(query)
            return False if handlerRow is None else True 
    @staticmethod
    def _genGetHandlersQuery(
        targets: Iterable[HandlerTarget] = tuple(HandlerTarget),
        accountId: Optional[str] = None,
        description: Optional[str] = None,
        handlerType: int | None = None,
        handlerIds: Optional[list[str]] = None,
        lastUpdateTimeGte: Optional[datetime] = None,
    ) -> select:
        """
        Get handlers by filters
        Args:
            targets: handler targets
            accountId: handler account id
            description: handler description
            handlerType: type of handler
            handlerIds: handlers ids
            lastUpdateTimeGte: lower bound of handler update time
        Returns:
            generated "select" statement
        """
        selectSt = select([getattr(models.Handler, target.value) for target in targets]).where(
            and_(
                models.Handler.handler_id.in_(handlerIds) if handlerIds is not None else True,
                models.Handler.description.like("%{}%".format(description)) if description is not None else True,
                models.Handler.account_id == accountId if accountId is not None else True,
                models.Handler.handler_type == handlerType if handlerType is not None else True,
                models.Handler.last_update_time >= lastUpdateTimeGte if lastUpdateTimeGte is not None else True,
            )
        )
        return selectSt
[docs]
    @exceptionWrap
    async def getHandlers(
        self,
        accountId: Optional[str] = None,
        description: Optional[str] = None,
        handlerType: int | None = None,
        page: int = 1,
        pageSize: int = 100,
    ) -> List[dict]:
        """
        Get handlers by filters
        Args:
            accountId: handler account id
            description: handler description
            handlerType: handler type
            page: page
            pageSize: page size
        Returns:
            list of deserialize handlers
        """
        selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, handlerType=handlerType)
        selectSt = selectSt.order_by(models.Handler.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize)
        async with DBContext.adaptor.connection(self.logger) as connection:
            handlerRows = await connection.fetchall(selectSt)
            return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows] 
[docs]
    @exceptionWrap
    async def getHandlerCount(
        self, accountId: str | None = None, description: str | None = None, handlerType: int | None = None
    ):
        """
        Get handler count
        Args:
            accountId: handler account id
            description: handler description
            handlerType: handler type
        Returns:
            handler count
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, handlerType=handlerType)
            selectSt = select([func.count()]).select_from(aliased(selectSt))
            count = (await connection.fetchone(selectSt))[0]
            return count