"""
Module contains schemas for match policy
"""
from copy import deepcopy
from enum import Enum
from typing import List, Optional, Union
from uuid import UUID
from luna3.client import Client
from luna3.common.http_objs import RawDescriptor
from luna3.python_matcher.match_objects import (
AttributeFilters,
Candidates,
EventFilters,
FaceFilters,
RawDescriptorReference,
)
from pydantic import conlist, validator
from typing_extensions import Literal
from vlutils.helpers import convertToSnakeCase
from vlutils.structures.pydantic import DiscriminatedUnion
from classes.event import Event
from classes.schemas import types
from classes.schemas.base_schema import BaseSchema
from classes.schemas.filters import AttributesFilters
from classes.schemas.types import MAX_ITEMS_LIST_LENGTH, OptionalNotNullable
from crutches_on_wheels.errors.errors import ErrorInfo
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.maps.vl_maps import (
CLOTHING_COLOR_MAP,
EMOTION_MAP,
ETHNIC_MAP,
EVENT_TARGET_MAP,
FACE_TARGET_MAP,
LIVENESS_MAP,
MASK_MAP,
SLEEVE_LENGTH_MAP,
TEMPORARY_ATTRIBUTE_MATCH_TARGET_MAP,
)
from crutches_on_wheels.utils.timer import timer
[docs]class OriginEnum(str, Enum):
"""Match candidates origin"""
faces = "faces"
events = "events"
attributes = "attributes"
[docs]class FaceMatchCandidates(BaseSchema):
"""Face match candidates"""
# candidates origin - faces
origin: Literal[OriginEnum.faces.value]
# face ids
faceIds: conlist(UUID, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# face account id
accountId: UUID = types.OptionalNotNullable()
# face external ids
externalIds: conlist(types.Str36, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# face user data
userData: types.Str128 = types.OptionalNotNullable()
# face create time upper including boundary
createTimeGte: types.CustomExtendedDatetime = types.OptionalNotNullable()
# face create time lower excluding boundary
createTimeLt: types.CustomExtendedDatetime = types.OptionalNotNullable()
# face id upper including boundary
faceIdGte: UUID = types.OptionalNotNullable()
# face id lower excluding boundary
faceIdLt: UUID = types.OptionalNotNullable()
# face linked list id
listId: UUID = types.OptionalNotNullable()
[docs]class GeoPosition(BaseSchema):
"""Geo position: longitude and latitude with deltas"""
# geo position longitude origin
originLongitude: types.FloatLongitude
# geo position latitude origin
originLatitude: types.FloatLatitude
# geo position longitude delta
longitudeDelta: types.FloatGeoDelta = 0.01
# geo position latitude delta
latitudeDelta: types.FloatGeoDelta = 0.01
[docs]class EventMatchCandidates(BaseSchema):
"""Event match candidates"""
# candidates origin - events
origin: Literal[OriginEnum.events.value]
# event ids
eventIds: conlist(UUID, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event account id
accountId: UUID = types.OptionalNotNullable()
# event id upper including boundary
eventIdGte: UUID = types.OptionalNotNullable()
# event id lower excluding boundary
eventIdLt: UUID = types.OptionalNotNullable()
# event create time upper including boundary
createTimeGte: types.CustomExtendedDatetime = types.OptionalNotNullable()
# event create time lower excluding boundary
createTimeLt: types.CustomExtendedDatetime = types.OptionalNotNullable()
# event end time upper including boundary
endTimeGte: types.CustomExtendedDatetime = types.OptionalNotNullable()
# event end time lower excluding boundary
endTimeLt: types.CustomExtendedDatetime = types.OptionalNotNullable()
# event handler ids
handlerIds: conlist(UUID, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event external ids
externalIds: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event top matching candidates label
topMatchingCandidatesLabel: types.Str36 = types.OptionalNotNullable()
# event top similar object ids
topSimilarObjectIds: conlist(UUID, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event top similar external ids
topSimilarExternalIds: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event top similar object similarity upper including boundary
topSimilarObjectSimilarityGte: types.Float01 = types.OptionalNotNullable()
# event top similar object similarity lower excluding boundary
topSimilarObjectSimilarityLt: types.Float01 = types.OptionalNotNullable()
# event age upper including boundary
ageGte: types.IntAge = types.OptionalNotNullable()
# event age lower excluding boundary
ageLt: types.IntAge = types.OptionalNotNullable()
# event gender
gender: types.Int01 = types.OptionalNotNullable()
# event emotion list
emotions: conlist(types.IntEmotions, min_items=1, max_items=len(EMOTION_MAP)) = OptionalNotNullable()
# event mask list
masks: conlist(types.IntMasks, min_items=1, max_items=len(MASK_MAP)) = OptionalNotNullable()
# event liveness states
liveness: conlist(types.IntLiveness, min_items=1, max_items=len(LIVENESS_MAP)) = OptionalNotNullable()
# event ethnic group list
ethnicGroups: conlist(types.IntEthnicities, min_items=1, max_items=len(ETHNIC_MAP)) = OptionalNotNullable()
# event face ids
faceIds: conlist(UUID, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event user data
userData: types.Str128 = types.OptionalNotNullable()
# event sources
sources: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event tags
tags: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event cities
cities: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event areas
areas: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event districts
districts: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event streets
streets: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event house numbers
house_numbers: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# event geo position
geoPosition: GeoPosition = types.OptionalNotNullable()
# event track ids
trackIds: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# apparent age lt
apparentAgeLt: types.IntAge = types.OptionalNotNullable()
# apparent age gte
apparentAgeGte: types.IntAge = types.OptionalNotNullable()
# apparent gender
apparentGender: conlist(types.Int02, min_items=1, unique_items=True) = types.OptionalNotNullable()
# headwear states
headwearStates: conlist(types.Int02, min_items=1, unique_items=True) = OptionalNotNullable()
# sleeve length
sleeveLengths: conlist(Literal[tuple(SLEEVE_LENGTH_MAP)], min_items=1, unique_items=True) = OptionalNotNullable()
# upper clothing colors
upperClothingColors: conlist(
Literal[tuple(CLOTHING_COLOR_MAP)], min_items=1, unique_items=True
) = OptionalNotNullable()
# backpack states
backpackStates: conlist(types.Int02, min_items=1, unique_items=True) = OptionalNotNullable()
[docs]class AttributeMatchCandidates(BaseSchema):
"""Face match candidates"""
# candidates origin - attributes
origin: Literal[OriginEnum.attributes.value]
# temporary attribute ids
attributeIds: conlist(UUID, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable()
# temporary attribute account id
accountId: UUID = types.OptionalNotNullable()
[docs]class EventMatchResult:
"""
Event match result.
Attributes:
candidates: match candidates
listInfo: matching list info
matchingLabel: matching label
"""
__slots__ = ("candidates", "listInfo", "matchingLabel")
def __init__(self, matchingLabel: str, matchResult: dict):
self.candidates = matchResult["result"]
self.listInfo = matchResult["filters"].get("list")
self.matchingLabel = matchingLabel
[docs] def asDict(self) -> dict:
"""
Get Event match results without matching label as dict
Returns:
dict with matching candidates and matching label
"""
return dict(candidates=deepcopy(self.candidates), label=self.matchingLabel)
[docs]class BaseMatchPolicy(BaseSchema):
"""Base match policy schema"""
# matching label
label: types.Str36 = ""
# matching candidates
candidates: DiscriminatedUnion("origin", [FaceMatchCandidates, EventMatchCandidates, AttributeMatchCandidates])
# matching filters
filters: AttributesFilters = AttributesFilters()
# matching limit
limit: Optional[types.IntMatchingLimit] = types.DEFAULT_MATCH_LIMIT
# matching targets
targets: Optional[List[str]] = types.OptionalNotNullable()
# matching threshold
threshold: Optional[types.Float01] = types.OptionalNotNullable()
[docs] @validator("targets")
def targetsValidator(cls, targets, values) -> Union[List[str], None]:
"""Targets validator, depends on matching candidates"""
if targets is None:
return targets
if "candidates" not in values:
# if validation fails on field (or that field is missing), it is not included in values
return
candidatesType = values["candidates"].origin
candidateTargetMap = {
OriginEnum.faces.name: FACE_TARGET_MAP,
OriginEnum.events.name: EVENT_TARGET_MAP,
OriginEnum.attributes.name: TEMPORARY_ATTRIBUTE_MATCH_TARGET_MAP,
}[candidatesType]
expectedTargets = list(candidateTargetMap)
unexpectedTargets = [row for row in targets if row not in expectedTargets]
if unexpectedTargets:
raise ValueError(
f"Bad matching policy targets {targets}, "
f"allowed targets for '{candidatesType}' candidates are: {expectedTargets}"
)
return targets
[docs] @timer
async def execute(self, events: List[Event], luna3Client: Client) -> None:
"""
Execute match policy:
* filter events fo matching
* match events' descriptors with policy candidates
Results are stored in the input events.
Args:
events: events (references)
luna3Client: client
"""
eventsForMatch = []
for event in events:
if self.filters.isEventSatisfies(event):
if event.faceAttributes:
if event.faceAttributes.descriptor:
eventsForMatch.append(event)
if len(eventsForMatch) == 0:
return
candidatesFilters = self.candidates.asDict()
origin = candidatesFilters.pop("origin")
if "geo_position" in candidatesFilters:
candidatesFilters["geo_position"] = convertToSnakeCase(candidatesFilters["geo_position"])
references = []
for event in eventsForMatch:
if isinstance(event.faceAttributes.descriptor, RawDescriptor):
reference = event.faceAttributes.descriptor
else:
reference = RawDescriptor(
version=event.faceAttributes.descriptor.model,
descriptor=event.faceAttributes.descriptor.asBytes,
)
references.append(RawDescriptorReference(event.eventId, reference))
ObjectFilters = {"faces": FaceFilters, "events": EventFilters, "attributes": AttributeFilters}[origin]
matchKwargs = {
"candidates": [
Candidates(
filters=ObjectFilters().initFromSnakeKwargs(candidatesFilters),
targets=self.targets + ["similarity"] if self.targets is not None else None,
limit=self.limit,
threshold=self.threshold,
)
],
"references": references,
}
reply = await luna3Client.lunaPythonMatcher.matchFaces(**matchKwargs, raiseError=True)
for event in events:
for match in reply.json:
for singleMatch in match["matches"]:
err = singleMatch.get("error")
if err is not None:
raise VLException(ErrorInfo.fromDict(err))
if event.eventId == match["reference"]["id"]:
event.matches.append(EventMatchResult(matchingLabel=self.label, matchResult=match["matches"][0]))
break
[docs]class MatchPolicy(BaseMatchPolicy):
"""Match policy schema"""
# matching candidates
candidates: DiscriminatedUnion("origin", [FaceMatchCandidates, EventMatchCandidates])