Source code for luna_handlers.classes.schemas.policies

"""
Module contains schemas for policies
"""
import asyncio
from typing import Any, Type, TypeVar, Optional, List, Dict, Tuple, Union
from uuid import UUID

from luna3.common.exceptions import LunaApiException
from pydantic import Field, root_validator, ValidationError
from pydantic.error_wrappers import ErrorWrapper

from luna3.client import Client
from lunavl.sdk.estimators.face_estimators.livenessv1 import LivenessPrediction
from lunavl.sdk.estimators.face_estimators.mask import MaskState
from vlutils.jobs.async_runner import AsyncRunner
from app.api_sdk_adaptors.handler import APISDKHandlerAdaptor
from app.global_vars.constants import MAX_ANGLE
from app.global_vars.context_vars import requestIdCtx
from classes.event import Event, Location
from classes.event_parts import RawDescriptorResult, RawImageResult
from classes.monitoring import HandlersMonitoringData as DataForMonitoring
from classes.raw_descriptor_data import RawDescriptorData
from classes.schemas.base_schema import BaseSchema
from classes.schemas.conditional_tags_policy import ConditionalTagsPolicy
from classes.schemas.detect_policies import HandlerDetectPolicy
from classes.schemas.extract_policies import HandlerExtractPolicy
from classes.schemas.filters import MatchFilter, AttributesFilters
from classes.schemas.match_policy import MatchPolicy
from classes.schemas.storage_policy import StoragePolicy
from classes.schemas.types import MAX_POLICY_LIST_LENGTH
from configs.config import PLUGINS_PUBLISHING_CONCURRENCY
from crutches_on_wheels.enums.attributes import Liveness
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.monitoring.points import monitorTime
from crutches_on_wheels.plugins.manager import PluginManager
from crutches_on_wheels.utils.log import Logger
from img_utils.utils import getExif
from redis_db.redis_context import RedisContext
from sdk.sdk_loop.enums import MultifacePolicy
from sdk.sdk_loop.estimation_targets import SDKEstimationTargets, SDKFaceEstimationTargets, SDKHumanEstimationTargets
from sdk.sdk_loop.sdk_task import (
    SDKTaskFilters,
    SDKTask,
    FaceWarp,
    SDKDetectableImage,
    TaskDataSource,
    HumanWarp,
)
from sdk.sdk_loop.task_loop import SDKTaskLoop

T = TypeVar("T")


