"""
Module contains pydantic schemas for handlers lambda
"""
from typing import Literal, Union, final
from cow.pydantic.types import Float, OptionalNotNullable
from pydantic import Field, StrictBytes, conlist, model_validator
from vlutils.structures.pydantic import BaseModel as _BaseModel
from luna_lambda_tools.public.schemas.base import BoundingBox, Location
[docs]
@final
class ImageOrigin(_BaseModel):
"""Image origin"""
# image origin as bytes
body: str | bytes
# image body meta
bodyMeta: dict | None = None
[docs]
@final
class SourceData(_BaseModel):
"""Source data"""
# source (if present)
source: str = Field(default_factory=lambda: None)
# stream id
streamId: str = Field(default_factory=lambda: None)
# tags (if present)
tags: list[str] = Field(default_factory=lambda: None)
# user data
userData: str = ""
# external id
externalId: str = ""
# track id
trackId: str = Field(default_factory=lambda: None)
# meta information provided by user (if present)
meta: dict = Field(default_factory=lambda: None)
# luna-event time (if present)
eventTime: str = Field(default_factory=lambda: None)
# luna event end time (if present)
eventEndTime: str = Field(default_factory=lambda: None)
# luna event location
location: Location | None = None
[docs]
class EventSourceBase(_BaseModel):
"""Event source base"""
filename: str = Field(default_factory=lambda: None)
detectTime: str = Field(default_factory=lambda: None)
detectTs: str = Field(default_factory=lambda: None)
imageOrigin: ImageOrigin | str = Field(default_factory=lambda: None)
[docs]
class Angles(_BaseModel):
"""Angles"""
pitch: Float(minValue=-180, maxValue=180)
roll: Float(minValue=-180, maxValue=180)
yaw: Float(minValue=-180, maxValue=180)
[docs]
class BodyDetectionData(_BaseModel):
"""Body detection data"""
boundingBox: BoundingBox
originBoundingBox: BoundingBox = OptionalNotNullable()
class Config:
extra = "allow"
[docs]
class FaceDetectionData(_BaseModel):
"""Face detection data"""
boundingBox: BoundingBox
originBoundingBox: BoundingBox = OptionalNotNullable()
angles: Angles = OptionalNotNullable()
score: Float(minValue=0, maxValue=1) = OptionalNotNullable()
class Config:
extra = "allow"
[docs]
class RawImageAggregated(_BaseModel):
"""Raw image source"""
body: StrictBytes
faceDetectionData: conlist(FaceDetectionData, min_length=1, max_length=1) = OptionalNotNullable()
bodyDetectionData: conlist(BodyDetectionData, min_length=1, max_length=1) = OptionalNotNullable()
meta: dict | None = Field(default_factory=lambda: None)
trustedDetections: Literal[0, 1] = 0
[docs]
class RawDetection(_BaseModel):
"""Raw detection"""
warp: StrictBytes = Field(default_factory=lambda: None)
originBoundingBox: BoundingBox = Field(default_factory=lambda: None)
boundingBox: BoundingBox = Field(default_factory=lambda: None)
meta: dict | None = Field(default_factory=lambda: None)
[docs]
class FaceDetection(RawDetection):
"""Face detection"""
angles: Angles = OptionalNotNullable()
score: Float(minValue=0, maxValue=1) = OptionalNotNullable()
class Config:
extra = "allow"
[docs]
class RawDetectionsAggregated(_BaseModel):
"""Raw detections pair"""
face: FaceDetection = Field(default_factory=lambda: None)
body: RawDetection = Field(default_factory=lambda: None)
trustedDetections: Literal[0, 1] = 0
[docs]
class RawImageNonAggregated(RawImageAggregated):
"""Raw image non aggregated"""
sourceData: SourceData = SourceData()
[docs]
class RawDetectionsNonAggregated(RawDetectionsAggregated):
"""Raw detections non aggregated"""
sourceData: SourceData = SourceData()
[docs]
class EventSourceAggregatedRawImage(EventSourceBase):
sourceType: Literal["raw_image"]
source: RawImageAggregated
class Config:
extra = "allow"
[docs]
class EventSourceAggregatedDetections(EventSourceBase):
sourceType: Literal["detections"]
source: conlist(RawDetectionsAggregated)
class Config:
extra = "allow"
[docs]
class EventSourceNonAggregatedRawImage(EventSourceBase):
sourceType: Literal["raw_image"]
source: RawImageNonAggregated
class Config:
extra = "allow"
[docs]
class EventSourceNonAggregatedDetections(EventSourceBase):
sourceType: Literal["detections"]
source: conlist(RawDetectionsNonAggregated)
class Config:
extra = "allow"
[docs]
class EventSourceSchema(_BaseModel):
sources: list[
Union[
EventSourceNonAggregatedRawImage,
EventSourceNonAggregatedDetections,
EventSourceAggregatedRawImage,
EventSourceAggregatedDetections,
]
]
aggregateAttributes: Literal[0, 1]
useExifInfo: bool = True
sourceData: SourceData | None = None
@model_validator(mode="before")
def discriminateSources(cls, values):
aggregateAttrs = values.get("aggregate_attributes")
sourceItems = values.get("sources", [])
if aggregateAttrs == 0:
processedSources = [
cls.discriminateSourceType(item, EventSourceNonAggregatedRawImage, EventSourceNonAggregatedDetections)
for item in sourceItems
]
elif aggregateAttrs == 1:
processedSources = [
cls.discriminateSourceType(item, EventSourceAggregatedRawImage, EventSourceAggregatedDetections)
for item in sourceItems
]
else:
raise ValueError(f"Invalid value for aggregate_attributes {aggregateAttrs}")
values["sources"] = processedSources
return values
@staticmethod
def discriminateSourceType(item, rawImageClass, detectionsClass):
sourceType = item.get("source_type")
if sourceType == "raw_image":
return rawImageClass(**item)
elif sourceType == "detections":
return detectionsClass(**item)
else:
raise ValueError(f"Unsupported source_type {sourceType}")