"""
Module contains schemas for policies
"""
import asyncio
from functools import cached_property
from typing import Any, Type, TypeVar, Union
from uuid import UUID
from luna3.client import Client
from luna3.common.exceptions import LunaApiException
from lunavl.sdk.estimators.face_estimators.livenessv1 import LivenessPrediction
from lunavl.sdk.estimators.face_estimators.mask import MaskState
from pydantic import Field, ValidationError, root_validator
from pydantic.error_wrappers import ErrorWrapper
from vlutils.jobs.async_runner import AsyncRunner
from vlutils.structures.dataclasses import dataclass
from app.api_sdk_adaptors.base import (
    HandlerEstimations,
    LoopEstimationsAlwaysOn,
    buildImageFilteredDetections,
    executeSDKTask,
)
from app.api_sdk_adaptors.handler import APISDKHandlerAdaptor
from app.global_vars.constants import MAX_ANGLE
from app.global_vars.context_vars import requestIdCtx
from classes.event import Event, EventMetadata
from classes.image_meta import ProcessedImageData
from classes.monitoring import HandlersMonitoringData
from classes.raw_descriptor_data import RawDescriptorData
from classes.schemas.base_schema import BaseSchema
from classes.schemas.conditional_tags_policy import ConditionalTagsPolicy
from classes.schemas.detect_policies import HandlerDetectPolicy
from classes.schemas.extract_policies import HandlerExtractPolicy
from classes.schemas.filters import AttributesFilters, MatchFilter
from classes.schemas.match_policy import MatchPolicy
from classes.schemas.storage_policy import StoragePolicy, StorePolicyConfig
from classes.schemas.types import MAX_POLICY_LIST_LENGTH
from configs.config import PLUGINS_PUBLISHING_CONCURRENCY
from crutches_on_wheels.enums.attributes import Liveness
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.monitoring.points import monitorTime
from crutches_on_wheels.plugins.manager import PluginManager
from redis_db.redis_context import RedisContext
from sdk.sdk_loop.enums import LoopEstimations, MultifacePolicy
from sdk.sdk_loop.errors.errors import MultipleFaces
from sdk.sdk_loop.models.image import InputImage
from sdk.sdk_loop.tasks.filters import FaceDetectionFilters, Filters
from sdk.sdk_loop.tasks.task import LivenessV1Params, TaskEstimationParams, TaskParams
T = TypeVar("T")
[docs]def getObjectRecursively(data: Any, expectedType: Type[T]) -> list[T]:
    """Recursively get object of expected type"""
    res = []
    def collectObjects(dataPart: Any) -> None:
        """Collect object of expected type to 'res'"""
        if isinstance(dataPart, expectedType):
            res.append(dataPart)
        if isinstance(dataPart, list):
            [collectObjects(row) for row in dataPart]
        if isinstance(dataPart, BaseSchema):
            [collectObjects(row) for row in dataPart.__dict__.values()]
    collectObjects(data)
    return res 
[docs]@dataclass(withSlots=True)
class HandlerConfig(StorePolicyConfig):
    """Handler config that policies should apply."""
    aggregate: bool
    useExifInfo: bool
    useAutoRotation: bool
    faceDescriptorVersion: int
    bodyDescriptorVersion: int 
[docs]class Policies(BaseSchema):
    """Policies schema"""
    # detect policy
    detectPolicy: HandlerDetectPolicy = Field(default_factory=lambda: HandlerDetectPolicy())
    # extract policy
    extractPolicy: HandlerExtractPolicy = HandlerExtractPolicy()
    # matching policy list
    matchPolicy: list[MatchPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH)
    # conditional tags policy list
    conditionalTagsPolicy: list[ConditionalTagsPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH)
    # storage policy
    storagePolicy: StoragePolicy = Field(default_factory=lambda: StoragePolicy())
[docs]    @staticmethod
    def validateMatchAndExtractCompatibility(matchPolicies: list[MatchPolicy], extractPolicy: HandlerExtractPolicy):
        """Validate match and extract policies compatibility"""
        if len(matchPolicies) and not extractPolicy.extractFaceDescriptor:
            raise ValueError("extract_face_descriptor should be equal to 1 for using matching policy") 
[docs]    @staticmethod
    def validateMatchLabelsCompatibility(
        matchPolicies: list[MatchPolicy],
        conditionalTagsPolicies: list[ConditionalTagsPolicy],
        storagePolicy: StoragePolicy,
    ):
        """Validate matching label compatibility"""
        matchPolicyMatchingLabels = {matchPolicy.label for matchPolicy in matchPolicies}
        matchFiltersMatchingLabels = {
            matchFilter.label
            for matchFilter in getObjectRecursively([conditionalTagsPolicies, storagePolicy], MatchFilter)
        }
        for matchFilterLabel in matchFiltersMatchingLabels:
            if matchFilterLabel not in matchPolicyMatchingLabels:
                raise ValueError(
                    f'"{matchFilterLabel}" should be in match policy for filtration based on a matching by this label'
                ) 
