from enum import Enum
from typing import List, NamedTuple, Optional, Set, Union
import msgpack
import ujson as json
from sanic.request import File
from sanic.response import HTTPResponse
from vlutils.descriptors.containers import sdkDescriptorDecode, sdkDescriptorEncode
from vlutils.descriptors.data import DescriptorsEnum
from vlutils.descriptors.xpk_reader import readXPKFromBinary
from vlutils.helpers import bytesToBase64, convertTimeToString
from app.app import FacesApp, FacesRequest
from app.handlers.helpers import base64ToBytes
from app.handlers.schemas import CREATE_FACE_SCHEMA, INPUT_ATTRIBUTE_SCHEMA, INPUT_XPK_ATTRIBUTE_SCHEMA
from app.handlers.validators import validate_url
from attributes_db.model import BasicAttributes, Descriptor, TemporaryAttributes
from configs.configs.configs.services import SettingsFaces
from crutches_on_wheels.cow.enums.attributes import TemporaryAttributeTargets
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.web.handlers import BaseHandler
descriptorDependencies = {"descriptor", "descriptor_samples", "descriptor_obtaining_method", "descriptor_version"}
basicDependencies = {
"age",
"age_obtaining_method",
"age_version",
"gender",
"gender_obtaining_method",
"gender_version",
"ethnicity",
"ethnicity_obtaining_method",
"ethnicity_version",
"basic_attributes_samples",
}
[docs]class NewFace(NamedTuple):
"""
Data for new face
"""
data: dict #: dict with face data (user_data)
listIds: Optional[Set[str]] = None #: lists to link
attribute: Optional[TemporaryAttributes] = None #: face attribute
[docs]class MimeTypes(Enum):
BinaryDescriptor = "application/x-binary-descriptor"
FlatBuffers = "application/x-flatbuf"
JSON = "application/json"
Multipart = "multipart/form-data"
XPK = "application/x-vl-xpk"
Msgpack = "application/msgpack"
[docs]class BaseRequestHandler(BaseHandler):
"""
Base handler for other handlers.
"""
def __init__(self, request: FacesRequest):
super().__init__(request)
self.facesContext = request.dbContext
self.attributesContext = request.attributesContext
self.licenseChecker = request.licenseChecker
[docs] def success(
self,
statusCode: int = 200,
body: Optional[Union[bytes, str]] = None,
outputJson: Optional[Union[dict, list]] = None,
contentType: Optional[str] = None,
extraHeaders: Optional[dict] = None,
) -> HTTPResponse:
"""
Finish success request. Generate correct reply with request id header, correct Content Type header
Support "application/msgpack" "Accept" header.
Args:
contentType: body content type
statusCode: response status code, range(200, 300), default 200
body: pure body
outputJson: json as object
extraHeaders: extra headers that will be added to the response
(default headers in case of overlapping will be replaced);
Returns:
aiohttp.web.Response object
Raises:
ValueError: if response has body but its content-type was not determined in "Content-Type" header
"""
data = None
if outputJson is not None:
if body is not None:
raise ValueError("Must be only one data source.")
acceptHeaders = self.request.headers.get("Accept", "application/json")
if "application/msgpack" in acceptHeaders:
body = msgpack.packb(outputJson, use_bin_type=True)
contentType = "application/msgpack"
else:
data = outputJson
response = super().success(
statusCode=statusCode,
body=body,
outputJson=data,
contentType=contentType,
extraHeaders=extraHeaders,
)
return response
[docs] def checkLicense(self):
"""
Check that the license is available and the limit of face with attributes is not exceeded.
Raises:
VLException(Error.LicenseProblem, 403): if something went wrong with the license
"""
if not self.licenseChecker.checkExpirationTime():
raise VLException(Error.LicenseProblem.format("License expired"), statusCode=403, isCriticalError=False)
if not self.licenseChecker.checkFacesLimit():
raise VLException(
Error.LicenseProblem.format(
"License limit exceeded. "
"Please contact VisionLabs for license upgrade or delete redundant faces."
),
statusCode=403,
isCriticalError=False,
)
@property
def app(self) -> FacesApp:
"""
Get running app
Returns:
app
"""
return self.request.app
@property
def config(self) -> SettingsFaces:
"""
Get app config
Returns:
app config
"""
return self.app.ctx.serviceConfig
[docs]class FaceBaseRequestHandler(BaseRequestHandler):
"""
Base handler for handlers which work with faces and face attributes.
"""
[docs] async def getAttributeById(self, attributeId: str, accountId: str) -> TemporaryAttributes:
"""
Get temporary attribute container by attribute id.
Args:
attributeId: attribute id
accountId: account id
Returns:
temporary attribute container
Raises:
VLException(Error.AttributesNotFound, 400): if attribute was not found
"""
try:
return await self.attributesContext.get(attributeId=attributeId, accountId=accountId)
except VLException as e:
if e.error == Error.AttributesNotFound:
e.statusCode = 400
raise e
[docs] @staticmethod
def getAttributeByData(attribute: dict) -> TemporaryAttributes:
"""
Generate temporary attribute container by attribute data.
Args:
attribute: attribute data
Returns:
temporary attribute container
"""
attribute["account_id"] = None # it will belong to the particular face so it does not have an account
if "face_descriptors" in attribute:
descriptorsList = attribute["face_descriptors"]
descriptors = AttributesBaseRequestHandler.convertInputDescriptors(descriptorsList)
attribute["face_descriptors"] = descriptors
attributeContainer = AttributesBaseRequestHandler.createTemporaryAttributes(attribute)
AttributesBaseRequestHandler.validateAttribute(attributeContainer)
return attributeContainer
[docs] async def checkListIds(self, listIds: Set[str], accountId: str) -> None:
"""
Check list ids existence.
Args:
listIds: list ids to check
accountId: account id
Raises:
VLException(Error.ListsNotFound.format(nonExistListId), 400) if some list was not found
"""
nonExistListId = await self.facesContext.getNonexistentListId(requiredListIds=listIds, accountId=accountId)
if nonExistListId is not None:
raise VLException(Error.ListsNotFound.format(nonExistListId), 400, False)
[docs] async def getDataForNewFace(self) -> NewFace:
"""
Get data for new face from request
Returns:
new face structure
"""
data = self.request.json
self.validateJson(data, CREATE_FACE_SCHEMA, False)
accountId = self.getAccountIdFromHeader()
data["account_id"] = accountId
validate_url(data.get("avatar"))
data.setdefault("user_data", "")
data.setdefault("external_id", "")
listIds = data.pop("lists", None)
if listIds is not None:
listIds = set(listIds)
attribute = data.pop("attribute", None)
if attribute is not None:
# check the license is available
self.checkLicense()
if "attribute_id" in attribute:
attributeId = attribute["attribute_id"]
attribute = await self.getAttributeById(attributeId=attributeId, accountId=accountId)
else:
attribute = self.getAttributeByData(attribute)
return NewFace(data, listIds, attribute)
[docs]class AttributesBaseRequestHandler(BaseRequestHandler):
"""
Base handler for handlers which work with attributes.
"""
[docs] @staticmethod
def createTemporaryAttributes(inputAttribute: dict, attributeId: Optional[str] = None) -> TemporaryAttributes:
"""
Create temporary attribute
Args:
inputAttribute: input json with converted descriptors
attributeId: attribute id
Returns:
new temporary attribute
"""
basicAttributesDict = inputAttribute.get("basic_attributes")
if basicAttributesDict:
basicAttributes = BasicAttributes(
age=basicAttributesDict["age"],
gender=basicAttributesDict["gender"],
ethnicity=basicAttributesDict["ethnicity"],
)
else:
basicAttributes = basicAttributesDict
basicAttributesSamples = inputAttribute.get("basic_attributes_samples", ())
faceDescriptorsSamples = inputAttribute.get("face_descriptor_samples", ())
accountId = inputAttribute["account_id"]
return TemporaryAttributes(
accountId=accountId,
attributeId=attributeId,
descriptors=inputAttribute.get("face_descriptors"),
basicAttributes=basicAttributes,
basicAttributesSamples=basicAttributesSamples,
descriptorSamples=faceDescriptorsSamples,
)
[docs] def getEncodedDescriptor(self, descriptorVersion: int, descriptor: bytes) -> Union[bytes, str]:
"""
Encode given descriptor.
If requested format requires, encode bytes with base64. Otherwise, return bytes.
Returns:
Encoded descriptor.
"""
encodedDescriptor = sdkDescriptorEncode(descriptorVersion, descriptor)
acceptHeaders = self.request.headers.get("Accept", "application/json")
if "application/msgpack" in acceptHeaders:
return encodedDescriptor
return bytesToBase64(encodedDescriptor)
[docs] def makeOutputAttribute(
self, attribute: TemporaryAttributes, targets: List[str], descriptorVersion: Optional[int] = None
) -> dict:
"""
Make output attribute
Args:
attribute: temporary attribute
targets: list of targets
descriptorVersion: descriptor version
Returns:
dict with all fields requested fields
"""
meta = {
"create_time": convertTimeToString(attribute.createTime, self.config.storageTime == "UTC"),
"attribute_id": attribute.attributeId,
"account_id": attribute.accountId,
"face_descriptor_samples": attribute.descriptorSamples,
"basic_attributes_samples": attribute.basicAttributesSamples,
}
res = {target: value for target, value in meta.items() if target in targets}
descriptorVersion = (
descriptorVersion if descriptorVersion is not None else self.config.defaultFaceDescriptorVersion
)
if TemporaryAttributeTargets.faceDescriptor.value in targets:
for descriptor in attribute.descriptors:
if descriptor.version == descriptorVersion:
res["face_descriptor"] = self.getEncodedDescriptor(descriptor.version, descriptor.descriptor)
break
else:
res["face_descriptor"] = None
if TemporaryAttributeTargets.basicAttributes.value in targets:
if attribute.basicAttributes is not None:
res["basic_attributes"] = {
"age": attribute.basicAttributes.age,
"gender": attribute.basicAttributes.gender,
"ethnicity": attribute.basicAttributes.ethnicity,
}
else:
res["basic_attributes"] = None
return res
[docs] @staticmethod
def validateAttributeDescriptors(attribute: TemporaryAttributes):
"""
Validate length and version of input descriptors. Also function checks a duplication of descriptor versions.
Args:
attribute: input attribute
Raises:
VLException(Error.InvalidDescriptorLength): if descriptor has incorrect length
VLException(Error.UnknownDescriptorVersion): if descriptor has unknown version
VLException(Error.AttributeWithDescriptorsIdenticalVersion): if there are two ore more descriptors
with same versions
"""
for descriptor in attribute.descriptors:
for descriptorType in DescriptorsEnum:
if descriptorType.value.version == descriptor.version:
if descriptorType.value.length != len(descriptor.descriptor):
raise VLException(
Error.InvalidDescriptorLength.format(len(descriptor.descriptor)), 400, isCriticalError=False
)
break
else:
raise VLException(Error.UnknownDescriptorVersion.format(descriptor.version), 400, isCriticalError=False)
versions = [descriptor.version for descriptor in attribute.descriptors]
duplicated = [version for version in versions if versions.count(version) > 1]
if duplicated:
raise VLException(
Error.AttributeWithDescriptorsIdenticalVersion.format(list(set(duplicated))), 400, isCriticalError=False
)
[docs] @staticmethod
def validateAttribute(attribute: TemporaryAttributes):
"""
Validate compatibility attribute data
Args:
attribute: attribute
Raises:
VLException(Error.AttributeDoesNotContainAnyData): if attribute does not contain basic attributes and
descriptors
VLException(Error.AttributeContainsSamplesWithoutData): if attribute contain samples without data
"""
if not attribute.descriptors and not attribute.basicAttributes:
raise VLException(Error.AttributeDoesNotContainAnyData, 400, isCriticalError=False)
if attribute.descriptorSamples and not attribute.descriptors:
raise VLException(
Error.AttributeContainsSamplesWithoutData.format("descriptors"), 400, isCriticalError=False
)
if attribute.basicAttributesSamples and not attribute.basicAttributes:
raise VLException(
Error.AttributeContainsSamplesWithoutData.format("basic_attributes"), 400, isCriticalError=False
)
AttributesBaseRequestHandler.validateAttributeDescriptors(attribute)