Source code for luna_faces.app.handlers.base_handler

from enum import Enum
from typing import List, NamedTuple, Optional, Set, Union

import msgpack
import ujson as json
from sanic.request import File
from sanic.response import HTTPResponse
from vlutils.descriptors.containers import sdkDescriptorDecode, sdkDescriptorEncode
from vlutils.descriptors.data import DescriptorsEnum
from vlutils.descriptors.xpk_reader import readXPKFromBinary
from vlutils.helpers import bytesToBase64, convertTimeToString

from app.app import FacesApp, FacesRequest
from app.handlers.helpers import base64ToBytes
from app.handlers.schemas import CREATE_FACE_SCHEMA, INPUT_ATTRIBUTE_SCHEMA, INPUT_XPK_ATTRIBUTE_SCHEMA
from app.handlers.validators import validate_url
from attributes_db.model import BasicAttributes, Descriptor, TemporaryAttributes
from configs.configs.configs.services import SettingsFaces
from crutches_on_wheels.enums.attributes import TemporaryAttributeTargets
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.web.handlers import BaseHandler

descriptorDependencies = {"descriptor", "descriptor_samples", "descriptor_obtaining_method", "descriptor_version"}
basicDependencies = {
    "age",
    "age_obtaining_method",
    "age_version",
    "gender",
    "gender_obtaining_method",
    "gender_version",
    "ethnicity",
    "ethnicity_obtaining_method",
    "ethnicity_version",
    "basic_attributes_samples",
}


