""" SDK estimator handler"""
from typing import List
import msgpack
from sanic.response import HTTPResponse
from app.api_sdk_adaptors.base import LoopEstimationsAlwaysOn
from app.api_sdk_adaptors.orientation import handleImageOrientation
from app.api_sdk_adaptors.sdk_adaptor import APISDKAdaptor
from app.handlers.base_handler import BaseHandlerWithMultipart
from app.handlers.custom_query_getters import (
    multifacePolicyGetter,
    int0180Getter,
    maskStatesValidator,
    livenessStatesValidator,
)
from classes.image_meta import InputImageData
from classes.multipart_processing import SDKMultipartProcessor
from classes.schemas.sdk import SDKInputEstimationsModel
from configs.config import LIVENESS_V2_QUALITY_THRESHOLD
from crutches_on_wheels.errors.errors import ErrorInfo
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.monitoring.points import monitorTime
from crutches_on_wheels.web.query_getters import int01Getter, float01Getter, boolFrom01Getter
from sdk.sdk_loop.enums import LoopEstimations, MultifacePolicy
from sdk.sdk_loop.models.image import ImageType
from sdk.sdk_loop.task import HandlersTask
from sdk.sdk_loop.tasks.filters import FaceDetectionFilters, Filters
from sdk.sdk_loop.tasks.task import TaskParams, TaskEstimationParams, LivenessV1Params
[docs]class SDKHandler(BaseHandlerWithMultipart):
    """
    SDK estimator handler
    Resource: "/{api_version}/sdk"
    """
[docs]    async def getDataFromMultipart(self, imageType: ImageType = ImageType.IMAGE) -> List[InputImageData]:
        """Description see :func:`~BaseHandlerWithMultipart.getDataFromMultipart`."""
        dataFromRequest = await SDKMultipartProcessor().getData(self.request)
        estimationDataFromMultiPart = self._getDataFromMultipart(dataFromRequest.images, imageType)
        return estimationDataFromMultiPart 
