"""
Module contains schemas for policies
"""
import asyncio
from typing import Any, Optional, Type, TypeVar, Union
from uuid import UUID
import cachetools
from luna3.client import Client
from luna3.common.luna_response import LunaResponse
from luna3.remote_sdk import http_objs as sdk
from luna_plugins.base.manager import PluginManager
from pydantic import Field, ValidationError, model_validator
from vlutils.cache.cache import cache
from app.global_vars.context_vars import requestIdCtx
from classes.event import HandlerEvent as Event
from classes.schemas.base_schema import BaseSchema
from classes.schemas.conditional_tags_policy import ConditionalTagsPolicy
from classes.schemas.detect_policies import HandlerDetectPolicy
from classes.schemas.extract_policies import HandlerExtractPolicy
from classes.schemas.filters import AttributesFilters, MatchFilter
from classes.schemas.match_policy import MatchPolicy
from classes.schemas.storage_policy import StoragePolicy, StorePolicyConfig
from classes.schemas.types import MAX_POLICY_LIST_LENGTH
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.errors.pydantic_errors import PydanticError
from crutches_on_wheels.cow.monitoring.points import DataForMonitoring, monitorTime
from crutches_on_wheels.cow.utils.healthcheck import logger
from redis_db.redis_context import RedisContext
T = TypeVar("T")
# cache for credentials verification response
cacheTTL = cachetools.TTLCache(maxsize=256, ttl=10)
[docs]
def getObjectRecursively(data: Any, expectedType: Type[T]) -> list[T]:
"""Recursively get object of expected type"""
res = []
def collectObjects(dataPart: Any) -> None:
"""Collect object of expected type to 'res'"""
if isinstance(dataPart, expectedType):
res.append(dataPart)
if isinstance(dataPart, list):
[collectObjects(row) for row in dataPart]
if isinstance(dataPart, BaseSchema):
[collectObjects(row) for row in dataPart.__dict__.values()]
collectObjects(data)
return res
[docs]
@cache(lambda: cacheTTL, keyGen=lambda luna3Client, listId, accountId: f"{listId}{accountId}".__hash__())
async def checkListAvailability(luna3Client: Client, listId: str, accountId: Optional[str] = None) -> LunaResponse:
"""
Check availability of list.
Args:
luna3Client: luna platform client
accountId: account id
listId: list id
Returns:
Response from luna-faces service
"""
return await luna3Client.lunaFaces.checkList(listId=listId, accountId=accountId)
[docs]
class Policies(BaseSchema):
"""Policies schema"""
# detect policy
detectPolicy: HandlerDetectPolicy = Field(default_factory=lambda: HandlerDetectPolicy())
# extract policy
extractPolicy: HandlerExtractPolicy = HandlerExtractPolicy()
# matching policy list
matchPolicy: list[MatchPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH)
# conditional tags policy list
conditionalTagsPolicy: list[ConditionalTagsPolicy] = Field([], max_items=MAX_POLICY_LIST_LENGTH)
# storage policy
storagePolicy: StoragePolicy = Field(default_factory=lambda: StoragePolicy())
[docs]
@staticmethod
def validateMatchAndExtractCompatibility(matchPolicies: list[MatchPolicy], extractPolicy: HandlerExtractPolicy):
"""Validate match and extract policies compatibility"""
faceMatchPolicies = list(filter(lambda x: x.descriptor.descriptorType == "face", matchPolicies))
if faceMatchPolicies and not extractPolicy.extractFaceDescriptor:
raise PydanticError.PydanticValidationError.format(
"extract_face_descriptor should be equal to 1 for using face matching policy"
)()
bodyMatchPolicies = list(filter(lambda x: x.descriptor.descriptorType == "body", matchPolicies))
if bodyMatchPolicies and not extractPolicy.extractBodyDescriptor:
raise PydanticError.PydanticValidationError.format(
"extract_body_descriptor should be equal to 1 for using body matching policy"
)()
[docs]
@staticmethod
def validateMatchLabelsCompatibility(
matchPolicies: list[MatchPolicy],
conditionalTagsPolicies: list[ConditionalTagsPolicy],
storagePolicy: StoragePolicy,
):
"""Validate matching label compatibility"""
matchPolicyMatchingLabels = {matchPolicy.label for matchPolicy in matchPolicies}
matchFiltersMatchingLabels = {
matchFilter.label
for matchFilter in getObjectRecursively([conditionalTagsPolicies, storagePolicy], MatchFilter)
}
for matchFilterLabel in matchFiltersMatchingLabels:
if matchFilterLabel not in matchPolicyMatchingLabels:
raise PydanticError.PydanticValidationError.format(
f'"{matchFilterLabel}" should be in match policy for filtration based on a matching by this label',
)()
[docs]
@staticmethod
def validateGeneratedAttributesFilters(
detectPolicy: HandlerDetectPolicy,
extractPolicy: HandlerExtractPolicy,
matchPolicies: list[MatchPolicy],
conditionalTagsPolicies: list[ConditionalTagsPolicy],
storagePolicy: StoragePolicy,
):
"""Validate attributes and detect/extract policy compatibility"""
attributeFilters: list[AttributesFilters] = getObjectRecursively(
data=[matchPolicies, conditionalTagsPolicies, storagePolicy], expectedType=AttributesFilters
)
if not extractPolicy.extractBasicAttributes:
basicAttributesNeeded = any(
any(
(
filters.ethnicities is not None,
filters.ageLt is not None,
filters.ageGte is not None,
filters.gender is not None,
)
)
for filters in attributeFilters
)
if basicAttributesNeeded:
raise PydanticError.PydanticValidationError.format(
"extract_basic_attributes should be equal to 1 for filtration based on basic attributes",
)()
if not detectPolicy.estimateLiveness.estimate:
livenessNeeded = any((filters.liveness is not None) for filters in attributeFilters)
if livenessNeeded:
raise PydanticError.PydanticValidationError.format(
"estimate_liveness.estimate should be equal to 1 for filtration based on liveness"
)()
if not detectPolicy.estimateDeepfake.estimate:
deepfakeNeeded = any((filters.deepfake is not None) for filters in attributeFilters)
if deepfakeNeeded:
raise ValueError("estimate_deepfake.estimate should be equal to 1 for filtration based on deepfake")
[docs]
@staticmethod
def validateMatchPolicyUniqueLabels(matchPolicies: list[MatchPolicy]):
"""Validate match policy matching label uniqueness"""
labels = [matchPolicy.label for matchPolicy in matchPolicies]
if len(labels) != len(set(labels)):
error = PydanticError.PydanticValidationError.format("Matching allowed only by unique labels")()
raise ValidationError.from_exception_data(
title="Label uniqueness error",
line_errors=[
{"type": error, "loc": ("match_policy",), "input": MatchPolicy, "ctx": {"labels": labels}}
],
)
[docs]
@staticmethod
def validateDetectPolicyNotEmpty(detectPolicy: HandlerDetectPolicy):
"""Validate non-empty detect policy"""
if not detectPolicy.detectFace and not detectPolicy.detectBody:
raise PydanticError.PydanticValidationError.format(
"At least one of *detect_face* or *detect_body* should be equal to 1"
)()
if not detectPolicy.detectFace and detectPolicy.multifacePolicy != 1:
raise PydanticError.PydanticValidationError.format(
"*detect_face* should be equal to 1 to set *multiface_policy* to 0 or 2"
)()
[docs]
@staticmethod
def validateDetectAndExtractCompatibility(detectPolicy: HandlerDetectPolicy, extractPolicy: HandlerExtractPolicy):
"""Validate detect and extract policies compatibility"""
if extractPolicy.extractFaceDescriptor:
if not detectPolicy.detectFace:
raise PydanticError.PydanticValidationError.format(
"*detect_face* should be equal to 1 to enable *extract_face_descriptor*"
)()
if extractPolicy.extractBasicAttributes:
if not detectPolicy.detectFace:
raise PydanticError.PydanticValidationError.format(
"*detect_face* should be equal to 1 to enable *extract_basic_attributes*"
)()
if extractPolicy.extractBodyDescriptor:
if not detectPolicy.detectBody:
raise PydanticError.PydanticValidationError.format(
"*detect_body* should be equal to 1 to enable *extract_body_descriptor*"
)()
[docs]
@model_validator(mode="after")
def validatePolicies(cls, values):
"""Execute all compatibility validators"""
detectPolicy = values.detectPolicy
matchPolicies = values.matchPolicy
extractPolicy = values.extractPolicy
conditionalTagsPolicies = values.conditionalTagsPolicy
storagePolicy = values.storagePolicy
cls.validateDetectPolicyNotEmpty(detectPolicy=detectPolicy)
cls.validateDetectAndExtractCompatibility(detectPolicy=detectPolicy, extractPolicy=extractPolicy)
cls.validateMatchPolicyUniqueLabels(matchPolicies=matchPolicies)
cls.validateMatchAndExtractCompatibility(matchPolicies=matchPolicies, extractPolicy=extractPolicy)
cls.validateMatchLabelsCompatibility(
matchPolicies=matchPolicies, conditionalTagsPolicies=conditionalTagsPolicies, storagePolicy=storagePolicy
)
cls.validateGeneratedAttributesFilters(
detectPolicy=detectPolicy,
extractPolicy=extractPolicy,
matchPolicies=matchPolicies,
conditionalTagsPolicies=conditionalTagsPolicies,
storagePolicy=storagePolicy,
)
return values
@staticmethod
async def _checkListsAvailability(
luna3Client: Client, listIds: list[UUID], accountId: Optional[Union[str, UUID]] = None
) -> None:
"""
Check availability of lists from matching and link to list policies.
Args:
luna3Client: luna platform client
listIds: list ids
accountId: account id
Raises:
VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found
"""
for listId in listIds:
reply = await checkListAvailability(
luna3Client=luna3Client, accountId=str(accountId) if accountId is not None else None, listId=str(listId)
)
if reply.success:
continue
if reply.statusCode == 404:
raise VLException(Error.ListNotFound.format(listId), 400, False)
raise VLException(
Error.UnknownServiceError.format(
"luna-faces", "HEAD", f"{luna3Client.lunaFaces.baseUri}/lists/{listId}"
),
reply.statusCode,
False,
)
async def _checkLinkingListsAvailability(self, luna3Client: Client, accountId: Optional[str] = None) -> None:
"""
Check availability of lists from linking policies.
Args:
luna3Client: luna platform client
"""
await self._checkListsAvailability(
luna3Client=luna3Client,
listIds=[linkPolicy.listId for linkPolicy in self.storagePolicy.facePolicy.linkToListsPolicy],
accountId=accountId,
)
async def _checkMatchingListsAvailability(self, luna3Client: Client) -> None:
"""
Check availability of lists from matching policies.
Args:
luna3Client: luna platform client
"""
for matchPolicy in self.matchPolicy:
if (listId := getattr(matchPolicy.candidates, "listId", None)) is None:
continue
await self._checkListsAvailability(
luna3Client=luna3Client,
listIds=[listId],
accountId=matchPolicy.candidates.accountId,
)
[docs]
async def checkListsAvailability(self, luna3Client: Client, accountId: str) -> None:
"""
Check availability of lists from matching and link to list policies.
Args:
luna3Client: luna platform client
accountId: account id
Raises:
VLException(Error.ListNotFound.format(listId), 400, False), if some list is not found
"""
await self._checkMatchingListsAvailability(luna3Client=luna3Client)
await self._checkLinkingListsAvailability(luna3Client=luna3Client, accountId=accountId)
@property
def estimator(self):
estimateFaceQuality = self.detectPolicy.detectFace * self.detectPolicy.faceQuality.estimate
targets = sdk.Targets(
exif=self.detectPolicy.extractExif,
peopleCount=self.detectPolicy.estimatePeopleCount or None,
faceDetection=self.detectPolicy.detectFace,
faceLandmarks5=self.detectPolicy.detectFace * self.detectPolicy.detectFace,
faceLandmarks68=self.detectPolicy.detectFace * self.detectPolicy.detectLandmarks68,
faceWarp=self.detectPolicy.detectFace * self.storagePolicy.faceSamplePolicy.storeSample,
gaze=self.detectPolicy.detectFace * self.detectPolicy.estimateGaze,
headPose=(dp := self.detectPolicy).detectFace
* (dp.estimateHeadPose | bool(dp.rollThreshold or dp.yawThreshold or dp.pitchThreshold)),
eyes=self.detectPolicy.detectFace * self.detectPolicy.estimateEyesAttributes,
mouthAttributes=self.detectPolicy.detectFace * self.detectPolicy.estimateMouthAttributes,
faceWarpQuality=self.detectPolicy.detectFace * self.detectPolicy.estimateQuality,
emotions=self.detectPolicy.detectFace * self.detectPolicy.estimateEmotions,
mask=self.detectPolicy.detectFace * (self.detectPolicy.estimateMask | bool(self.detectPolicy.maskStates)),
glasses=self.detectPolicy.detectFace * self.detectPolicy.estimateGlasses,
liveness=self.detectPolicy.detectFace
* (self.detectPolicy.estimateLiveness.estimate | bool(self.detectPolicy.livenessStates)),
deepfake=self.detectPolicy.detectFace
* (self.detectPolicy.estimateDeepfake.estimate | bool(self.detectPolicy.deepfakeStates)),
faceQuality=self.detectPolicy.faceQuality if estimateFaceQuality else None,
bodyDetection=self.detectPolicy.detectBody,
bodyWarp=self.detectPolicy.detectBody * self.storagePolicy.bodySamplePolicy.storeSample,
faceDescriptor=self.extractPolicy.extractFaceDescriptor,
basicAttributes=self.extractPolicy.extractBasicAttributes,
bodyDescriptor=self.extractPolicy.extractBodyDescriptor,
bodyAttributes=self.detectPolicy.detectBody * self.detectPolicy.bodyAttributes.estimateBasicAttributes,
upperBody=self.detectPolicy.detectBody * self.detectPolicy.bodyAttributes.estimateUpperBody,
lowerBody=self.detectPolicy.detectBody * self.detectPolicy.bodyAttributes.estimateLowerBody,
accessories=self.detectPolicy.detectBody * self.detectPolicy.bodyAttributes.estimateAccessories,
)
filters = sdk.Filters(
faceDetectionFilters=sdk.FaceDetectionFilters(
rollThreshold=self.detectPolicy.rollThreshold if self.detectPolicy.detectFace else None,
yawThreshold=self.detectPolicy.yawThreshold if self.detectPolicy.detectFace else None,
pitchThreshold=self.detectPolicy.pitchThreshold if self.detectPolicy.detectFace else None,
livenessStates=self.detectPolicy.livenessStates if self.detectPolicy.detectFace else None,
deepfakeStates=self.detectPolicy.deepfakeStates if self.detectPolicy.detectFace else None,
maskStates=self.detectPolicy.maskStates if self.detectPolicy.detectFace else None,
scoreThreshold=self.extractPolicy.fdScoreThreshold,
)
)
estimationConfig = sdk.EstimatorsParams(
livenessParams=sdk.LivenessParams(
scoreThreshold=self.detectPolicy.estimateLiveness.livenessThreshold,
qualityThreshold=self.detectPolicy.estimateLiveness.qualityThreshold,
),
deepfakeParams=sdk.DeepfakeParams(
realThreshold=self.detectPolicy.estimateDeepfake.realThreshold,
mode=self.detectPolicy.estimateDeepfake.mode,
),
)
params = sdk.Params(
targets=targets,
filters=filters,
estimatorsParams=estimationConfig,
multifacePolicy=self.detectPolicy.multifacePolicy,
)
return sdk.Estimator(params, images=...)
[docs]
def publishEventsToPlugins(self, events: list[Event], accountId: str, plugins: PluginManager) -> None:
"""
Publish events to other services.
Args:
events: list of events
accountId: account id
plugins: plugin manager
"""
if not events:
return
pluginName = "sending_event"
requestId = requestIdCtx.get()
handlerId = events[0].handlerId
createTime = events[0].createTime
endTime = events[0].endTime
plugins.sendEventToPlugins(pluginName, events, handlerId, accountId, requestId, createTime, endTime, logger)
[docs]
async def execute(
self,
events: list[Event],
accountId: str,
monitoring: DataForMonitoring,
config: StorePolicyConfig,
luna3Client: Client,
redisContext: RedisContext,
plugins: PluginManager,
):
"""
Execute all policies for handler.
Args:
events: events
accountId: account id
monitoring: monitoring data
config: handler configuration parameters
luna3Client: luna platform client
redisContext: redis context
plugins: plugin manager
"""
# matching
with monitorTime(monitoring, "match_policy_time"):
await asyncio.gather(
*[matchByListPolicy.execute(events, luna3Client) for matchByListPolicy in self.matchPolicy]
)
# tags
for policy in self.conditionalTagsPolicy:
policy.execute(events)
# storage
await self.storagePolicy.execute(
events=events,
accountId=accountId,
monitoring=monitoring,
config=config,
luna3Client=luna3Client,
redisContext=redisContext,
)
self.publishEventsToPlugins(events=events, accountId=accountId, plugins=plugins)