Source code for luna_handlers.classes.schemas.policies

"""
Module contains schemas for policies
"""
import asyncio
from typing import Any, Optional, Type, TypeVar, Union
from uuid import UUID

import cachetools
from luna3.client import Client
from luna3.common.luna_response import LunaResponse
from luna3.remote_sdk import http_objs as sdk
from luna_plugins.base.manager import PluginManager
from pydantic import Field, ValidationError, model_validator
from vlutils.cache.cache import cache

from app.global_vars.context_vars import requestIdCtx
from classes.event import HandlerEvent as Event
from classes.schemas.base_schema import BaseSchema
from classes.schemas.conditional_tags_policy import ConditionalTagsPolicy
from classes.schemas.detect_policies import HandlerDetectPolicy
from classes.schemas.extract_policies import HandlerExtractPolicy
from classes.schemas.filters import AttributesFilters, MatchFilter
from classes.schemas.match_policy import MatchPolicy
from classes.schemas.storage_policy import StoragePolicy, StorePolicyConfig
from classes.schemas.types import MAX_POLICY_LIST_LENGTH
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.errors.pydantic_errors import PydanticError
from crutches_on_wheels.cow.monitoring.points import DataForMonitoring, monitorTime
from crutches_on_wheels.cow.utils.healthcheck import logger
from redis_db.redis_context import RedisContext

T = TypeVar("T")
# cache for credentials verification response
cacheTTL = cachetools.TTLCache(maxsize=256, ttl=10)


