Source code for luna_handlers.classes.schemas.match_policy

"""
Module contains schemas for match policy
"""
from enum import Enum
from typing import List, Optional, Union

from luna3.client import Client
from luna3.common.luna_response import LunaResponse
from luna3.python_matcher.match_objects import (
    AttributeFilters,
    Candidates,
    EventFilters,
    FaceFilters,
    RawDescriptorReference,
)
from pydantic import AfterValidator, Field, StrictInt, WrapValidator, conlist, field_validator
from typing_extensions import Annotated, Literal
from vlutils.helpers import convertToSnakeCase

from classes.event import HandlerEvent as Event
from classes.schemas import types
from classes.schemas.base_schema import BaseSchema
from classes.schemas.filters import AttributesFilters
from classes.schemas.types import MAX_ITEMS_LIST_LENGTH, UUIDObject
from crutches_on_wheels.cow.errors.errors import ErrorInfo
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.errors.pydantic_errors import PydanticError
from crutches_on_wheels.cow.maps.vl_maps import (
    CLOTHING_COLOR_MAP,
    DEEPFAKE_MAP,
    EMOTION_MAP,
    ETHNIC_MAP,
    EVENT_TARGET_MAP,
    FACE_TARGET_MAP,
    LIVENESS_MAP,
    LOWER_GARMENT_MAP,
    MASK_MAP,
    PRIMITIVE_CLOTHING_COLOR_MAP,
    SLEEVE_LENGTH_MAP,
    TEMPORARY_ATTRIBUTE_MATCH_TARGET_MAP,
)
from crutches_on_wheels.cow.pydantic.bases import NullableFieldsModelMixin
from crutches_on_wheels.cow.pydantic.meta import Meta
from crutches_on_wheels.cow.pydantic.types import CustomExtendedDatetime, OptionalNotNullable, validateUniqueItems
from crutches_on_wheels.cow.pydantic.validators import excludeDiscriminatorFromException
from crutches_on_wheels.cow.utils.functions import getAnnotationWithNullable
from crutches_on_wheels.cow.utils.timer import timer


