Source code for luna_python_matcher.app.handlers.cross_matcher_handler

"""Cross-matching handler."""
import os
from abc import ABC, abstractmethod
from typing import Union

import psutil
from sanic.compat import CancelledErrors
from sanic.response import HTTPResponse
from vlutils.descriptors.data import DescriptorType

from app_common.handlers.base_handler import CommonBaseHandler
from app_common.handlers.schemas import (
    BodyCrossMatch,
    CrossMatchAttributesFilters,
    CrossMatchEventsFilters,
    CrossMatchFacesFilters,
    FaceCrossMatch,
)
from classes.cross_match import CrossMatcher
from classes.cross_match_helpers import packCrossMatchResult
from classes.enums import MatchSource
from classes.filters import AttributeFilters, EventFilters, FaceFilters
from classes.match_reply_helpers import getPreparedFilters
from configs.config_common import DEFAULT_CROSS_MATCH_LIMIT
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.web.query_getters import uuidGetter, validateIntAsBool

# map with match origin and filters
FILTERS_MAP = {
    MatchSource.faces.value: FaceFilters,
    MatchSource.events.value: EventFilters,
    MatchSource.attributes.value: AttributeFilters,
}


[docs]class CrossMatcherBaseHandler(CommonBaseHandler, ABC): """ Base cross-matching handler """ @property @abstractmethod def descriptorType(self) -> DescriptorType: """ The tool descriptor type. Must be overridden. """ @property def configLunaEventsUsage(self) -> bool: """ Get luna-events service usages Returns: `True` if luna-events service is used """ return self.config.additionalServicesUsage.lunaEvents @property @abstractmethod def descriptorVersion(self) -> int: """ The tool descriptor version. Must be overridden. """
[docs] def validateCrossMatchJson(self, inputJson: dict) -> None: """ Validate input json. If error in candidates schemas - validate it separately and raise error with updated detail Args: inputJson: inputJson from request Raises: VLException(Error.BadInputJson, 400, False) if fail to validate input json """ try: if self.descriptorType == DescriptorType.face: self.validateJson(inputJson, FaceCrossMatch.schema, useJsonSchema=False) else: self.validateJson(inputJson, BodyCrossMatch.schema, useJsonSchema=False) except VLException as mainValidationException: partFromDetailWithPath = mainValidationException.error.detail.split("Path:")[1].split(",")[0] if "candidates" in partFromDetailWithPath: errorKey = "candidates" elif "references" in partFromDetailWithPath: errorKey = "references" else: raise mainValidationException try: if isinstance(inputJson[errorKey], dict): if inputJson[errorKey].get("origin") == "faces": self.validateJson(inputJson[errorKey], CrossMatchFacesFilters.schema, useJsonSchema=False) if inputJson[errorKey].get("origin") == "events": self.validateJson(inputJson[errorKey], CrossMatchEventsFilters.schema, useJsonSchema=False) if inputJson[errorKey].get("origin") == "attributes": self.validateJson(inputJson[errorKey], CrossMatchAttributesFilters.schema, useJsonSchema=False) except VLException as deepException: if deepException.error == Error.BadInputJson: oldDetail = deepException.error.detail pathStartIndex = oldDetail.index("Path: ") + len("Path: '") newDetail = oldDetail[:pathStartIndex] + f"{errorKey}." + oldDetail[pathStartIndex:] deepException.error.detail = newDetail raise deepException raise mainValidationException
[docs] def getFiltersFromJson(self, inputJson: dict, objectType: str) -> Union[FaceFilters, EventFilters]: """ Check account ids from request and luna-events services usage and get filters Args: inputJson: input json objectType: candidates or references Returns: candidate filters and reference filters """ accountId = self.getQueryParam("account_id", uuidGetter, default=None) filters = FILTERS_MAP[inputJson[objectType]["origin"]]().initFromRequest(inputJson[objectType]) if not self.configLunaEventsUsage and filters.origin == MatchSource.events.value: raise VLException(Error.LunaEventsIsDisabled, 403, isCriticalError=False) if accountId is not None and filters.accountId is not None and accountId != filters.accountId: raise VLException(Error.DifferentAccounts.format(objectType), 400, False) if accountId is not None: filters.accountId = accountId return filters
[docs] async def getPreparedFiltersAndCheckListsExistence(self, filters: Union[FaceFilters, EventFilters]) -> dict: """ Get filters prepared for reply Args: filters: filters Returns: dict with filters prepared for reply """ if isinstance(filters, FaceFilters) and filters.listId is not None: if not await self.facesDBContext.getList(listId=filters.listId, accountId=filters.accountId): raise VLException(Error.ListNotFound.format(filters.listId), 400, False) return getPreparedFilters(filters)
[docs] async def post(self) -> HTTPResponse: """ Search for faces/events by given filters and matching them with each other. To work with face descriptors, see `crossmatch_faces`_ or `crossmatch_bodies`_ for body descriptors. .. _crossmatch_faces: _static/api.html#operation/faceCrossMatching .. _crossmatch_bodies: _static/api.html#operation/bodyCrossMatching Returns: response with cross matching results """ self.logger.debug(f"Memory usages: {round(psutil.Process(os.getpid()).memory_info().rss / 2**30, 6)} Gb") inputJson: dict = self.request.json self.validateCrossMatchJson(inputJson) candidateFilters = self.getFiltersFromJson(inputJson=inputJson, objectType="candidates") referenceFilters = self.getFiltersFromJson(inputJson=inputJson, objectType="references") preparedCandidateFilters = await self.getPreparedFiltersAndCheckListsExistence(candidateFilters) preparedReferenceFilters = await self.getPreparedFiltersAndCheckListsExistence(referenceFilters) limit = inputJson.get("limit", DEFAULT_CROSS_MATCH_LIMIT) threshold = inputJson.get("threshold", 0.0) sortOn = self.getQueryParam("sorting", validateIntAsBool, default=True) crossMatchProcessor = CrossMatcher( facesDBContext=self.facesDBContext, eventsDBContext=self.eventsDBContext, attributesDBContext=self.attributesDBContext, candidateFilters=candidateFilters, referenceFilters=referenceFilters, limit=limit, threshold=threshold, sortOn=sortOn, logger=self.logger, descriptorType=self.descriptorType, descriptorVersion=self.descriptorVersion, ) successMatchResult, errorMatchResult = await crossMatchProcessor.process() preparedRequestFilters = dict( candidates=preparedCandidateFilters, references=preparedReferenceFilters, limit=limit, threshold=threshold ) contentType = "application/msgpack" packer = packCrossMatchResult self.request.headers["Accept-Encoding"] = "identity" try: async with await self.request.respond(status=200, content_type=contentType) as sender: for chunk in packer(preparedRequestFilters, successMatchResult, errorMatchResult): await sender(chunk) except CancelledErrors: self.logger.info(f"Response streaming cancelled") except Exception as e: self.logger.error(f"Response streaming failed: {e}") else: self.logger.info(f"Response streaming completed") self.logger.debug(f"Memory usages: {round(psutil.Process(os.getpid()).memory_info().rss / 2**30, 6)} Gb")
[docs]class FaceCrossMatcherHandler(CrossMatcherBaseHandler): """ Cross-matching handler to implement work with face descriptors. Resource: "/{api_version}/crossmatcher/faces" """ descriptorType = DescriptorType.face @property def descriptorVersion(self) -> int: """Default face descriptor version.""" return self.config.defaultFaceDescriptorVersion
[docs]class BodyCrossMatcherHandler(CrossMatcherBaseHandler): """ Cross-matching handler to implement work with body descriptors. Resource: "/{api_version}/crossmatcher/bodies" """ descriptorType = DescriptorType.body @property def descriptorVersion(self) -> int: """Default human descriptor version.""" return self.config.defaultHumanDescriptorVersion
[docs]class UnwantedCrossMatcherHandler(FaceCrossMatcherHandler): """ Unwanted handler for face cross-matching Resource: "/{api_version}/crossmatcher" alias for "/{api_version}/crossmatcher/bodies" """
[docs] async def post(self) -> HTTPResponse: """ Print warning also. """ self.logger.warning( f"Resource `/{self.app.ctx.apiVersion}/crossmatcher` is deprecated. " f"Use `/{self.app.ctx.apiVersion}/crossmatcher/faces` instead." ) return await super().post()