# -*- coding: utf-8 -*-
""" Base handler
Module realize base class for all handlers.
"""
import base64
import binascii
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import aiohttp
from luna3.client import Client
from lunavl.sdk.image_utils.geometry import Rect
from pydantic import BaseModel
from werkzeug.http import parse_accept_header
from yarl import URL
from app.app import BaseHandlersRequestHandler, HandlersRequest
from app.global_vars.context_vars import requestIdCtx
from app.handlers.available_content_types import (
MAP_BASE64_TYPE_TO_DATA_TYPE,
isAllowableContentType,
isAllowableRawContentType,
)
from classes.functions import loadDataFromJson
from classes.image_meta import ImageMeta, InputImageData
from classes.monitoring import HandlersMonitoringData
from classes.multipart_processing import ImageWithBB, ImageWithFaceBB
from classes.raw_descriptor_data import RawDescriptorData
from classes.schemas.base_schema import BaseSchema
from classes.schemas.policies import Policies
from classes.schemas.verifier import VerifierPoliciesModel as VerifierPolicies
from crutches_on_wheels.errors.errors import Error, ErrorInfo
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.utils.functions import convertDateTimeToCurrentFormatStr, currentDateTime, downloadImage
from img_utils.utils import convertToBytesIfNeed
from sdk.sdk_loop.models.image import ImageType, InputImage
[docs]class BaseHandler(BaseHandlersRequestHandler):
"""
Base handler for other handlers.
Attributes:
luna3Client (luna3.client.Client): luna3 client
dbContext (DBContext): db context
redisContext (RedisContext): redis context
"""
def __init__(self, request: HandlersRequest):
super().__init__(request)
requestIdCtx.set(self.requestId)
self.luna3Client: Client = request.luna3Client
self.dbContext = request.dbContext
self.redisContext = request.redisContext
self.accountId: Optional[str] = None
[docs] def checkLivenessEstimationLicensing(self, estimate: int):
"""
Check liveness estimation licensing
Args:
estimate: liveness estimation status
Raises:
VLException(Error.LicenseProblem) if liveness estimation disabled
"""
if not estimate:
return
if not self.app.ctx.licenseChecker.licenseState:
raise VLException(Error.LicenseProblem.format("Cannot get license information."), 403, False)
if not self.app.ctx.licenseChecker.licenseState.expirationTime.isAvailable:
raise VLException(Error.LicenseProblem.format("License expired"), 403, False)
if self.app.ctx.licenseChecker.licenseState.liveness.value is None:
raise VLException(Error.LicenseProblem.format("Liveness feature disabled"), 403, False)
if self.app.ctx.licenseChecker.licenseState.liveness.value != 2:
raise VLException(Error.LicenseProblem.format("Liveness v.2 feature disabled"), 403, False)
if self.app.ctx.licenseChecker.licenseState.livenessBalance is None:
# licensing by expiration, not by executions
return
if not self.app.ctx.licenseChecker.licenseState.livenessBalance.isAvailable:
raise VLException(Error.LicenseProblem.format("Liveness balance is exceeded"), 403, False)
if not self.app.ctx.licenseRecorder.isLivenessSynchronized():
raise VLException(Error.LicenseProblem.format("Feature execution synchronization failed"), 403, False)
[docs] def checkPolicyLicensing(self, policies: Union[Policies, VerifierPolicies]):
"""
Check handler policies licensing
Args:
policies: handler policies
"""
self.checkLivenessEstimationLicensing(policies.detectPolicy.estimateLiveness.estimate)
async def _downloadImage(self, url: Union[str, URL]) -> tuple[bytes, str]:
"""
Download image by external url
Args:
url: url
Returns:
image and content type
Raises:
VLException(Error.BadContentTypeDownloadedImage.format(url), 400, isCriticalError=False):
if a downloaded image content type is not allowable
"""
clientTimeout = aiohttp.ClientTimeout(
total=self.config.loadExternalImageTimeout.totalTimeout,
connect=self.config.loadExternalImageTimeout.connectTimeout,
sock_connect=self.config.loadExternalImageTimeout.sockConnectTimeout,
sock_read=self.config.loadExternalImageTimeout.sockReadTimeout,
)
imageBody, contentType = await downloadImage(
url=url, logger=self.logger, timeout=clientTimeout, accountId=self.accountId
)
if not isAllowableRawContentType(contentType):
raise VLException(Error.BadContentTypeDownloadedImage.format(url), 400, isCriticalError=False)
return imageBody, contentType
def _getRawDataContainer(
self,
body: bytes,
contentType: str,
imageType: Union[ImageType, None],
fileName: Optional[str] = None,
faceBoundingBoxList: Optional[List[dict]] = None,
bodyBoundingBoxList: Optional[List[dict]] = None,
sampleId: Optional[str] = None,
url: Optional[str] = None,
detectTime: Optional[str] = None,
detectTs: Optional[float] = None,
imageOrigin: Optional[str] = None,
error: Optional[ErrorInfo] = None,
) -> Union[InputImageData, RawDescriptorData]:
"""
Get raw data container: detectable image object or raw descriptor data
Args:
body: binary data
contentType: expected data content type from request
fileName: filename
imageType: image type
sampleId: sample id for warped image
faceBoundingBoxList: list with detection rectangles
url: image source
detectTime: detection time in ISO format
detectTs: user-defined timestamp relative to something, such as the start of a video
imageOrigin: image origin
error: image load error
Returns:
prepared raw data container
Raises:
VLException(Error.OnlyOneDetectionRectAvailable, 403, False) if there are more than 1 bounding box
VLException(Error.BadContentType, 400, False) if image content type is not allowable
VLException(Error.BoundingBoxNotAvailableForWarp, 400, False) if try to use bounding box for warp
"""
if not error:
if not isAllowableContentType(contentType, allowRawDescriptors=True):
raise VLException(Error.BadContentType, 400, isCriticalError=False)
data, contentType = convertToBytesIfNeed(body, contentType)
if contentType in ("application/x-sdk-descriptor", "application/x-vl-xpk"):
allowedVersions = [self.config.defaultFaceDescriptorVersion, self.config.defaultHumanDescriptorVersion]
return RawDescriptorData(data, mimetype=contentType, filename=fileName, allowedVersions=allowedVersions)
if faceBoundingBoxList and len(faceBoundingBoxList) > 1:
raise VLException(Error.OnlyOneDetectionRectAvailable, 403, isCriticalError=False)
if bodyBoundingBoxList and len(bodyBoundingBoxList) > 1:
raise VLException(Error.OnlyOneDetectionRectAvailable, 403, isCriticalError=False)
if (faceBoundingBoxList or bodyBoundingBoxList) and imageType in (ImageType.FACE_WARP, ImageType.BODY_WARP):
raise VLException(Error.BoundingBoxNotAvailableForWarp, 400, False)
else:
data = body
faceBoxes = [Rect(**faceBoundingBoxList[0])] if faceBoundingBoxList is not None else None
bodyBoxes = [Rect(**bodyBoundingBoxList[0])] if bodyBoundingBoxList is not None else None
image = InputImageData(
image=InputImage(
filename=fileName or "raw image",
body=data,
imageType=imageType or ImageType.IMAGE,
faceBoxes=faceBoxes,
bodyBoxes=bodyBoxes,
),
meta=ImageMeta(
sampleId=sampleId,
url=url,
detectTime=detectTime,
imageOrigin=imageOrigin,
error=error,
detectTs=detectTs,
),
)
return image
async def _getImagesFromJson(
self, inputJson: dict, imageType: ImageType, defaultDetectTime: str, allowRawDescriptors: bool
) -> List[InputImageData]:
"""
Get images from request json
Args:
inputJson: json from request
imageType: image type
defaultDetectTime: image detection time in ISO format
allowRawDescriptors: whether raw descriptor mimetypes allowed or not
Returns:
list of prepared SDKDetectableImage or FaceWarp or HumanWarp
Raises:
VLException(Error.BadInputJson, 400, False) if failed decode descriptor
"""
try:
contentType = inputJson["mimetype"]
if not isAllowableContentType(contentType, allowRawDescriptors=allowRawDescriptors):
raise VLException(Error.BadContentType, 400, isCriticalError=False)
image = base64.b64decode(inputJson["image"])
contentType = MAP_BASE64_TYPE_TO_DATA_TYPE.get(contentType, contentType)
return [
self._getRawDataContainer(
body=image,
contentType=contentType,
imageType=imageType,
faceBoundingBoxList=inputJson.get("face_bounding_boxes"),
bodyBoundingBoxList=inputJson.get("body_bounding_boxes"),
detectTime=self.convertDetectionTimeToCurrentFormat(
inputJson.get("detect_time"), defaultDetectTime
),
detectTs=inputJson.get("detect_ts"),
imageOrigin=inputJson.get("image_origin"),
)
]
except binascii.Error:
raise VLException(Error.BadInputJson.format("image", "Failed to decode descriptor"), 400, False)
async def _getImagesFromUrls(
self, inputJson: dict, imageType: ImageType, defaultDetectTime: str
) -> List[InputImageData]:
"""
Get images from request's urls (list of urls in json with optional detection rectangles)
Args:
inputJson: json from request
imageType: image type
defaultDetectTime: image detection time in ISO format
Returns:
list of prepared SDKDetectableImage or FaceWarp or HumanWarp
"""
resultImages = []
for row in inputJson["urls"]:
url = row["url"]
try:
image, contentType = await self._downloadImage(url)
loadError = None
except VLException as e:
image, contentType = b"", None
loadError = e.error
data = self._getRawDataContainer(
body=image,
contentType=contentType,
imageType=imageType,
faceBoundingBoxList=row.get("face_bounding_boxes"),
bodyBoundingBoxList=row.get("body_bounding_boxes"),
fileName=url,
url=url,
detectTime=self.convertDetectionTimeToCurrentFormat(row.get("detect_time"), defaultDetectTime),
detectTs=row.get("detect_ts"),
imageOrigin=row.get("image_origin"),
error=loadError,
)
resultImages.append(data)
return resultImages
async def _getImagesFromSamples(
self, inputJson: dict, imageType: ImageType, defaultDetectTime: str
) -> List[InputImageData]:
"""
Get images from request's samples
(list of sample ids to get from luna-image-store and optional detection rectangles)
Args:
inputJson: json from request. None for unknown
imageType: imageType
defaultDetectTime: image detection time in ISO format
Returns:
list of prepared SDKDetectableImage or FaceWarp or HumanWarp
"""
resultImages = []
if imageType == imageType.FACE_WARP:
storeApiClient = self.luna3Client.lunaFaceSamplesStore
bucketName = self.config.faceSamplesStorage.bucket
elif imageType == imageType.BODY_WARP:
storeApiClient = self.luna3Client.lunaBodySamplesStore
bucketName = self.config.bodySamplesStorage.bucket
else:
raise VLException(
error=Error.BadWarpImage.format(
"Not supported image type for samples. Valid image type one of: face or body warp"
),
statusCode=400,
isCriticalError=False,
)
responses, samples = [], []
if isinstance(inputJson["samples"][0], dict):
for sample in inputJson["samples"]:
responses.append(
await storeApiClient.getImage(
imageId=sample["sample_id"], bucketName=bucketName, accountId=self.accountId, raiseError=True
)
)
detectTime = self.convertDetectionTimeToCurrentFormat(sample.get("detect_time"), defaultDetectTime)
samples.append(
{
"sample_id": sample["sample_id"],
"detect_time": detectTime,
"detect_ts": sample.get("detect_ts"),
"image_origin": sample.get("image_origin"),
}
)
else:
for sampleId in inputJson["samples"]:
responses.append(
await storeApiClient.getImage(
imageId=sampleId, bucketName=bucketName, accountId=self.accountId, raiseError=True
)
)
samples.append({"sample_id": sampleId, "detect_time": defaultDetectTime, "image_origin": None})
for sample, response in zip(samples, responses):
resultImages.append(
self._getRawDataContainer(
body=response.body,
contentType=response.headers["Content-Type"],
imageType=imageType,
sampleId=sample["sample_id"],
fileName=sample["sample_id"],
detectTime=sample["detect_time"],
detectTs=sample.get("detect_ts"),
imageOrigin=sample["image_origin"],
)
)
return resultImages
[docs] @staticmethod
def loadDataFromJson(data: dict, model: Type[BaseModel]) -> Any:
"""
Load data from json with pydantic
Args:
data: input data
model: pydantic model
Returns:
initialized object
"""
return loadDataFromJson(data, model)
[docs] def handleMonitoringData(self, monitoringData: HandlersMonitoringData):
"""
Handle monitoring data.
Args:
monitoringData: monitoring data
"""
if not self.config.monitoring.sendData:
return
self.request.dataForMonitoring += monitoringData.request
if monitoringData.sdkUsages:
self.app.ctx.monitoring.flushPoints([monitoringData.sdkUsages])
[docs] def getResponseContentType(self):
"""
Get response content type.
Returns:
response content type
"""
acceptHeader = self.request.headers.get("Accept", "application/json")
responseContentType = parse_accept_header(acceptHeader).best_match(
("application/msgpack", "application/json"), default="application/json"
)
return responseContentType
[docs]class BaseHandlerWithMultipart(BaseHandler):
"""
Base handler class for resource with multipart requests availability
"""
[docs] @abstractmethod
async def getDataFromMultipart(
self, imageType: ImageType = ImageType.IMAGE
) -> Tuple[Union[List[InputImageData]], Optional[Union[dict]]]:
"""
Get data from multipart request
Args:
imageType: image type
Returns:
list of Images or list warps and optionally dict with policies (for multipart request with policies)
"""
def _getDataFromMultipart(
self,
multipartData: Dict[str, Union[ImageWithBB, ImageWithFaceBB]],
imageType: ImageType,
allowRawDescriptors: bool = False,
) -> List[Union[InputImageData, RawDescriptorData]]:
"""
Get data from multipart request (list of images and optional detection rectangles, or raw descriptors)
Args:
multipartData: validated images from multipart
imageType: image type
allowRawDescriptors: whether raw descriptor mimetypes allowed or not
Returns:
list of prepared SDKDetectableImage or FaceWarp
Raises:
VLException(Error.BadContentTypeInMultipartImage, 400, isCriticalError=False): if content type of a part of
multipart request is wrong
"""
resultImages = []
defaultDetectTime = currentDateTime(self.config.storageTime)
for image in multipartData.values():
if not isAllowableContentType(image.contentType, allowRawDescriptors=allowRawDescriptors):
raise VLException(Error.BadContentTypeInMultipartImage, 400, isCriticalError=False)
faceBoundingBoxList = None
bodyBoundingBoxList = None
if image.faceBoundingBoxes:
faceBoundingBoxList = image.faceBoundingBoxes
if isinstance(image, ImageWithBB) and image.bodyBoundingBoxes:
bodyBoundingBoxList = image.bodyBoundingBoxes
resultImages.append(
self._getRawDataContainer(
body=image.body,
contentType=image.contentType,
imageType=imageType,
fileName=image.filename,
faceBoundingBoxList=faceBoundingBoxList,
bodyBoundingBoxList=bodyBoundingBoxList,
detectTime=self.convertDetectionTimeToCurrentFormat(image.detectTime, defaultDetectTime),
detectTs=image.detectTs,
imageOrigin=image.imageOrigin,
)
)
return resultImages
[docs] async def getDataFromRequest(
self,
request: HandlersRequest,
validationModel: Type[BaseSchema],
imageType: Union[ImageType, None],
allowRawDescriptors: bool = False,
) -> List[Union[InputImageData, RawDescriptorData]]:
"""
Get images from request body to detect faces.
Args:
request: request
imageType: imageType
validationModel: validation model
allowRawDescriptors: whether raw descriptor mimetypes allowed or not
Returns:
list of Images or list warps
Raises:
VLException(Error.BadContentType, 400, isCriticalError=False): if content type of request is wrong
VLException(Error.BadMultipartInput, 400, isCriticalError=False): if failed to read multipart
"""
contentType = request.content_type
defaultDetectTime = currentDateTime(self.config.storageTime)
if isAllowableContentType(contentType, allowRawDescriptors=allowRawDescriptors):
body = request.body
estimationData = [
self._getRawDataContainer(
body=body,
contentType=contentType,
imageType=imageType,
detectTime=defaultDetectTime,
)
]
elif contentType == "application/json":
inputJson = request.json
self.loadDataFromJson(inputJson, validationModel)
if "image" in inputJson:
estimationData = await self._getImagesFromJson(
inputJson=inputJson,
imageType=imageType,
defaultDetectTime=defaultDetectTime,
allowRawDescriptors=allowRawDescriptors,
)
elif "urls" in inputJson:
estimationData = await self._getImagesFromUrls(
inputJson=inputJson, imageType=imageType, defaultDetectTime=defaultDetectTime
)
elif "samples" in inputJson:
estimationData = await self._getImagesFromSamples(
inputJson=inputJson, imageType=imageType, defaultDetectTime=defaultDetectTime
)
else:
raise RuntimeError(f"bad input json {inputJson}")
else:
raise VLException(Error.BadContentType, 400, isCriticalError=False)
return estimationData