"""Cross-matching handler."""
import os
from abc import ABC, abstractmethod
from typing import Union
import psutil
from sanic.compat import CancelledErrors
from sanic.response import HTTPResponse
from vlutils.descriptors.data import DescriptorType
from app_common.handlers.base_handler import CommonBaseHandler
from app_common.handlers.schemas import (
BodyCrossMatch,
CrossMatchAttributesFilters,
CrossMatchEventsFilters,
CrossMatchFacesFilters,
FaceCrossMatch,
)
from classes.cross_match import CrossMatcher
from classes.cross_match_helpers import packCrossMatchResult
from classes.enums import MatchSource
from classes.filters import AttributeFilters, EventFilters, FaceFilters
from classes.match_reply_helpers import getPreparedFilters
from configs.config_common import DEFAULT_CROSS_MATCH_LIMIT
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.web.query_getters import uuidGetter, validateIntAsBool
# map with match origin and filters
FILTERS_MAP = {
MatchSource.faces.value: FaceFilters,
MatchSource.events.value: EventFilters,
MatchSource.attributes.value: AttributeFilters,
}
[docs]class CrossMatcherBaseHandler(CommonBaseHandler, ABC):
"""
Base cross-matching handler
"""
@property
@abstractmethod
def descriptorType(self) -> DescriptorType:
"""
The tool descriptor type. Must be overridden.
"""
@property
def configLunaEventsUsage(self) -> bool:
"""
Get luna-events service usages
Returns:
`True` if luna-events service is used
"""
return self.config.additionalServicesUsage.lunaEvents
@property
@abstractmethod
def descriptorVersion(self) -> int:
"""
The tool descriptor version. Must be overridden.
"""
[docs] def validateCrossMatchJson(self, inputJson: dict) -> None:
"""
Validate input json. If error in candidates schemas - validate it separately and raise error with updated detail
Args:
inputJson: inputJson from request
Raises:
VLException(Error.BadInputJson, 400, False) if fail to validate input json
"""
try:
if self.descriptorType == DescriptorType.face:
self.validateJson(inputJson, FaceCrossMatch.schema, useJsonSchema=False)
else:
self.validateJson(inputJson, BodyCrossMatch.schema, useJsonSchema=False)
except VLException as mainValidationException:
partFromDetailWithPath = mainValidationException.error.detail.split("Path:")[1].split(",")[0]
if "candidates" in partFromDetailWithPath:
errorKey = "candidates"
elif "references" in partFromDetailWithPath:
errorKey = "references"
else:
raise mainValidationException
try:
if isinstance(inputJson[errorKey], dict):
if inputJson[errorKey].get("origin") == "faces":
self.validateJson(inputJson[errorKey], CrossMatchFacesFilters.schema, useJsonSchema=False)
if inputJson[errorKey].get("origin") == "events":
self.validateJson(inputJson[errorKey], CrossMatchEventsFilters.schema, useJsonSchema=False)
if inputJson[errorKey].get("origin") == "attributes":
self.validateJson(inputJson[errorKey], CrossMatchAttributesFilters.schema, useJsonSchema=False)
except VLException as deepException:
if deepException.error == Error.BadInputJson:
oldDetail = deepException.error.detail
pathStartIndex = oldDetail.index("Path: ") + len("Path: '")
newDetail = oldDetail[:pathStartIndex] + f"{errorKey}." + oldDetail[pathStartIndex:]
deepException.error.detail = newDetail
raise deepException
raise mainValidationException
[docs] def getFiltersFromJson(self, inputJson: dict, objectType: str) -> Union[FaceFilters, EventFilters]:
"""
Check account ids from request and luna-events services usage and get filters
Args:
inputJson: input json
objectType: candidates or references
Returns:
candidate filters and reference filters
"""
accountId = self.getQueryParam("account_id", uuidGetter, default=None)
filters = FILTERS_MAP[inputJson[objectType]["origin"]]().initFromRequest(inputJson[objectType])
if not self.configLunaEventsUsage and filters.origin == MatchSource.events.value:
raise VLException(Error.LunaEventsIsDisabled, 403, isCriticalError=False)
if accountId is not None and filters.accountId is not None and accountId != filters.accountId:
raise VLException(Error.DifferentAccounts.format(objectType), 400, False)
if accountId is not None:
filters.accountId = accountId
return filters
[docs] async def getPreparedFiltersAndCheckListsExistence(self, filters: Union[FaceFilters, EventFilters]) -> dict:
"""
Get filters prepared for reply
Args:
filters: filters
Returns:
dict with filters prepared for reply
"""
if isinstance(filters, FaceFilters) and filters.listId is not None:
if not await self.facesDBContext.getList(listId=filters.listId, accountId=filters.accountId):
raise VLException(Error.ListNotFound.format(filters.listId), 400, False)
return getPreparedFilters(filters)
[docs] async def post(self) -> HTTPResponse:
"""
Search for faces/events by given filters and matching them with each other.
To work with face descriptors, see `crossmatch_faces`_ or `crossmatch_bodies`_ for body descriptors.
.. _crossmatch_faces:
_static/api.html#operation/faceCrossMatching
.. _crossmatch_bodies:
_static/api.html#operation/bodyCrossMatching
Returns:
response with cross matching results
"""
self.logger.debug(f"Memory usages: {round(psutil.Process(os.getpid()).memory_info().rss / 2**30, 6)} Gb")
inputJson: dict = self.request.json
self.validateCrossMatchJson(inputJson)
candidateFilters = self.getFiltersFromJson(inputJson=inputJson, objectType="candidates")
referenceFilters = self.getFiltersFromJson(inputJson=inputJson, objectType="references")
preparedCandidateFilters = await self.getPreparedFiltersAndCheckListsExistence(candidateFilters)
preparedReferenceFilters = await self.getPreparedFiltersAndCheckListsExistence(referenceFilters)
limit = inputJson.get("limit", DEFAULT_CROSS_MATCH_LIMIT)
threshold = inputJson.get("threshold", 0.0)
sortOn = self.getQueryParam("sorting", validateIntAsBool, default=True)
crossMatchProcessor = CrossMatcher(
facesDBContext=self.facesDBContext,
eventsDBContext=self.eventsDBContext,
attributesDBContext=self.attributesDBContext,
candidateFilters=candidateFilters,
referenceFilters=referenceFilters,
limit=limit,
threshold=threshold,
sortOn=sortOn,
logger=self.logger,
descriptorType=self.descriptorType,
descriptorVersion=self.descriptorVersion,
)
successMatchResult, errorMatchResult = await crossMatchProcessor.process()
preparedRequestFilters = dict(
candidates=preparedCandidateFilters, references=preparedReferenceFilters, limit=limit, threshold=threshold
)
contentType = "application/msgpack"
packer = packCrossMatchResult
self.request.headers["Accept-Encoding"] = "identity"
try:
async with await self.request.respond(status=200, content_type=contentType) as sender:
for chunk in packer(preparedRequestFilters, successMatchResult, errorMatchResult):
await sender(chunk)
except CancelledErrors:
self.logger.info(f"Response streaming cancelled")
except Exception as e:
self.logger.error(f"Response streaming failed: {e}")
else:
self.logger.info(f"Response streaming completed")
self.logger.debug(f"Memory usages: {round(psutil.Process(os.getpid()).memory_info().rss / 2**30, 6)} Gb")
[docs]class FaceCrossMatcherHandler(CrossMatcherBaseHandler):
"""
Cross-matching handler to implement work with face descriptors.
Resource: "/{api_version}/crossmatcher/faces"
"""
descriptorType = DescriptorType.face
@property
def descriptorVersion(self) -> int:
"""Default face descriptor version."""
return self.config.defaultFaceDescriptorVersion
[docs]class BodyCrossMatcherHandler(CrossMatcherBaseHandler):
"""
Cross-matching handler to implement work with body descriptors.
Resource: "/{api_version}/crossmatcher/bodies"
"""
descriptorType = DescriptorType.body
@property
def descriptorVersion(self) -> int:
"""Default human descriptor version."""
return self.config.defaultHumanDescriptorVersion
[docs]class UnwantedCrossMatcherHandler(FaceCrossMatcherHandler):
"""
Unwanted handler for face cross-matching
Resource: "/{api_version}/crossmatcher" alias for "/{api_version}/crossmatcher/bodies"
"""
[docs] async def post(self) -> HTTPResponse:
"""
Print warning also.
"""
self.logger.warning(
f"Resource `/{self.app.ctx.apiVersion}/crossmatcher` is deprecated. "
f"Use `/{self.app.ctx.apiVersion}/crossmatcher/faces` instead."
)
return await super().post()