[docs]    @staticmethod
    def validateGeneratedAttributesFilters(
        detectPolicy: HandlerDetectPolicy,
        extractPolicy: HandlerExtractPolicy,
        matchPolicies: list[MatchPolicy],
        conditionalTagsPolicies: list[ConditionalTagsPolicy],
        storagePolicy: StoragePolicy,
    ):
        """Validate attributes and detect/extract policy compatibility"""
        attributeFilters: list[AttributesFilters] = getObjectRecursively(
            data=[matchPolicies, conditionalTagsPolicies, storagePolicy], expectedType=AttributesFilters
        )
        if not extractPolicy.extractBasicAttributes:
            basicAttributesNeeded = any(
                any(
                    (
                        filters.ethnicities is not None,
                        filters.ageLt is not None,
                        filters.ageGte is not None,
                        filters.gender is not None,
                    )
                )
                for filters in attributeFilters
            )
            if basicAttributesNeeded:
                raise ValueError(
                    "extract_basic_attributes should be equal to 1 for filtration based on basic attributes"
                )
        if not detectPolicy.estimateLiveness.estimate:
            livenessNeeded = any((filters.liveness is not None) for filters in attributeFilters)
            if livenessNeeded:
                raise ValueError("estimate_liveness.estimate should be equal to 1 for filtration based on liveness") 
[docs]    @staticmethod
    def validateMatchPolicyUniqueLabels(matchPolicies: list[MatchPolicy]):
        """Validate match policy matching label uniqueness"""
        labels = [matchPolicy.label for matchPolicy in matchPolicies]
        if len(labels) != len(set(labels)):
            error = ValueError("Matching allowed only by unique labels")
            raise ValidationError([ErrorWrapper(exc=error, loc=("match_policy"))], Policies) 
[docs]    @staticmethod
    def validateDetectPolicyNotEmpty(detectPolicy: HandlerDetectPolicy):
        """Validate non-empty detect policy"""
        if not detectPolicy.detectFace and not detectPolicy.detectBody:
            raise ValueError("At least one of *detect_face* or *detect_body* should be equal to 1")
        if not detectPolicy.detectFace and MultifacePolicy(detectPolicy.multifacePolicy) is not MultifacePolicy.allowed:
            raise ValueError("*detect_face* should be equal to 1 to set *multiface_policy* to 0 or 2") 
[docs]    @staticmethod
    def validateDetectAndExtractCompatibility(detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy):
        """Validate detect and extract policies compatibility"""
        if extractPolicy.extractFaceDescriptor:
            if not detectPolicy.detectFace:
                raise ValueError("*detect_face* should be equal to 1 to enable *extract_face_descriptor*")
        if extractPolicy.extractBasicAttributes:
            if not detectPolicy.detectFace:
                raise ValueError("*detect_face* should be equal to 1 to enable *extract_basic_attributes*")
        if extractPolicy.extractBodyDescriptor:
            if not detectPolicy.detectBody:
                raise ValueError("*detect_body* should be equal to 1 to enable *extract_body_descriptor*") 
[docs]    @root_validator(skip_on_failure=True)
    def validatePolicies(cls, values):
        """Execute all compatibility validators"""
        detectPolicy = values["detectPolicy"]
        matchPolicies = values["matchPolicy"]
        extractPolicy = values["extractPolicy"]
        conditionalTagsPolicies = values["conditionalTagsPolicy"]
        storagePolicy = values["storagePolicy"]
        cls.validateDetectPolicyNotEmpty(detectPolicy=detectPolicy)
        cls.validateDetectAndExtractCompatibility(detectPolicy=detectPolicy, extractPolicy=extractPolicy)
        cls.validateMatchPolicyUniqueLabels(matchPolicies=matchPolicies)
        cls.validateMatchAndExtractCompatibility(matchPolicies=matchPolicies, extractPolicy=extractPolicy)
        cls.validateMatchLabelsCompatibility(
            matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy
        )
        cls.validateGeneratedAttributesFilters(
            detectPolicy=detectPolicy,
            extractPolicy=extractPolicy,
            matchPolicies=matchPolicies,
            conditionalTagsPolicies=conditionalTagsPolicies,
            storagePolicy=storagePolicy,
        )
        return values 
