"""Cross-matching handler."""
import os
from abc import ABC, abstractmethod
from typing import Union
import psutil
from sanic.compat import CancelledErrors
from sanic.response import HTTPResponse
from vlutils.descriptors.data import DescriptorType
from app_common.handlers.base_handler import CommonBaseHandler
from app_common.handlers.schemas import (
    BodyCrossMatch,
    CrossMatchAttributesFilters,
    CrossMatchEventsFilters,
    CrossMatchFacesFilters,
    FaceCrossMatch,
)
from classes.cross_match import CrossMatcher
from classes.cross_match_helpers import packCrossMatchResult
from classes.enums import MatchSource
from classes.filters import AttributeFilters, EventFilters, FaceFilters
from classes.match_reply_helpers import getPreparedFilters
from configs.config_common import DEFAULT_CROSS_MATCH_LIMIT
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.web.query_getters import uuidGetter, validateIntAsBool
# map with match origin and filters
FILTERS_MAP = {
    MatchSource.faces.value: FaceFilters,
    MatchSource.events.value: EventFilters,
    MatchSource.attributes.value: AttributeFilters,
}
[docs]class CrossMatcherBaseHandler(CommonBaseHandler, ABC):
    """
    Base cross-matching handler
    """
    @property
    @abstractmethod
    def descriptorType(self) -> DescriptorType:
        """
        The tool descriptor type. Must be overridden.
        """
    @property
    def configLunaEventsUsage(self) -> bool:
        """
        Get luna-events service usages
        Returns:
            `True` if luna-events service is used
        """
        return self.config.additionalServicesUsage.lunaEvents
    @property
    @abstractmethod
    def descriptorVersion(self) -> int:
        """
        The tool descriptor version. Must be overridden.
        """
[docs]    def validateCrossMatchJson(self, inputJson: dict) -> None:
        """
        Validate input json. If error in candidates schemas - validate it separately and raise error with updated detail
        Args:
            inputJson: inputJson from request
        Raises:
            VLException(Error.BadInputJson, 400, False) if fail to validate input json
        """
        try:
            if self.descriptorType == DescriptorType.face:
                self.validateJson(inputJson, FaceCrossMatch.schema, useJsonSchema=False)
            else:
                self.validateJson(inputJson, BodyCrossMatch.schema, useJsonSchema=False)
        except VLException as mainValidationException:
            partFromDetailWithPath = mainValidationException.error.detail.split("Path:")[1].split(",")[0]
            if "candidates" in partFromDetailWithPath:
                errorKey = "candidates"
            elif "references" in partFromDetailWithPath:
                errorKey = "references"
            else:
                raise mainValidationException
            try:
                if isinstance(inputJson[errorKey], dict):
                    if inputJson[errorKey].get("origin") == "faces":
                        self.validateJson(inputJson[errorKey], CrossMatchFacesFilters.schema, useJsonSchema=False)
                    if inputJson[errorKey].get("origin") == "events":
                        self.validateJson(inputJson[errorKey], CrossMatchEventsFilters.schema, useJsonSchema=False)
                    if inputJson[errorKey].get("origin") == "attributes":
                        self.validateJson(inputJson[errorKey], CrossMatchAttributesFilters.schema, useJsonSchema=False)
            except VLException as deepException:
                if deepException.error == Error.BadInputJson:
                    oldDetail = deepException.error.detail
                    pathStartIndex = oldDetail.index("Path: ") + len("Path: '")
                    newDetail = oldDetail[:pathStartIndex] + f"{errorKey}." + oldDetail[pathStartIndex:]
                    deepException.error.detail = newDetail
                    raise deepException
            raise mainValidationException 
[docs]    def getFiltersFromJson(self, inputJson: dict, objectType: str) -> Union[FaceFilters, EventFilters]:
        """
        Check account ids from request and luna-events services usage and get filters
        Args:
            inputJson: input json
            objectType: candidates or references
        Returns:
            candidate filters and reference filters
        """
        accountId = self.getQueryParam("account_id", uuidGetter, default=None)
        filters = FILTERS_MAP[inputJson[objectType]["origin"]]().initFromRequest(inputJson[objectType])
        if not self.configLunaEventsUsage and filters.origin == MatchSource.events.value:
            raise VLException(Error.LunaEventsIsDisabled, 403, isCriticalError=False)
        if accountId is not None and filters.accountId is not None and accountId != filters.accountId:
            raise VLException(Error.DifferentAccounts.format(objectType), 400, False)
        if accountId is not None:
            filters.accountId = accountId
        return filters 
