# -*- coding: utf-8 -*-
""" Base handler
Module realize base class for all handlers.
"""
import base64
import binascii
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import aiohttp
from luna3.client import Client
from lunavl.sdk.image_utils.geometry import Rect
from pydantic import BaseModel
from werkzeug.http import parse_accept_header
from yarl import URL
from app.app import BaseHandlersRequestHandler, HandlersRequest
from app.global_vars.context_vars import requestIdCtx
from app.handlers.available_content_types import (
    MAP_BASE64_TYPE_TO_DATA_TYPE,
    isAllowableContentType,
    isAllowableRawContentType,
)
from classes.functions import loadDataFromJson
from classes.image_meta import ImageMeta, InputImageData
from classes.monitoring import HandlersMonitoringData
from classes.multipart_processing import ImageWithBB, ImageWithFaceBB
from classes.raw_descriptor_data import RawDescriptorData
from classes.schemas.base_schema import BaseSchema
from classes.schemas.policies import Policies
from classes.schemas.verifier import VerifierPoliciesModel as VerifierPolicies
from crutches_on_wheels.errors.errors import Error, ErrorInfo
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.utils.functions import convertDateTimeToCurrentFormatStr, currentDateTime, downloadImage
from img_utils.utils import convertToBytesIfNeed
from sdk.sdk_loop.models.image import ImageType, InputImage
[docs]class BaseHandler(BaseHandlersRequestHandler):
    """
    Base handler for other handlers.
    Attributes:
        luna3Client (luna3.client.Client): luna3 client
        dbContext (DBContext): db context
        redisContext (RedisContext): redis context
    """
    def __init__(self, request: HandlersRequest):
        super().__init__(request)
        requestIdCtx.set(self.requestId)
        self.luna3Client: Client = request.luna3Client
        self.dbContext = request.dbContext
        self.redisContext = request.redisContext
        self.accountId: Optional[str] = None
[docs]    def checkLivenessEstimationLicensing(self, estimate: int):
        """
        Check liveness estimation licensing
        Args:
            estimate: liveness estimation status
        Raises:
            VLException(Error.LicenseProblem) if liveness estimation disabled
        """
        if not estimate:
            return
        if not self.app.ctx.licenseChecker.licenseState:
            raise VLException(Error.LicenseProblem.format("Cannot get license information."), 403, False)
        if not self.app.ctx.licenseChecker.licenseState.expirationTime.isAvailable:
            raise VLException(Error.LicenseProblem.format("License expired"), 403, False)
        if self.app.ctx.licenseChecker.licenseState.liveness.value is None:
            raise VLException(Error.LicenseProblem.format("Liveness feature disabled"), 403, False)
        if self.app.ctx.licenseChecker.licenseState.liveness.value != 2:
            raise VLException(Error.LicenseProblem.format("Liveness v.2 feature disabled"), 403, False)
        if self.app.ctx.licenseChecker.licenseState.livenessBalance is None:
            # licensing by expiration, not by executions
            return
        if not self.app.ctx.licenseChecker.licenseState.livenessBalance.isAvailable:
            raise VLException(Error.LicenseProblem.format("Liveness balance is exceeded"), 403, False)
        if not self.app.ctx.licenseRecorder.isLivenessSynchronized():
            raise VLException(Error.LicenseProblem.format("Feature execution synchronization failed"), 403, False) 