[docs]    def getListIdsFromPolicies(self) -> list[UUID]:
        """
        Get list ids from matching and link to lists policies.
        Returns:
            list ids
        """
        candidateLists = list(
            filter(None, (match.candidates.listId for match in self.matchPolicy if hasattr(match.candidates, "listId")))
        )
        return [linkPolicy.listId for linkPolicy in self.storagePolicy.facePolicy.linkToListsPolicy] + candidateLists 
[docs]    async def checkListsAvailability(self, luna3Client: Client, accountId: str) -> None:
        """
        Check availability of lists from matching and link to list policies.
        Args:
            luna3Client: luna platform client
            accountId: account id
        Raises:
            VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found
        """
        for listId in self.getListIdsFromPolicies():
            try:
                await luna3Client.lunaFaces.checkList(listId=str(listId), accountId=accountId, raiseError=True)
            except LunaApiException as e:
                if e.statusCode == 404:
                    raise VLException(Error.ListNotFound.format(listId), 400, isCriticalError=False)
                raise 
[docs]    @classmethod
    async def onStartup(cls):
        """Init Policies"""
        cls.pluginsAsyncRunner = AsyncRunner(PLUGINS_PUBLISHING_CONCURRENCY, closeTimeout=1) 
[docs]    @classmethod
    async def onShutdown(cls):
        """Stop Policies"""
        await cls.pluginsAsyncRunner.close() 
    @cached_property
    def sdkTargets(self) -> set[HandlerEstimations]:
        """
        Prepare sdk task targets
        Returns:
            sdk task targets
        """
        targets = set()
        if self.detectPolicy.extractExif:
            targets.add(LoopEstimations.exif)
        if self.detectPolicy.detectFace:
            targets.add(LoopEstimationsAlwaysOn.faceLandmarks5)
            targets.add(LoopEstimations.faceDetection)
            # if there are filters, estimations should be present in the result - so lets put them into targets
            isHeadFiltersEnabled = (
                self.detectPolicy.yawThreshold is not None
                or self.detectPolicy.pitchThreshold is not None
                or self.detectPolicy.rollThreshold is not None
            )
            isLivenessFiltersEnabled = bool(self.detectPolicy.livenessStates)
            isMaskFiltersEnabled = bool(self.detectPolicy.maskStates)
            if self.detectPolicy.detectLandmarks68:
                targets.add(LoopEstimations.faceLandmarks68)
            if self.detectPolicy.estimateQuality:
                targets.add(LoopEstimations.faceWarpQuality)
            if self.detectPolicy.estimateMouthAttributes:
                targets.add(LoopEstimations.mouthAttributes)
            if self.detectPolicy.estimateGaze:
                targets.add(LoopEstimations.gaze)
            if self.detectPolicy.estimateEyesAttributes:
                targets.add(LoopEstimations.eyes)
            if self.detectPolicy.estimateEmotions:
                targets.add(LoopEstimations.emotions)
            if self.detectPolicy.estimateMask or isMaskFiltersEnabled:
                targets.add(LoopEstimations.mask)
            if self.detectPolicy.estimateHeadPose or isHeadFiltersEnabled:
                targets.add(LoopEstimations.headPose)
            if self.detectPolicy.estimateLiveness.estimate or isLivenessFiltersEnabled:
                targets.add(LoopEstimations.livenessV1)
        if self.extractPolicy.extractBasicAttributes:
            targets.add(LoopEstimations.basicAttributes)
        if self.extractPolicy.extractFaceDescriptor:
            targets.add(LoopEstimations.faceDescriptor)
        if self.detectPolicy.detectBody:
            targets.add(LoopEstimations.bodyDetection)
        if self.extractPolicy.extractBodyDescriptor:
            targets.add(LoopEstimations.bodyDescriptor)
        return targets
    @cached_property
    def sdkFilters(self) -> Filters:
        """
        Prepare sdk task filters
        Returns:
            sdk task filters
        """
        def suitFilter(x):
            """Return useful thresholds."""
            if x != MAX_ANGLE:
                return x
            return None
        maskStates = None
        if self.detectPolicy.maskStates:
            maskStates = [MaskState(x) for x in self.detectPolicy.maskStates]
        livenessStates = None
        if self.detectPolicy.livenessStates is not None:
            livenessStates = [LivenessPrediction(Liveness(x).name) for x in self.detectPolicy.livenessStates]
        faceFilters = (
            FaceDetectionFilters(
                yawThreshold=suitFilter(self.detectPolicy.yawThreshold),
                pitchThreshold=suitFilter(self.detectPolicy.pitchThreshold),
                rollThreshold=suitFilter(self.detectPolicy.rollThreshold),
                livenessStates=livenessStates,
                gcThreshold=self.extractPolicy.fdScoreThreshold or None,  # ignore 0.0
                maskStates=maskStates,
            )
            if self.detectPolicy.detectFace
            else FaceDetectionFilters()
        )
        filters = Filters(faceDetection=faceFilters)
        return filters