[docs]    async def getPreparedFiltersAndCheckListsExistence(self, filters: Union[FaceFilters, EventFilters]) -> dict:
        """
        Get filters prepared for reply
        Args:
            filters: filters
        Returns:
            dict with filters prepared for reply
        """
        if isinstance(filters, FaceFilters) and filters.listId is not None:
            if not await self.facesDBContext.getList(listId=filters.listId, accountId=filters.accountId):
                raise VLException(Error.ListNotFound.format(filters.listId), 400, False)
        return getPreparedFilters(filters) 
[docs]    async def post(self) -> HTTPResponse:
        """
        Search for faces/events by given filters and matching them with each other.
        To work with face descriptors, see `crossmatch_faces`_ or `crossmatch_bodies`_ for body descriptors.
        .. _crossmatch_faces:
            _static/api.html#operation/faceCrossMatching
        .. _crossmatch_bodies:
            _static/api.html#operation/bodyCrossMatching
        Returns:
            response with cross matching results
        """
        self.logger.debug(f"Memory usages: {round(psutil.Process(os.getpid()).memory_info().rss / 2**30, 6)} Gb")
        inputJson: dict = self.request.json
        self.validateCrossMatchJson(inputJson)
        candidateFilters = self.getFiltersFromJson(inputJson=inputJson, objectType="candidates")
        referenceFilters = self.getFiltersFromJson(inputJson=inputJson, objectType="references")
        preparedCandidateFilters = await self.getPreparedFiltersAndCheckListsExistence(candidateFilters)
        preparedReferenceFilters = await self.getPreparedFiltersAndCheckListsExistence(referenceFilters)
        limit = inputJson.get("limit", DEFAULT_CROSS_MATCH_LIMIT)
        threshold = inputJson.get("threshold", 0.0)
        sortOn = self.getQueryParam("sorting", validateIntAsBool, default=True)
        crossMatchProcessor = CrossMatcher(
            facesDBContext=self.facesDBContext,
            eventsDBContext=self.eventsDBContext,
            attributesDBContext=self.attributesDBContext,
            candidateFilters=candidateFilters,
            referenceFilters=referenceFilters,
            limit=limit,
            threshold=threshold,
            sortOn=sortOn,
            logger=self.logger,
            descriptorType=self.descriptorType,
            descriptorVersion=self.descriptorVersion,
        )
        successMatchResult, errorMatchResult = await crossMatchProcessor.process()
        preparedRequestFilters = dict(
            candidates=preparedCandidateFilters, references=preparedReferenceFilters, limit=limit, threshold=threshold
        )
        contentType = "application/msgpack"
        packer = packCrossMatchResult
        self.request.headers["Accept-Encoding"] = "identity"
        try:
            async with await self.request.respond(status=200, content_type=contentType) as sender:
                for chunk in packer(preparedRequestFilters, successMatchResult, errorMatchResult):
                    await sender(chunk)
        except CancelledErrors:
            self.logger.info(f"Response streaming cancelled")
        except Exception as e:
            self.logger.error(f"Response streaming failed: {e}")
        else:
            self.logger.info(f"Response streaming completed")
        self.logger.debug(f"Memory usages: {round(psutil.Process(os.getpid()).memory_info().rss / 2**30, 6)} Gb")  
[docs]class FaceCrossMatcherHandler(CrossMatcherBaseHandler):
    """
    Cross-matching handler to implement work with face descriptors.
    Resource: "/{api_version}/crossmatcher/faces"
    """
    descriptorType = DescriptorType.face
    @property
    def descriptorVersion(self) -> int:
        """Default face descriptor version."""
        return self.config.defaultFaceDescriptorVersion 
[docs]class BodyCrossMatcherHandler(CrossMatcherBaseHandler):
    """
    Cross-matching handler to implement work with body descriptors.
    Resource: "/{api_version}/crossmatcher/bodies"
    """
    descriptorType = DescriptorType.body
    @property
    def descriptorVersion(self) -> int:
        """Default human descriptor version."""
        return self.config.defaultHumanDescriptorVersion 
[docs]class UnwantedCrossMatcherHandler(FaceCrossMatcherHandler):
    """
    Unwanted handler for face cross-matching
    Resource: "/{api_version}/crossmatcher" alias for "/{api_version}/crossmatcher/bodies"
    """
[docs]    async def post(self) -> HTTPResponse:
        """
        Print warning also.
        """
        self.logger.warning(
            f"Resource `/{self.app.ctx.apiVersion}/crossmatcher` is deprecated. "
            f"Use `/{self.app.ctx.apiVersion}/crossmatcher/faces` instead."
        )
        return await super().post()