Source code for luna_handlers.classes.schemas.policies

"""
Module contains schemas for policies
"""
import asyncio
from functools import cached_property
from typing import Any, Optional, Type, TypeVar, Union
from uuid import UUID

import cachetools
from luna3.client import Client
from luna3.common.luna_response import LunaResponse
from luna_plugins.base.manager import PluginManager
from lunavl.sdk.estimators.face_estimators.livenessv1 import LivenessPrediction
from lunavl.sdk.estimators.face_estimators.mask import MaskState
from pydantic import Field, ValidationError, root_validator
from pydantic.error_wrappers import ErrorWrapper
from vlutils.cache.cache import cache
from vlutils.jobs.async_runner import AsyncRunner
from vlutils.structures.dataclasses import dataclass

from app.api_sdk_adaptors.base import buildImageFilteredDetections, executeSDKTask
from app.api_sdk_adaptors.handler import APISDKHandlerAdaptor
from app.global_vars.constants import MAX_ANGLE
from app.global_vars.context_vars import requestIdCtx
from classes.event import Event, EventMetadata
from classes.image_meta import ProcessedImageData
from classes.monitoring import HandlersMonitoringData
from classes.raw_descriptor_data import RawDescriptorData
from classes.schemas.base_schema import BaseSchema
from classes.schemas.conditional_tags_policy import ConditionalTagsPolicy
from classes.schemas.detect_policies import HandlerDetectPolicy
from classes.schemas.estimation_targets import ExtendedEstimationTargets, ResponseEstimations, ResponseOnlyEstimations
from classes.schemas.extract_policies import HandlerExtractPolicy
from classes.schemas.filters import AttributesFilters, MatchFilter
from classes.schemas.match_policy import MatchPolicy
from classes.schemas.storage_policy import StoragePolicy, StorePolicyConfig
from classes.schemas.types import MAX_POLICY_LIST_LENGTH
from configs.config import PLUGINS_PUBLISHING_CONCURRENCY
from crutches_on_wheels.cow.enums.attributes import Liveness
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.monitoring.points import monitorTime
from redis_db.redis_context import RedisContext
from sdk.sdk_loop.enums import LoopEstimations, MultifacePolicy
from sdk.sdk_loop.errors.errors import MultipleFaces
from sdk.sdk_loop.models.image import InputImage
from sdk.sdk_loop.tasks.filters import FaceDetectionFilters, Filters
from sdk.sdk_loop.tasks.task import LivenessV1Params, TaskEstimationParams, TaskParams

T = TypeVar("T")
# cache for credentials verification response
cacheTTL = cachetools.TTLCache(maxsize=256, ttl=10)


