import ujson
from datetime import datetime
from sqlalchemy.engine import RowProxy
from sqlalchemy.sql.base import ImmutableColumnCollection
from typing import Tuple, Dict, Any, Optional, List, Iterable, Union
from uuid import uuid4
from sqlalchemy import insert, and_, update, delete, func, select
from sqlalchemy.orm import aliased
from app.global_vars.enums import HandlerTarget
from vlutils.helpers import convertTimeToString
from crutches_on_wheels.db.base_context import BaseDBContext
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.utils.db_functions import dbExceptionWrap as exceptionWrap
from db import models
[docs]class DBContext(BaseDBContext):
    """Handlers DB context."""
[docs]    def loadHandlerFromRow(self, handlerRow: RowProxy, selectColumns: ImmutableColumnCollection) -> dict[str, Any]:
        """
        Load handler as dict from raw row.
        Args:
            handlerRow: full row from db
            selectColumns: columns from sql query
        Returns:
            handler as dict
        """
        result = {}
        for column in selectColumns:
            snakeCaseTarget = column.key
            target = HandlerTarget(snakeCaseTarget)
            if target == HandlerTarget.description:
                result[snakeCaseTarget] = handlerRow[snakeCaseTarget] or ""
            elif target == HandlerTarget.policies:
                result[snakeCaseTarget] = (
                    ujson.loads(handlerRow[snakeCaseTarget]) if handlerRow[snakeCaseTarget] is not None else None
                )
            elif target in (HandlerTarget.createTime, HandlerTarget.lastUpdateTime):
                result[snakeCaseTarget] = convertTimeToString(handlerRow[snakeCaseTarget], self.storageTime == "UTC")
            else:
                result[snakeCaseTarget] = handlerRow[snakeCaseTarget]
        return result 
[docs]    def loadVerifierFromRow(self, verifierRow: Tuple) -> Dict[str, Any]:
        """
        Load verifier as dict from raw row.
        Args:
            verifierRow: full row from db
        Returns:
            verifier as dict
        """
        verifier = dict(zip(models.Verifier.getColumnNames(), verifierRow))
        if verifier["description"] is None:
            verifier["description"] = ""
        verifier["policies"] = ujson.loads(verifier["policies"]) if verifier["policies"] is not None else None
        verifier["create_time"] = convertTimeToString(verifier["create_time"], self.storageTime == "UTC")
        verifier["last_update_time"] = convertTimeToString(verifier["last_update_time"], self.storageTime == "UTC")
        return verifier 
