"""
Module contains schemas for match policy
"""
from enum import Enum
from typing import List, Optional, Union
from luna3.client import Client
from luna3.common.luna_response import LunaResponse
from luna3.python_matcher.match_objects import (
AttributeFilters,
Candidates,
EventFilters,
FaceFilters,
RawDescriptorReference,
)
from pydantic import AfterValidator, Field, StrictInt, WrapValidator, conlist, field_validator
from typing_extensions import Annotated, Literal
from vlutils.helpers import convertToSnakeCase
from classes.event import HandlerEvent as Event
from classes.schemas import types
from classes.schemas.base_schema import BaseSchema
from classes.schemas.filters import AttributesFilters
from classes.schemas.types import MAX_ITEMS_LIST_LENGTH, UUIDObject
from crutches_on_wheels.cow.errors.errors import ErrorInfo
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.errors.pydantic_errors import PydanticError
from crutches_on_wheels.cow.maps.vl_maps import (
CLOTHING_COLOR_MAP,
DEEPFAKE_MAP,
EMOTION_MAP,
ETHNIC_MAP,
EVENT_TARGET_MAP,
FACE_TARGET_MAP,
LIVENESS_MAP,
LOWER_GARMENT_MAP,
MASK_MAP,
PRIMITIVE_CLOTHING_COLOR_MAP,
SLEEVE_LENGTH_MAP,
TEMPORARY_ATTRIBUTE_MATCH_TARGET_MAP,
)
from crutches_on_wheels.cow.pydantic.bases import NullableFieldsModelMixin
from crutches_on_wheels.cow.pydantic.meta import Meta
from crutches_on_wheels.cow.pydantic.types import CustomExtendedDatetime, OptionalNotNullable, validateUniqueItems
from crutches_on_wheels.cow.pydantic.validators import excludeDiscriminatorFromException
from crutches_on_wheels.cow.utils.functions import getAnnotationWithNullable
from crutches_on_wheels.cow.utils.timer import timer
[docs]
class OriginEnum(str, Enum):
"""Match candidates origin"""
faces = "faces"
events = "events"
attributes = "attributes"
[docs]
class FaceMatchCandidates(BaseSchema):
"""Face match candidates"""
# candidates origin - faces
origin: Literal[OriginEnum.faces.value]
# face ids
faceIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# face account id
accountId: UUIDObject = OptionalNotNullable()
# face external ids
externalIds: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# face user data
userData: types.Str128 = OptionalNotNullable()
# face create time upper including boundary
createTimeGte: CustomExtendedDatetime = OptionalNotNullable()
# face create time lower excluding boundary
createTimeLt: CustomExtendedDatetime = OptionalNotNullable()
# face id upper including boundary
faceIdGte: UUIDObject = OptionalNotNullable()
# face id lower excluding boundary
faceIdLt: UUIDObject = OptionalNotNullable()
# face linked list id
listId: UUIDObject = OptionalNotNullable()
[docs]
class GeoPosition(BaseSchema):
"""Geo position: longitude and latitude with deltas"""
# geo position longitude origin
originLongitude: types.FloatLongitude
# geo position latitude origin
originLatitude: types.FloatLatitude
# geo position longitude delta
longitudeDelta: types.FloatGeoDelta = 0.01
# geo position latitude delta
latitudeDelta: types.FloatGeoDelta = 0.01
def _strictIntToList(value: StrictInt):
"""Convert pydantic StrictInt to list"""
return list(range(value.ge, value.le + 1))
[docs]
class EventMatchCandidates(NullableFieldsModelMixin, BaseSchema):
"""Event match candidates"""
# candidates origin - events
origin: Literal[OriginEnum.events.value]
# event ids
eventIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event account id
accountId: UUIDObject = OptionalNotNullable()
# event id upper including boundary
eventIdGte: UUIDObject = OptionalNotNullable()
# event id lower excluding boundary
eventIdLt: UUIDObject = OptionalNotNullable()
# event create time upper including boundary
createTimeGte: CustomExtendedDatetime = OptionalNotNullable()
# event create time lower excluding boundary
createTimeLt: CustomExtendedDatetime = OptionalNotNullable()
# event end time upper including boundary
endTimeGte: CustomExtendedDatetime = OptionalNotNullable()
# event end time lower excluding boundary
endTimeLt: CustomExtendedDatetime = OptionalNotNullable()
# event handler ids
handlerIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event external ids
externalIds: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event top matching candidates label
topMatchingCandidatesLabel: types.Str36 = OptionalNotNullable()
# event top similar object ids
topSimilarObjectIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event top similar external ids
topSimilarExternalIds: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event top similar object similarity upper including boundary
topSimilarObjectSimilarityGte: types.StrictFloat01 = OptionalNotNullable()
# event top similar object similarity lower excluding boundary
topSimilarObjectSimilarityLt: types.StrictFloat01 = OptionalNotNullable()
# event age upper including boundary
ageGte: types.IntAge = OptionalNotNullable()
# event age lower excluding boundary
ageLt: types.IntAge = OptionalNotNullable()
# event gender
gender: types.Int01 | None = None
# event emotion list
emotions: Annotated[
list[getAnnotationWithNullable(_strictIntToList(types.IntEmotions))],
AfterValidator(validateUniqueItems),
Field(min_length=1, max_length=len(EMOTION_MAP) + 1),
] = OptionalNotNullable()
# event mask list
masks: Annotated[
list[getAnnotationWithNullable(_strictIntToList(types.IntMasks))],
AfterValidator(validateUniqueItems),
Field(min_length=1, max_length=len(MASK_MAP) + 1),
] = OptionalNotNullable()
# event liveness states
liveness: Annotated[
list[getAnnotationWithNullable(_strictIntToList(types.IntLiveness))],
AfterValidator(validateUniqueItems),
Field(min_length=1, max_length=len(LIVENESS_MAP) + 1),
] = OptionalNotNullable()
# event deepfake states
deepfake: Annotated[
list[getAnnotationWithNullable(DEEPFAKE_MAP.values())],
AfterValidator(validateUniqueItems),
Field(min_length=1, max_length=len(DEEPFAKE_MAP) + 1),
] = OptionalNotNullable()
# event ethnic group list
ethnicGroups: Annotated[
list[getAnnotationWithNullable(_strictIntToList(types.IntEthnicities))],
AfterValidator(validateUniqueItems),
Field(min_length=1, max_length=len(ETHNIC_MAP) + 1),
] = OptionalNotNullable()
# event face ids
faceIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event user data
userData: types.Str128 = OptionalNotNullable()
# event sources
sources: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event stream ids
streamIds: conlist(UUIDObject, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event tags
tags: conlist(types.Str36, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event cities
cities: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event areas
areas: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event districts
districts: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event streets
streets: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event house numbers
house_numbers: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event geo position
geoPosition: GeoPosition | None = None
# event track ids
trackIds: conlist(types.Str36 | None, min_length=1, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# apparent age lt
apparentAgeLt: types.IntAge = OptionalNotNullable()
# apparent age gte
apparentAgeGte: types.IntAge = OptionalNotNullable()
# apparent gender
apparentGender: Annotated[
List[types.Int02 | None],
AfterValidator(validateUniqueItems),
Field(min_length=1),
] = OptionalNotNullable()
# headwear states
headwearStates: Annotated[
List[types.Int02 | None],
AfterValidator(validateUniqueItems),
Field(min_length=1),
] = OptionalNotNullable()
# sleeve length
sleeveLengths: Annotated[
List[getAnnotationWithNullable(SLEEVE_LENGTH_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1)
] = OptionalNotNullable()
# upper clothing colors
upperClothingColors: Annotated[
List[getAnnotationWithNullable(CLOTHING_COLOR_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1)
] = OptionalNotNullable()
# backpack states
backpackStates: Annotated[
List[types.Int02 | None],
AfterValidator(validateUniqueItems),
Field(min_length=1),
] = OptionalNotNullable()
# lower garment colors
lowerGarmentColors: Annotated[
List[getAnnotationWithNullable(CLOTHING_COLOR_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1)
] = OptionalNotNullable()
# lower garment types
lowerGarmentTypes: Annotated[
List[getAnnotationWithNullable(LOWER_GARMENT_MAP)], AfterValidator(validateUniqueItems), Field(min_length=1)
] = OptionalNotNullable()
# headwear apparent colors
headwearApparentColors: Annotated[
List[getAnnotationWithNullable(PRIMITIVE_CLOTHING_COLOR_MAP)],
AfterValidator(validateUniqueItems),
Field(min_length=1),
] = OptionalNotNullable()
# shoes apparent colors
shoesApparentColors: Annotated[
List[getAnnotationWithNullable(PRIMITIVE_CLOTHING_COLOR_MAP)],
AfterValidator(validateUniqueItems),
Field(min_length=1),
] = OptionalNotNullable()
# meta filters
meta: Meta = OptionalNotNullable()
[docs]
class AttributeMatchCandidates(BaseSchema):
"""Face match candidates"""
# candidates origin - attributes
origin: Literal[OriginEnum.attributes.value]
# temporary attribute ids
attributeIds: conlist(UUIDObject, max_length=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# temporary attribute account id
accountId: UUIDObject = OptionalNotNullable()
[docs]
class MatchingDescriptor(BaseSchema):
"""Matching descriptor parameters"""
# matching descriptor type
descriptorType: Literal["face", "body"] = "face"
[docs]
class BaseMatchPolicy(BaseSchema):
"""Base match policy schema"""
# matching label
label: types.Str36 = ""
# matching candidates
candidates: Annotated[
FaceMatchCandidates | EventMatchCandidates | AttributeMatchCandidates,
WrapValidator(excludeDiscriminatorFromException),
Field(discriminator="origin"),
]
# matching candidate descriptor parameters
descriptor: MatchingDescriptor = MatchingDescriptor()
# matching filters
filters: AttributesFilters = AttributesFilters()
# matching limit
limit: Optional[types.IntMatchingLimit] = types.DEFAULT_MATCH_LIMIT
# matching targets
targets: Optional[List[str]] = OptionalNotNullable()
# matching threshold
threshold: Optional[types.StrictFloat01] = OptionalNotNullable()
[docs]
@field_validator("targets", mode="before")
def targetsValidator(cls, targets, values) -> Union[List[str], None]:
"""Targets validator, depends on matching candidates"""
if targets is None:
return targets
if "candidates" not in values.data:
# if validation fails on field (or that field is missing), it is not included in values
return
candidatesType = values.data["candidates"].origin
candidateTargetMap = {
OriginEnum.faces.name: FACE_TARGET_MAP,
OriginEnum.events.name: EVENT_TARGET_MAP,
OriginEnum.attributes.name: TEMPORARY_ATTRIBUTE_MATCH_TARGET_MAP,
}[candidatesType]
expectedTargets = list(candidateTargetMap)
unexpectedTargets = [row for row in targets if row not in expectedTargets]
if unexpectedTargets:
raise PydanticError.PydanticValidationError.format(
f"Bad matching policy targets {targets}, allowed targets for '{candidatesType}' candidates are: {expectedTargets}",
)()
return targets
[docs]
@field_validator("descriptor")
def descriptorValidator(cls, descriptor, values) -> Optional[str]:
"""Candidate descriptor parameters validator, depends on matching candidates"""
if "candidates" not in values.data:
# if validation fails on field (or that field is missing), it is not included in values
return
if values.data.get("candidates").origin == OriginEnum.faces.value and descriptor.descriptorType != "face":
raise PydanticError.PydanticValidationError.format(
f"Bad matching policy descriptor type '{descriptor.descriptorType}' for face candidates"
)()
return descriptor
[docs]
def getEventsReferences(self, events: list[Event]) -> list[RawDescriptorReference] | None:
"""
Builds descriptor references list for events
Args:
events: list of events that suppose to be matching
Returns:
List of descriptors references
"""
references = []
for event in events:
if self.filters.isEventSatisfies(event):
if self.descriptor.descriptorType == "face" and event.faceDescriptor:
refDescriptor = event.faceDescriptor
elif self.descriptor.descriptorType == "body" and event.bodyDescriptor:
refDescriptor = event.bodyDescriptor
else:
continue
references.append(RawDescriptorReference(event.raw["event_id"], refDescriptor))
if len(references) == 0:
return
return references
[docs]
def prepareCandidatesFilters(self) -> dict[str, any]:
"""
Builds filter data from candidates' fields, skipping unset ones
Returns:
Dict with actual data from matching candidate
"""
candidatesFilters = self.candidates.asDict()
if candidatesFilters.get("geo_position") is not None:
candidatesFilters["geo_position"] = convertToSnakeCase(candidatesFilters["geo_position"])
return candidatesFilters
[docs]
def prepareMatchPayload(
self, candidatesFilters: dict[str, any], references: list[RawDescriptorReference]
) -> dict[str, list[Candidates | RawDescriptorReference]]:
"""
Builds payload with all the necessary attributes for matching request
Args:
candidatesFilters: filter built from candidate's data
references: list of descriptor references built from events that suppose to be matching
Returns:
Payload dictionary with candidates and references lists
"""
origin = candidatesFilters.pop("origin")
ObjectFilters = {"faces": FaceFilters, "events": EventFilters, "attributes": AttributeFilters}[origin]
return {
"candidates": [
Candidates(
filters=ObjectFilters().initFromSnakeKwargs(candidatesFilters),
targets=self.targets + ["similarity"] if self.targets is not None else None,
limit=self.limit,
threshold=self.threshold,
)
],
"references": references,
}
[docs]
async def requestMatching(
self, luna3Client: Client, matchPayload: dict[str, list[Candidates | RawDescriptorReference]]
) -> LunaResponse:
"""
Sends request for matching
Args:
luna3Client: client for communication with matching service
matchPayload: payload dictionary with candidates and references lists
Returns:
Result of matching
"""
if self.descriptor.descriptorType == "face":
return await luna3Client.lunaPythonMatcher.matchFaces(**matchPayload, raiseError=True)
else:
return await luna3Client.lunaPythonMatcher.matchBodies(**matchPayload, raiseError=True)
[docs]
def applyMatchingResultToEvents(self, events: list[Event], matchingResult: LunaResponse) -> None:
"""
Unpack response data and updates events
Args:
events: list of events that suppose to be matching
matchingResult: result of matching
"""
eventsMap = {event.raw["event_id"]: event for event in events}
for match in matchingResult.json:
for singleMatch in match["matches"]:
err = singleMatch.get("error")
if err is not None:
raise VLException(ErrorInfo.fromDict(err))
if event := eventsMap.get(match["reference"]["id"]):
event.raw["matches"].append({"label": self.label, "candidates": match["matches"][0]["result"]})
[docs]
@timer
async def execute(self, events: List[Event], luna3Client: Client) -> None:
"""
Execute match policy:
* filter events fo matching
* match events' descriptors with policy candidates
Results are stored in the input events.
Args:
events: events (references)
luna3Client: client
"""
if (references := self.getEventsReferences(events)) is None:
return
candidatesFilters = self.prepareCandidatesFilters()
matchPayload = self.prepareMatchPayload(candidatesFilters, references)
matchingResult = await self.requestMatching(luna3Client, matchPayload)
self.applyMatchingResultToEvents(events, matchingResult)
[docs]
class MatchPolicy(BaseMatchPolicy):
"""Match policy schema"""
# matching candidates
candidates: Annotated[
FaceMatchCandidates | EventMatchCandidates,
WrapValidator(excludeDiscriminatorFromException),
Field(discriminator="origin"),
]