[docs]def getObjectRecursively(data: Any, expectedType: Type[T]) -> list[T]: """Recursively get object of expected type""" res = [] def collectObjects(dataPart: Any) -> None: """Collect object of expected type to 'res'""" if isinstance(dataPart, expectedType): res.append(dataPart) if isinstance(dataPart, list): [collectObjects(row) for row in dataPart] if isinstance(dataPart, BaseSchema): [collectObjects(row) for row in dataPart.__dict__.values()] collectObjects(data) return res
[docs]@cache(lambda: cacheTTL, keyGen=lambda luna3Client, listId, accountId: f"{listId}{accountId}".__hash__()) async def checkListAvailability(luna3Client: Client, listId: str, accountId: Optional[str] = None) -> LunaResponse: """ Check availability of list. Args: luna3Client: luna platform client accountId: account id listId: list id Returns: Response from luna-faces service """ return await luna3Client.lunaFaces.checkList(listId=listId, accountId=accountId)
[docs]@dataclass(withSlots=True) class HandlerConfig(StorePolicyConfig): """Handler config that policies should apply.""" aggregate: bool useExifInfo: bool useAutoRotation: bool faceDescriptorVersion: int bodyDescriptorVersion: int
[docs]class Policies(BaseSchema): """Policies schema""" # detect policy detectPolicy: HandlerDetectPolicy = Field(default_factory=lambda: HandlerDetectPolicy()) # extract policy extractPolicy: HandlerExtractPolicy = HandlerExtractPolicy() # matching policy list matchPolicy: list[MatchPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH) # conditional tags policy list conditionalTagsPolicy: list[ConditionalTagsPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH) # storage policy storagePolicy: StoragePolicy = Field(default_factory=lambda: StoragePolicy())
[docs] @staticmethod def validateMatchAndExtractCompatibility(matchPolicies: list[MatchPolicy], extractPolicy: HandlerExtractPolicy): """Validate match and extract policies compatibility""" faceMatchPolicies = list(filter(lambda x: x.descriptor.descriptorType == "face", matchPolicies)) if faceMatchPolicies and not extractPolicy.extractFaceDescriptor: raise ValueError("extract_face_descriptor should be equal to 1 for using face matching policy") bodyMatchPolicies = list(filter(lambda x: x.descriptor.descriptorType == "body", matchPolicies)) if bodyMatchPolicies and not extractPolicy.extractBodyDescriptor: raise ValueError("extract_body_descriptor should be equal to 1 for using body matching policy")
[docs] @staticmethod def validateMatchLabelsCompatibility( matchPolicies: list[MatchPolicy], conditionalTagsPolicies: list[ConditionalTagsPolicy], storagePolicy: StoragePolicy, ): """Validate matching label compatibility""" matchPolicyMatchingLabels = {matchPolicy.label for matchPolicy in matchPolicies} matchFiltersMatchingLabels = { matchFilter.label for matchFilter in getObjectRecursively([conditionalTagsPolicies, storagePolicy], MatchFilter) } for matchFilterLabel in matchFiltersMatchingLabels: if matchFilterLabel not in matchPolicyMatchingLabels: raise ValueError( f'"{matchFilterLabel}" should be in match policy for filtration based on a matching by this label' )
[docs] @staticmethod def validateGeneratedAttributesFilters( detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy, matchPolicies: list[MatchPolicy], conditionalTagsPolicies: list[ConditionalTagsPolicy], storagePolicy: StoragePolicy, ): """Validate attributes and detect/extract policy compatibility""" attributeFilters: list[AttributesFilters] = getObjectRecursively( data=[matchPolicies, conditionalTagsPolicies, storagePolicy], expectedType=AttributesFilters ) if not extractPolicy.extractBasicAttributes: basicAttributesNeeded = any( any( ( filters.ethnicities is not None, filters.ageLt is not None, filters.ageGte is not None, filters.gender is not None, ) ) for filters in attributeFilters ) if basicAttributesNeeded: raise ValueError( "extract_basic_attributes should be equal to 1 for filtration based on basic attributes" ) if not detectPolicy.estimateLiveness.estimate: livenessNeeded = any((filters.liveness is not None) for filters in attributeFilters) if livenessNeeded: raise ValueError("estimate_liveness.estimate should be equal to 1 for filtration based on liveness")
[docs] @staticmethod def validateMatchPolicyUniqueLabels(matchPolicies: list[MatchPolicy]): """Validate match policy matching label uniqueness""" labels = [matchPolicy.label for matchPolicy in matchPolicies] if len(labels) != len(set(labels)): error = ValueError("Matching allowed only by unique labels") raise ValidationError([ErrorWrapper(exc=error, loc=("match_policy"))], Policies)
[docs] @staticmethod def validateDetectPolicyNotEmpty(detectPolicy: HandlerDetectPolicy): """Validate non-empty detect policy""" if not detectPolicy.detectFace and not detectPolicy.detectBody: raise ValueError("At least one of *detect_face* or *detect_body* should be equal to 1") if not detectPolicy.detectFace and MultifacePolicy(detectPolicy.multifacePolicy) is not MultifacePolicy.allowed: raise ValueError("*detect_face* should be equal to 1 to set *multiface_policy* to 0 or 2")
[docs] @staticmethod def validateDetectAndExtractCompatibility(detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy): """Validate detect and extract policies compatibility""" if extractPolicy.extractFaceDescriptor: if not detectPolicy.detectFace: raise ValueError("*detect_face* should be equal to 1 to enable *extract_face_descriptor*") if extractPolicy.extractBasicAttributes: if not detectPolicy.detectFace: raise ValueError("*detect_face* should be equal to 1 to enable *extract_basic_attributes*") if extractPolicy.extractBodyDescriptor: if not detectPolicy.detectBody: raise ValueError("*detect_body* should be equal to 1 to enable *extract_body_descriptor*")
[docs] @root_validator(skip_on_failure=True) def validatePolicies(cls, values): """Execute all compatibility validators""" detectPolicy = values["detectPolicy"] matchPolicies = values["matchPolicy"] extractPolicy = values["extractPolicy"] conditionalTagsPolicies = values["conditionalTagsPolicy"] storagePolicy = values["storagePolicy"] cls.validateDetectPolicyNotEmpty(detectPolicy=detectPolicy) cls.validateDetectAndExtractCompatibility(detectPolicy=detectPolicy, extractPolicy=extractPolicy) cls.validateMatchPolicyUniqueLabels(matchPolicies=matchPolicies) cls.validateMatchAndExtractCompatibility(matchPolicies=matchPolicies, extractPolicy=extractPolicy) cls.validateMatchLabelsCompatibility( matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy ) cls.validateGeneratedAttributesFilters( detectPolicy=detectPolicy, extractPolicy=extractPolicy, matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy, ) return values
@staticmethod async def _checkListsAvailability( luna3Client: Client, listIds: list[UUID], accountId: Optional[Union[str, UUID]] = None ) -> None: """ Check availability of lists from matching and link to list policies. Args: luna3Client: luna platform client listIds: list ids accountId: account id Raises: VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found """ for listId in listIds: reply = await checkListAvailability( luna3Client=luna3Client, accountId=str(accountId) if accountId is not None else None, listId=str(listId) ) if reply.success: continue if reply.statusCode == 404: raise VLException(Error.ListNotFound.format(listId), 400, False) raise VLException( Error.UnknownServiceError.format( "luna-faces", "HEAD", f"{luna3Client.lunaFaces.baseUri}/lists/{listId}" ), reply.statusCode, False, ) async def _checkLinkingListsAvailability(self, luna3Client: Client, accountId: Optional[str] = None) -> None: """ Check availability of lists from linking policies. Args: luna3Client: luna platform client """ await self._checkListsAvailability( luna3Client=luna3Client, listIds=[linkPolicy.listId for linkPolicy in self.storagePolicy.facePolicy.linkToListsPolicy], accountId=accountId, ) async def _checkMatchingListsAvailability(self, luna3Client: Client) -> None: """ Check availability of lists from matching policies. Args: luna3Client: luna platform client """ for matchPolicy in self.matchPolicy: if (listId := getattr(matchPolicy.candidates, "listId", None)) is None: continue await self._checkListsAvailability( luna3Client=luna3Client, listIds=[listId], accountId=matchPolicy.candidates.accountId, )
[docs] async def checkListsAvailability(self, luna3Client: Client, accountId: str) -> None: """ Check availability of lists from matching and link to list policies. Args: luna3Client: luna platform client accountId: account id Raises: VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found """ await self._checkMatchingListsAvailability(luna3Client=luna3Client) await self._checkLinkingListsAvailability(luna3Client=luna3Client, accountId=accountId)
[docs] @classmethod async def onStartup(cls): """Init Policies""" cls.pluginsAsyncRunner = AsyncRunner(PLUGINS_PUBLISHING_CONCURRENCY, closeTimeout=1)
[docs] @classmethod async def onShutdown(cls): """Stop Policies""" await cls.pluginsAsyncRunner.close()
@cached_property def extendedEstimationTargets(self) -> ExtendedEstimationTargets: """ Get extended estimation targets """ return ExtendedEstimationTargets( sdkTargets=self._sdkTargets(), responseTargets=self._sdkTargets() | self.detectPolicy.faceQualityTargets, ) def _sdkTargets(self) -> set[ResponseEstimations]: """ Prepare sdk task targets Returns: sdk task targets """ targets = set() if self.detectPolicy.extractExif: targets.add(LoopEstimations.exif) if self.detectPolicy.detectFace: targets.add(LoopEstimations.faceLandmarks5) targets.add(LoopEstimations.faceDetection) targets.add(LoopEstimations.faceWarp) # if there are filters, estimations should be present in the result - so lets put them into targets isHeadFiltersEnabled = ( self.detectPolicy.yawThreshold is not None or self.detectPolicy.pitchThreshold is not None or self.detectPolicy.rollThreshold is not None ) isLivenessFiltersEnabled = bool(self.detectPolicy.livenessStates) isMaskFiltersEnabled = bool(self.detectPolicy.maskStates) if self.detectPolicy.detectLandmarks68: targets.add(LoopEstimations.faceLandmarks68) if self.detectPolicy.estimateQuality: targets.add(LoopEstimations.faceWarpQuality) if self.detectPolicy.estimateMouthAttributes: targets.add(LoopEstimations.mouthAttributes) if self.detectPolicy.estimateGaze: targets.add(LoopEstimations.gaze) if self.detectPolicy.estimateEyesAttributes: targets.add(LoopEstimations.eyes) if self.detectPolicy.estimateEmotions: targets.add(LoopEstimations.emotions) if self.detectPolicy.estimateMask or isMaskFiltersEnabled: targets.add(LoopEstimations.mask) if self.detectPolicy.estimateHeadPose or isHeadFiltersEnabled: targets.add(LoopEstimations.headPose) if self.detectPolicy.estimateLiveness.estimate or isLivenessFiltersEnabled: targets.add(LoopEstimations.livenessV1) if self.extractPolicy.extractBasicAttributes: targets.add(LoopEstimations.basicAttributes) if self.extractPolicy.extractFaceDescriptor: targets.add(LoopEstimations.faceDescriptor) if self.detectPolicy.detectBody: targets.add(LoopEstimations.bodyDetection) targets.add(LoopEstimations.bodyWarp) if self.extractPolicy.extractBodyDescriptor: targets.add(LoopEstimations.bodyDescriptor) if self.detectPolicy.detectBody and self.detectPolicy.bodyAttributes.estimateBasicAttributes: targets.add(LoopEstimations.bodyAttributes) targets.add(ResponseOnlyEstimations.bodyBasicAttributes) if self.detectPolicy.detectBody and self.detectPolicy.bodyAttributes.estimateUpperBody: targets.add(LoopEstimations.bodyAttributes) targets.add(ResponseOnlyEstimations.upperBody) if self.detectPolicy.detectBody and self.detectPolicy.bodyAttributes.estimateAccessories: targets.add(LoopEstimations.bodyAttributes) targets.add(ResponseOnlyEstimations.accessories) return targets @cached_property def sdkFilters(self) -> Filters: """ Prepare sdk task filters Returns: sdk task filters """ def suitFilter(x): """Return useful thresholds.""" if x != MAX_ANGLE: return x return None maskStates = None if self.detectPolicy.maskStates: maskStates = [MaskState(x) for x in self.detectPolicy.maskStates] livenessStates = None if self.detectPolicy.livenessStates is not None: livenessStates = [LivenessPrediction(Liveness(x).name) for x in self.detectPolicy.livenessStates] faceFilters = ( FaceDetectionFilters( yawThreshold=suitFilter(self.detectPolicy.yawThreshold), pitchThreshold=suitFilter(self.detectPolicy.pitchThreshold), rollThreshold=suitFilter(self.detectPolicy.rollThreshold), livenessStates=livenessStates, gcThreshold=self.extractPolicy.fdScoreThreshold or None, # ignore 0.0 maskStates=maskStates, ) if self.detectPolicy.detectFace else FaceDetectionFilters() ) filters = Filters(faceDetection=faceFilters) return filters
[docs] def publishEventsToPlugins(self, events: list[Event], plugins: PluginManager) -> None: """ Publish events to other services. Args: events: list of events plugins: plugin manager """ futures = [plugins.sendEventToPlugins("sending_event", events, requestIdCtx.get())] self.pluginsAsyncRunner.runNoWait(futures)
[docs] def prepareSDKTaskParams(self, config: HandlerConfig): """ Prepare sdk task parameters Returns: sdk task parameters """ return TaskParams( targets=self.extendedEstimationTargets.responseTargets, filters=self.sdkFilters, estimatorsParams=TaskEstimationParams( faceDescriptorVersion=config.faceDescriptorVersion, bodyDescriptorVersion=config.bodyDescriptorVersion, livenessv1=LivenessV1Params( scoreThreshold=self.detectPolicy.estimateLiveness.livenessThreshold, qualityThreshold=self.detectPolicy.estimateLiveness.qualityThreshold, ), ), multifacePolicy=MultifacePolicy(self.detectPolicy.multifacePolicy), useExifInfo=config.useExifInfo, autoRotation=config.useAutoRotation, aggregate=config.aggregate, )
[docs] async def execute( self, inputData: list[Union[RawDescriptorData, InputImage]], eventMetadata: EventMetadata, config: HandlerConfig, luna3Client: Client, redisContext: RedisContext, plugins: PluginManager, ) -> tuple[dict, HandlersMonitoringData]: """ Execute all policies for handler. Args: inputData: input data (images / raw descriptors) eventMetadata: user defined event metadata config: handler configuration parameters luna3Client: luna platform client redisContext: redis context plugins: plugin manager Returns: * estimations in api format * monitoring data """ processedSources, aggregatedSample, monitoringData = await executeSDKTask( params=self.prepareSDKTaskParams(config), inputData=inputData, useAutoRotation=config.useAutoRotation, sdkTargets=self.extendedEstimationTargets.sdkTargets, ) result = {"images": [], "events": [], "filtered_detections": {"face_detections": []}} for source in processedSources: if isinstance(source, RawDescriptorData): imageRes = { "filename": source.filename, "status": int(source.error == Error.Success), "error": source.error.asDict(), } else: image = source.image if isinstance(image.error, MultipleFaces): raise VLException(image.error, 400, isCriticalError=False) if self.detectPolicy.isFaceQualityChecksEnabled(): self.detectPolicy.faceQuality.processSource(source) monitoringData.sdkUsages.faceQualityEstimator = monitoringData.sdkUsages.faceDetector imageRes = { "filename": image.origin.filename, "status": int(not image.error), "error": (source.meta.error or image.error or Error.Success).asDict(), } if image.exif is not None: imageRes["exif"] = image.exif result["filtered_detections"]["face_detections"].extend( buildImageFilteredDetections( image=image, estimationTargets=self.extendedEstimationTargets.sdkTargets ) ) result["images"].append(imageRes) events = APISDKHandlerAdaptor.createEvents( processedSources, aggregatedSample, eventMetadata, self.extendedEstimationTargets.sdkTargets ) # matching with monitorTime(monitoringData.request, "match_policy_time"): await asyncio.gather( *[matchByListPolicy.execute(events, luna3Client) for matchByListPolicy in self.matchPolicy] ) # tags for policy in self.conditionalTagsPolicy: policy.execute(events) # storage imageSources = [source for source in processedSources if isinstance(source, ProcessedImageData)] monitoringData += await self.storagePolicy.execute( accountId=eventMetadata.accountId, sources=imageSources, events=events, config=config, luna3Client=luna3Client, redisContext=redisContext, ) self.publishEventsToPlugins(events=events, plugins=plugins) result["events"] = [event.asDict() for event in events] return result, monitoringData