[docs]def getObjectRecursively(data: Any, expectedType: Type[T]) -> List[T]: """Recursively get object of expected type""" res = [] def collectObjects(dataPart: Any) -> None: """Collect object of expected type to 'res'""" if isinstance(dataPart, expectedType): res.append(dataPart) if isinstance(dataPart, list): [collectObjects(row) for row in dataPart] if isinstance(dataPart, BaseSchema): [collectObjects(row) for row in dataPart.__dict__.values()] collectObjects(data) return res
[docs]def getImagesReplyData( sdkData: List[Union[HumanWarp, FaceWarp, SDKDetectableImage]], images: List[Union[RawImageResult, RawDescriptorResult]], extractExif: int, ) -> List[dict]: """ Get images' processing data in reply format. Args: sdkData: sdk data images: images extractExif: whether to extract exif Returns: list of images with status, filename, etc """ imagesInHandlerFormat = { image.id: {"error": image.error.asDict(), "status": image.status, "filename": image.filename} for image in images } if extractExif: for sdkImage in sdkData: if sdkImage.error is None: imagesInHandlerFormat[sdkImage.id]["exif"] = getExif(sdkImage) return list(imagesInHandlerFormat.values())
[docs]class Policies(BaseSchema): """Policies schema""" # detect policy detectPolicy: HandlerDetectPolicy = Field(default_factory=lambda: HandlerDetectPolicy()) # extract policy extractPolicy: HandlerExtractPolicy = HandlerExtractPolicy() # matching policy list matchPolicy: List[MatchPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH) # conditional tags policy list conditionalTagsPolicy: List[ConditionalTagsPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH) # storage policy storagePolicy: StoragePolicy = Field(default_factory=lambda: StoragePolicy())
[docs] @staticmethod def validateMatchAndExtractCompatibility(matchPolicies: List[MatchPolicy], extractPolicy: HandlerExtractPolicy): """Validate match and extract policies compatibility""" if len(matchPolicies) and not extractPolicy.extractFaceDescriptor: raise ValueError("extract_face_descriptor should be equal to 1 for using matching policy")
[docs] @staticmethod def validateMatchLabelsCompatibility( matchPolicies: List[MatchPolicy], conditionalTagsPolicies: List[ConditionalTagsPolicy], storagePolicy: StoragePolicy, ): """Validate matching label compatibility""" matchPolicyMatchingLabels = {matchPolicy.label for matchPolicy in matchPolicies} matchFiltersMatchingLabels = { matchFilter.label for matchFilter in getObjectRecursively([conditionalTagsPolicies, storagePolicy], MatchFilter) } for matchFilterLabel in matchFiltersMatchingLabels: if matchFilterLabel not in matchPolicyMatchingLabels: raise ValueError( f'"{matchFilterLabel}" should be in match policy for filtration based on a matching by this label' )
[docs] @staticmethod def validateGeneratedAttributesFilters( detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy, matchPolicies: List[MatchPolicy], conditionalTagsPolicies: List[ConditionalTagsPolicy], storagePolicy: StoragePolicy, ): """Validate attributes and detect/extract policy compatibility""" attributeFilters: List[AttributesFilters] = getObjectRecursively( data=[matchPolicies, conditionalTagsPolicies, storagePolicy], expectedType=AttributesFilters ) if not extractPolicy.extractBasicAttributes: basicAttributesNeeded = any( any( ( filters.ethnicities is not None, filters.ageLt is not None, filters.ageGte is not None, filters.gender is not None, ) ) for filters in attributeFilters ) if basicAttributesNeeded: raise ValueError( "extract_basic_attributes should be equal to 1 for filtration based on basic attributes" ) if not detectPolicy.estimateLiveness.estimate: livenessNeeded = any((filters.liveness is not None) for filters in attributeFilters) if livenessNeeded: raise ValueError("estimate_liveness.estimate should be equal to 1 for filtration based on liveness")
[docs] @staticmethod def validateMatchPolicyUniqueLabels(matchPolicies: List[MatchPolicy]): """Validate match policy matching label uniqueness""" labels = [matchPolicy.label for matchPolicy in matchPolicies] if len(labels) != len(set(labels)): error = ValueError("Matching allowed only by unique labels") raise ValidationError([ErrorWrapper(exc=error, loc=("match_policy"))], Policies)
[docs] @staticmethod def validateDetectPolicyNotEmpty(detectPolicy: HandlerDetectPolicy): """Validate non-empty detect policy""" if not detectPolicy.detectFace and not detectPolicy.detectBody: raise ValueError("At least one of *detect_face* or *detect_body* should be equal to 1") if not detectPolicy.detectFace and MultifacePolicy(detectPolicy.multifacePolicy) is not MultifacePolicy.allowed: raise ValueError("*detect_face* should be equal to 1 to set *multiface_policy* to 0 or 2")
[docs] @staticmethod def validateDetectAndExtractCompatibility(detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy): """Validate detect and extract policies compatibility""" if extractPolicy.extractFaceDescriptor: if not detectPolicy.detectFace: raise ValueError("*detect_face* should be equal to 1 to enable *extract_face_descriptor*") if extractPolicy.extractBasicAttributes: if not detectPolicy.detectFace: raise ValueError("*detect_face* should be equal to 1 to enable *extract_basic_attributes*") if extractPolicy.extractBodyDescriptor: if not detectPolicy.detectBody: raise ValueError("*detect_body* should be equal to 1 to enable *extract_body_descriptor*")
[docs] @root_validator(skip_on_failure=True) def validatePolicies(cls, values): """Execute all compatibility validators""" detectPolicy = values["detectPolicy"] matchPolicies = values["matchPolicy"] extractPolicy = values["extractPolicy"] conditionalTagsPolicies = values["conditionalTagsPolicy"] storagePolicy = values["storagePolicy"] cls.validateDetectPolicyNotEmpty(detectPolicy=detectPolicy) cls.validateDetectAndExtractCompatibility(detectPolicy=detectPolicy, extractPolicy=extractPolicy) cls.validateMatchPolicyUniqueLabels(matchPolicies=matchPolicies) cls.validateMatchAndExtractCompatibility(matchPolicies=matchPolicies, extractPolicy=extractPolicy) cls.validateMatchLabelsCompatibility( matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy ) cls.validateGeneratedAttributesFilters( detectPolicy=detectPolicy, extractPolicy=extractPolicy, matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy, ) return values
[docs] def getListIdsFromPolicies(self) -> List[UUID]: """ Get list ids from matching and link to lists policies. Returns: list ids """ candidateLists = list( filter(None, (match.candidates.listId for match in self.matchPolicy if hasattr(match.candidates, "listId"))) ) return [linkPolicy.listId for linkPolicy in self.storagePolicy.facePolicy.linkToListsPolicy] + candidateLists
[docs] async def checkListsAvailability(self, luna3Client: Client, accountId: str) -> None: """ Check availability of lists from matching and link to list policies. Args: luna3Client: luna platform client accountId: account id Raises: VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found """ for listId in self.getListIdsFromPolicies(): try: await luna3Client.lunaFaces.checkList(listId=str(listId), accountId=accountId, raiseError=True) except LunaApiException as e: if e.statusCode == 404: raise VLException(Error.ListNotFound.format(listId), 400, isCriticalError=False) raise
[docs] @classmethod async def onStartup(cls): """ Init Policies """ cls.pluginsAsyncRunner = AsyncRunner(PLUGINS_PUBLISHING_CONCURRENCY, closeTimeout=1)
[docs] @classmethod async def onShutdown(cls): """ Stop Policies """ await cls.pluginsAsyncRunner.close()
[docs] @staticmethod def enrichEventsWithMetadata( events: List[Event], source: Union[str, None], tags: Union[List[str], None], userData: str, externalId: Union[str, None], location: Location, trackId: Optional[str] = None, ) -> None: """ Enrich processing event with metadata. Args: events: events source: event source tags: event tags userData: user data externalId: event external id location: event location trackId: (str) event track id """ for event in events: event.source = source event.userData = userData event.externalId = externalId event.location = location event.trackId = trackId if tags is not None: event.tags = list({*tags, *event.tags})
[docs] def prepareSDKTask(self, sdkData: List[Union[SDKDetectableImage, FaceWarp, HumanWarp]], aggregate: int) -> SDKTask: """ Prepare sdk task Args: sdkData: a list of input images or warps aggregate: aggregate all extracted samples to one or not Returns: sdk task """ faceTargets = SDKFaceEstimationTargets( estimateQuality=self.detectPolicy.estimateQuality, estimateMouthAttributes=self.detectPolicy.estimateMouthAttributes, estimateAGS=0, estimateGaze=self.detectPolicy.estimateGaze, estimateEyesAttributes=self.detectPolicy.estimateEyesAttributes, estimateEmotions=self.detectPolicy.estimateEmotions, estimateMask=self.detectPolicy.estimateMask, estimateHeadPose=self.detectPolicy.estimateHeadPose, estimateBasicAttributes=self.extractPolicy.extractBasicAttributes, estimateFaceDescriptor=self.extractPolicy.extractFaceDescriptor, estimateLiveness=self.detectPolicy.estimateLiveness.asSDKPolicy(), ) humanTargets = SDKHumanEstimationTargets(estimateHumanDescriptor=self.extractPolicy.extractBodyDescriptor) estimateHuman = ( self.detectPolicy.detectBody or self.extractPolicy.extractBodyDescriptor or not humanTargets.isEmpty() ) toEstimate = SDKEstimationTargets( estimateHuman=estimateHuman, humanEstimationTargets=humanTargets, estimateFace=self.detectPolicy.detectFace, faceEstimationTargets=faceTargets, ) def suitFilter(x): """ Return useful thresholds. """ if x != MAX_ANGLE: return x return None maskStates = None if self.detectPolicy.maskStates: maskStates = [MaskState(x) for x in self.detectPolicy.maskStates] livenessStates = None if self.detectPolicy.livenessStates is not None: livenessStates = [LivenessPrediction(Liveness(x).name) for x in self.detectPolicy.livenessStates] filters = SDKTaskFilters( yawThreshold=suitFilter(self.detectPolicy.yawThreshold), pitchThreshold=suitFilter(self.detectPolicy.pitchThreshold), rollThreshold=suitFilter(self.detectPolicy.rollThreshold), livenessStates=livenessStates, garbageScoreThreshold=self.extractPolicy.fdScoreThreshold, maskStates=maskStates, ) return SDKTask( toEstimate, data=sdkData, filters=filters, aggregateAttributes=bool(aggregate), multifacePolicy=MultifacePolicy(self.detectPolicy.multifacePolicy), )
[docs] def publishEventsToPlugins( self, events: List[Event], handlerId: str, createEventTime: str, endEventTime: str, plugins: PluginManager, logger: Logger, accountId: str, ) -> None: """ Publish events to other services. Args: events: list of events handlerId: handler id createEventTime: event creation time endEventTime: event end time plugins: plugin manager logger: logger accountId: account id """ futures = [ plugins.sendEventToPlugins( "sending_event", events, handlerId, accountId, requestIdCtx.get(), createEventTime, endEventTime, logger ) ] self.pluginsAsyncRunner.runNoWait(futures)
[docs] async def execute( self, accountId: str, inputData: List[Union[SDKDetectableImage, FaceWarp, HumanWarp, RawDescriptorData]], sdkLoop: SDKTaskLoop, luna3Client: Client, aggregate: int, userData: str, externalId: Optional[str], source: str, userDefinedTags: List[str], location: Location, logger: Logger, createEventTime: str, endEventTime: str, handlerId: str, trackId: Optional[str], plugins: PluginManager, redisContext: RedisContext, facesBucket: str, bodiesBucket: str, originBucket: str, lunaEventsUsage: bool, lunaSenderUsage: bool, ) -> Tuple[List[dict], List[Event], Dict[str, List[dict]], DataForMonitoring]: """ Execute all policies for handler. Args: accountId: account id inputData: a list of input images (or warps) or descriptors sdkLoop: sdk loop luna3Client: luna platform client aggregate: aggregate all extracted samples to one or not userData: user data for created faces externalId: external id for created faces source: user-defined source userDefinedTags: user-defined tags location: user-defined location logger: logger createEventTime: event creation time endEventTime: event end time handlerId: handler id trackId: event track id plugins: plugin manager redisContext: redis context facesBucket: faces sample bucket bodiesBucket: bodies sample bucket originBucket: origin image bucket lunaEventsUsage: luna events usage lunaSenderUsage: luna sender usage Returns: tuple, first - all detection, second - events, third - monitoring data Raises: VLException(Error.AggregationNotSupported) if `aggregate` flag enabled and raw descriptor data received """ # detect + extract processedDataSources: List[Union[RawImageResult, RawDescriptorResult]] = [...] * len(inputData) descriptorIdxData = list(filter(lambda x: isinstance(x[1], RawDescriptorData), enumerate(inputData))) rawDescriptorIdxs, rawDescriptorData = zip(*descriptorIdxData) if descriptorIdxData else ([], []) rawImageIdxData = list(filter(lambda x: not isinstance(x[1], RawDescriptorData), enumerate(inputData))) rawImageIdxs, sdkData = zip(*rawImageIdxData) if rawImageIdxData else ([], []) if aggregate and rawDescriptorData: raise VLException(Error.AggregationNotSupported, 400, isCriticalError=False) sdkAdapter = APISDKHandlerAdaptor(logger=logger, accountId=accountId, sdkLoop=sdkLoop,) if sdkData: task = self.prepareSDKTask(sdkData, aggregate=aggregate) detectLandmarks68 = self.detectPolicy.detectLandmarks68 and task.source == TaskDataSource.images events, filteredDetections, processedImages, monitoringData = await sdkAdapter.handle( task=task, detectLandmarks68=detectLandmarks68 ) for index in rawImageIdxs: processedDataSources[index] = processedImages.pop(0) else: events, filteredDetections = [], sdkAdapter.prepareFilteredDetections() monitoringData = DataForMonitoring() rawDescriptorEvents, rawDescriptorProcessedImages = sdkAdapter.handleRawDescriptors(rawDescriptorData) events += rawDescriptorEvents for index in rawDescriptorIdxs: processedDataSources[index] = rawDescriptorProcessedImages.pop(0) # matching with monitorTime(monitoringData.request, "match_policy_time"): await asyncio.gather( *[matchByListPolicy.execute(events, luna3Client) for matchByListPolicy in self.matchPolicy] ) # tags for policy in self.conditionalTagsPolicy: policy.execute(events) # user data self.enrichEventsWithMetadata( events=events, source=source, tags=userDefinedTags, userData=userData, externalId=externalId, location=location, trackId=trackId, ) # storage monitoringData += await self.storagePolicy.execute( events=events, accountId=accountId, luna3Client=luna3Client, originImages=sdkData, userData=userData, externalId=externalId, logger=logger, createEventTime=createEventTime, endEventTime=endEventTime, handlerId=handlerId, facesBucket=facesBucket, bodiesBucket=bodiesBucket, originBucket=originBucket, lunaEventsUsage=lunaEventsUsage, redisContext=redisContext, lunaSenderUsage=lunaSenderUsage, ) self.publishEventsToPlugins( events=events, handlerId=handlerId, createEventTime=createEventTime, endEventTime=endEventTime, plugins=plugins, logger=logger, accountId=accountId, ) imagesReply = getImagesReplyData( sdkData=sdkData, images=processedDataSources, extractExif=self.detectPolicy.extractExif ) return imagesReply, events, filteredDetections, monitoringData