"""
Module contains schemas for policies
"""
import asyncio
from functools import cached_property
from typing import Any, Type, TypeVar, Union
from uuid import UUID
from luna3.client import Client
from luna3.common.exceptions import LunaApiException
from lunavl.sdk.estimators.face_estimators.livenessv1 import LivenessPrediction
from lunavl.sdk.estimators.face_estimators.mask import MaskState
from pydantic import Field, ValidationError, root_validator
from pydantic.error_wrappers import ErrorWrapper
from vlutils.jobs.async_runner import AsyncRunner
from vlutils.structures.dataclasses import dataclass
from app.api_sdk_adaptors.base import (
HandlerEstimations,
LoopEstimationsAlwaysOn,
LoopEstimationsExtended,
buildImageFilteredDetections,
executeSDKTask,
)
from app.api_sdk_adaptors.handler import APISDKHandlerAdaptor
from app.global_vars.constants import MAX_ANGLE
from app.global_vars.context_vars import requestIdCtx
from classes.event import Event, EventMetadata
from classes.image_meta import ProcessedImageData
from classes.monitoring import HandlersMonitoringData
from classes.raw_descriptor_data import RawDescriptorData
from classes.schemas.base_schema import BaseSchema
from classes.schemas.conditional_tags_policy import ConditionalTagsPolicy
from classes.schemas.detect_policies import HandlerDetectPolicy
from classes.schemas.extract_policies import HandlerExtractPolicy
from classes.schemas.filters import AttributesFilters, MatchFilter
from classes.schemas.match_policy import MatchPolicy
from classes.schemas.storage_policy import StoragePolicy, StorePolicyConfig
from classes.schemas.types import MAX_POLICY_LIST_LENGTH
from configs.config import PLUGINS_PUBLISHING_CONCURRENCY
from crutches_on_wheels.enums.attributes import Liveness
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.monitoring.points import monitorTime
from crutches_on_wheels.plugins.manager import PluginManager
from redis_db.redis_context import RedisContext
from sdk.sdk_loop.enums import LoopEstimations, MultifacePolicy
from sdk.sdk_loop.errors.errors import MultipleFaces
from sdk.sdk_loop.models.image import InputImage
from sdk.sdk_loop.tasks.filters import FaceDetectionFilters, Filters
from sdk.sdk_loop.tasks.task import LivenessV1Params, TaskEstimationParams, TaskParams
T = TypeVar("T")
[docs]def getObjectRecursively(data: Any, expectedType: Type[T]) -> list[T]:
"""Recursively get object of expected type"""
res = []
def collectObjects(dataPart: Any) -> None:
"""Collect object of expected type to 'res'"""
if isinstance(dataPart, expectedType):
res.append(dataPart)
if isinstance(dataPart, list):
[collectObjects(row) for row in dataPart]
if isinstance(dataPart, BaseSchema):
[collectObjects(row) for row in dataPart.__dict__.values()]
collectObjects(data)
return res
[docs]@dataclass(withSlots=True)
class HandlerConfig(StorePolicyConfig):
"""Handler config that policies should apply."""
aggregate: bool
useExifInfo: bool
useAutoRotation: bool
faceDescriptorVersion: int
bodyDescriptorVersion: int
[docs]class Policies(BaseSchema):
"""Policies schema"""
# detect policy
detectPolicy: HandlerDetectPolicy = Field(default_factory=lambda: HandlerDetectPolicy())
# extract policy
extractPolicy: HandlerExtractPolicy = HandlerExtractPolicy()
# matching policy list
matchPolicy: list[MatchPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH)
# conditional tags policy list
conditionalTagsPolicy: list[ConditionalTagsPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH)
# storage policy
storagePolicy: StoragePolicy = Field(default_factory=lambda: StoragePolicy())
[docs] @staticmethod
def validateMatchAndExtractCompatibility(matchPolicies: list[MatchPolicy], extractPolicy: HandlerExtractPolicy):
"""Validate match and extract policies compatibility"""
if len(matchPolicies) and not extractPolicy.extractFaceDescriptor:
raise ValueError("extract_face_descriptor should be equal to 1 for using matching policy")
[docs] @staticmethod
def validateMatchLabelsCompatibility(
matchPolicies: list[MatchPolicy],
conditionalTagsPolicies: list[ConditionalTagsPolicy],
storagePolicy: StoragePolicy,
):
"""Validate matching label compatibility"""
matchPolicyMatchingLabels = {matchPolicy.label for matchPolicy in matchPolicies}
matchFiltersMatchingLabels = {
matchFilter.label
for matchFilter in getObjectRecursively([conditionalTagsPolicies, storagePolicy], MatchFilter)
}
for matchFilterLabel in matchFiltersMatchingLabels:
if matchFilterLabel not in matchPolicyMatchingLabels:
raise ValueError(
f'"{matchFilterLabel}" should be in match policy for filtration based on a matching by this label'
)
[docs] @staticmethod
def validateGeneratedAttributesFilters(
detectPolicy: HandlerDetectPolicy,
extractPolicy: HandlerExtractPolicy,
matchPolicies: list[MatchPolicy],
conditionalTagsPolicies: list[ConditionalTagsPolicy],
storagePolicy: StoragePolicy,
):
"""Validate attributes and detect/extract policy compatibility"""
attributeFilters: list[AttributesFilters] = getObjectRecursively(
data=[matchPolicies, conditionalTagsPolicies, storagePolicy], expectedType=AttributesFilters
)
if not extractPolicy.extractBasicAttributes:
basicAttributesNeeded = any(
any(
(
filters.ethnicities is not None,
filters.ageLt is not None,
filters.ageGte is not None,
filters.gender is not None,
)
)
for filters in attributeFilters
)
if basicAttributesNeeded:
raise ValueError(
"extract_basic_attributes should be equal to 1 for filtration based on basic attributes"
)
if not detectPolicy.estimateLiveness.estimate:
livenessNeeded = any((filters.liveness is not None) for filters in attributeFilters)
if livenessNeeded:
raise ValueError("estimate_liveness.estimate should be equal to 1 for filtration based on liveness")
[docs] @staticmethod
def validateMatchPolicyUniqueLabels(matchPolicies: list[MatchPolicy]):
"""Validate match policy matching label uniqueness"""
labels = [matchPolicy.label for matchPolicy in matchPolicies]
if len(labels) != len(set(labels)):
error = ValueError("Matching allowed only by unique labels")
raise ValidationError([ErrorWrapper(exc=error, loc=("match_policy"))], Policies)
[docs] @staticmethod
def validateDetectPolicyNotEmpty(detectPolicy: HandlerDetectPolicy):
"""Validate non-empty detect policy"""
if not detectPolicy.detectFace and not detectPolicy.detectBody:
raise ValueError("At least one of *detect_face* or *detect_body* should be equal to 1")
if not detectPolicy.detectFace and MultifacePolicy(detectPolicy.multifacePolicy) is not MultifacePolicy.allowed:
raise ValueError("*detect_face* should be equal to 1 to set *multiface_policy* to 0 or 2")
[docs] @staticmethod
def validateDetectAndExtractCompatibility(detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy):
"""Validate detect and extract policies compatibility"""
if extractPolicy.extractFaceDescriptor:
if not detectPolicy.detectFace:
raise ValueError("*detect_face* should be equal to 1 to enable *extract_face_descriptor*")
if extractPolicy.extractBasicAttributes:
if not detectPolicy.detectFace:
raise ValueError("*detect_face* should be equal to 1 to enable *extract_basic_attributes*")
if extractPolicy.extractBodyDescriptor:
if not detectPolicy.detectBody:
raise ValueError("*detect_body* should be equal to 1 to enable *extract_body_descriptor*")
[docs] @root_validator(skip_on_failure=True)
def validatePolicies(cls, values):
"""Execute all compatibility validators"""
detectPolicy = values["detectPolicy"]
matchPolicies = values["matchPolicy"]
extractPolicy = values["extractPolicy"]
conditionalTagsPolicies = values["conditionalTagsPolicy"]
storagePolicy = values["storagePolicy"]
cls.validateDetectPolicyNotEmpty(detectPolicy=detectPolicy)
cls.validateDetectAndExtractCompatibility(detectPolicy=detectPolicy, extractPolicy=extractPolicy)
cls.validateMatchPolicyUniqueLabels(matchPolicies=matchPolicies)
cls.validateMatchAndExtractCompatibility(matchPolicies=matchPolicies, extractPolicy=extractPolicy)
cls.validateMatchLabelsCompatibility(
matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy
)
cls.validateGeneratedAttributesFilters(
detectPolicy=detectPolicy,
extractPolicy=extractPolicy,
matchPolicies=matchPolicies,
conditionalTagsPolicies=conditionalTagsPolicies,
storagePolicy=storagePolicy,
)
return values
[docs] def getListIdsFromPolicies(self) -> list[UUID]:
"""
Get list ids from matching and link to lists policies.
Returns:
list ids
"""
candidateLists = list(
filter(None, (match.candidates.listId for match in self.matchPolicy if hasattr(match.candidates, "listId")))
)
return [linkPolicy.listId for linkPolicy in self.storagePolicy.facePolicy.linkToListsPolicy] + candidateLists
[docs] async def checkListsAvailability(self, luna3Client: Client, accountId: str) -> None:
"""
Check availability of lists from matching and link to list policies.
Args:
luna3Client: luna platform client
accountId: account id
Raises:
VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found
"""
for listId in self.getListIdsFromPolicies():
try:
await luna3Client.lunaFaces.checkList(listId=str(listId), accountId=accountId, raiseError=True)
except LunaApiException as e:
if e.statusCode == 404:
raise VLException(Error.ListNotFound.format(listId), 400, isCriticalError=False)
raise
[docs] @classmethod
async def onStartup(cls):
"""Init Policies"""
cls.pluginsAsyncRunner = AsyncRunner(PLUGINS_PUBLISHING_CONCURRENCY, closeTimeout=1)
[docs] @classmethod
async def onShutdown(cls):
"""Stop Policies"""
await cls.pluginsAsyncRunner.close()
@cached_property
def sdkTargets(self) -> set[HandlerEstimations]:
"""
Prepare sdk task targets
Returns:
sdk task targets
"""
targets = set()
if self.detectPolicy.extractExif:
targets.add(LoopEstimations.exif)
if self.detectPolicy.detectFace:
targets.add(LoopEstimationsAlwaysOn.faceLandmarks5)
targets.add(LoopEstimations.faceDetection)
# if there are filters, estimations should be present in the result - so lets put them into targets
isHeadFiltersEnabled = (
self.detectPolicy.yawThreshold is not None
or self.detectPolicy.pitchThreshold is not None
or self.detectPolicy.rollThreshold is not None
)
isLivenessFiltersEnabled = bool(self.detectPolicy.livenessStates)
isMaskFiltersEnabled = bool(self.detectPolicy.maskStates)
if self.detectPolicy.detectLandmarks68:
targets.add(LoopEstimations.faceLandmarks68)
if self.detectPolicy.estimateQuality:
targets.add(LoopEstimations.faceWarpQuality)
if self.detectPolicy.estimateMouthAttributes:
targets.add(LoopEstimations.mouthAttributes)
if self.detectPolicy.estimateGaze:
targets.add(LoopEstimations.gaze)
if self.detectPolicy.estimateEyesAttributes:
targets.add(LoopEstimations.eyes)
if self.detectPolicy.estimateEmotions:
targets.add(LoopEstimations.emotions)
if self.detectPolicy.estimateMask or isMaskFiltersEnabled:
targets.add(LoopEstimations.mask)
if self.detectPolicy.estimateHeadPose or isHeadFiltersEnabled:
targets.add(LoopEstimations.headPose)
if self.detectPolicy.estimateLiveness.estimate or isLivenessFiltersEnabled:
targets.add(LoopEstimations.livenessV1)
if self.extractPolicy.extractBasicAttributes:
targets.add(LoopEstimations.basicAttributes)
if self.extractPolicy.extractFaceDescriptor:
targets.add(LoopEstimations.faceDescriptor)
if self.detectPolicy.detectBody:
targets.add(LoopEstimations.bodyDetection)
if self.extractPolicy.extractBodyDescriptor:
targets.add(LoopEstimations.bodyDescriptor)
if self.detectPolicy.detectBody and self.detectPolicy.bodyAttributes.estimateBasicAttributes:
targets.add(LoopEstimations.bodyAttributes)
targets.add(LoopEstimationsExtended.bodyBasicAttributes)
if self.detectPolicy.detectBody and self.detectPolicy.bodyAttributes.estimateUpperBody:
targets.add(LoopEstimations.bodyAttributes)
targets.add(LoopEstimationsExtended.upperBody)
if self.detectPolicy.detectBody and self.detectPolicy.bodyAttributes.estimateAccessories:
targets.add(LoopEstimations.bodyAttributes)
targets.add(LoopEstimationsExtended.accessories)
return targets
@cached_property
def sdkFilters(self) -> Filters:
"""
Prepare sdk task filters
Returns:
sdk task filters
"""
def suitFilter(x):
"""Return useful thresholds."""
if x != MAX_ANGLE:
return x
return None
maskStates = None
if self.detectPolicy.maskStates:
maskStates = [MaskState(x) for x in self.detectPolicy.maskStates]
livenessStates = None
if self.detectPolicy.livenessStates is not None:
livenessStates = [LivenessPrediction(Liveness(x).name) for x in self.detectPolicy.livenessStates]
faceFilters = (
FaceDetectionFilters(
yawThreshold=suitFilter(self.detectPolicy.yawThreshold),
pitchThreshold=suitFilter(self.detectPolicy.pitchThreshold),
rollThreshold=suitFilter(self.detectPolicy.rollThreshold),
livenessStates=livenessStates,
gcThreshold=self.extractPolicy.fdScoreThreshold or None, # ignore 0.0
maskStates=maskStates,
)
if self.detectPolicy.detectFace
else FaceDetectionFilters()
)
filters = Filters(faceDetection=faceFilters)
return filters
[docs] def publishEventsToPlugins(self, events: list[Event], plugins: PluginManager) -> None:
"""
Publish events to other services.
Args:
events: list of events
plugins: plugin manager
"""
futures = [plugins.sendEventToPlugins("sending_event", events, requestIdCtx.get())]
self.pluginsAsyncRunner.runNoWait(futures)
[docs] def prepareSDKTaskParams(self, config: HandlerConfig):
"""
Prepare sdk task parameters
Returns:
sdk task parameters
"""
return TaskParams(
targets=self.sdkTargets | self.detectPolicy.faceQualityTargets,
filters=self.sdkFilters,
estimatorsParams=TaskEstimationParams(
faceDescriptorVersion=config.faceDescriptorVersion,
bodyDescriptorVersion=config.bodyDescriptorVersion,
livenessv1=LivenessV1Params(
scoreThreshold=self.detectPolicy.estimateLiveness.livenessThreshold,
qualityThreshold=self.detectPolicy.estimateLiveness.qualityThreshold,
),
),
multifacePolicy=MultifacePolicy(self.detectPolicy.multifacePolicy),
useExifInfo=config.useExifInfo,
autoRotation=config.useAutoRotation,
aggregate=config.aggregate,
)
[docs] async def execute(
self,
inputData: list[Union[RawDescriptorData, InputImage]],
eventMetadata: EventMetadata,
config: HandlerConfig,
luna3Client: Client,
redisContext: RedisContext,
plugins: PluginManager,
) -> tuple[dict, HandlersMonitoringData]:
"""
Execute all policies for handler.
Args:
inputData: input data (images / raw descriptors)
eventMetadata: user defined event metadata
config: handler configuration parameters
luna3Client: luna platform client
redisContext: redis context
plugins: plugin manager
Returns:
* estimations in api format
* monitoring data
"""
processedSources, aggregatedSample, monitoringData = await executeSDKTask(
params=self.prepareSDKTaskParams(config),
inputData=inputData,
useAutoRotation=config.useAutoRotation,
sdkTargets=self.sdkTargets,
)
result = {"images": [], "events": [], "filtered_detections": {"face_detections": []}}
for source in processedSources:
if isinstance(source, RawDescriptorData):
imageRes = {
"filename": source.filename,
"status": int(source.error == Error.Success),
"error": source.error.asDict(),
}
else:
image = source.image
if isinstance(image.error, MultipleFaces):
raise VLException(image.error, 400, isCriticalError=False)
if self.detectPolicy.isFaceQualityChecksEnabled():
self.detectPolicy.faceQuality.processSource(source)
monitoringData.sdkUsages.faceQualityEstimator = monitoringData.sdkUsages.faceDetector
imageRes = {
"filename": image.origin.filename,
"status": int(not image.error),
"error": (source.meta.error or image.error or Error.Success).asDict(),
}
if image.exif is not None:
imageRes["exif"] = image.exif
result["filtered_detections"]["face_detections"].extend(
buildImageFilteredDetections(image=image, estimationTargets=self.sdkTargets)
)
result["images"].append(imageRes)
events = APISDKHandlerAdaptor.createEvents(processedSources, aggregatedSample, eventMetadata, self.sdkTargets)
# matching
with monitorTime(monitoringData.request, "match_policy_time"):
await asyncio.gather(
*[matchByListPolicy.execute(events, luna3Client) for matchByListPolicy in self.matchPolicy]
)
# tags
for policy in self.conditionalTagsPolicy:
policy.execute(events)
# storage
imageSources = [source for source in processedSources if isinstance(source, ProcessedImageData)]
monitoringData += await self.storagePolicy.execute(
sources=imageSources,
events=events,
config=config,
luna3Client=luna3Client,
redisContext=redisContext,
)
self.publishEventsToPlugins(events=events, plugins=plugins)
result["events"] = [event.asDict() for event in events]
return result, monitoringData