Source code for luna_handlers.classes.schemas.match_policy

"""
Module contains schemas for match policy
"""
from copy import deepcopy
from enum import Enum
from typing import List, Optional, Union
from uuid import UUID

from luna3.client import Client
from luna3.common.http_objs import RawDescriptor
from luna3.python_matcher.match_objects import (
    AttributeFilters,
    Candidates,
    EventFilters,
    FaceFilters,
    RawDescriptorReference,
)
from pydantic import conlist, validator
from typing_extensions import Literal
from vlutils.helpers import convertToSnakeCase
from vlutils.structures.pydantic import DiscriminatedUnion

from classes.event import Event
from classes.schemas import types
from classes.schemas.base_schema import BaseSchema
from classes.schemas.filters import AttributesFilters
from classes.schemas.types import MAX_ITEMS_LIST_LENGTH, OptionalNotNullable
from crutches_on_wheels.cow.errors.errors import ErrorInfo
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.maps.vl_maps import (
    CLOTHING_COLOR_MAP,
    EMOTION_MAP,
    ETHNIC_MAP,
    EVENT_TARGET_MAP,
    FACE_TARGET_MAP,
    LIVENESS_MAP,
    MASK_MAP,
    SLEEVE_LENGTH_MAP,
    TEMPORARY_ATTRIBUTE_MATCH_TARGET_MAP,
)
from crutches_on_wheels.cow.utils.timer import timer