[docs]    @exceptionWrap
    async def createVerifier(self, policies: dict, accountId: str, description: str = "") -> str:
        """
        Create new verifier.
        Args:
            policies: verifier policies
            accountId: account id
            description: user verifier description
        Returns:
            verifier id
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            verifierId = str(uuid4())
            insertSt = insert(models.Verifier).values(
                description=description, account_id=accountId, policies=ujson.dumps(policies), verifier_id=verifierId
            )
            await connection.execute(insertSt)
            return verifierId 
[docs]    @exceptionWrap
    async def getVerifier(self, verifierId: str, accountId: Optional[str] = None) -> Dict:
        """
        Get verifier by id
        Args:
            verifierId: verifier id
            accountId: verifier account id
        Returns:
            deserialized dict with verifier
        Raises:
            VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False): if verifier not found
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            query = select([models.Verifier]).where(
                and_(
                    models.Verifier.verifier_id == verifierId,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            verifierRow = await connection.fetchone(query)
            if verifierRow is None:
                raise VLException(Error.VerifierNotFound.format(verifierId), 404, isCriticalError=False)
            return self.loadVerifierFromRow(verifierRow) 
[docs]    @exceptionWrap
    async def putVerifier(
        self, verifierId: str, policies: dict, accountId: str, description: str = ""
    ) -> Optional[int]:
        """
        Replace verifier by id.
        Args:
            verifierId: verifier id
            policies: verifier policies
            accountId: account id
            description: user verifier description
        Warnings:
            function does not create verifier!
        Return:
            verifier current version if verifier replaced
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            updateSt = (
                update(models.Verifier)
                .where(and_(models.Verifier.verifier_id == verifierId, models.Verifier.account_id == accountId))
                .values(description=description, policies=ujson.dumps(policies), version=models.Verifier.version + 1)
            ).returning(models.Verifier.version)
            return await connection.scalar(updateSt) 
[docs]    @exceptionWrap
    async def deleteVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool:
        """
        Delete verifier by id
        Args:
            verifierId: verifier id
            accountId: account id
        Return:
            True if verifier deleted, otherwise False
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            deleteSt = delete(models.Verifier).where(
                and_(
                    models.Verifier.verifier_id == verifierId,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            return bool(await connection.execute(deleteSt)) 
[docs]    @exceptionWrap
    async def checkVerifier(self, verifierId: str, accountId: Optional[str] = None) -> bool:
        """
        Check verifier existence
        Args:
            verifierId: verifier id
            accountId: account id
        Returns:
            True - if verifier exists, otherwise False
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            query = select([models.Verifier.verifier_id]).where(
                and_(
                    models.Verifier.verifier_id == verifierId,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            verifierRow = await connection.fetchone(query)
            return False if verifierRow is None else True 
[docs]    @exceptionWrap
    async def getVerifierCount(self, accountId: Optional[str] = None, description: Optional[str] = None):
        """
        Get verifier count
        Args:
            accountId: account id
            description: verifier description
        Returns:
            verifier count
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            selectSt = select([models.Verifier]).where(
                and_(
                    models.Verifier.description.like("%{}%".format(description)) if description is not None else True,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            selectSt = select([func.count()]).select_from(aliased(selectSt))
            count = await connection.scalar(selectSt)
            return count 
[docs]    @exceptionWrap
    async def getVerifiers(
        self, accountId: Optional[str] = None, description: Optional[str] = None, page: int = 1, pageSize: int = 100,
    ) -> List[dict]:
        """
        Get verifiers by filters
        Args:
            accountId: account id
            description: verifier description
            page: page
            pageSize: page size
        Returns:
            list of verifiers
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            selectSt = select([models.Verifier]).where(
                and_(
                    models.Verifier.description.like("%{}%".format(description)) if description is not None else True,
                    models.Verifier.account_id == accountId if accountId is not None else True,
                )
            )
            selectSt = (
                selectSt.order_by(models.Verifier.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize)
            )
            verifierRows = await connection.fetchall(selectSt)
            return [self.loadVerifierFromRow(handlerRow) for handlerRow in verifierRows] 
[docs]    @exceptionWrap
    async def createHandler(
        self, policies: dict, accountId: str, description: str = "", isDynamic: bool = False
    ) -> str:
        """
        Create new handler.
        Args:
            policies: set handler policies
            accountId: account id
            description: user handler description
            isDynamic: dynamic handler flag
        Returns:
            handler id
        """
        if isDynamic:
            rulesStr = None
        else:
            rulesStr = ujson.dumps(policies)
        async with DBContext.adaptor.connection(self.logger) as connection:
            handlerId = str(uuid4())
            insertSt = insert(models.Handler).values(
                description=description,
                account_id=accountId,
                policies=rulesStr,
                handler_id=handlerId,
                is_dynamic=isDynamic,
            )
            await connection.execute(insertSt)
            return handlerId 
[docs]    @exceptionWrap
    async def getHandler(
        self,
        handlerId: str,
        targets: Union[Iterable[HandlerTarget]] = tuple(HandlerTarget),
        accountId: Optional[str] = None,
    ) -> dict:
        """
        Get handler by id
        Args:
            handlerId: handler id
            accountId: handler account id
            targets: handler targets
        Returns:
            deserialize dict with handler
        Raises:
            VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False): if handler not found
        """
        selectSt = self._genGetHandlersQuery(targets=targets, handlerIds=[handlerId], accountId=accountId)
        async with DBContext.adaptor.connection(self.logger) as connection:
            handlerRow = await connection.fetchone(selectSt)
            if handlerRow is None:
                raise VLException(Error.HandlerNotFound.format(handlerId), 404, isCriticalError=False)
            return self.loadHandlerFromRow(handlerRow, selectSt.columns) 
[docs]    @exceptionWrap
    async def getAbsentHandlersIds(self, handlerIds: List[str]) -> List[str]:
        """Given list of handlerIds, return those handlerIds that were removed from database.
        Used for removal of items from cache by cache invalidator.
        Args:
            handlerIds: handler ids
        Returns:
            List with handlersIds that were removed.
        """
        selectSt = self._genGetHandlersQuery(targets=[HandlerTarget.handlerId], handlerIds=handlerIds)
        async with DBContext.adaptor.connection(self.logger) as connection:
            result = await connection.fetchall(selectSt)
            return list(set(handlerIds) - set([row["handler_id"] for row in result])) 
[docs]    @exceptionWrap
    async def getUpdatedHandlers(
        self, handlerIds: List[str], lastUpdateTimeGte: datetime, targets: Optional[list[HandlerTarget]] = None
    ) -> List[Dict]:
        """Get updated handlers by ids. Used for cache invalidation.
        Args:
            handlerIds: handler ids
            lastUpdateTimeGte: lower bound of handler update time
            targets: handler targets
        Returns:
            List with handlers
        """
        selectSt = self._genGetHandlersQuery(
            handlerIds=handlerIds, lastUpdateTimeGte=lastUpdateTimeGte, targets=targets
        )
        async with DBContext.adaptor.connection(self.logger) as connection:
            handlerRows = await connection.fetchall(selectSt)
            return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows] 
[docs]    @exceptionWrap
    async def putHandler(
        self, handlerId: str, policies: dict, accountId: str, description: str = "", isDynamic: bool = False
    ) -> bool:
        """
        Put handler  instead old handler.
        Args:
            handlerId: handler
            policies: set handler policies
            accountId: account id
            description: user handler description
            isDynamic: dynamic handler flag
        Warnings:
            function does not create handler!
        Return:
            True if handler exist otherwise false
        """
        if isDynamic:
            rulesStr = None
        else:
            rulesStr = ujson.dumps(policies)
        async with DBContext.adaptor.connection(self.logger) as connection:
            updateSt = (
                update(models.Handler)
                .where(and_(models.Handler.handler_id == handlerId, models.Handler.account_id == accountId))
                .values(
                    description=description,
                    policies=rulesStr,
                    is_dynamic=isDynamic,
                    last_update_time=self.currentDBTimestamp,
                )
            )
            return bool(await connection.execute(updateSt)) 
[docs]    @exceptionWrap
    async def deleteHandler(self, handlerId: str, accountId: Optional[str] = None) -> bool:
        """
        Delete handler by id
        Args:
            handlerId: handler id
            accountId: account id of the handler
        Return:
            True if handler exist otherwise false
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            deleteSt = delete(models.Handler).where(
                and_(
                    models.Handler.handler_id == handlerId,
                    models.Handler.account_id == accountId if accountId is not None else True,
                )
            )
            return bool(await connection.execute(deleteSt)) 
[docs]    @exceptionWrap
    async def doesHandlerExist(self, handlerId: str, accountId: Optional[str] = None) -> bool:
        """
        Check a account handler with id=handlerId existence
        Args:
            handlerId: handler id
            accountId: handler account id
        Returns:
            true - if handler is exist otherwise false
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            query = select([models.Handler.handler_id]).where(
                and_(
                    models.Handler.handler_id == handlerId,
                    models.Handler.account_id == accountId if accountId is not None else True,
                )
            )
            handlerRow = await connection.fetchone(query)
            return False if handlerRow is None else True 
    @staticmethod
    def _genGetHandlersQuery(
        targets: Iterable[HandlerTarget] = tuple(HandlerTarget),
        accountId: Optional[str] = None,
        description: Optional[str] = None,
        isDynamic: Optional[bool] = None,
        handlerIds: Optional[list[str]] = None,
        lastUpdateTimeGte: Optional[datetime] = None,
    ) -> select:
        """
        Get handlers by filters
        Args:
            targets: handler targets
            accountId: handler account id
            description: handler description
            isDynamic: whether to get only dynamic (non-dynamic) handlers
            handlerIds: handlers ids
            lastUpdateTimeGte: lower bound of handler update time
        Returns:
            generated "select" statement
        """
        selectSt = select([getattr(models.Handler, target.value) for target in targets]).where(
            and_(
                models.Handler.handler_id.in_(handlerIds) if handlerIds is not None else True,
                models.Handler.description.like("%{}%".format(description)) if description is not None else True,
                models.Handler.account_id == accountId if accountId is not None else True,
                models.Handler.is_dynamic == isDynamic if isDynamic is not None else True,
                models.Handler.last_update_time >= lastUpdateTimeGte if lastUpdateTimeGte is not None else True,
            )
        )
        return selectSt
[docs]    @exceptionWrap
    async def getHandlers(
        self,
        accountId: Optional[str] = None,
        description: Optional[str] = None,
        isDynamic: Optional[bool] = None,
        page: int = 1,
        pageSize: int = 100,
    ) -> List[dict]:
        """
        Get handlers by filters
        Args:
            accountId: handler account id
            description: handler description
            isDynamic: whether to get only dynamic (non-dynamic) handlers
            page: page
            pageSize: page size
        Returns:
            list of deserialize handlers
        """
        selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, isDynamic=isDynamic)
        selectSt = selectSt.order_by(models.Handler.create_time.desc()).offset((page - 1) * pageSize).limit(pageSize)
        async with DBContext.adaptor.connection(self.logger) as connection:
            handlerRows = await connection.fetchall(selectSt)
            return [self.loadHandlerFromRow(handlerRow, selectSt.columns) for handlerRow in handlerRows] 
[docs]    @exceptionWrap
    async def getHandlerCount(
        self, accountId: Optional[str] = None, description: Optional[str] = None, isDynamic: Optional[bool] = None
    ):
        """
        Get handler count
        Args:
            accountId: handler account id
            description: handler description
            isDynamic: whether to get only dynamic (non-dynamic) handlers
        Returns:
            handler count
        """
        async with DBContext.adaptor.connection(self.logger) as connection:
            selectSt = self._genGetHandlersQuery(accountId=accountId, description=description, isDynamic=isDynamic)
            selectSt = select([func.count()]).select_from(aliased(selectSt))
            count = (await connection.fetchone(selectSt))[0]
            return count