"""
Module contains schemas for policies
"""
import asyncio
from typing import Any, Type, TypeVar, Optional, List, Dict, Tuple, Union
from uuid import UUID
from luna3.common.exceptions import LunaApiException
from pydantic import Field, root_validator, ValidationError
from pydantic.error_wrappers import ErrorWrapper
from luna3.client import Client
from lunavl.sdk.estimators.face_estimators.livenessv1 import LivenessPrediction
from lunavl.sdk.estimators.face_estimators.mask import MaskState
from vlutils.jobs.async_runner import AsyncRunner
from app.api_sdk_adaptors.handler import APISDKHandlerAdaptor
from app.global_vars.constants import MAX_ANGLE
from app.global_vars.context_vars import requestIdCtx
from classes.event import Event, Location
from classes.event_parts import RawDescriptorResult, RawImageResult
from classes.monitoring import HandlersMonitoringData as DataForMonitoring
from classes.raw_descriptor_data import RawDescriptorData
from classes.schemas.base_schema import BaseSchema
from classes.schemas.conditional_tags_policy import ConditionalTagsPolicy
from classes.schemas.detect_policies import HandlerDetectPolicy
from classes.schemas.extract_policies import HandlerExtractPolicy
from classes.schemas.filters import MatchFilter, AttributesFilters
from classes.schemas.match_policy import MatchPolicy
from classes.schemas.storage_policy import StoragePolicy
from classes.schemas.types import MAX_POLICY_LIST_LENGTH
from configs.config import PLUGINS_PUBLISHING_CONCURRENCY
from crutches_on_wheels.enums.attributes import Liveness
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.monitoring.points import monitorTime
from crutches_on_wheels.plugins.manager import PluginManager
from crutches_on_wheels.utils.log import Logger
from img_utils.utils import getExif
from redis_db.redis_context import RedisContext
from sdk.sdk_loop.enums import MultifacePolicy
from sdk.sdk_loop.estimation_targets import SDKEstimationTargets, SDKFaceEstimationTargets, SDKHumanEstimationTargets
from sdk.sdk_loop.sdk_task import (
SDKTaskFilters,
SDKTask,
FaceWarp,
SDKDetectableImage,
TaskDataSource,
HumanWarp,
)
from sdk.sdk_loop.task_loop import SDKTaskLoop
T = TypeVar("T")
[docs]def getObjectRecursively(data: Any, expectedType: Type[T]) -> List[T]:
"""Recursively get object of expected type"""
res = []
def collectObjects(dataPart: Any) -> None:
"""Collect object of expected type to 'res'"""
if isinstance(dataPart, expectedType):
res.append(dataPart)
if isinstance(dataPart, list):
[collectObjects(row) for row in dataPart]
if isinstance(dataPart, BaseSchema):
[collectObjects(row) for row in dataPart.__dict__.values()]
collectObjects(data)
return res
[docs]def getImagesReplyData(
sdkData: List[Union[HumanWarp, FaceWarp, SDKDetectableImage]],
images: List[Union[RawImageResult, RawDescriptorResult]],
extractExif: int,
) -> List[dict]:
"""
Get images' processing data in reply format.
Args:
sdkData: sdk data
images: images
extractExif: whether to extract exif
Returns:
list of images with status, filename, etc
"""
imagesInHandlerFormat = {
image.id: {"error": image.error.asDict(), "status": image.status, "filename": image.filename}
for image in images
}
if extractExif:
for sdkImage in sdkData:
if sdkImage.error is None:
imagesInHandlerFormat[sdkImage.id]["exif"] = getExif(sdkImage)
return list(imagesInHandlerFormat.values())
[docs]class Policies(BaseSchema):
"""Policies schema"""
# detect policy
detectPolicy: HandlerDetectPolicy = Field(default_factory=lambda: HandlerDetectPolicy())
# extract policy
extractPolicy: HandlerExtractPolicy = HandlerExtractPolicy()
# matching policy list
matchPolicy: List[MatchPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH)
# conditional tags policy list
conditionalTagsPolicy: List[ConditionalTagsPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH)
# storage policy
storagePolicy: StoragePolicy = Field(default_factory=lambda: StoragePolicy())
[docs] @staticmethod
def validateMatchAndExtractCompatibility(matchPolicies: List[MatchPolicy], extractPolicy: HandlerExtractPolicy):
"""Validate match and extract policies compatibility"""
if len(matchPolicies) and not extractPolicy.extractFaceDescriptor:
raise ValueError("extract_face_descriptor should be equal to 1 for using matching policy")
[docs] @staticmethod
def validateMatchLabelsCompatibility(
matchPolicies: List[MatchPolicy],
conditionalTagsPolicies: List[ConditionalTagsPolicy],
storagePolicy: StoragePolicy,
):
"""Validate matching label compatibility"""
matchPolicyMatchingLabels = {matchPolicy.label for matchPolicy in matchPolicies}
matchFiltersMatchingLabels = {
matchFilter.label
for matchFilter in getObjectRecursively([conditionalTagsPolicies, storagePolicy], MatchFilter)
}
for matchFilterLabel in matchFiltersMatchingLabels:
if matchFilterLabel not in matchPolicyMatchingLabels:
raise ValueError(
f'"{matchFilterLabel}" should be in match policy for filtration based on a matching by this label'
)
[docs] @staticmethod
def validateGeneratedAttributesFilters(
detectPolicy: HandlerDetectPolicy,
extractPolicy: HandlerExtractPolicy,
matchPolicies: List[MatchPolicy],
conditionalTagsPolicies: List[ConditionalTagsPolicy],
storagePolicy: StoragePolicy,
):
"""Validate attributes and detect/extract policy compatibility"""
attributeFilters: List[AttributesFilters] = getObjectRecursively(
data=[matchPolicies, conditionalTagsPolicies, storagePolicy], expectedType=AttributesFilters
)
if not extractPolicy.extractBasicAttributes:
basicAttributesNeeded = any(
any(
(
filters.ethnicities is not None,
filters.ageLt is not None,
filters.ageGte is not None,
filters.gender is not None,
)
)
for filters in attributeFilters
)
if basicAttributesNeeded:
raise ValueError(
"extract_basic_attributes should be equal to 1 for filtration based on basic attributes"
)
if not detectPolicy.estimateLiveness.estimate:
livenessNeeded = any((filters.liveness is not None) for filters in attributeFilters)
if livenessNeeded:
raise ValueError("estimate_liveness.estimate should be equal to 1 for filtration based on liveness")
[docs] @staticmethod
def validateMatchPolicyUniqueLabels(matchPolicies: List[MatchPolicy]):
"""Validate match policy matching label uniqueness"""
labels = [matchPolicy.label for matchPolicy in matchPolicies]
if len(labels) != len(set(labels)):
error = ValueError("Matching allowed only by unique labels")
raise ValidationError([ErrorWrapper(exc=error, loc=("match_policy"))], Policies)
[docs] @staticmethod
def validateDetectPolicyNotEmpty(detectPolicy: HandlerDetectPolicy):
"""Validate non-empty detect policy"""
if not detectPolicy.detectFace and not detectPolicy.detectBody:
raise ValueError("At least one of *detect_face* or *detect_body* should be equal to 1")
if not detectPolicy.detectFace and MultifacePolicy(detectPolicy.multifacePolicy) is not MultifacePolicy.allowed:
raise ValueError("*detect_face* should be equal to 1 to set *multiface_policy* to 0 or 2")
[docs] @staticmethod
def validateDetectAndExtractCompatibility(detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy):
"""Validate detect and extract policies compatibility"""
if extractPolicy.extractFaceDescriptor:
if not detectPolicy.detectFace:
raise ValueError("*detect_face* should be equal to 1 to enable *extract_face_descriptor*")
if extractPolicy.extractBasicAttributes:
if not detectPolicy.detectFace:
raise ValueError("*detect_face* should be equal to 1 to enable *extract_basic_attributes*")
if extractPolicy.extractBodyDescriptor:
if not detectPolicy.detectBody:
raise ValueError("*detect_body* should be equal to 1 to enable *extract_body_descriptor*")
[docs] @root_validator(skip_on_failure=True)
def validatePolicies(cls, values):
"""Execute all compatibility validators"""
detectPolicy = values["detectPolicy"]
matchPolicies = values["matchPolicy"]
extractPolicy = values["extractPolicy"]
conditionalTagsPolicies = values["conditionalTagsPolicy"]
storagePolicy = values["storagePolicy"]
cls.validateDetectPolicyNotEmpty(detectPolicy=detectPolicy)
cls.validateDetectAndExtractCompatibility(detectPolicy=detectPolicy, extractPolicy=extractPolicy)
cls.validateMatchPolicyUniqueLabels(matchPolicies=matchPolicies)
cls.validateMatchAndExtractCompatibility(matchPolicies=matchPolicies, extractPolicy=extractPolicy)
cls.validateMatchLabelsCompatibility(
matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy
)
cls.validateGeneratedAttributesFilters(
detectPolicy=detectPolicy,
extractPolicy=extractPolicy,
matchPolicies=matchPolicies,
conditionalTagsPolicies=conditionalTagsPolicies,
storagePolicy=storagePolicy,
)
return values
[docs] def getListIdsFromPolicies(self) -> List[UUID]:
"""
Get list ids from matching and link to lists policies.
Returns:
list ids
"""
candidateLists = list(
filter(None, (match.candidates.listId for match in self.matchPolicy if hasattr(match.candidates, "listId")))
)
return [linkPolicy.listId for linkPolicy in self.storagePolicy.facePolicy.linkToListsPolicy] + candidateLists
[docs] async def checkListsAvailability(self, luna3Client: Client, accountId: str) -> None:
"""
Check availability of lists from matching and link to list policies.
Args:
luna3Client: luna platform client
accountId: account id
Raises:
VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found
"""
for listId in self.getListIdsFromPolicies():
try:
await luna3Client.lunaFaces.checkList(listId=str(listId), accountId=accountId, raiseError=True)
except LunaApiException as e:
if e.statusCode == 404:
raise VLException(Error.ListNotFound.format(listId), 400, isCriticalError=False)
raise
[docs] @classmethod
async def onStartup(cls):
""" Init Policies """
cls.pluginsAsyncRunner = AsyncRunner(PLUGINS_PUBLISHING_CONCURRENCY, closeTimeout=1)
[docs] @classmethod
async def onShutdown(cls):
""" Stop Policies """
await cls.pluginsAsyncRunner.close()
[docs] def prepareSDKTask(self, sdkData: List[Union[SDKDetectableImage, FaceWarp, HumanWarp]], aggregate: int) -> SDKTask:
"""
Prepare sdk task
Args:
sdkData: a list of input images or warps
aggregate: aggregate all extracted samples to one or not
Returns:
sdk task
"""
faceTargets = SDKFaceEstimationTargets(
estimateQuality=self.detectPolicy.estimateQuality,
estimateMouthAttributes=self.detectPolicy.estimateMouthAttributes,
estimateAGS=0,
estimateGaze=self.detectPolicy.estimateGaze,
estimateEyesAttributes=self.detectPolicy.estimateEyesAttributes,
estimateEmotions=self.detectPolicy.estimateEmotions,
estimateMask=self.detectPolicy.estimateMask,
estimateHeadPose=self.detectPolicy.estimateHeadPose,
estimateBasicAttributes=self.extractPolicy.extractBasicAttributes,
estimateFaceDescriptor=self.extractPolicy.extractFaceDescriptor,
estimateLiveness=self.detectPolicy.estimateLiveness.asSDKPolicy(),
)
humanTargets = SDKHumanEstimationTargets(estimateHumanDescriptor=self.extractPolicy.extractBodyDescriptor)
estimateHuman = (
self.detectPolicy.detectBody or self.extractPolicy.extractBodyDescriptor or not humanTargets.isEmpty()
)
toEstimate = SDKEstimationTargets(
estimateHuman=estimateHuman,
humanEstimationTargets=humanTargets,
estimateFace=self.detectPolicy.detectFace,
faceEstimationTargets=faceTargets,
)
def suitFilter(x):
""" Return useful thresholds. """
if x != MAX_ANGLE:
return x
return None
maskStates = None
if self.detectPolicy.maskStates:
maskStates = [MaskState(x) for x in self.detectPolicy.maskStates]
livenessStates = None
if self.detectPolicy.livenessStates is not None:
livenessStates = [LivenessPrediction(Liveness(x).name) for x in self.detectPolicy.livenessStates]
filters = SDKTaskFilters(
yawThreshold=suitFilter(self.detectPolicy.yawThreshold),
pitchThreshold=suitFilter(self.detectPolicy.pitchThreshold),
rollThreshold=suitFilter(self.detectPolicy.rollThreshold),
livenessStates=livenessStates,
garbageScoreThreshold=self.extractPolicy.fdScoreThreshold,
maskStates=maskStates,
)
return SDKTask(
toEstimate,
data=sdkData,
filters=filters,
aggregateAttributes=bool(aggregate),
multifacePolicy=MultifacePolicy(self.detectPolicy.multifacePolicy),
)
[docs] def publishEventsToPlugins(
self,
events: List[Event],
handlerId: str,
createEventTime: str,
endEventTime: str,
plugins: PluginManager,
logger: Logger,
accountId: str,
) -> None:
"""
Publish events to other services.
Args:
events: list of events
handlerId: handler id
createEventTime: event creation time
endEventTime: event end time
plugins: plugin manager
logger: logger
accountId: account id
"""
futures = [
plugins.sendEventToPlugins(
"sending_event", events, handlerId, accountId, requestIdCtx.get(), createEventTime, endEventTime, logger
)
]
self.pluginsAsyncRunner.runNoWait(futures)
[docs] async def execute(
self,
accountId: str,
inputData: List[Union[SDKDetectableImage, FaceWarp, HumanWarp, RawDescriptorData]],
sdkLoop: SDKTaskLoop,
luna3Client: Client,
aggregate: int,
userData: str,
externalId: Optional[str],
source: str,
userDefinedTags: List[str],
location: Location,
logger: Logger,
createEventTime: str,
endEventTime: str,
handlerId: str,
trackId: Optional[str],
plugins: PluginManager,
redisContext: RedisContext,
facesBucket: str,
bodiesBucket: str,
originBucket: str,
lunaEventsUsage: bool,
lunaSenderUsage: bool,
) -> Tuple[List[dict], List[Event], Dict[str, List[dict]], DataForMonitoring]:
"""
Execute all policies for handler.
Args:
accountId: account id
inputData: a list of input images (or warps) or descriptors
sdkLoop: sdk loop
luna3Client: luna platform client
aggregate: aggregate all extracted samples to one or not
userData: user data for created faces
externalId: external id for created faces
source: user-defined source
userDefinedTags: user-defined tags
location: user-defined location
logger: logger
createEventTime: event creation time
endEventTime: event end time
handlerId: handler id
trackId: event track id
plugins: plugin manager
redisContext: redis context
facesBucket: faces sample bucket
bodiesBucket: bodies sample bucket
originBucket: origin image bucket
lunaEventsUsage: luna events usage
lunaSenderUsage: luna sender usage
Returns:
tuple, first - all detection, second - events, third - monitoring data
Raises:
VLException(Error.AggregationNotSupported) if `aggregate` flag enabled and raw descriptor data received
"""
# detect + extract
processedDataSources: List[Union[RawImageResult, RawDescriptorResult]] = [...] * len(inputData)
descriptorIdxData = list(filter(lambda x: isinstance(x[1], RawDescriptorData), enumerate(inputData)))
rawDescriptorIdxs, rawDescriptorData = zip(*descriptorIdxData) if descriptorIdxData else ([], [])
rawImageIdxData = list(filter(lambda x: not isinstance(x[1], RawDescriptorData), enumerate(inputData)))
rawImageIdxs, sdkData = zip(*rawImageIdxData) if rawImageIdxData else ([], [])
if aggregate and rawDescriptorData:
raise VLException(Error.AggregationNotSupported, 400, isCriticalError=False)
sdkAdapter = APISDKHandlerAdaptor(logger=logger, accountId=accountId, sdkLoop=sdkLoop,)
if sdkData:
task = self.prepareSDKTask(sdkData, aggregate=aggregate)
detectLandmarks68 = self.detectPolicy.detectLandmarks68 and task.source == TaskDataSource.images
events, filteredDetections, processedImages, monitoringData = await sdkAdapter.handle(
task=task, detectLandmarks68=detectLandmarks68
)
for index in rawImageIdxs:
processedDataSources[index] = processedImages.pop(0)
else:
events, filteredDetections = [], sdkAdapter.prepareFilteredDetections()
monitoringData = DataForMonitoring()
rawDescriptorEvents, rawDescriptorProcessedImages = sdkAdapter.handleRawDescriptors(rawDescriptorData)
events += rawDescriptorEvents
for index in rawDescriptorIdxs:
processedDataSources[index] = rawDescriptorProcessedImages.pop(0)
# matching
with monitorTime(monitoringData.request, "match_policy_time"):
await asyncio.gather(
*[matchByListPolicy.execute(events, luna3Client) for matchByListPolicy in self.matchPolicy]
)
# tags
for policy in self.conditionalTagsPolicy:
policy.execute(events)
# user data
self.enrichEventsWithMetadata(
events=events,
source=source,
tags=userDefinedTags,
userData=userData,
externalId=externalId,
location=location,
trackId=trackId,
)
# storage
monitoringData += await self.storagePolicy.execute(
events=events,
accountId=accountId,
luna3Client=luna3Client,
originImages=sdkData,
userData=userData,
externalId=externalId,
logger=logger,
createEventTime=createEventTime,
endEventTime=endEventTime,
handlerId=handlerId,
facesBucket=facesBucket,
bodiesBucket=bodiesBucket,
originBucket=originBucket,
lunaEventsUsage=lunaEventsUsage,
redisContext=redisContext,
lunaSenderUsage=lunaSenderUsage,
)
self.publishEventsToPlugins(
events=events,
handlerId=handlerId,
createEventTime=createEventTime,
endEventTime=endEventTime,
plugins=plugins,
logger=logger,
accountId=accountId,
)
imagesReply = getImagesReplyData(
sdkData=sdkData, images=processedDataSources, extractExif=self.detectPolicy.extractExif
)
return imagesReply, events, filteredDetections, monitoringData