"""
Module contains schemas for verifier
"""
import asyncio
from functools import cached_property
from typing import Optional, Union
from uuid import UUID
from lunavl.sdk.estimators.face_estimators.livenessv1 import LivenessPrediction
from lunavl.sdk.estimators.face_estimators.mask import MaskState
from pydantic import Field
from vlutils.structures.dataclasses import dataclass
from classes.image_meta import ProcessedImageData
from classes.raw_descriptor_data import RawDescriptorData
from luna3.client import Client
from app.api_sdk_adaptors.base import (
    createFaceDetectionJson,
    HandlerEstimations,
    LoopEstimationsAlwaysOn,
    buildImageFilteredDetections,
    executeSDKTask,
)
from app.api_sdk_adaptors.handler import APISDKHandlerAdaptor
from app.global_vars.constants import MAX_ANGLE
from classes.event import Event, EventMetadata
from classes.monitoring import HandlersMonitoringData
from classes.schemas import types
from classes.schemas.base_schema import BaseSchema
from classes.schemas.detect_policies import BaseDetectPolicy
from classes.schemas.extract_policies import BaseExtractPolicy
from classes.schemas.handler import FaceInputEstimationsModel
from classes.schemas.match_policy import BaseMatchPolicy, EventMatchResult
from classes.schemas.storage_policy import FaceSamplePolicy, AttributeStorePolicy
from crutches_on_wheels.enums.attributes import Liveness
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.monitoring.points import monitorTime
from sdk.sdk_loop.enums import LoopEstimations, MultifacePolicy
from sdk.sdk_loop.errors.errors import MultipleFaces
from sdk.sdk_loop.models.image import InputImage
from sdk.sdk_loop.tasks.filters import FaceDetectionFilters, Filters
from sdk.sdk_loop.tasks.task import TaskEstimationParams, LivenessV1Params, TaskParams
MATCH_BY_EVENT_ID_LABEL = "eventIds"
MATCH_BY_FACE_ID_LABEL = "faceIds"
MATCH_BY_EXTERNAL_ID_LABEL = "externalIds"
MATCH_BY_ATTRIBUTE_ID_LABEL = "attributeIds"
[docs]@dataclass(withSlots=True)
class VerifierConfig:
    """Verifier config that policies should apply."""
    useExifInfo: bool
    useAutoRotation: bool
    faceDescriptorVersion: int
    facesBucket: str 
[docs]class VerifierMatchPolicy(BaseMatchPolicy):
    """Verifier match policy schema""" 
[docs]class VerifierDetectPolicy(BaseDetectPolicy):
    """Verifier detect policy""" 
[docs]class VerifierAttributeStorePolicy(BaseSchema):
    """Verifier attribute storage policy"""
    # whether to store attribute
    storeAttribute: types.Int01 = 0
[docs]    async def execute(self, events: list[Event], luna3Client: Client) -> None:
        """
        Save attributes.
        Args:
            events: events
            luna3Client: client
        """
        await AttributeStorePolicy(store_attribute=self.storeAttribute).execute(events, luna3Client)  
[docs]class VerifierFaceSampleStorePolicy(BaseSchema):
    """Verifier face sample storage policy"""
    # whether to store face sample
    storeSample: types.Int01 = 0
[docs]    async def execute(self, events: list[Event], sources: list[ProcessedImageData], bucket: str, luna3Client: Client):
        """
        Save face samples.
        Args:
            events: events
            bucket: bucket name
            luna3Client: client
        """
        await FaceSamplePolicy(store_sample=self.storeSample).execute(events, sources, bucket, luna3Client)  
[docs]class VerifierStoragePolicy(BaseSchema):
    """Verifier storage policy"""
    # attribute storage policy
    attributePolicy: Optional[VerifierAttributeStorePolicy] = VerifierAttributeStorePolicy()
    # face sample storage policy
    faceSamplePolicy: Optional[VerifierFaceSampleStorePolicy] = VerifierFaceSampleStorePolicy()
[docs]    async def execute(
        self, sources: list[ProcessedImageData], events: list[Event], luna3Client: Client, facesBucket: str
    ) -> HandlersMonitoringData:
        """
        Execute storage policy - save objects.
        Args:
            sources: origin image sources
            events: events
            facesBucket: faces samples bucket
            luna3Client: luna3 client
        Returns:
            monitoring data
        """
        async def _faceSample() -> None:
            with monitorTime(monitoringData.request, "face_sample_storage_policy_time"):
                await self.faceSamplePolicy.execute(events, sources, facesBucket, luna3Client)
        async def _attribute() -> None:
            with monitorTime(monitoringData.request, "face_attribute_storage_policy_time"):
                await self.attributePolicy.execute(events, luna3Client)
        monitoringData = HandlersMonitoringData()
        await _faceSample()
        # save attribute and face only after executing previous policies (^^^ samples are updated here ^^^)
        await _attribute()
        return monitoringData  