[docs]    def publishEventsToPlugins(self, events: list[Event], plugins: PluginManager) -> None:
        """
        Publish events to other services.
        Args:
            events: list of events
            plugins: plugin manager
        """
        futures = [plugins.sendEventToPlugins("sending_event", events, requestIdCtx.get())]
        self.pluginsAsyncRunner.runNoWait(futures) 
[docs]    def prepareSDKTaskParams(self, config: HandlerConfig):
        """
        Prepare sdk task parameters
        Returns:
            sdk task parameters
        """
        return TaskParams(
            targets=self.sdkTargets | self.detectPolicy.faceQualityTargets,
            filters=self.sdkFilters,
            estimatorsParams=TaskEstimationParams(
                faceDescriptorVersion=config.faceDescriptorVersion,
                bodyDescriptorVersion=config.bodyDescriptorVersion,
                livenessv1=LivenessV1Params(
                    scoreThreshold=self.detectPolicy.estimateLiveness.livenessThreshold,
                    qualityThreshold=self.detectPolicy.estimateLiveness.qualityThreshold,
                ),
            ),
            multifacePolicy=MultifacePolicy(self.detectPolicy.multifacePolicy),
            useExifInfo=config.useExifInfo,
            autoRotation=config.useAutoRotation,
            aggregate=config.aggregate,
        ) 
[docs]    async def execute(
        self,
        inputData: list[Union[RawDescriptorData, InputImage]],
        eventMetadata: EventMetadata,
        config: HandlerConfig,
        luna3Client: Client,
        redisContext: RedisContext,
        plugins: PluginManager,
    ) -> tuple[dict, HandlersMonitoringData]:
        """
        Execute all policies for handler.
        Args:
            inputData: input data (images / raw descriptors)
            eventMetadata: user defined event metadata
            config: handler configuration parameters
            luna3Client: luna platform client
            redisContext: redis context
            plugins: plugin manager
        Returns:
            * estimations in api format
            * monitoring data
        """
        processedSources, aggregatedSample, monitoringData = await executeSDKTask(
            params=self.prepareSDKTaskParams(config),
            inputData=inputData,
            useAutoRotation=config.useAutoRotation,
            sdkTargets=self.sdkTargets,
        )
        result = {"images": [], "events": [], "filtered_detections": {"face_detections": []}}
        for source in processedSources:
            if isinstance(source, RawDescriptorData):
                imageRes = {
                    "filename": source.filename,
                    "status": int(source.error == Error.Success),
                    "error": source.error.asDict(),
                }
            else:
                image = source.image
                if isinstance(image.error, MultipleFaces):
                    raise VLException(image.error, 400, isCriticalError=False)
                if self.detectPolicy.isFaceQualityChecksEnabled():
                    self.detectPolicy.faceQuality.processSource(source)
                    monitoringData.sdkUsages.faceQualityEstimator = monitoringData.sdkUsages.faceDetector
                imageRes = {
                    "filename": image.origin.filename,
                    "status": int(not image.error),
                    "error": (source.meta.error or image.error or Error.Success).asDict(),
                }
                if image.exif is not None:
                    imageRes["exif"] = image.exif
                result["filtered_detections"]["face_detections"].extend(
                    buildImageFilteredDetections(image=image, estimationTargets=self.sdkTargets)
                )
            result["images"].append(imageRes)
        events = APISDKHandlerAdaptor.createEvents(processedSources, aggregatedSample, eventMetadata, self.sdkTargets)
        # matching
        with monitorTime(monitoringData.request, "match_policy_time"):
            await asyncio.gather(
                *[matchByListPolicy.execute(events, luna3Client) for matchByListPolicy in self.matchPolicy]
            )
        # tags
        for policy in self.conditionalTagsPolicy:
            policy.execute(events)
        # storage
        imageSources = [source for source in processedSources if isinstance(source, ProcessedImageData)]
        monitoringData += await self.storagePolicy.execute(
            sources=imageSources,
            events=events,
            config=config,
            luna3Client=luna3Client,
            redisContext=redisContext,
        )
        self.publishEventsToPlugins(events=events, plugins=plugins)
        result["events"] = [event.asDict() for event in events]
        return result, monitoringData