[docs]    def checkPolicyLicensing(self, policies: Union[Policies, VerifierPolicies]):
        """
        Check handler policies licensing
        Args:
            policies: handler policies
        """
        self.checkLivenessEstimationLicensing(policies.detectPolicy.estimateLiveness.estimate) 
    async def _downloadImage(self, url: Union[str, URL]) -> tuple[bytes, str]:
        """
        Download image by external url
        Args:
            url: url
        Returns:
            image and content type
        Raises:
            VLException(Error.BadContentTypeDownloadedImage.format(url), 400, isCriticalError=False):
                if a downloaded image content type is not allowable
        """
        clientTimeout = aiohttp.ClientTimeout(
            total=self.config.loadExternalImageTimeout.totalTimeout,
            connect=self.config.loadExternalImageTimeout.connectTimeout,
            sock_connect=self.config.loadExternalImageTimeout.sockConnectTimeout,
            sock_read=self.config.loadExternalImageTimeout.sockReadTimeout,
        )
        imageBody, contentType = await downloadImage(
            url=url, logger=self.logger, timeout=clientTimeout, accountId=self.accountId
        )
        if not isAllowableRawContentType(contentType):
            raise VLException(Error.BadContentTypeDownloadedImage.format(url), 400, isCriticalError=False)
        return imageBody, contentType
    def _getRawDataContainer(
        self,
        body: bytes,
        contentType: str,
        imageType: Union[ImageType, None],
        fileName: Optional[str] = None,
        faceBoundingBoxList: Optional[List[dict]] = None,
        bodyBoundingBoxList: Optional[List[dict]] = None,
        sampleId: Optional[str] = None,
        url: Optional[str] = None,
        detectTime: Optional[str] = None,
        detectTs: Optional[float] = None,
        imageOrigin: Optional[str] = None,
        error: Optional[ErrorInfo] = None,
    ) -> Union[InputImageData, RawDescriptorData]:
        """
        Get raw data container: detectable image object or raw descriptor data
        Args:
            body: binary data
            contentType: expected data content type from request
            fileName: filename
            imageType: image type
            sampleId: sample id for warped image
            faceBoundingBoxList: list with detection rectangles
            url: image source
            detectTime: detection time in ISO format
            detectTs: user-defined timestamp relative to something, such as the start of a video
            imageOrigin: image origin
            error: image load error
        Returns:
            prepared raw data container
        Raises:
            VLException(Error.OnlyOneDetectionRectAvailable, 403, False) if there are more than 1 bounding box
            VLException(Error.BadContentType, 400, False) if image content type is not allowable
            VLException(Error.BoundingBoxNotAvailableForWarp, 400, False) if try to use bounding box for warp
        """
        if not error:
            if not isAllowableContentType(contentType, allowRawDescriptors=True):
                raise VLException(Error.BadContentType, 400, isCriticalError=False)
            data, contentType = convertToBytesIfNeed(body, contentType)
            if contentType in ("application/x-sdk-descriptor", "application/x-vl-xpk"):
                allowedVersions = [self.config.defaultFaceDescriptorVersion, self.config.defaultHumanDescriptorVersion]
                return RawDescriptorData(data, mimetype=contentType, filename=fileName, allowedVersions=allowedVersions)
            if faceBoundingBoxList and len(faceBoundingBoxList) > 1:
                raise VLException(Error.OnlyOneDetectionRectAvailable, 403, isCriticalError=False)
            if bodyBoundingBoxList and len(bodyBoundingBoxList) > 1:
                raise VLException(Error.OnlyOneDetectionRectAvailable, 403, isCriticalError=False)
            if (faceBoundingBoxList or bodyBoundingBoxList) and imageType in (ImageType.FACE_WARP, ImageType.BODY_WARP):
                raise VLException(Error.BoundingBoxNotAvailableForWarp, 400, False)
        else:
            data = body
        faceBoxes = [Rect(**faceBoundingBoxList[0])] if faceBoundingBoxList is not None else None
        bodyBoxes = [Rect(**bodyBoundingBoxList[0])] if bodyBoundingBoxList is not None else None
        image = InputImageData(
            image=InputImage(
                filename=fileName or "raw image",
                body=data,
                imageType=imageType or ImageType.IMAGE,
                faceBoxes=faceBoxes,
                bodyBoxes=bodyBoxes,
            ),
            meta=ImageMeta(
                sampleId=sampleId,
                url=url,
                detectTime=detectTime,
                imageOrigin=imageOrigin,
                error=error,
                detectTs=detectTs,
            ),
        )
        return image
    async def _getImagesFromJson(
        self, inputJson: dict, imageType: ImageType, defaultDetectTime: str, allowRawDescriptors: bool
    ) -> List[InputImageData]:
        """
        Get images from request json
        Args:
            inputJson: json from request
            imageType: image type
            defaultDetectTime: image detection time in ISO format
            allowRawDescriptors: whether raw descriptor mimetypes allowed or not
        Returns:
            list of prepared SDKDetectableImage or FaceWarp or HumanWarp
        Raises:
            VLException(Error.BadInputJson, 400, False) if failed decode descriptor
        """
        try:
            contentType = inputJson["mimetype"]
            if not isAllowableContentType(contentType, allowRawDescriptors=allowRawDescriptors):
                raise VLException(Error.BadContentType, 400, isCriticalError=False)
            image = base64.b64decode(inputJson["image"])
            contentType = MAP_BASE64_TYPE_TO_DATA_TYPE.get(contentType, contentType)
            return [
                self._getRawDataContainer(
                    body=image,
                    contentType=contentType,
                    imageType=imageType,
                    faceBoundingBoxList=inputJson.get("face_bounding_boxes"),
                    bodyBoundingBoxList=inputJson.get("body_bounding_boxes"),
                    detectTime=self.convertDetectionTimeToCurrentFormat(
                        inputJson.get("detect_time"), defaultDetectTime
                    ),
                    detectTs=inputJson.get("detect_ts"),
                    imageOrigin=inputJson.get("image_origin"),
                )
            ]
        except binascii.Error:
            raise VLException(Error.BadInputJson.format("image", "Failed to decode descriptor"), 400, False)
    async def _getImagesFromUrls(
        self, inputJson: dict, imageType: ImageType, defaultDetectTime: str
    ) -> List[InputImageData]:
        """
        Get images from request's urls (list of urls in json with optional detection rectangles)
        Args:
            inputJson: json from request
            imageType: image type
            defaultDetectTime: image detection time in ISO format
        Returns:
            list of prepared SDKDetectableImage or FaceWarp or HumanWarp
        """
        resultImages = []
        for row in inputJson["urls"]:
            url = row["url"]
            try:
                image, contentType = await self._downloadImage(url)
                loadError = None
            except VLException as e:
                image, contentType = b"", None
                loadError = e.error
            data = self._getRawDataContainer(
                body=image,
                contentType=contentType,
                imageType=imageType,
                faceBoundingBoxList=row.get("face_bounding_boxes"),
                bodyBoundingBoxList=row.get("body_bounding_boxes"),
                fileName=url,
                url=url,
                detectTime=self.convertDetectionTimeToCurrentFormat(row.get("detect_time"), defaultDetectTime),
                detectTs=row.get("detect_ts"),
                imageOrigin=row.get("image_origin"),
                error=loadError,
            )
            resultImages.append(data)
        return resultImages
    async def _getImagesFromSamples(
        self, inputJson: dict, imageType: ImageType, defaultDetectTime: str
    ) -> List[InputImageData]:
        """
        Get images from request's samples
        (list of sample ids to get from luna-image-store and optional detection rectangles)
        Args:
            inputJson: json from request. None for unknown
            imageType: imageType
            defaultDetectTime: image detection time in ISO format
        Returns:
            list of prepared SDKDetectableImage or FaceWarp or HumanWarp
        """
        resultImages = []
        if imageType == imageType.FACE_WARP:
            storeApiClient = self.luna3Client.lunaFaceSamplesStore
            bucketName = self.config.faceSamplesStorage.bucket
        elif imageType == imageType.BODY_WARP:
            storeApiClient = self.luna3Client.lunaBodySamplesStore
            bucketName = self.config.bodySamplesStorage.bucket
        else:
            raise VLException(
                error=Error.BadWarpImage.format(
                    "Not supported image type for samples. Valid image type one of: face or body warp"
                ),
                statusCode=400,
                isCriticalError=False,
            )
        responses, samples = [], []
        if isinstance(inputJson["samples"][0], dict):
            for sample in inputJson["samples"]:
                responses.append(
                    await storeApiClient.getImage(
                        imageId=sample["sample_id"], bucketName=bucketName, accountId=self.accountId, raiseError=True
                    )
                )
                detectTime = self.convertDetectionTimeToCurrentFormat(sample.get("detect_time"), defaultDetectTime)
                samples.append(
                    {
                        "sample_id": sample["sample_id"],
                        "detect_time": detectTime,
                        "detect_ts": sample.get("detect_ts"),
                        "image_origin": sample.get("image_origin"),
                    }
                )
        else:
            for sampleId in inputJson["samples"]:
                responses.append(
                    await storeApiClient.getImage(
                        imageId=sampleId, bucketName=bucketName, accountId=self.accountId, raiseError=True
                    )
                )
                samples.append({"sample_id": sampleId, "detect_time": defaultDetectTime, "image_origin": None})
        for sample, response in zip(samples, responses):
            resultImages.append(
                self._getRawDataContainer(
                    body=response.body,
                    contentType=response.headers["Content-Type"],
                    imageType=imageType,
                    sampleId=sample["sample_id"],
                    fileName=sample["sample_id"],
                    detectTime=sample["detect_time"],
                    detectTs=sample.get("detect_ts"),
                    imageOrigin=sample["image_origin"],
                )
            )
        return resultImages