[docs]class OriginEnum(str, Enum): """Match candidates origin""" faces = "faces" events = "events" attributes = "attributes"
[docs]class FaceMatchCandidates(BaseSchema): """Face match candidates""" # candidates origin - faces origin: Literal[OriginEnum.faces.value] # face ids faceIds: conlist(UUID, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # face account id accountId: UUID = types.OptionalNotNullable() # face external ids externalIds: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # face user data userData: types.Str128 = types.OptionalNotNullable() # face create time upper including boundary createTimeGte: types.CustomExtendedDatetime = types.OptionalNotNullable() # face create time lower excluding boundary createTimeLt: types.CustomExtendedDatetime = types.OptionalNotNullable() # face id upper including boundary faceIdGte: UUID = types.OptionalNotNullable() # face id lower excluding boundary faceIdLt: UUID = types.OptionalNotNullable() # face linked list id listId: UUID = types.OptionalNotNullable()
[docs]class GeoPosition(BaseSchema): """Geo position: longitude and latitude with deltas""" # geo position longitude origin originLongitude: types.FloatLongitude # geo position latitude origin originLatitude: types.FloatLatitude # geo position longitude delta longitudeDelta: types.FloatGeoDelta = 0.01 # geo position latitude delta latitudeDelta: types.FloatGeoDelta = 0.01
[docs]class EventMatchCandidates(BaseSchema): """Event match candidates""" # candidates origin - events origin: Literal[OriginEnum.events.value] # event ids eventIds: conlist(UUID, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event account id accountId: UUID = types.OptionalNotNullable() # event id upper including boundary eventIdGte: UUID = types.OptionalNotNullable() # event id lower excluding boundary eventIdLt: UUID = types.OptionalNotNullable() # event create time upper including boundary createTimeGte: types.CustomExtendedDatetime = types.OptionalNotNullable() # event create time lower excluding boundary createTimeLt: types.CustomExtendedDatetime = types.OptionalNotNullable() # event end time upper including boundary endTimeGte: types.CustomExtendedDatetime = types.OptionalNotNullable() # event end time lower excluding boundary endTimeLt: types.CustomExtendedDatetime = types.OptionalNotNullable() # event handler ids handlerIds: conlist(UUID, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event external ids externalIds: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event top matching candidates label topMatchingCandidatesLabel: types.Str36 = types.OptionalNotNullable() # event top similar object ids topSimilarObjectIds: conlist(UUID, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event top similar external ids topSimilarExternalIds: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event top similar object similarity upper including boundary topSimilarObjectSimilarityGte: types.StrictFloat01 = types.OptionalNotNullable() # event top similar object similarity lower excluding boundary topSimilarObjectSimilarityLt: types.StrictFloat01 = types.OptionalNotNullable() # event age upper including boundary ageGte: types.IntAge = types.OptionalNotNullable() # event age lower excluding boundary ageLt: types.IntAge = types.OptionalNotNullable() # event gender gender: types.Int01 = types.OptionalNotNullable() # event emotion list emotions: conlist(types.IntEmotions, min_items=1, max_items=len(EMOTION_MAP)) = OptionalNotNullable() # event mask list masks: conlist(types.IntMasks, min_items=1, max_items=len(MASK_MAP)) = OptionalNotNullable() # event liveness states liveness: conlist(types.IntLiveness, min_items=1, max_items=len(LIVENESS_MAP)) = OptionalNotNullable() # event ethnic group list ethnicGroups: conlist(types.IntEthnicities, min_items=1, max_items=len(ETHNIC_MAP)) = OptionalNotNullable() # event face ids faceIds: conlist(UUID, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event user data userData: types.Str128 = types.OptionalNotNullable() # event sources sources: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event tags tags: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event cities cities: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event areas areas: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event districts districts: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event streets streets: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event house numbers house_numbers: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # event geo position geoPosition: GeoPosition = types.OptionalNotNullable() # event track ids trackIds: conlist(types.Str36, min_items=1, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # apparent age lt apparentAgeLt: types.IntAge = types.OptionalNotNullable() # apparent age gte apparentAgeGte: types.IntAge = types.OptionalNotNullable() # apparent gender apparentGender: conlist(types.Int02, min_items=1, unique_items=True) = types.OptionalNotNullable() # headwear states headwearStates: conlist(types.Int02, min_items=1, unique_items=True) = OptionalNotNullable() # sleeve length sleeveLengths: conlist(Literal[tuple(SLEEVE_LENGTH_MAP)], min_items=1, unique_items=True) = OptionalNotNullable() # upper clothing colors upperClothingColors: conlist( Literal[tuple(CLOTHING_COLOR_MAP)], min_items=1, unique_items=True ) = OptionalNotNullable() # backpack states backpackStates: conlist(types.Int02, min_items=1, unique_items=True) = OptionalNotNullable()
[docs]class AttributeMatchCandidates(BaseSchema): """Face match candidates""" # candidates origin - attributes origin: Literal[OriginEnum.attributes.value] # temporary attribute ids attributeIds: conlist(UUID, max_items=MAX_ITEMS_LIST_LENGTH) = OptionalNotNullable() # temporary attribute account id accountId: UUID = types.OptionalNotNullable()
[docs]class EventMatchResult: """ Event match result. Attributes: candidates: match candidates listInfo: matching list info matchingLabel: matching label """ __slots__ = ("candidates", "listInfo", "matchingLabel") def __init__(self, matchingLabel: str, matchResult: dict): self.candidates = matchResult["result"] self.listInfo = matchResult["filters"].get("list") self.matchingLabel = matchingLabel
[docs] def asDict(self) -> dict: """ Get Event match results without matching label as dict Returns: dict with matching candidates and matching label """ return dict(candidates=deepcopy(self.candidates), label=self.matchingLabel)
[docs]class BaseMatchPolicy(BaseSchema): """Base match policy schema""" # matching label label: types.Str36 = "" # matching candidates candidates: DiscriminatedUnion("origin", [FaceMatchCandidates, EventMatchCandidates, AttributeMatchCandidates]) # matching filters filters: AttributesFilters = AttributesFilters() # matching limit limit: Optional[types.IntMatchingLimit] = types.DEFAULT_MATCH_LIMIT # matching targets targets: Optional[List[str]] = types.OptionalNotNullable() # matching threshold threshold: Optional[types.StrictFloat01] = types.OptionalNotNullable()
[docs] @validator("targets") def targetsValidator(cls, targets, values) -> Union[List[str], None]: """Targets validator, depends on matching candidates""" if targets is None: return targets if "candidates" not in values: # if validation fails on field (or that field is missing), it is not included in values return candidatesType = values["candidates"].origin candidateTargetMap = { OriginEnum.faces.name: FACE_TARGET_MAP, OriginEnum.events.name: EVENT_TARGET_MAP, OriginEnum.attributes.name: TEMPORARY_ATTRIBUTE_MATCH_TARGET_MAP, }[candidatesType] expectedTargets = list(candidateTargetMap) unexpectedTargets = [row for row in targets if row not in expectedTargets] if unexpectedTargets: raise ValueError( f"Bad matching policy targets {targets}, " f"allowed targets for '{candidatesType}' candidates are: {expectedTargets}" ) return targets
[docs] @timer async def execute(self, events: List[Event], luna3Client: Client) -> None: """ Execute match policy: * filter events fo matching * match events' descriptors with policy candidates Results are stored in the input events. Args: events: events (references) luna3Client: client """ eventsForMatch = [] for event in events: if self.filters.isEventSatisfies(event): if event.faceAttributes: if event.faceAttributes.descriptor: eventsForMatch.append(event) if len(eventsForMatch) == 0: return candidatesFilters = self.candidates.asDict() origin = candidatesFilters.pop("origin") if "geo_position" in candidatesFilters: candidatesFilters["geo_position"] = convertToSnakeCase(candidatesFilters["geo_position"]) references = [] for event in eventsForMatch: if isinstance(event.faceAttributes.descriptor, RawDescriptor): reference = event.faceAttributes.descriptor else: reference = RawDescriptor( version=event.faceAttributes.descriptor.model, descriptor=event.faceAttributes.descriptor.asBytes, ) references.append(RawDescriptorReference(event.eventId, reference)) ObjectFilters = {"faces": FaceFilters, "events": EventFilters, "attributes": AttributeFilters}[origin] matchKwargs = { "candidates": [ Candidates( filters=ObjectFilters().initFromSnakeKwargs(candidatesFilters), targets=self.targets + ["similarity"] if self.targets is not None else None, limit=self.limit, threshold=self.threshold, ) ], "references": references, } reply = await luna3Client.lunaPythonMatcher.matchFaces(**matchKwargs, raiseError=True) for event in events: for match in reply.json: for singleMatch in match["matches"]: err = singleMatch.get("error") if err is not None: raise VLException(ErrorInfo.fromDict(err)) if event.eventId == match["reference"]["id"]: event.matches.append(EventMatchResult(matchingLabel=self.label, matchResult=match["matches"][0])) break
[docs]class MatchPolicy(BaseMatchPolicy): """Match policy schema""" # matching candidates candidates: DiscriminatedUnion("origin", [FaceMatchCandidates, EventMatchCandidates])