[docs]class VerifierPoliciesModel(BaseSchema):
    """Verifier policies"""
    # detect policy
    detectPolicy: VerifierDetectPolicy = Field(default_factory=lambda: VerifierDetectPolicy())
    # extract policy
    extractPolicy: VerifierExtractPolicy = VerifierExtractPolicy()
    # storage policy
    storagePolicy: VerifierStoragePolicy = VerifierStoragePolicy()
    # verification threshold
    verificationThreshold: types.Float01 = 0.9
    @cached_property
    def sdkTargets(self) -> set[HandlerEstimations]:
        """
        Prepare sdk task targets
        Returns:
            sdk task targets
        """
        targets = {
            LoopEstimations.faceDetection,
            LoopEstimationsAlwaysOn.faceLandmarks5,
            LoopEstimations.faceDescriptor,
        }
        # if there are filters, estimations should be present in the result - so lets put them into targets
        isHeadFiltersEnabled = (
            self.detectPolicy.yawThreshold is not None
            or self.detectPolicy.pitchThreshold is not None
            or self.detectPolicy.rollThreshold is not None
        )
        isLivenessFiltersEnabled = bool(self.detectPolicy.livenessStates)
        isMaskFiltersEnabled = bool(self.detectPolicy.maskStates)
        if self.detectPolicy.extractExif:
            targets.add(LoopEstimations.exif)
        if self.detectPolicy.detectLandmarks68:
            targets.add(LoopEstimations.faceLandmarks68)
        if self.detectPolicy.estimateQuality:
            targets.add(LoopEstimations.faceWarpQuality)
        if self.detectPolicy.estimateMouthAttributes:
            targets.add(LoopEstimations.mouthAttributes)
        if self.detectPolicy.estimateGaze:
            targets.add(LoopEstimations.gaze)
        if self.detectPolicy.estimateEyesAttributes:
            targets.add(LoopEstimations.eyes)
        if self.detectPolicy.estimateEmotions:
            targets.add(LoopEstimations.emotions)
        if self.detectPolicy.estimateMask or isMaskFiltersEnabled:
            targets.add(LoopEstimations.mask)
        if self.detectPolicy.estimateHeadPose or isHeadFiltersEnabled:
            targets.add(LoopEstimations.headPose)
        if self.detectPolicy.estimateLiveness.estimate or isLivenessFiltersEnabled:
            targets.add(LoopEstimations.livenessV1)
        if self.extractPolicy.extractBasicAttributes:
            targets.add(LoopEstimations.basicAttributes)
        return targets
    @cached_property
    def sdkFilters(self) -> Filters:
        """
        Prepare sdk task filters
        Returns:
            sdk task filters
        """
        def suitFilter(x):
            """ Return useful thresholds. """
            if x != MAX_ANGLE:
                return x
            return None
        maskStates = None
        if self.detectPolicy.maskStates:
            maskStates = [MaskState(x) for x in self.detectPolicy.maskStates]
        livenessStates = None
        if self.detectPolicy.livenessStates is not None:
            livenessStates = [LivenessPrediction(Liveness(x).name) for x in self.detectPolicy.livenessStates]
        faceFilters = FaceDetectionFilters(
            yawThreshold=suitFilter(self.detectPolicy.yawThreshold),
            pitchThreshold=suitFilter(self.detectPolicy.pitchThreshold),
            rollThreshold=suitFilter(self.detectPolicy.rollThreshold),
            gcThreshold=self.extractPolicy.fdScoreThreshold or None,
            maskStates=maskStates,
            livenessStates=livenessStates,
        )
        filters = Filters(faceDetection=faceFilters)
        return filters
[docs]    def prepareSDKTaskParams(self, config: VerifierConfig):
        """
        Prepare sdk task parameters
        Returns:
            sdk task parameters
        """
        return TaskParams(
            targets=self.sdkTargets | self.detectPolicy.faceQualityTargets,
            filters=self.sdkFilters,
            autoRotation=config.useAutoRotation,
            estimatorsParams=TaskEstimationParams(
                faceDescriptorVersion=config.faceDescriptorVersion,
                livenessv1=LivenessV1Params(
                    scoreThreshold=self.detectPolicy.estimateLiveness.livenessThreshold,
                    qualityThreshold=self.detectPolicy.estimateLiveness.qualityThreshold,
                ),
            ),
            multifacePolicy=MultifacePolicy(self.detectPolicy.multifacePolicy),
            useExifInfo=config.useExifInfo,
            aggregate=False,
        ) 