[docs]    @staticmethod
    def loadDataFromJson(data: dict, model: Type[BaseModel]) -> Any:
        """
        Load data from json with pydantic
        Args:
            data: input data
            model: pydantic model
        Returns:
            initialized object
        """
        return loadDataFromJson(data, model) 
[docs]    def handleMonitoringData(self, monitoringData: HandlersMonitoringData):
        """
        Handle monitoring data.
        Args:
            monitoringData: monitoring data
        """
        if not self.config.monitoring.sendData:
            return
        self.request.dataForMonitoring += monitoringData.request
        if monitoringData.sdkUsages:
            self.app.ctx.monitoring.flushPoints([monitoringData.sdkUsages]) 
[docs]    def getResponseContentType(self):
        """
        Get response content type.
        Returns:
            response content type
        """
        acceptHeader = self.request.headers.get("Accept", "application/json")
        responseContentType = parse_accept_header(acceptHeader).best_match(
            ("application/msgpack", "application/json"), default="application/json"
        )
        return responseContentType  
[docs]class BaseHandlerWithMultipart(BaseHandler):
    """
    Base handler class for resource with multipart requests availability
    """
[docs]    @abstractmethod
    async def getDataFromMultipart(
        self, imageType: ImageType = ImageType.IMAGE
    ) -> Tuple[Union[List[InputImageData]], Optional[Union[dict]]]:
        """
        Get data from multipart request
        Args:
            imageType: image type
        Returns:
            list of Images or list warps and optionally dict with policies (for multipart request with policies)
        """ 
    def _getDataFromMultipart(
        self,
        multipartData: Dict[str, Union[ImageWithBB, ImageWithFaceBB]],
        imageType: ImageType,
        allowRawDescriptors: bool = False,
    ) -> List[Union[InputImageData, RawDescriptorData]]:
        """
        Get data from multipart request (list of images and optional detection rectangles, or raw descriptors)
        Args:
            multipartData: validated images from multipart
            imageType: image type
            allowRawDescriptors: whether raw descriptor mimetypes allowed or not
        Returns:
            list of prepared SDKDetectableImage or FaceWarp
        Raises:
            VLException(Error.BadContentTypeInMultipartImage, 400, isCriticalError=False): if content type of a part of
                multipart request is wrong
        """
        resultImages = []
        defaultDetectTime = currentDateTime(self.config.storageTime)
        for image in multipartData.values():
            if not isAllowableContentType(image.contentType, allowRawDescriptors=allowRawDescriptors):
                raise VLException(Error.BadContentTypeInMultipartImage, 400, isCriticalError=False)
            faceBoundingBoxList = None
            bodyBoundingBoxList = None
            if image.faceBoundingBoxes:
                faceBoundingBoxList = image.faceBoundingBoxes
            if isinstance(image, ImageWithBB) and image.bodyBoundingBoxes:
                bodyBoundingBoxList = image.bodyBoundingBoxes
            resultImages.append(
                self._getRawDataContainer(
                    body=image.body,
                    contentType=image.contentType,
                    imageType=imageType,
                    fileName=image.filename,
                    faceBoundingBoxList=faceBoundingBoxList,
                    bodyBoundingBoxList=bodyBoundingBoxList,
                    detectTime=self.convertDetectionTimeToCurrentFormat(image.detectTime, defaultDetectTime),
                    detectTs=image.detectTs,
                    imageOrigin=image.imageOrigin,
                )
            )
        return resultImages