[docs]    async def post(self) -> HTTPResponse:
        """
        SDK estimations handler. See `spec sdk`_.
        .. _`spec sdk`:
            _static/api.html#operation/sdk
        Returns:
            Response with estimations
        """
        targets = set()
        if self.getQueryParam("detect_face", int01Getter, default=0):
            targets.add(LoopEstimations.faceDetection)
        if self.getQueryParam("estimate_landmarks5", int01Getter, default=0):
            targets.add(LoopEstimations.faceDetection)
            targets.add(LoopEstimationsAlwaysOn.faceLandmarks5)
        if self.getQueryParam("estimate_landmarks68", int01Getter, default=0):
            targets.add(LoopEstimations.faceLandmarks68)
        if self.getQueryParam("estimate_liveness", int01Getter, default=0):
            targets.add(LoopEstimations.livenessV1)
        if self.getQueryParam("estimate_head_pose", int01Getter, default=0):
            targets.add(LoopEstimations.headPose)
        if self.getQueryParam("estimate_gaze", int01Getter, default=0):
            targets.add(LoopEstimations.gaze)
        if self.getQueryParam("estimate_eyes_attributes", int01Getter, default=0):
            targets.add(LoopEstimations.eyes)
        if self.getQueryParam("estimate_mouth_attributes", int01Getter, default=0):
            targets.add(LoopEstimations.mouthAttributes)
        if self.getQueryParam("estimate_emotions", int01Getter, default=0):
            targets.add(LoopEstimations.emotions)
        if self.getQueryParam("estimate_mask", int01Getter, default=0):
            targets.add(LoopEstimations.mask)
        if self.getQueryParam("estimate_glasses", int01Getter, default=0):
            targets.add(LoopEstimations.glasses)
        if self.getQueryParam("estimate_face_warp", int01Getter, default=0):
            targets.add(LoopEstimations.faceWarp)
        if self.getQueryParam("estimate_quality", int01Getter, default=0):
            targets.add(LoopEstimations.faceWarpQuality)
        if self.getQueryParam("estimate_basic_attributes", int01Getter, default=0):
            targets.add(LoopEstimations.basicAttributes)
        if self.getQueryParam("estimate_face_descriptor", int01Getter, default=0):
            targets.add(LoopEstimations.faceDescriptor)
        if self.getQueryParam("detect_body", int01Getter, default=0):
            targets.add(LoopEstimations.bodyDetection)
        if self.getQueryParam("estimate_body_warp", int01Getter, default=0):
            targets.add(LoopEstimations.bodyWarp)
        if self.getQueryParam("estimate_body_descriptor", int01Getter, default=0):
            targets.add(LoopEstimations.bodyDescriptor)
        pitchThreshold = self.getQueryParam("pitch_threshold", int0180Getter)
        rollThreshold = self.getQueryParam("roll_threshold", int0180Getter)
        yawThreshold = self.getQueryParam("yaw_threshold", int0180Getter)
        scoreThreshold = self.getQueryParam("score_threshold", float01Getter)
        maskStates = self.getQueryParam("mask_states", maskStatesValidator)
        livenessStates = self.getQueryParam("liveness_states", livenessStatesValidator)
        faceFilters = FaceDetectionFilters(
            yawThreshold=yawThreshold,
            pitchThreshold=pitchThreshold,
            rollThreshold=rollThreshold,
            gcThreshold=scoreThreshold,
            maskStates=maskStates,
            livenessStates=livenessStates,
        )
        params = TaskParams(
            targets=targets,
            filters=Filters(faceDetection=faceFilters),
            estimatorsParams=TaskEstimationParams(
                livenessv1=LivenessV1Params(qualityThreshold=LIVENESS_V2_QUALITY_THRESHOLD),
                faceDescriptorVersion=self.config.defaultFaceDescriptorVersion,
                bodyDescriptorVersion=self.config.defaultHumanDescriptorVersion,
            ),
            multifacePolicy=self.getQueryParam(
                "multiface_policy", multifacePolicyGetter, default=MultifacePolicy.allowed
            ),
            useExifInfo=self.getQueryParam("use_exif_info", boolFrom01Getter, default=True),
            autoRotation=self.config.useAutoRotation,
            aggregate=self.getQueryParam("aggregate_attributes", int01Getter, default=0),
        )
        imageType = self.getQueryParam("image_type", lambda x: ImageType(int(x)), default=ImageType.IMAGE)
        with monitorTime(self.request.dataForMonitoring, "load_images_for_processing_time"):
            inputData = await self.getInputEstimationData(
                self.request, imageType=imageType, validationModel=SDKInputEstimationsModel,
            )
        self.checkLivenessEstimationLicensing(LoopEstimations.livenessV1 in targets or faceFilters.livenessFilter)
        task = HandlersTask(data=[metaImage.image for metaImage in inputData], params=params)
        await task.execute()
        if task.result.error:
            raise VLException(ErrorInfo.fromDict(task.result.error.asDict()), 400, isCriticalError=False)
        if self.config.useAutoRotation:
            handleImageOrientation(task.result.images)
        encodeToBase64 = True
        if (responseContentType := self.getResponseContentType()) == "application/msgpack":
            encodeToBase64 = False
        sdkAdaptor = APISDKAdaptor(
            estimationTargets=targets, aggregationEnabled=params.aggregate, encodeBytesToBase64=encodeToBase64,
        )
        result, monitoringData = await sdkAdaptor.buildResult(
            task.result, meta=[metaImage.meta for metaImage in inputData]
        )
        self.countLivenessEstimationsPerformed(monitoringData.sdkUsages.livenessEstimator)
        self.handleMonitoringData(monitoringData)
        if responseContentType == "application/msgpack":
            body = msgpack.packb(result, use_bin_type=True)
            return self.success(200, body=body, contentType="application/msgpack")
        return self.success(200, outputJson=result)