Source code for luna_handlers.app.handlers.base_handler

# -*- coding: utf-8 -*-
""" Base handler

Module realize base class for all handlers.
"""
import base64
import binascii
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import aiohttp
from luna3.client import Client
from lunavl.sdk.image_utils.geometry import Rect
from pydantic import BaseModel
from werkzeug.http import parse_accept_header
from yarl import URL

from app.app import BaseHandlersRequestHandler, HandlersRequest
from app.global_vars.context_vars import requestIdCtx
from app.handlers.available_content_types import (
    MAP_BASE64_TYPE_TO_DATA_TYPE,
    isAllowableContentType,
    isAllowableRawContentType,
)
from classes.functions import loadDataFromJson
from classes.image_meta import ImageMeta, InputImageData
from classes.monitoring import HandlersMonitoringData
from classes.multipart_processing import ImageWithBB, ImageWithFaceBB
from classes.raw_descriptor_data import RawDescriptorData
from classes.schemas.base_schema import BaseSchema
from crutches_on_wheels.errors.errors import Error, ErrorInfo
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.utils.functions import convertDateTimeToCurrentFormatStr, currentDateTime, downloadImage
from img_utils.utils import convertToBytesIfNeed
from sdk.sdk_loop.models.image import ImageType, InputImage


[docs]class BaseHandler(BaseHandlersRequestHandler): """ Base handler for other handlers. Attributes: luna3Client (luna3.client.Client): luna3 client dbContext (DBContext): db context redisContext (RedisContext): redis context """ def __init__(self, request: HandlersRequest): super().__init__(request) requestIdCtx.set(self.requestId) self.luna3Client: Client = request.luna3Client self.dbContext = request.dbContext self.redisContext = request.redisContext self.accountId: Optional[str] = None
[docs] def countLivenessEstimationsPerformed(self, count: int): """ Count liveness estimations performed Args: count: estimation count """ self.app.ctx.licenseRecorder.countExecutionsPerformed(liveness=count)
[docs] def checkLivenessEstimationLicensing(self): """ Check liveness estimation licensing Raises: VLException(Error.LicenseProblem) if liveness estimation disabled """ if not self.app.ctx.licenseChecker.licenseState: raise VLException(Error.LicenseProblem.format("Cannot get license information."), 403, False) if not self.app.ctx.licenseChecker.licenseState.expirationTime.isAvailable: raise VLException(Error.LicenseProblem.format("License expired"), 403, False) if self.app.ctx.licenseChecker.licenseState.liveness.value is None: raise VLException(Error.LicenseProblem.format("Liveness feature disabled"), 403, False) if self.app.ctx.licenseChecker.licenseState.liveness.value != 2: raise VLException(Error.LicenseProblem.format("Liveness v.2 feature disabled"), 403, False) if self.app.ctx.licenseChecker.licenseState.livenessBalance is None: # licensing by expiration, not by executions return if not self.app.ctx.licenseChecker.licenseState.livenessBalance.isAvailable: raise VLException(Error.LicenseProblem.format("Liveness balance is exceeded"), 403, False) if not self.app.ctx.licenseRecorder.isLivenessSynchronized(): raise VLException(Error.LicenseProblem.format("Feature execution synchronization failed"), 403, False)
[docs] def assertLicence(self, iso: bool = False, liveness: bool = False, bodyAttributes: bool = False): """ Assert estimations in licences Args: iso: check or not iso feature liveness: check or not liveness feature bodyAttributes: check or not basic attributes feature Raises: VLException(Error.LicenseProblem) """ if not self.app.ctx.licenseChecker.checkExpirationTime(): raise VLException(Error.LicenseProblem.format("License expired"), 403, isCriticalError=False) if iso and ( not self.app.ctx.licenseChecker.licenseState.iso or not self.app.ctx.licenseChecker.licenseState.iso.isAvailable ): raise VLException( Error.LicenseProblem.format("ISO license feature is disabled."), statusCode=403, isCriticalError=False, ) if bodyAttributes and ( not self.app.ctx.licenseChecker.licenseState.bodyAttributes or not self.app.ctx.licenseChecker.licenseState.bodyAttributes.isAvailable ): raise VLException( Error.LicenseProblem.format("Body attributes license feature is disabled."), statusCode=403, isCriticalError=False, ) if liveness: self.checkLivenessEstimationLicensing()
async def _downloadImage(self, url: Union[str, URL]) -> tuple[bytes, str]: """ Download image by external url Args: url: url Returns: image and content type Raises: VLException(Error.BadContentTypeDownloadedImage.format(url), 400, isCriticalError=False): if a downloaded image content type is not allowable """ clientTimeout = aiohttp.ClientTimeout( total=self.config.loadExternalImageTimeout.totalTimeout, connect=self.config.loadExternalImageTimeout.connectTimeout, sock_connect=self.config.loadExternalImageTimeout.sockConnectTimeout, sock_read=self.config.loadExternalImageTimeout.sockReadTimeout, ) imageBody, contentType = await downloadImage( url=url, logger=self.logger, timeout=clientTimeout, accountId=self.accountId ) if not isAllowableRawContentType(contentType): raise VLException(Error.BadContentTypeDownloadedImage.format(url), 400, isCriticalError=False) return imageBody, contentType
[docs] def convertDetectionTimeToCurrentFormat(self, detectTime: Union[str, None], defaultDetectTime: str) -> str: """ Convert detection time from UTC or LOCAL to string, considering current STORAGE_TIME Args: detectTime: detection time in ISO format defaultDetectTime: default time in UTC or LOCAL time zones Returns: detection time in current format """ if detectTime is None: return defaultDetectTime return convertDateTimeToCurrentFormatStr(detectTime, self.config.storageTime)
def _getRawDataContainer( self, body: bytes, contentType: str, imageType: Union[ImageType, None], fileName: Optional[str] = None, faceBoundingBoxList: Optional[List[dict]] = None, bodyBoundingBoxList: Optional[List[dict]] = None, sampleId: Optional[str] = None, url: Optional[str] = None, detectTime: Optional[str] = None, detectTs: Optional[float] = None, imageOrigin: Optional[str] = None, error: Optional[ErrorInfo] = None, ) -> Union[InputImageData, RawDescriptorData]: """ Get raw data container: detectable image object or raw descriptor data Args: body: binary data contentType: expected data content type from request fileName: filename imageType: image type sampleId: sample id for warped image faceBoundingBoxList: list with detection rectangles url: image source detectTime: detection time in ISO format detectTs: user-defined timestamp relative to something, such as the start of a video imageOrigin: image origin error: image load error Returns: prepared raw data container Raises: VLException(Error.OnlyOneDetectionRectAvailable, 403, False) if there are more than 1 bounding box VLException(Error.BadContentType, 400, False) if image content type is not allowable VLException(Error.BoundingBoxNotAvailableForWarp, 400, False) if try to use bounding box for warp """ if not error: if not isAllowableContentType(contentType, allowRawDescriptors=True): raise VLException(Error.BadContentType, 400, isCriticalError=False) data, contentType = convertToBytesIfNeed(body, contentType) if contentType in ("application/x-sdk-descriptor", "application/x-vl-xpk"): allowedVersions = [self.config.defaultFaceDescriptorVersion, self.config.defaultHumanDescriptorVersion] return RawDescriptorData(data, mimetype=contentType, filename=fileName, allowedVersions=allowedVersions) if faceBoundingBoxList and len(faceBoundingBoxList) > 1: raise VLException(Error.OnlyOneDetectionRectAvailable, 403, isCriticalError=False) if bodyBoundingBoxList and len(bodyBoundingBoxList) > 1: raise VLException(Error.OnlyOneDetectionRectAvailable, 403, isCriticalError=False) if (faceBoundingBoxList or bodyBoundingBoxList) and imageType in (ImageType.FACE_WARP, ImageType.BODY_WARP): raise VLException(Error.BoundingBoxNotAvailableForWarp, 400, False) else: data = body faceBoxes = [Rect(**faceBoundingBoxList[0])] if faceBoundingBoxList is not None else None bodyBoxes = [Rect(**bodyBoundingBoxList[0])] if bodyBoundingBoxList is not None else None image = InputImageData( image=InputImage( filename=fileName or "raw image", body=data, imageType=imageType or ImageType.IMAGE, faceBoxes=faceBoxes, bodyBoxes=bodyBoxes, ), meta=ImageMeta( sampleId=sampleId, url=url, detectTime=detectTime, imageOrigin=imageOrigin, error=error, detectTs=detectTs, ), ) return image async def _getImagesFromJson( self, inputJson: dict, imageType: ImageType, defaultDetectTime: str, allowRawDescriptors: bool ) -> List[InputImageData]: """ Get images from request json Args: inputJson: json from request imageType: image type defaultDetectTime: image detection time in ISO format allowRawDescriptors: whether raw descriptor mimetypes allowed or not Returns: list of prepared SDKDetectableImage or FaceWarp or HumanWarp Raises: VLException(Error.BadInputJson, 400, False) if failed decode descriptor """ try: contentType = inputJson["mimetype"] if not isAllowableContentType(contentType, allowRawDescriptors=allowRawDescriptors): raise VLException(Error.BadContentType, 400, isCriticalError=False) image = base64.b64decode(inputJson["image"]) contentType = MAP_BASE64_TYPE_TO_DATA_TYPE.get(contentType, contentType) return [ self._getRawDataContainer( body=image, contentType=contentType, imageType=imageType, faceBoundingBoxList=inputJson.get("face_bounding_boxes"), bodyBoundingBoxList=inputJson.get("body_bounding_boxes"), detectTime=self.convertDetectionTimeToCurrentFormat( inputJson.get("detect_time"), defaultDetectTime ), detectTs=inputJson.get("detect_ts"), imageOrigin=inputJson.get("image_origin"), ) ] except binascii.Error: raise VLException(Error.BadInputJson.format("image", "Failed to decode descriptor"), 400, False) async def _getImagesFromUrls( self, inputJson: dict, imageType: ImageType, defaultDetectTime: str ) -> List[InputImageData]: """ Get images from request's urls (list of urls in json with optional detection rectangles) Args: inputJson: json from request imageType: image type defaultDetectTime: image detection time in ISO format Returns: list of prepared SDKDetectableImage or FaceWarp or HumanWarp """ resultImages = [] for row in inputJson["urls"]: url = row["url"] try: image, contentType = await self._downloadImage(url) loadError = None except VLException as e: image, contentType = b"", None loadError = e.error data = self._getRawDataContainer( body=image, contentType=contentType, imageType=imageType, faceBoundingBoxList=row.get("face_bounding_boxes"), bodyBoundingBoxList=row.get("body_bounding_boxes"), fileName=url, url=url, detectTime=self.convertDetectionTimeToCurrentFormat(row.get("detect_time"), defaultDetectTime), detectTs=row.get("detect_ts"), imageOrigin=row.get("image_origin"), error=loadError, ) resultImages.append(data) return resultImages async def _getImagesFromSamples( self, inputJson: dict, imageType: ImageType, defaultDetectTime: str ) -> List[InputImageData]: """ Get images from request's samples (list of sample ids to get from luna-image-store and optional detection rectangles) Args: inputJson: json from request. None for unknown imageType: imageType defaultDetectTime: image detection time in ISO format Returns: list of prepared SDKDetectableImage or FaceWarp or HumanWarp """ resultImages = [] if imageType == imageType.FACE_WARP: storeApiClient = self.luna3Client.lunaFaceSamplesStore bucketName = self.config.faceSamplesStorage.bucket elif imageType == imageType.BODY_WARP: storeApiClient = self.luna3Client.lunaBodySamplesStore bucketName = self.config.bodySamplesStorage.bucket else: raise VLException( error=Error.BadWarpImage.format( "Not supported image type for samples. Valid image type one of: face or body warp" ), statusCode=400, isCriticalError=False, ) responses, samples = [], [] if isinstance(inputJson["samples"][0], dict): for sample in inputJson["samples"]: responses.append( await storeApiClient.getImage( imageId=sample["sample_id"], bucketName=bucketName, accountId=self.accountId, raiseError=True ) ) detectTime = self.convertDetectionTimeToCurrentFormat(sample.get("detect_time"), defaultDetectTime) samples.append( { "sample_id": sample["sample_id"], "detect_time": detectTime, "detect_ts": sample.get("detect_ts"), "image_origin": sample.get("image_origin"), } ) else: for sampleId in inputJson["samples"]: responses.append( await storeApiClient.getImage( imageId=sampleId, bucketName=bucketName, accountId=self.accountId, raiseError=True ) ) samples.append({"sample_id": sampleId, "detect_time": defaultDetectTime, "image_origin": None}) for sample, response in zip(samples, responses): resultImages.append( self._getRawDataContainer( body=response.body, contentType=response.headers["Content-Type"], imageType=imageType, sampleId=sample["sample_id"], fileName=sample["sample_id"], detectTime=sample["detect_time"], detectTs=sample.get("detect_ts"), imageOrigin=sample["image_origin"], ) ) return resultImages
[docs] @staticmethod def loadDataFromJson(data: dict, model: Type[BaseModel]) -> Any: """ Load data from json with pydantic Args: data: input data model: pydantic model Returns: initialized object """ return loadDataFromJson(data, model)
[docs] def handleMonitoringData(self, monitoringData: HandlersMonitoringData): """ Handle monitoring data. Args: monitoringData: monitoring data """ if not self.config.monitoring.sendData: return self.request.dataForMonitoring += monitoringData.request if monitoringData.sdkUsages: self.app.ctx.monitoring.flushPoints([monitoringData.sdkUsages])
[docs] def getResponseContentType(self): """ Get response content type. Returns: response content type """ acceptHeader = self.request.headers.get("Accept", "application/json") responseContentType = parse_accept_header(acceptHeader).best_match( ("application/msgpack", "application/json"), default="application/json" ) return responseContentType
[docs]class BaseHandlerWithMultipart(BaseHandler): """ Base handler class for resource with multipart requests availability """
[docs] @abstractmethod async def getDataFromMultipart( self, imageType: ImageType = ImageType.IMAGE ) -> Tuple[Union[List[InputImageData]], Optional[Union[dict]]]: """ Get data from multipart request Args: imageType: image type Returns: list of Images or list warps and optionally dict with policies (for multipart request with policies) """
def _getDataFromMultipart( self, multipartData: Dict[str, Union[ImageWithBB, ImageWithFaceBB]], imageType: ImageType, allowRawDescriptors: bool = False, ) -> List[Union[InputImageData, RawDescriptorData]]: """ Get data from multipart request (list of images and optional detection rectangles, or raw descriptors) Args: multipartData: validated images from multipart imageType: image type allowRawDescriptors: whether raw descriptor mimetypes allowed or not Returns: list of prepared SDKDetectableImage or FaceWarp Raises: VLException(Error.BadContentTypeInMultipartImage, 400, isCriticalError=False): if content type of a part of multipart request is wrong """ resultImages = [] defaultDetectTime = currentDateTime(self.config.storageTime) for image in multipartData.values(): if not isAllowableContentType(image.contentType, allowRawDescriptors=allowRawDescriptors): raise VLException(Error.BadContentTypeInMultipartImage, 400, isCriticalError=False) faceBoundingBoxList = None bodyBoundingBoxList = None if image.faceBoundingBoxes: faceBoundingBoxList = image.faceBoundingBoxes if isinstance(image, ImageWithBB) and image.bodyBoundingBoxes: bodyBoundingBoxList = image.bodyBoundingBoxes resultImages.append( self._getRawDataContainer( body=image.body, contentType=image.contentType, imageType=imageType, fileName=image.filename, faceBoundingBoxList=faceBoundingBoxList, bodyBoundingBoxList=bodyBoundingBoxList, detectTime=self.convertDetectionTimeToCurrentFormat(image.detectTime, defaultDetectTime), detectTs=image.detectTs, imageOrigin=image.imageOrigin, ) ) return resultImages
[docs] async def getDataFromRequest( self, request: HandlersRequest, validationModel: Type[BaseSchema], imageType: Union[ImageType, None], allowRawDescriptors: bool = False, ) -> List[Union[InputImageData, RawDescriptorData]]: """ Get images from request body to detect faces. Args: request: request imageType: imageType validationModel: validation model allowRawDescriptors: whether raw descriptor mimetypes allowed or not Returns: list of Images or list warps Raises: VLException(Error.BadContentType, 400, isCriticalError=False): if content type of request is wrong VLException(Error.BadMultipartInput, 400, isCriticalError=False): if failed to read multipart """ contentType = request.content_type defaultDetectTime = currentDateTime(self.config.storageTime) if isAllowableContentType(contentType, allowRawDescriptors=allowRawDescriptors): body = request.body estimationData = [ self._getRawDataContainer( body=body, contentType=contentType, imageType=imageType, detectTime=defaultDetectTime, ) ] elif contentType == "application/json": inputJson = request.json self.loadDataFromJson(inputJson, validationModel) if "image" in inputJson: estimationData = await self._getImagesFromJson( inputJson=inputJson, imageType=imageType, defaultDetectTime=defaultDetectTime, allowRawDescriptors=allowRawDescriptors, ) elif "urls" in inputJson: estimationData = await self._getImagesFromUrls( inputJson=inputJson, imageType=imageType, defaultDetectTime=defaultDetectTime ) elif "samples" in inputJson: estimationData = await self._getImagesFromSamples( inputJson=inputJson, imageType=imageType, defaultDetectTime=defaultDetectTime ) else: raise RuntimeError(f"bad input json {inputJson}") else: raise VLException(Error.BadContentType, 400, isCriticalError=False) return estimationData
[docs] async def getInputEstimationData( self, request, validationModel: Type[BaseSchema], imageType: ImageType = ImageType.IMAGE, allowRawDescriptors: bool = False, ) -> List[Union[InputImageData, RawDescriptorData]]: """ Get images from request body to detect faces. Args: request: request imageType: imageType validationModel: validation model allowRawDescriptors: whether raw descriptor mimetypes allowed or not Returns: list of Images or list warps. """ if self.request.content_type.startswith("multipart/form-data"): images = await self.getDataFromMultipart(imageType) else: images = await self.getDataFromRequest( request, validationModel, imageType, allowRawDescriptors=allowRawDescriptors ) return images