[docs]    async def getDataFromRequest(
        self,
        request: HandlersRequest,
        validationModel: Type[BaseSchema],
        imageType: Union[ImageType, None],
        allowRawDescriptors: bool = False,
    ) -> List[Union[InputImageData, RawDescriptorData]]:
        """
        Get images from request body to detect faces.
        Args:
            request: request
            imageType: imageType
            validationModel: validation model
            allowRawDescriptors: whether raw descriptor mimetypes allowed or not
        Returns:
            list of Images or list warps
        Raises:
            VLException(Error.BadContentType, 400, isCriticalError=False): if content type of request is wrong
            VLException(Error.BadMultipartInput, 400, isCriticalError=False): if failed to read multipart
        """
        contentType = request.content_type
        defaultDetectTime = currentDateTime(self.config.storageTime)
        if isAllowableContentType(contentType, allowRawDescriptors=allowRawDescriptors):
            body = request.body
            estimationData = [
                self._getRawDataContainer(
                    body=body,
                    contentType=contentType,
                    imageType=imageType,
                    detectTime=defaultDetectTime,
                )
            ]
        elif contentType == "application/json":
            inputJson = request.json
            self.loadDataFromJson(inputJson, validationModel)
            if "image" in inputJson:
                estimationData = await self._getImagesFromJson(
                    inputJson=inputJson,
                    imageType=imageType,
                    defaultDetectTime=defaultDetectTime,
                    allowRawDescriptors=allowRawDescriptors,
                )
            elif "urls" in inputJson:
                estimationData = await self._getImagesFromUrls(
                    inputJson=inputJson, imageType=imageType, defaultDetectTime=defaultDetectTime
                )
            elif "samples" in inputJson:
                estimationData = await self._getImagesFromSamples(
                    inputJson=inputJson, imageType=imageType, defaultDetectTime=defaultDetectTime
                )
            else:
                raise RuntimeError(f"bad input json {inputJson}")
        else:
            raise VLException(Error.BadContentType, 400, isCriticalError=False)
        return estimationData