[docs]    async def execute(
        self,
        inputData: list[Union[RawDescriptorData, InputImage]],
        config: VerifierConfig,
        matchPolicies: list[VerifierMatchPolicy],
        accountId: str,
        luna3Client: Client,
        facesBucket: str,
    ) -> tuple[dict, HandlersMonitoringData]:
        """
        Executes given policies against provided data.
        Args:
            inputData: input data (images / raw descriptors)
            config: handler configuration parameters
            matchPolicies: MatchingPolicy instances
            accountId: A str, account id
            luna3Client: A luna3 client instance
            facesBucket: faces samples bucket
        Returns:
            * estimations in api format
            * monitoring data
        """
        processedSources, aggregatedSample, monitoringData = await executeSDKTask(
            params=self.prepareSDKTaskParams(config),
            inputData=inputData,
            useAutoRotation=config.useAutoRotation,
            sdkTargets=self.sdkTargets,
        )
        result = {"images": []}
        for source in processedSources:
            if isinstance(source, RawDescriptorData):
                sourceRes = {
                    "filename": source.filename,
                    "status": int(source.error == Error.Success),
                    "error": source.error.asDict(),
                    "detections": {"face_detections": {}, "filtered_detections": {"face_detections": []}},
                }
            else:
                image = source.image
                if isinstance(image.error, MultipleFaces):
                    raise VLException(image.error.error, 400, isCriticalError=False)
                if self.detectPolicy.isFaceQualityChecksEnabled():
                    self.detectPolicy.faceQuality.processSource(source)
                    monitoringData.sdkUsages.faceQualityEstimator = monitoringData.sdkUsages.faceDetector
                sourceRes = {
                    "filename": image.origin.filename,
                    "status": int(not image.error),
                    "error": (source.meta.error or image.error or Error.Success).asDict(),
                    "detections": {"face_detections": {}, "filtered_detections": {"face_detections": []}},
                }
                if LoopEstimations.exif in self.sdkTargets and image.exif is not None:
                    sourceRes["exif"] = image.exif
                sourceRes["detections"]["filtered_detections"]["face_detections"] = buildImageFilteredDetections(
                    image=image, estimationTargets=self.sdkTargets
                )
            eventMeta = EventMetadata(accountId=accountId, handlerId=None, createEventTime=None, endEventTime=None)
            events = APISDKHandlerAdaptor.createEvents([source], aggregatedSample, eventMeta, self.sdkTargets)
            # matching
            with monitorTime(monitoringData.request, "match_policy_time"):
                await asyncio.gather(
                    *[matchByListPolicy.execute(events, luna3Client) for matchByListPolicy in matchPolicies]
                )
            # storage
            imageSources = [source for source in processedSources if isinstance(source, ProcessedImageData)]
            await self.storagePolicy.execute(
                sources=imageSources, events=events, facesBucket=facesBucket, luna3Client=luna3Client,
            )
            sourceRes["detections"]["face_detections"] = self._buildFaceDetections(events)
            result["images"].append(sourceRes)
        return result, monitoringData 
    def _buildFaceDetections(self, events: list[Event]) -> list[dict]:
        """
        Build face detections from events.
        Args:
            events: events
        Returns:
            face detections in api format
        """
        faceDetections = []
        for event in events:
            faceDetection = {
                "sample": None,
                "face_attributes": event.faceAttributes.asDict() if event.faceAttributes else None,
                "verifications": self.buildVerifications(event.matches),
            }
            if event.sdkEstimations:
                # There should be only one detection since aggregation is disabled.
                sample = event.sdkEstimations[0].detection.face
                faceDetection["sample"] = {
                    "face": {
                        "sample_id": sample.sampleId,
                        "url": sample.url,
                        **createFaceDetectionJson(sample.sdkEstimation, self.sdkTargets),
                    }
                }
            faceDetections.append(faceDetection)
        return faceDetections
[docs]    def buildVerifications(self, matches: list[EventMatchResult]) -> list[dict]:
        """
        Build verification result in api format.
        Args:
            matches: matches result
        Returns:
            verification result in api format
        """
        if not matches:
            return []
        verifications = []
        for match in matches:
            label = match.matchingLabel
            for candidate in match.candidates:
                similarity = candidate["similarity"]
                verification = {"similarity": similarity, "status": similarity >= self.verificationThreshold}
                if label == MATCH_BY_FACE_ID_LABEL:
                    verification["face"] = {"face_id": candidate["face"]["face_id"]}
                elif label == MATCH_BY_EVENT_ID_LABEL:
                    verification["event"] = {"event_id": candidate["event"]["event_id"]}
                elif label == MATCH_BY_EXTERNAL_ID_LABEL:
                    verification["face"] = {"external_id": candidate["face"]["external_id"]}
                elif label == MATCH_BY_ATTRIBUTE_ID_LABEL:
                    verification["attribute"] = {"attribute_id": candidate["attribute"]["attribute_id"]}
                verifications.append(verification)
        return verifications  
[docs]class VerifierModel(BaseSchema):
    """Verifier"""
    # verifier description
    description: types.Str128 = ""
    # verifier policies
    policies: VerifierPoliciesModel = Field(default_factory=lambda: VerifierPoliciesModel())
    # verifier account id
    accountId: UUID