[docs] def getObjectRecursively(data: Any, expectedType: Type[T]) -> list[T]: """Recursively get object of expected type""" res = [] def collectObjects(dataPart: Any) -> None: """Collect object of expected type to 'res'""" if isinstance(dataPart, expectedType): res.append(dataPart) if isinstance(dataPart, list): [collectObjects(row) for row in dataPart] if isinstance(dataPart, BaseSchema): [collectObjects(row) for row in dataPart.__dict__.values()] collectObjects(data) return res
[docs] @cache(lambda: cacheTTL, keyGen=lambda luna3Client, listId, accountId: f"{listId}{accountId}".__hash__()) async def checkListAvailability(luna3Client: Client, listId: str, accountId: Optional[str] = None) -> LunaResponse: """ Check availability of list. Args: luna3Client: luna platform client accountId: account id listId: list id Returns: Response from luna-faces service """ return await luna3Client.lunaFaces.checkList(listId=listId, accountId=accountId)
[docs] class Policies(BaseSchema): """Policies schema""" # detect policy detectPolicy: HandlerDetectPolicy = Field(default_factory=lambda: HandlerDetectPolicy()) # extract policy extractPolicy: HandlerExtractPolicy = HandlerExtractPolicy() # matching policy list matchPolicy: list[MatchPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH) # conditional tags policy list conditionalTagsPolicy: list[ConditionalTagsPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH) # storage policy storagePolicy: StoragePolicy = Field(default_factory=lambda: StoragePolicy())
[docs] @staticmethod def validateMatchAndExtractCompatibility(matchPolicies: list[MatchPolicy], extractPolicy: HandlerExtractPolicy): """Validate match and extract policies compatibility""" faceMatchPolicies = list(filter(lambda x: x.descriptor.descriptorType == "face", matchPolicies)) if faceMatchPolicies and not extractPolicy.extractFaceDescriptor: raise PydanticError.PydanticValidationError.format( "extract_face_descriptor should be equal to 1 for using face matching policy" )() bodyMatchPolicies = list(filter(lambda x: x.descriptor.descriptorType == "body", matchPolicies)) if bodyMatchPolicies and not extractPolicy.extractBodyDescriptor: raise PydanticError.PydanticValidationError.format( "extract_body_descriptor should be equal to 1 for using body matching policy" )()
[docs] @staticmethod def validateMatchLabelsCompatibility( matchPolicies: list[MatchPolicy], conditionalTagsPolicies: list[ConditionalTagsPolicy], storagePolicy: StoragePolicy, ): """Validate matching label compatibility""" matchPolicyMatchingLabels = {matchPolicy.label for matchPolicy in matchPolicies} matchFiltersMatchingLabels = { matchFilter.label for matchFilter in getObjectRecursively([conditionalTagsPolicies, storagePolicy], MatchFilter) } for matchFilterLabel in matchFiltersMatchingLabels: if matchFilterLabel not in matchPolicyMatchingLabels: raise PydanticError.PydanticValidationError.format( f'"{matchFilterLabel}" should be in match policy for filtration based on a matching by this label', )()
[docs] @staticmethod def validateGeneratedAttributesFilters( detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy, matchPolicies: list[MatchPolicy], conditionalTagsPolicies: list[ConditionalTagsPolicy], storagePolicy: StoragePolicy, ): """Validate attributes and detect/extract policy compatibility""" attributeFilters: list[AttributesFilters] = getObjectRecursively( data=[matchPolicies, conditionalTagsPolicies, storagePolicy], expectedType=AttributesFilters ) if not extractPolicy.extractBasicAttributes: basicAttributesNeeded = any( any( ( filters.ethnicities is not None, filters.ageLt is not None, filters.ageGte is not None, filters.gender is not None, ) ) for filters in attributeFilters ) if basicAttributesNeeded: raise PydanticError.PydanticValidationError.format( "extract_basic_attributes should be equal to 1 for filtration based on basic attributes", )() if not detectPolicy.estimateLiveness.estimate: livenessNeeded = any((filters.liveness is not None) for filters in attributeFilters) if livenessNeeded: raise PydanticError.PydanticValidationError.format( "estimate_liveness.estimate should be equal to 1 for filtration based on liveness" )() if not detectPolicy.estimateDeepfake.estimate: deepfakeNeeded = any((filters.deepfake is not None) for filters in attributeFilters) if deepfakeNeeded: raise ValueError("estimate_deepfake.estimate should be equal to 1 for filtration based on deepfake")
[docs] @staticmethod def validateMatchPolicyUniqueLabels(matchPolicies: list[MatchPolicy]): """Validate match policy matching label uniqueness""" labels = [matchPolicy.label for matchPolicy in matchPolicies] if len(labels) != len(set(labels)): error = PydanticError.PydanticValidationError.format("Matching allowed only by unique labels")() raise ValidationError.from_exception_data( title="Label uniqueness error", line_errors=[ {"type": error, "loc": ("match_policy",), "input": MatchPolicy, "ctx": {"labels": labels}} ], )
[docs] @staticmethod def validateDetectPolicyNotEmpty(detectPolicy: HandlerDetectPolicy): """Validate non-empty detect policy""" if not detectPolicy.detectFace and not detectPolicy.detectBody: raise PydanticError.PydanticValidationError.format( "At least one of *detect_face* or *detect_body* should be equal to 1" )() if not detectPolicy.detectFace and detectPolicy.multifacePolicy != 1: raise PydanticError.PydanticValidationError.format( "*detect_face* should be equal to 1 to set *multiface_policy* to 0 or 2" )()
[docs] @staticmethod def validateDetectAndExtractCompatibility(detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy): """Validate detect and extract policies compatibility""" if extractPolicy.extractFaceDescriptor: if not detectPolicy.detectFace: raise PydanticError.PydanticValidationError.format( "*detect_face* should be equal to 1 to enable *extract_face_descriptor*" )() if extractPolicy.extractBasicAttributes: if not detectPolicy.detectFace: raise PydanticError.PydanticValidationError.format( "*detect_face* should be equal to 1 to enable *extract_basic_attributes*" )() if extractPolicy.extractBodyDescriptor: if not detectPolicy.detectBody: raise PydanticError.PydanticValidationError.format( "*detect_body* should be equal to 1 to enable *extract_body_descriptor*" )()
[docs] @model_validator(mode="after") def validatePolicies(cls, values): """Execute all compatibility validators""" detectPolicy = values.detectPolicy matchPolicies = values.matchPolicy extractPolicy = values.extractPolicy conditionalTagsPolicies = values.conditionalTagsPolicy storagePolicy = values.storagePolicy cls.validateDetectPolicyNotEmpty(detectPolicy=detectPolicy) cls.validateDetectAndExtractCompatibility(detectPolicy=detectPolicy, extractPolicy=extractPolicy) cls.validateMatchPolicyUniqueLabels(matchPolicies=matchPolicies) cls.validateMatchAndExtractCompatibility(matchPolicies=matchPolicies, extractPolicy=extractPolicy) cls.validateMatchLabelsCompatibility( matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy ) cls.validateGeneratedAttributesFilters( detectPolicy=detectPolicy, extractPolicy=extractPolicy, matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy, ) return values
@staticmethod async def _checkListsAvailability( luna3Client: Client, listIds: list[UUID], accountId: Optional[Union[str, UUID]] = None ) -> None: """ Check availability of lists from matching and link to list policies. Args: luna3Client: luna platform client listIds: list ids accountId: account id Raises: VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found """ for listId in listIds: reply = await checkListAvailability( luna3Client=luna3Client, accountId=str(accountId) if accountId is not None else None, listId=str(listId) ) if reply.success: continue if reply.statusCode == 404: raise VLException(Error.ListNotFound.format(listId), 400, False) raise VLException( Error.UnknownServiceError.format( "luna-faces", "HEAD", f"{luna3Client.lunaFaces.baseUri}/lists/{listId}" ), reply.statusCode, False, ) async def _checkLinkingListsAvailability(self, luna3Client: Client, accountId: Optional[str] = None) -> None: """ Check availability of lists from linking policies. Args: luna3Client: luna platform client """ await self._checkListsAvailability( luna3Client=luna3Client, listIds=[linkPolicy.listId for linkPolicy in self.storagePolicy.facePolicy.linkToListsPolicy], accountId=accountId, ) async def _checkMatchingListsAvailability(self, luna3Client: Client) -> None: """ Check availability of lists from matching policies. Args: luna3Client: luna platform client """ for matchPolicy in self.matchPolicy: if (listId := getattr(matchPolicy.candidates, "listId", None)) is None: continue await self._checkListsAvailability( luna3Client=luna3Client, listIds=[listId], accountId=matchPolicy.candidates.accountId, )
[docs] async def checkListsAvailability(self, luna3Client: Client, accountId: str) -> None: """ Check availability of lists from matching and link to list policies. Args: luna3Client: luna platform client accountId: account id Raises: VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found """ await self._checkMatchingListsAvailability(luna3Client=luna3Client) await self._checkLinkingListsAvailability(luna3Client=luna3Client, accountId=accountId)
@property def estimator(self): estimateFaceQuality = self.detectPolicy.detectFace * self.detectPolicy.faceQuality.estimate targets = sdk.Targets( exif=self.detectPolicy.extractExif, peopleCount=self.detectPolicy.estimatePeopleCount or None, faceDetection=self.detectPolicy.detectFace, faceLandmarks5=self.detectPolicy.detectFace * self.detectPolicy.detectFace, faceLandmarks68=self.detectPolicy.detectFace * self.detectPolicy.detectLandmarks68, faceWarp=self.detectPolicy.detectFace * self.storagePolicy.faceSamplePolicy.storeSample, gaze=self.detectPolicy.detectFace * self.detectPolicy.estimateGaze, headPose=(dp := self.detectPolicy).detectFace * (dp.estimateHeadPose | bool(dp.rollThreshold or dp.yawThreshold or dp.pitchThreshold)), eyes=self.detectPolicy.detectFace * self.detectPolicy.estimateEyesAttributes, mouthAttributes=self.detectPolicy.detectFace * self.detectPolicy.estimateMouthAttributes, faceWarpQuality=self.detectPolicy.detectFace * self.detectPolicy.estimateQuality, emotions=self.detectPolicy.detectFace * self.detectPolicy.estimateEmotions, mask=self.detectPolicy.detectFace * (self.detectPolicy.estimateMask | bool(self.detectPolicy.maskStates)), glasses=self.detectPolicy.detectFace * self.detectPolicy.estimateGlasses, liveness=self.detectPolicy.detectFace * (self.detectPolicy.estimateLiveness.estimate | bool(self.detectPolicy.livenessStates)), deepfake=self.detectPolicy.detectFace * (self.detectPolicy.estimateDeepfake.estimate | bool(self.detectPolicy.deepfakeStates)), faceQuality=self.detectPolicy.faceQuality if estimateFaceQuality else None, bodyDetection=self.detectPolicy.detectBody, bodyWarp=self.detectPolicy.detectBody * self.storagePolicy.bodySamplePolicy.storeSample, faceDescriptor=self.extractPolicy.extractFaceDescriptor, basicAttributes=self.extractPolicy.extractBasicAttributes, bodyDescriptor=self.extractPolicy.extractBodyDescriptor, bodyAttributes=self.detectPolicy.detectBody * self.detectPolicy.bodyAttributes.estimateBasicAttributes, upperBody=self.detectPolicy.detectBody * self.detectPolicy.bodyAttributes.estimateUpperBody, lowerBody=self.detectPolicy.detectBody * self.detectPolicy.bodyAttributes.estimateLowerBody, accessories=self.detectPolicy.detectBody * self.detectPolicy.bodyAttributes.estimateAccessories, ) filters = sdk.Filters( faceDetectionFilters=sdk.FaceDetectionFilters( rollThreshold=self.detectPolicy.rollThreshold if self.detectPolicy.detectFace else None, yawThreshold=self.detectPolicy.yawThreshold if self.detectPolicy.detectFace else None, pitchThreshold=self.detectPolicy.pitchThreshold if self.detectPolicy.detectFace else None, livenessStates=self.detectPolicy.livenessStates if self.detectPolicy.detectFace else None, deepfakeStates=self.detectPolicy.deepfakeStates if self.detectPolicy.detectFace else None, maskStates=self.detectPolicy.maskStates if self.detectPolicy.detectFace else None, scoreThreshold=self.extractPolicy.fdScoreThreshold, ) ) estimationConfig = sdk.EstimatorsParams( livenessParams=sdk.LivenessParams( scoreThreshold=self.detectPolicy.estimateLiveness.livenessThreshold, qualityThreshold=self.detectPolicy.estimateLiveness.qualityThreshold, ), deepfakeParams=sdk.DeepfakeParams( realThreshold=self.detectPolicy.estimateDeepfake.realThreshold, mode=self.detectPolicy.estimateDeepfake.mode, ), ) params = sdk.Params( targets=targets, filters=filters, estimatorsParams=estimationConfig, multifacePolicy=self.detectPolicy.multifacePolicy, ) return sdk.Estimator(params, images=...)
[docs] def publishEventsToPlugins(self, events: list[Event], accountId: str, plugins: PluginManager) -> None: """ Publish events to other services. Args: events: list of events accountId: account id plugins: plugin manager """ if not events: return pluginName = "sending_event" requestId = requestIdCtx.get() handlerId = events[0].handlerId createTime = events[0].createTime endTime = events[0].endTime plugins.sendEventToPlugins(pluginName, events, handlerId, accountId, requestId, createTime, endTime, logger)
[docs] async def execute( self, events: list[Event], accountId: str, monitoring: DataForMonitoring, config: StorePolicyConfig, luna3Client: Client, redisContext: RedisContext, plugins: PluginManager, ): """ Execute all policies for handler. Args: events: events accountId: account id monitoring: monitoring data config: handler configuration parameters luna3Client: luna platform client redisContext: redis context plugins: plugin manager """ # matching with monitorTime(monitoring, "match_policy_time"): await asyncio.gather( *[matchByListPolicy.execute(events, luna3Client) for matchByListPolicy in self.matchPolicy] ) # tags for policy in self.conditionalTagsPolicy: policy.execute(events) # storage await self.storagePolicy.execute( events=events, accountId=accountId, monitoring=monitoring, config=config, luna3Client=luna3Client, redisContext=redisContext, ) self.publishEventsToPlugins(events=events, accountId=accountId, plugins=plugins)