[docs]class NewFace(NamedTuple): """ Data for new face """ data: dict #: dict with face data (user_data) listIds: Optional[Set[str]] = None #: lists to link attribute: Optional[TemporaryAttributes] = None #: face attribute
[docs]class MimeTypes(Enum): BinaryDescriptor = "application/x-binary-descriptor" FlatBuffers = "application/x-flatbuf" JSON = "application/json" Multipart = "multipart/form-data" XPK = "application/x-vl-xpk" Msgpack = "application/msgpack"
[docs]class BaseRequestHandler(BaseHandler): """ Base handler for other handlers. """ def __init__(self, request: FacesRequest): super().__init__(request) self.facesContext = request.dbContext self.attributesContext = request.attributesContext self.licenseChecker = request.licenseChecker
[docs] def success( self, statusCode: int = 200, body: Optional[Union[bytes, str]] = None, outputJson: Optional[Union[dict, list]] = None, contentType: Optional[str] = None, extraHeaders: Optional[dict] = None, ) -> HTTPResponse: """ Finish success request. Generate correct reply with request id header, correct Content Type header Support "application/msgpack" "Accept" header. Args: contentType: body content type statusCode: response status code, range(200, 300), default 200 body: pure body outputJson: json as object extraHeaders: extra headers that will be added to the response (default headers in case of overlapping will be replaced); Returns: aiohttp.web.Response object Raises: ValueError: if response has body but its content-type was not determined in "Content-Type" header """ data = None if outputJson is not None: if body is not None: raise ValueError("Must be only one data source.") acceptHeaders = self.request.headers.get("Accept", "application/json") if "application/msgpack" in acceptHeaders: body = msgpack.packb(outputJson, use_bin_type=True) contentType = "application/msgpack" else: data = outputJson response = super().success( statusCode=statusCode, body=body, outputJson=data, contentType=contentType, extraHeaders=extraHeaders, ) return response
[docs] def checkLicense(self): """ Check that the license is available and the limit of face with attributes is not exceeded. Raises: VLException(Error.LicenseProblem, 403): if something went wrong with the license """ if not self.licenseChecker.checkExpirationTime(): raise VLException(Error.LicenseProblem.format("License expired"), statusCode=403, isCriticalError=False) if not self.licenseChecker.checkFacesLimit(): raise VLException( Error.LicenseProblem.format( "License limit exceeded. " "Please contact VisionLabs for license upgrade or delete redundant faces." ), statusCode=403, isCriticalError=False, )
@property def app(self) -> FacesApp: """ Get running app Returns: app """ return self.request.app @property def config(self) -> SettingsFaces: """ Get app config Returns: app config """ return self.app.ctx.serviceConfig
[docs]class FaceBaseRequestHandler(BaseRequestHandler): """ Base handler for handlers which work with faces and face attributes. """
[docs] async def getAttributeById(self, attributeId: str, accountId: str) -> TemporaryAttributes: """ Get temporary attribute container by attribute id. Args: attributeId: attribute id accountId: account id Returns: temporary attribute container Raises: VLException(Error.AttributesNotFound, 400): if attribute was not found """ try: return await self.attributesContext.get(attributeId=attributeId, accountId=accountId) except VLException as e: if e.error == Error.AttributesNotFound: e.statusCode = 400 raise e
[docs] @staticmethod def getAttributeByData(attribute: dict) -> TemporaryAttributes: """ Generate temporary attribute container by attribute data. Args: attribute: attribute data Returns: temporary attribute container """ attribute["account_id"] = None # it will belong to the particular face so it does not have an account if "face_descriptors" in attribute: descriptorsList = attribute["face_descriptors"] descriptors = AttributesBaseRequestHandler.convertInputDescriptors(descriptorsList) attribute["face_descriptors"] = descriptors attributeContainer = AttributesBaseRequestHandler.createTemporaryAttributes(attribute) AttributesBaseRequestHandler.validateAttribute(attributeContainer) return attributeContainer
[docs] async def checkListIds(self, listIds: Set[str], accountId: str) -> None: """ Check list ids existence. Args: listIds: list ids to check accountId: account id Raises: VLException(Error.ListsNotFound.format(nonExistListId), 400) if some list was not found """ nonExistListId = await self.facesContext.getNonexistentListId(requiredListIds=listIds, accountId=accountId) if nonExistListId is not None: raise VLException(Error.ListsNotFound.format(nonExistListId), 400, False)
[docs] async def getDataForNewFace(self) -> NewFace: """ Get data for new face from request Returns: new face structure """ data = self.request.json self.validateJson(data, CREATE_FACE_SCHEMA, False) validate_url(data.get("avatar")) data.setdefault("user_data", "") data.setdefault("external_id", "") listIds = data.pop("lists", None) if listIds is not None: listIds = set(listIds) attribute = data.pop("attribute", None) if attribute is not None: # check the license is available self.checkLicense() if "attribute_id" in attribute: attributeId = attribute["attribute_id"] attribute = await self.getAttributeById(attributeId=attributeId, accountId=data["account_id"]) else: attribute = self.getAttributeByData(attribute) return NewFace(data, listIds, attribute)
[docs]class AttributesBaseRequestHandler(BaseRequestHandler): """ Base handler for handlers which work with attributes. """
[docs] @staticmethod def createTemporaryAttributes(inputAttribute: dict, attributeId: Optional[str] = None) -> TemporaryAttributes: """ Create temporary attribute Args: inputAttribute: input json with converted descriptors attributeId: attribute id Returns: new temporary attribute """ basicAttributesDict = inputAttribute.get("basic_attributes") if basicAttributesDict: basicAttributes = BasicAttributes( age=basicAttributesDict["age"], gender=basicAttributesDict["gender"], ethnicity=basicAttributesDict["ethnicity"], ) else: basicAttributes = basicAttributesDict basicAttributesSamples = inputAttribute.get("basic_attributes_samples", ()) faceDescriptorsSamples = inputAttribute.get("face_descriptor_samples", ()) accountId = inputAttribute["account_id"] return TemporaryAttributes( accountId=accountId, attributeId=attributeId, descriptors=inputAttribute.get("face_descriptors"), basicAttributes=basicAttributes, basicAttributesSamples=basicAttributesSamples, descriptorSamples=faceDescriptorsSamples, )
[docs] def getEncodedDescriptor(self, descriptorVersion: int, descriptor: bytes) -> Union[bytes, str]: """ Encode given descriptor. If requested format requires, encode bytes with base64. Otherwise, return bytes. Returns: Encoded descriptor. """ encodedDescriptor = sdkDescriptorEncode(descriptorVersion, descriptor) acceptHeaders = self.request.headers.get("Accept", "application/json") if "application/msgpack" in acceptHeaders: return encodedDescriptor return bytesToBase64(encodedDescriptor)
[docs] def makeOutputAttribute( self, attribute: TemporaryAttributes, targets: List[str], descriptorVersion: Optional[int] = None ) -> dict: """ Make output attribute Args: attribute: temporary attribute targets: list of targets descriptorVersion: descriptor version Returns: dict with all fields requested fields """ meta = { "create_time": convertTimeToString(attribute.createTime, self.config.storageTime == "UTC"), "attribute_id": attribute.attributeId, "account_id": attribute.accountId, "face_descriptor_samples": attribute.descriptorSamples, "basic_attributes_samples": attribute.basicAttributesSamples, } res = {target: value for target, value in meta.items() if target in targets} descriptorVersion = ( descriptorVersion if descriptorVersion is not None else self.config.defaultFaceDescriptorVersion ) if TemporaryAttributeTargets.faceDescriptor.value in targets: for descriptor in attribute.descriptors: if descriptor.version == descriptorVersion: res["face_descriptor"] = self.getEncodedDescriptor(descriptor.version, descriptor.descriptor) break else: res["face_descriptor"] = None if TemporaryAttributeTargets.basicAttributes.value in targets: if attribute.basicAttributes is not None: res["basic_attributes"] = { "age": attribute.basicAttributes.age, "gender": attribute.basicAttributes.gender, "ethnicity": attribute.basicAttributes.ethnicity, } else: res["basic_attributes"] = None return res
[docs] @staticmethod def convertInputDescriptors(inputDescriptors: list) -> List[Descriptor]: """ Convert list of descriptors to Descriptor instances. Input descriptors have one of the following formats: 1. {"descriptor": b"descr", "version": 51} - from msgpack 2. {"descriptor": "base64 descr", "version": 51} - from json 3. b"SDK descr" - from msgpack 4. "base64 SDK descr" - from json Args: inputDescriptors: base64 descriptors from input json Returns: descriptors list """ descriptors = [] for descriptor in inputDescriptors: if isinstance(descriptor, dict): binaryDescriptor = descriptor["descriptor"] if isinstance(binaryDescriptor, str): binaryDescriptor = base64ToBytes(binaryDescriptor) descriptor = Descriptor(version=descriptor["version"], descriptor=binaryDescriptor) elif isinstance(descriptor, (str, bytes)): sdkDescriptor = descriptor if isinstance(sdkDescriptor, str): sdkDescriptor: bytes = base64ToBytes(sdkDescriptor) try: version, binaryDescriptor = sdkDescriptorDecode(sdkDescriptor) except (ValueError, SyntaxError): raise VLException(Error.BadSdkDescriptor, 400, isCriticalError=False) descriptor = Descriptor(version=version, descriptor=binaryDescriptor) else: raise RuntimeError(f"bad type of descriptor {type(descriptor)}") descriptors.append(descriptor) return descriptors
[docs] @staticmethod def validateAttributeDescriptors(attribute: TemporaryAttributes): """ Validate length and version of input descriptors. Also function checks a duplication of descriptor versions. Args: attribute: input attribute Raises: VLException(Error.InvalidDescriptorLength): if descriptor has incorrect length VLException(Error.UnknownDescriptorVersion): if descriptor has unknown version VLException(Error.AttributeWithDescriptorsIdenticalVersion): if there are two ore more descriptors with same versions """ for descriptor in attribute.descriptors: for descriptorType in DescriptorsEnum: if descriptorType.value.version == descriptor.version: if descriptorType.value.length != len(descriptor.descriptor): raise VLException( Error.InvalidDescriptorLength.format(len(descriptor.descriptor)), 400, isCriticalError=False ) break else: raise VLException(Error.UnknownDescriptorVersion.format(descriptor.version), 400, isCriticalError=False) versions = [descriptor.version for descriptor in attribute.descriptors] duplicated = [version for version in versions if versions.count(version) > 1] if duplicated: raise VLException( Error.AttributeWithDescriptorsIdenticalVersion.format(list(set(duplicated))), 400, isCriticalError=False )
[docs] async def getInputAttributeFromJson(self) -> dict: """ Get input attribute from request with json Returns: attribute as dict """ attribute = self.request.json self.validateJson(attribute, INPUT_ATTRIBUTE_SCHEMA, useJsonSchema=False) if "face_descriptors" in attribute: descriptorsList = attribute["face_descriptors"] descriptors = AttributesBaseRequestHandler.convertInputDescriptors(descriptorsList) attribute["face_descriptors"] = descriptors return attribute
[docs] async def getInputAttributeFromMultipart(self): """ Get input attribute from request with multipart Returns: attribute as dict """ attribute = {} files = self.request.files names = set(files.keys()) if names - set(("meta", "xpk_file")): unknownMultipartName = (names - set(("meta", "xpk_file"))).pop() raise VLException(Error.UnknownMultipartName.format(unknownMultipartName), 400, isCriticalError=False) if "xpk_file" in files: xpkFiles = files["xpk_file"] if len(xpkFiles) > 1: raise VLException(Error.DuplicateMultipartName.format("xpk_file"), 400, isCriticalError=False) xpkFilePart: File = xpkFiles[0] if not xpkFilePart.type.startswith(MimeTypes.XPK.value): raise VLException(Error.BadMultipartContentType.format("xpk_file"), 400, isCriticalError=False) binaryXpk = xpkFilePart.body try: xpk = readXPKFromBinary(binaryXpk) except ValueError: raise VLException(Error.BadInputXpk, 400, isCriticalError=False) if "Descriptor" in xpk: descriptor = Descriptor( version=xpk["Descriptor"]["version"], descriptor=xpk["Descriptor"]["raw_descriptor"] ) attribute["face_descriptors"] = [descriptor] if "meta" in files: metas = files["meta"] if len(metas) > 1: raise VLException(Error.DuplicateMultipartName.format("meta"), 400, isCriticalError=False) elif metas: metaPart: File = metas[0] if not metaPart.type.startswith(MimeTypes.JSON.value): raise VLException(Error.BadMultipartContentType.format("meta"), 400, isCriticalError=False) body = metaPart.body meta = json.loads(body.decode("utf-8")) self.validateJson(meta, INPUT_XPK_ATTRIBUTE_SCHEMA, useJsonSchema=False) attribute.update(meta) return attribute
[docs] async def getInputAttribute(self, attributeId: str) -> TemporaryAttributes: """ Get input attributes from either json or Multipart. Returns: temporary attribute """ mimeType = self.request.content_type if mimeType == MimeTypes.JSON.value or mimeType == MimeTypes.Msgpack.value: attribute = await self.getInputAttributeFromJson() elif mimeType == MimeTypes.Multipart.value: attribute = await self.getInputAttributeFromMultipart() else: raise VLException(Error.BadContentType, 400, isCriticalError=False) attribute = AttributesBaseRequestHandler.createTemporaryAttributes(attribute, attributeId) AttributesBaseRequestHandler.validateAttribute(attribute) return attribute
[docs] @staticmethod def validateAttribute(attribute: TemporaryAttributes): """ Validate compatibility attribute data Args: attribute: attribute Raises: VLException(Error.AttributeDoesNotContainAnyData): if attribute does not contain basic attributes and descriptors VLException(Error.AttributeContainsSamplesWithoutData): if attribute contain samples without data """ if not attribute.descriptors and not attribute.basicAttributes: raise VLException(Error.AttributeDoesNotContainAnyData, 400, isCriticalError=False) if attribute.descriptorSamples and not attribute.descriptors: raise VLException( Error.AttributeContainsSamplesWithoutData.format("descriptors"), 400, isCriticalError=False ) if attribute.basicAttributesSamples and not attribute.basicAttributes: raise VLException( Error.AttributeContainsSamplesWithoutData.format("basic_attributes"), 400, isCriticalError=False ) AttributesBaseRequestHandler.validateAttributeDescriptors(attribute)