[docs] class OriginEnum(str, Enum): """Match candidates origin""" faces = "faces" events = "events" attributes = "attributes"
[docs] class FaceMatchCandidates(BaseSchema): """Face match candidates""" # candidates origin - faces origin: Literal[OriginEnum.faces.value] # face ids faceIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # face account id accountId: UUIDObject = OptionalNotNullable() # face external ids externalIds: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # face user data userData: types.Str128 = OptionalNotNullable() # face create time upper including boundary createTimeGte: CustomExtendedDatetime = OptionalNotNullable() # face create time lower excluding boundary createTimeLt: CustomExtendedDatetime = OptionalNotNullable() # face id upper including boundary faceIdGte: UUIDObject = OptionalNotNullable() # face id lower excluding boundary faceIdLt: UUIDObject = OptionalNotNullable() # face linked list id listId: UUIDObject = OptionalNotNullable()
[docs] class GeoPosition(BaseSchema): """Geo position: longitude and latitude with deltas""" # geo position longitude origin originLongitude: types.FloatLongitude # geo position latitude origin originLatitude: types.FloatLatitude # geo position longitude delta longitudeDelta: types.FloatGeoDelta = 0.01 # geo position latitude delta latitudeDelta: types.FloatGeoDelta = 0.01
def _strictIntToList(value: StrictInt): """Convert pydantic StrictInt to list""" return list(range(value.ge, value.le + 1))
[docs] class EventMatchCandidates(NullableFieldsModelMixin, BaseSchema): """Event match candidates""" # candidates origin - events origin: Literal[OriginEnum.events.value] # event ids eventIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event account id accountId: UUIDObject = OptionalNotNullable() # event id upper including boundary eventIdGte: UUIDObject = OptionalNotNullable() # event id lower excluding boundary eventIdLt: UUIDObject = OptionalNotNullable() # event create time upper including boundary createTimeGte: CustomExtendedDatetime = OptionalNotNullable() # event create time lower excluding boundary createTimeLt: CustomExtendedDatetime = OptionalNotNullable() # event end time upper including boundary endTimeGte: CustomExtendedDatetime = OptionalNotNullable() # event end time lower excluding boundary endTimeLt: CustomExtendedDatetime = OptionalNotNullable() # event handler ids handlerIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event external ids externalIds: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event top matching candidates label topMatchingCandidatesLabel: types.Str36 = OptionalNotNullable() # event top similar object ids topSimilarObjectIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event top similar external ids topSimilarExternalIds: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event top similar object similarity upper including boundary topSimilarObjectSimilarityGte: types.StrictFloat01 = OptionalNotNullable() # event top similar object similarity lower excluding boundary topSimilarObjectSimilarityLt: types.StrictFloat01 = OptionalNotNullable() # event age upper including boundary ageGte: types.IntAge = OptionalNotNullable() # event age lower excluding boundary ageLt: types.IntAge = OptionalNotNullable() # event gender gender: types.Int01 | None = None # event emotion list emotions: Annotated[ list[getAnnotationWithNullable(_strictIntToList(types.IntEmotions))], AfterValidator(validateUniqueItems), Field(min_length=1, max_length=len(EMOTION_MAP) + 1), ] = OptionalNotNullable() # event mask list masks: Annotated[ list[getAnnotationWithNullable(_strictIntToList(types.IntMasks))], AfterValidator(validateUniqueItems), Field(min_length=1, max_length=len(MASK_MAP) + 1), ] = OptionalNotNullable() # event liveness states liveness: Annotated[ list[getAnnotationWithNullable(_strictIntToList(types.IntLiveness))], AfterValidator(validateUniqueItems), Field(min_length=1, max_length=len(LIVENESS_MAP) + 1), ] = OptionalNotNullable() # event deepfake states deepfake: Annotated[ list[getAnnotationWithNullable(DEEPFAKE_MAP.values())], AfterValidator(validateUniqueItems), Field(min_length=1, max_length=len(DEEPFAKE_MAP) + 1), ] = OptionalNotNullable() # event ethnic group list ethnicGroups: Annotated[ list[getAnnotationWithNullable(_strictIntToList(types.IntEthnicities))], AfterValidator(validateUniqueItems), Field(min_length=1, max_length=len(ETHNIC_MAP) + 1), ] = OptionalNotNullable() # event face ids faceIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event user data userData: types.Str128 = OptionalNotNullable() # event sources sources: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event stream ids streamIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event tags tags: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event cities cities: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event areas areas: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event districts districts: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event streets streets: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event house numbers house_numbers: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event geo position geoPosition: GeoPosition | None = None # event track ids trackIds: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # apparent age lt apparentAgeLt: types.IntAge = OptionalNotNullable() # apparent age gte apparentAgeGte: types.IntAge = OptionalNotNullable() # apparent gender apparentGender: Annotated[ List[types.Int02 | None], AfterValidator(validateUniqueItems), Field(min_length=1), ] = OptionalNotNullable() # headwear states headwearStates: Annotated[ List[types.Int02 | None], AfterValidator(validateUniqueItems), Field(min_length=1), ] = OptionalNotNullable() # sleeve length sleeveLengths: Annotated[ List[getAnnotationWithNullable(SLEEVE_LENGTH_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1) ] = OptionalNotNullable() # upper clothing colors upperClothingColors: Annotated[ List[getAnnotationWithNullable(CLOTHING_COLOR_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1) ] = OptionalNotNullable() # backpack states backpackStates: Annotated[ List[types.Int02 | None], AfterValidator(validateUniqueItems), Field(min_length=1), ] = OptionalNotNullable() # lower garment colors lowerGarmentColors: Annotated[ List[getAnnotationWithNullable(CLOTHING_COLOR_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1) ] = OptionalNotNullable() # lower garment types lowerGarmentTypes: Annotated[ List[getAnnotationWithNullable(LOWER_GARMENT_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1) ] = OptionalNotNullable() # headwear apparent colors headwearApparentColors: Annotated[ List[getAnnotationWithNullable(PRIMITIVE_CLOTHING_COLOR_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1), ] = OptionalNotNullable() # shoes apparent colors shoesApparentColors: Annotated[ List[getAnnotationWithNullable(PRIMITIVE_CLOTHING_COLOR_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1), ] = OptionalNotNullable() # meta filters meta: Meta = OptionalNotNullable()
[docs] class AttributeMatchCandidates(BaseSchema): """Face match candidates""" # candidates origin - attributes origin: Literal[OriginEnum.attributes.value] # temporary attribute ids attributeIds: conlist(UUIDObject, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # temporary attribute account id accountId: UUIDObject = OptionalNotNullable()
[docs] class MatchingDescriptor(BaseSchema): """Matching descriptor parameters""" # matching descriptor type descriptorType: Literal["face", "body"] = "face"
[docs] class BaseMatchPolicy(BaseSchema): """Base match policy schema""" # matching label label: types.Str36 = "" # matching candidates candidates: Annotated[ FaceMatchCandidates | EventMatchCandidates | AttributeMatchCandidates, WrapValidator(excludeDiscriminatorFromException), Field(discriminator="origin"), ] # matching candidate descriptor parameters descriptor: MatchingDescriptor = MatchingDescriptor() # matching filters filters: AttributesFilters = AttributesFilters() # matching limit limit: Optional[types.IntMatchingLimit] = types.DEFAULT_MATCH_LIMIT # matching targets targets: Optional[List[str]] = OptionalNotNullable() # matching threshold threshold: Optional[types.StrictFloat01] = OptionalNotNullable()
[docs] @field_validator("targets", mode="before") def targetsValidator(cls, targets, values) -> Union[List[str], None]: """Targets validator, depends on matching candidates""" if targets is None: return targets if "candidates" not in values.data: # if validation fails on field (or that field is missing), it is not included in values return candidatesType = values.data["candidates"].origin candidateTargetMap = { OriginEnum.faces.name: FACE_TARGET_MAP, OriginEnum.events.name: EVENT_TARGET_MAP, OriginEnum.attributes.name: TEMPORARY_ATTRIBUTE_MATCH_TARGET_MAP, }[candidatesType] expectedTargets = list(candidateTargetMap) unexpectedTargets = [row for row in targets if row not in expectedTargets] if unexpectedTargets: raise PydanticError.PydanticValidationError.format( f"Bad matching policy targets {targets}, allowed targets for '{candidatesType}' candidates are: {expectedTargets}", )() return targets
[docs] @field_validator("descriptor") def descriptorValidator(cls, descriptor, values) -> Optional[str]: """Candidate descriptor parameters validator, depends on matching candidates""" if "candidates" not in values.data: # if validation fails on field (or that field is missing), it is not included in values return if values.data.get("candidates").origin == OriginEnum.faces.value and descriptor.descriptorType != "face": raise PydanticError.PydanticValidationError.format( f"Bad matching policy descriptor type '{descriptor.descriptorType}' for face candidates" )() return descriptor
[docs] def getEventsReferences(self, events: list[Event]) -> list[RawDescriptorReference] | None: """ Builds descriptor references list for events Args: events: list of events that suppose to be matching Returns: List of descriptors references """ references = [] for event in events: if self.filters.isEventSatisfies(event): if self.descriptor.descriptorType == "face" and event.faceDescriptor: refDescriptor = event.faceDescriptor elif self.descriptor.descriptorType == "body" and event.bodyDescriptor: refDescriptor = event.bodyDescriptor else: continue references.append(RawDescriptorReference(event.raw["event_id"], refDescriptor)) if len(references) == 0: return return references
[docs] def prepareCandidatesFilters(self) -> dict[str, any]: """ Builds filter data from candidates' fields, skipping unset ones Returns: Dict with actual data from matching candidate """ candidatesFilters = self.candidates.asDict() if candidatesFilters.get("geo_position") is not None: candidatesFilters["geo_position"] = convertToSnakeCase(candidatesFilters["geo_position"]) return candidatesFilters
[docs] def prepareMatchPayload( self, candidatesFilters: dict[str, any], references: list[RawDescriptorReference] ) -> dict[str, list[Candidates | RawDescriptorReference]]: """ Builds payload with all the necessary attributes for matching request Args: candidatesFilters: filter built from candidate's data references: list of descriptor references built from events that suppose to be matching Returns: Payload dictionary with candidates and references lists """ origin = candidatesFilters.pop("origin") ObjectFilters = {"faces": FaceFilters, "events": EventFilters, "attributes": AttributeFilters}[origin] return { "candidates": [ Candidates( filters=ObjectFilters().initFromSnakeKwargs(candidatesFilters), targets=self.targets + ["similarity"] if self.targets is not None else None, limit=self.limit, threshold=self.threshold, ) ], "references": references, }
[docs] async def requestMatching( self, luna3Client: Client, matchPayload: dict[str, list[Candidates | RawDescriptorReference]] ) -> LunaResponse: """ Sends request for matching Args: luna3Client: client for communication with matching service matchPayload: payload dictionary with candidates and references lists Returns: Result of matching """ if self.descriptor.descriptorType == "face": return await luna3Client.lunaPythonMatcher.matchFaces(**matchPayload, raiseError=True) else: return await luna3Client.lunaPythonMatcher.matchBodies(**matchPayload, raiseError=True)
[docs] def applyMatchingResultToEvents(self, events: list[Event], matchingResult: LunaResponse) -> None: """ Unpack response data and updates events Args: events: list of events that suppose to be matching matchingResult: result of matching """ eventsMap = {event.raw["event_id"]: event for event in events} for match in matchingResult.json: for singleMatch in match["matches"]: err = singleMatch.get("error") if err is not None: raise VLException(ErrorInfo.fromDict(err)) if event := eventsMap.get(match["reference"]["id"]): event.raw["matches"].append({"label": self.label, "candidates": match["matches"][0]["result"]})
[docs] @timer async def execute(self, events: List[Event], luna3Client: Client) -> None: """ Execute match policy: * filter events fo matching * match events' descriptors with policy candidates Results are stored in the input events. Args: events: events (references) luna3Client: client """ if (references := self.getEventsReferences(events)) is None: return candidatesFilters = self.prepareCandidatesFilters() matchPayload = self.prepareMatchPayload(candidatesFilters, references) matchingResult = await self.requestMatching(luna3Client, matchPayload) self.applyMatchingResultToEvents(events, matchingResult)
[docs] class MatchPolicy(BaseMatchPolicy): """Match policy schema""" # matching candidates candidates: Annotated[ FaceMatchCandidates | EventMatchCandidates, WrapValidator(excludeDiscriminatorFromException), Field(discriminator="origin"), ]