"""
Module contains pydantic schemas for handlers lambda
"""
from typing import Literal, Union, final
from cow.pydantic.types import CustomFloat, Float, OptionalNotNullable
from pydantic import Field, StrictBytes, conlist, model_validator
from vlutils.structures.pydantic import BaseModel as _BaseModel
from luna_lambda_tools.public.schemas.base import BoundingBox, Location
[docs]
@final
class ImageOrigin(_BaseModel):
    """Image origin"""
    # image origin as bytes
    body: str | bytes
    # image body meta
    bodyMeta: dict | None = None 
[docs]
@final
class SourceData(_BaseModel):
    """Source data"""
    # source (if present)
    source: str = Field(default_factory=lambda: None)
    # stream id
    streamId: str = Field(default_factory=lambda: None)
    # tags (if present)
    tags: list[str] = Field(default_factory=lambda: None)
    # user data
    userData: str = ""
    # external id
    externalId: str = ""
    # track id
    trackId: str = Field(default_factory=lambda: None)
    # meta information provided by user (if present)
    meta: dict = Field(default_factory=lambda: None)
    # luna-event time (if present)
    eventTime: str = Field(default_factory=lambda: None)
    # luna event end time (if present)
    eventEndTime: str = Field(default_factory=lambda: None)
    # luna event location
    location: Location | None = None 
[docs]
class DetectTs(CustomFloat):
    """User-defined timestamp relative to something, such as the start of a video"""
    ge = 0.0
    le = 158731466399.999 
[docs]
class EventSourceBase(_BaseModel):
    """Event source base"""
    filename: str = Field(default_factory=lambda: None)
    detectTime: str = Field(default_factory=lambda: None)
    detectTs: DetectTs = Field(default_factory=lambda: None)
    imageOrigin: ImageOrigin | str = Field(default_factory=lambda: None) 
[docs]
class Angles(_BaseModel):
    """Angles"""
    pitch: Float(minValue=-180, maxValue=180)
    roll: Float(minValue=-180, maxValue=180)
    yaw: Float(minValue=-180, maxValue=180) 
[docs]
class BodyDetectionData(_BaseModel):
    """Body detection data"""
    boundingBox: BoundingBox
    originBoundingBox: BoundingBox = OptionalNotNullable()
    class Config:
        extra = "allow" 
[docs]
class FaceDetectionData(_BaseModel):
    """Face detection data"""
    boundingBox: BoundingBox
    originBoundingBox: BoundingBox = OptionalNotNullable()
    angles: Angles = OptionalNotNullable()
    score: Float(minValue=0, maxValue=1) = OptionalNotNullable()
    class Config:
        extra = "allow" 
[docs]
class RawImageAggregated(_BaseModel):
    """Raw image source"""
    body: StrictBytes
    faceDetectionData: conlist(FaceDetectionData, min_length=1, max_length=1) = OptionalNotNullable()
    bodyDetectionData: conlist(BodyDetectionData, min_length=1, max_length=1) = OptionalNotNullable()
    meta: dict | None = Field(default_factory=lambda: None)
    trustedDetections: Literal[0, 1] = 0 
[docs]
class RawDetection(_BaseModel):
    """Raw detection"""
    warp: StrictBytes = Field(default_factory=lambda: None)
    originBoundingBox: BoundingBox = Field(default_factory=lambda: None)
    boundingBox: BoundingBox = Field(default_factory=lambda: None)
    meta: dict | None = Field(default_factory=lambda: None) 
[docs]
class FaceDetection(RawDetection):
    """Face detection"""
    angles: Angles = OptionalNotNullable()
    score: Float(minValue=0, maxValue=1) = OptionalNotNullable()
    class Config:
        extra = "allow" 
[docs]
class RawDetectionsAggregated(_BaseModel):
    """Raw detections pair"""
    face: FaceDetection = Field(default_factory=lambda: None)
    body: RawDetection = Field(default_factory=lambda: None)
    trustedDetections: Literal[0, 1] = 0 
[docs]
class RawImageNonAggregated(RawImageAggregated):
    """Raw image non aggregated"""
    sourceData: SourceData = SourceData() 
[docs]
class RawDetectionsNonAggregated(RawDetectionsAggregated):
    """Raw detections non aggregated"""
    sourceData: SourceData = SourceData() 
[docs]
class EventSourceAggregatedRawImage(EventSourceBase):
    sourceType: Literal["raw_image"]
    source: RawImageAggregated
    class Config:
        extra = "allow" 
[docs]
class EventSourceAggregatedDetections(EventSourceBase):
    sourceType: Literal["detections"]
    source: conlist(RawDetectionsAggregated)
    class Config:
        extra = "allow" 
[docs]
class EventSourceNonAggregatedRawImage(EventSourceBase):
    sourceType: Literal["raw_image"]
    source: RawImageNonAggregated
    class Config:
        extra = "allow" 
[docs]
class EventSourceNonAggregatedDetections(EventSourceBase):
    sourceType: Literal["detections"]
    source: conlist(RawDetectionsNonAggregated)
    class Config:
        extra = "allow" 
[docs]
class EventSourceSchema(_BaseModel):
    sources: list[
        Union[
            EventSourceNonAggregatedRawImage,
            EventSourceNonAggregatedDetections,
            EventSourceAggregatedRawImage,
            EventSourceAggregatedDetections,
        ]
    ]
    aggregateAttributes: Literal[0, 1]
    useExifInfo: bool = True
    sourceData: SourceData | None = None
    @model_validator(mode="before")
    def discriminateSources(cls, values):
        aggregateAttrs = values.get("aggregate_attributes")
        sourceItems = values.get("sources", [])
        if aggregateAttrs == 0:
            processedSources = [
                cls.discriminateSourceType(item, EventSourceNonAggregatedRawImage, EventSourceNonAggregatedDetections)
                for item in sourceItems
            ]
        elif aggregateAttrs == 1:
            processedSources = [
                cls.discriminateSourceType(item, EventSourceAggregatedRawImage, EventSourceAggregatedDetections)
                for item in sourceItems
            ]
        else:
            raise ValueError(f"Invalid value for aggregate_attributes {aggregateAttrs}")
        values["sources"] = processedSources
        return values
    @staticmethod
    def discriminateSourceType(item, rawImageClass, detectionsClass):
        sourceType = item.get("source_type")
        if sourceType == "raw_image":
            return rawImageClass(**item)
        elif sourceType == "detections":
            return detectionsClass(**item)
        else:
            raise ValueError(f"Unsupported source_type {sourceType}")