"""
Module contains schemas for storage policy
"""
import asyncio
import io
from typing import Union, Optional
from uuid import UUID, uuid4
from PIL import Image
from luna3.common.http_objs import RawDescriptor
from luna3.faces.faces import FacesApi
from lunavl.sdk.descriptors.descriptors import FaceDescriptor
from lunavl.sdk.estimators.body_estimators.humanwarper import HumanWarpedImage
from lunavl.sdk.estimators.face_estimators.facewarper import FaceWarpedImage
from vlutils.structures.dataclasses import dataclass
from app.global_vars.context_vars import requestIdCtx
from luna3.client import Client
from luna3.image_store.image_store import StoreApi
from pydantic import Field
from classes.image_meta import ProcessedImageData
from classes.monitoring import HandlersMonitoringData
from crutches_on_wheels.maps.vl_maps import ETHNIC_MAP
from redis_db.redis_context import RedisContext
from vlutils.jobs.async_runner import AsyncRunner
from classes.event import Event, FaceExtractedAttribute
from classes.luna3event import eventAsLuna3
from classes.schemas import types
from classes.schemas.base_schema import BaseSchema, HandlerSettings
from classes.schemas.filters import ComplexFilter
from configs.config import MAX_IMAGE_ORIGIN_EXTERNAL_URL_LENGTH, WARP_FORMAT
from crutches_on_wheels.monitoring.points import monitorTime
from crutches_on_wheels.utils.log import logger
# current api version for avatar location header
from sdk.sdk_loop.models.face_sample import FaceSample, AggregatedFaceSample
from sdk.sdk_loop.models.image import ImageType
CURRENT_LUNA_API_VERSION = 6
[docs]@dataclass(withSlots=True)
class StorePolicyConfig:
    """Handler config that policies should apply."""
    facesBucket: str
    bodiesBucket: str
    originBucket: str
    lunaEventsUsage: bool
    lunaSenderUsage: bool 
[docs]async def saveSamples(warpsToSave: list[MetaWarp], bucket: str, accountId: str, storeApiClient: StoreApi):
    """
    Save warps in LIS.
    Args:
        warpsToSave: SDK warps to save
        bucket: bucket name
        accountId: account id
        storeApiClient: image-store client
    """
    futures = []
    for warp in warpsToSave:
        futures.append(
            storeApiClient.putImage(
                imageInBytes=warp.asBytes(),
                imageId=warp.sampleId,
                accountId=accountId,
                bucketName=bucket,
                raiseError=True,
            )
        )
    await asyncio.gather(*futures) 
[docs]async def saveAttributes(
    attributesToSave: list[FaceExtractedAttribute], accountId: str, facesClient: FacesApi, ttl: int
) -> None:
    """
    Save attributes.
    Args:
        attributesToSave: attributes
        accountId: account id
        facesClient: faces client
        ttl: time to store attribute
    """
    futures = []
    for attribute in attributesToSave:
        if attribute.descriptor is not None:
            descriptorKwargs = dict(
                descriptors=[
                    RawDescriptor(version=attribute.descriptor.model, descriptor=attribute.descriptor.asBytes)
                    if isinstance(attribute.descriptor, FaceDescriptor)
                    else attribute.descriptor
                ],
                descriptorSamples=attribute.sampleIds,
            )
        else:
            descriptorKwargs = dict()
        if attribute.basicAttributes is not None:
            attributesKwargs = dict(
                basicAttributes=dict(
                    ethnicity=ETHNIC_MAP[str(attribute.basicAttributes.ethnicity.predominantEthnicity)],
                    age=round(attribute.basicAttributes.age),
                    gender=round(attribute.basicAttributes.gender),
                ),
                basicAttributesSamples=attribute.sampleIds,
            )
        else:
            attributesKwargs = dict()
        if not descriptorKwargs and not attributesKwargs:
            continue  # no need to create empty attribute
        futures.append(
            facesClient.putAttribute(
                attributeId=attribute.attributeId,
                accountId=accountId,
                **descriptorKwargs,
                **attributesKwargs,
                ttl=ttl,
                raiseError=True,
            )
        )
    await asyncio.gather(*futures) 
class _BaseStoragePolicy(BaseSchema):
    """Base storage policy"""
    # storage policy complex filters (matching+attributes)
    filters: ComplexFilter = ComplexFilter()
    def filterEventsByFilters(self, events: list[Event]) -> list[Event]:
        """
        Filter events for the further processing.
        Args:
            events: all events
        Returns:
            events without filtered ones
        """
        if self.filters.isEmpty:
            return events
        return [event for event in events if self.filters.isEventSatisfies(event)]
    def isEventSatisfyFilters(self, event: Event) -> bool:
        """Is events satisfy existing filters"""
        if not self.filters.isEmpty:
            return self.filters.isEventSatisfies(event)
        return True
[docs]class FaceSamplePolicy(_BaseStoragePolicy):
    """Face sample policy"""
    # whether to store sample
    storeSample: types.Int01 = 1
[docs]    async def execute(self, events: list[Event], sources: list[ProcessedImageData], bucket: str, luna3Client: Client):
        """
        Save face samples.
        Args:
            events: events
            sources: origin image sources
            bucket: bucket name
            luna3Client: client
        """
        if not self.storeSample:
            return
        imagesMap = {source.image.origin.filename: source for source in sources}
        warps = []
        for event in self.filterEventsByFilters(events):
            for estimation in event.sdkEstimations:
                if imagesMap[estimation.filename].meta.sampleId:
                    continue
                if estimation.detection.face is not None:
                    warps.append(sample := MetaWarp(estimation.detection.face.sdkEstimation.warp))
                    estimation.detection.face.sampleId = sample.sampleId
                    estimation.detection.face.url = (
                        f"{luna3Client.lunaFaceSamplesStore.baseUri}/buckets/{bucket}/images/{sample.sampleId}"
                    )
                    if event.faceAttributes:
                        event.faceAttributes.sampleIds.append(sample.sampleId)
        if not warps:
            return
        await saveSamples(
            warpsToSave=warps,
            bucket=bucket,
            accountId=events[0].meta.accountId,
            storeApiClient=luna3Client.lunaFaceSamplesStore,
        )  
[docs]class BodySamplePolicy(_BaseStoragePolicy):
    """Body sample policy"""
    # whether to store sample
    storeSample: types.Int01 = 1
[docs]    async def execute(self, events: list[Event], sources: list[ProcessedImageData], bucket: str, luna3Client: Client):
        """
        Save body samples.
        Args:
            events: events
            sources: origin image sources
            bucket: bucket name
            luna3Client: client
        """
        if not self.storeSample:
            return
        imagesMap = {source.image.origin.filename: source for source in sources}
        warps = []
        for event in self.filterEventsByFilters(events=events):
            for estimation in event.sdkEstimations:
                if imagesMap[estimation.filename].meta.sampleId:
                    continue
                if estimation.detection.body is not None:
                    warps.append(sample := MetaWarp(estimation.detection.body.sdkEstimation.warp))
                    estimation.detection.body.sampleId = sample.sampleId
                    estimation.detection.body.url = (
                        f"{luna3Client.lunaBodySamplesStore.baseUri}/buckets/{bucket}/images/{sample.sampleId}"
                    )
                    if event.bodyAttributes:
                        event.bodyAttributes.sampleIds.append(sample.sampleId)
        if not warps:
            return
        await saveSamples(
            warpsToSave=warps,
            bucket=bucket,
            accountId=events[0].meta.accountId,
            storeApiClient=luna3Client.lunaBodySamplesStore,
        )  
[docs]class ImageOriginPolicy(_BaseStoragePolicy):
    """Image origin policy"""
    # whether to store origin image
    storeImage: types.Int01 = 0
    # use external reference as image origin
    useExternalReferences: types.Int01 = 1
[docs]    async def execute(
        self, events: list[Event], sources: list[ProcessedImageData], bucket: str, luna3Client: Client,
    ) -> None:
        """
        Save origin images.
        Args:
            events: events
            sources: origin image sources
            bucket: bucket name
            luna3Client: client
        """
        if not self.storeImage:
            return
        if not events:
            return
        imagesMap = {source.image.origin.filename: source for source in sources}
        futures = []
        def addImageToUpload(accountId, imageId, body):
            futures.append(
                luna3Client.lunaImageOriginStore.putImage(
                    imageInBytes=body, imageId=imageId, accountId=accountId, bucketName=bucket, raiseError=True,
                )
            )
        for event in self.filterEventsByFilters(events=events):
            for estimation in event.sdkEstimations:
                source = imagesMap[estimation.filename]
                imageId = str(uuid4())
                if estimation.imageOrigin:
                    continue
                if source.image.origin.imageType == ImageType.IMAGE:
                    if (
                        self.useExternalReferences
                        and source.meta.url
                        and len(source.meta.url) <= MAX_IMAGE_ORIGIN_EXTERNAL_URL_LENGTH
                    ):
                        estimation.imageOrigin = source.meta.url
                    else:
                        addImageToUpload(event.meta.accountId, imageId, source.image.origin.body)
                        estimation.imageOrigin = f"/{CURRENT_LUNA_API_VERSION}/images/{imageId}"
                elif source.image.origin.imageType == ImageType.FACE_WARP:
                    sampleId = estimation.detection.face.sampleId if estimation.detection.face else None
                    if self.useExternalReferences and sampleId:
                        estimation.imageOrigin = f"/{CURRENT_LUNA_API_VERSION}/samples/faces/{sampleId}"
                    else:
                        addImageToUpload(event.meta.accountId, imageId, source.image.origin.body)
                        estimation.imageOrigin = f"/{CURRENT_LUNA_API_VERSION}/images/{imageId}"
                elif source.image.origin.imageType == ImageType.BODY_WARP:
                    sampleId = estimation.detection.body.sampleId if estimation.detection.body else None
                    if self.useExternalReferences and sampleId:
                        estimation.imageOrigin = f"/{CURRENT_LUNA_API_VERSION}/samples/bodies/{sampleId}"
                    else:
                        addImageToUpload(event.meta.accountId, imageId, source.image.origin.body)
                        estimation.imageOrigin = f"/{CURRENT_LUNA_API_VERSION}/images/{imageId}"
                else:
                    raise RuntimeError("Unsupported image type")
        await asyncio.gather(*futures)  
[docs]class AttributeStorePolicy(_BaseStoragePolicy):
    """Attribute store policy"""
    # whether to store attribute
    storeAttribute: types.Int01 = 0
    # attribute storage ttl
    ttl: types.IntAttributeTTL = Field(default_factory=lambda: HandlerSettings.defaultAttributeTTL)
[docs]    async def execute(self, events: list[Event], luna3Client: Client) -> None:
        """
        Save attributes.
        Args:
            events: events
            luna3Client: client
        """
        if not self.storeAttribute:
            return
        attributesToSave = []
        for event in self.filterEventsByFilters(events):
            if event.faceAttributes is None:
                continue
            event.faceAttributes.attributeId = str(uuid4())
            event.faceAttributes.url = f"{luna3Client.lunaFaces.baseUri}/attributes/{event.faceAttributes.attributeId}"
            attributesToSave.append(event.faceAttributes)
        if not attributesToSave:
            return
        await saveAttributes(
            attributesToSave, accountId=events[0].meta.accountId, facesClient=luna3Client.lunaFaces, ttl=self.ttl
        )  
[docs]class LinkToListsPolicy(_BaseStoragePolicy):
    """Link to lists policy schema"""
    # list id to link faces to
    listId: UUID 
[docs]class FaceStoragePolicy(_BaseStoragePolicy):
    """Face store policy"""
    # whether to store face
    storeFace: types.Int01 = 0
    # whether to set face sample as avatar
    setSampleAsAvatar: types.Int01 = 1
    # face link to lists policy list
    linkToListsPolicy: list[LinkToListsPolicy] = Field([], max_items=types.MAX_POLICY_LIST_LENGTH)
[docs]    async def execute(self, events: list[Event], luna3Client: Client) -> None:
        """
        Execute face policy (with link to list policy).
        Args:
            events: processing events
            luna3Client: luna3 client
        """
        if not self.storeFace:
            return
        futures, eventsToUpdate = [], []
        for event in self.filterEventsByFilters(events=events):
            lists = list(
                {
                    str(linkListPolicy.listId)
                    for linkListPolicy in self.linkToListsPolicy
                    if linkListPolicy.filters is None or linkListPolicy.filters.isEventSatisfies(event)
                }
            )
            faceDetection = (
                event.sdkEstimations[0].detection.face.asDict(
                    estimationTargets=event.sdkEstimations[0].estimationTargets
                )
                if event.sdkEstimations and event.sdkEstimations[0].detection.face
                else None
            )
            if faceDetection:
                firstWarpId = next(
                    (
                        detection.detection.face.sampleId
                        for detection in event.sdkEstimations
                        if detection.detection.face.sampleId
                    ),
                    None,
                )
                if self.setSampleAsAvatar and firstWarpId is not None:
                    event.avatar = f"/{CURRENT_LUNA_API_VERSION}/samples/faces/{firstWarpId}"
            event.linkedLists = lists
            attributeKwargs = {}
            attribute = event.faceAttributes
            if attribute and attribute.basicAttributes:
                attributeKwargs.update(
                    basicAttributesSamples=attribute.sampleIds,
                    basicAttributes=dict(
                        ethnicity=ETHNIC_MAP[str(attribute.basicAttributes.ethnicity.predominantEthnicity)],
                        age=round(attribute.basicAttributes.age),
                        gender=round(attribute.basicAttributes.gender),
                    ),
                )
            if attribute and attribute.descriptor:
                attributeKwargs.update(
                    descriptorSamples=attribute.sampleIds,
                    descriptors=[
                        RawDescriptor(version=attribute.descriptor.model, descriptor=attribute.descriptor.asBytes)
                        if isinstance(attribute.descriptor, FaceDescriptor)
                        else attribute.descriptor
                    ],
                )
            event.faceId = str(uuid4())
            event.faceUrl = f"{luna3Client.lunaFaces.origin}/faces/{event.faceId}"
            futures.append(
                luna3Client.lunaFaces.putFace(
                    faceId=event.faceId,
                    accountId=event.meta.accountId,
                    eventId=event.eventId,
                    userData=event.meta.userData,
                    externalId=event.meta.externalId,
                    listIds=lists or None,
                    avatar=event.avatar,
                    **attributeKwargs,
                    raiseError=True,
                )
            )
        await asyncio.gather(*futures)  
[docs]class NotificationStoragePolicy(_BaseStoragePolicy):
    """Notification store policy"""
    # whether to send notification
    sendNotification: types.Int01 = 1
[docs]    async def execute(self, events: list[Event], redisContext: RedisContext, lunaSenderUsage: bool) -> None:
        """
        Save notifications
        Args:
            events: events
            redisContext: redis context
            lunaSenderUsage: use or not luna sender
        """
        if not lunaSenderUsage or not self.sendNotification:
            return
        eventsToSend = [event for event in events if self.isEventSatisfyFilters(event)]
        if not eventsToSend:
            return
        await redisContext.publish(eventsToSend, requestIdCtx.get())  
[docs]class EventStoragePolicy(_BaseStoragePolicy):
    """Event store policy"""
    # whether to store event
    storeEvent: types.Int01 = 1
    # whether to wait events saving (response will be received only after events will be saved)
    waitSaving: types.Int01 = 1
[docs]    @classmethod
    async def onStartup(cls):
        """ Init Policies """
        cls.saveEventsAsyncRunner = AsyncRunner(100, closeTimeout=1) 
[docs]    @classmethod
    async def onShutdown(cls):
        """ Stop Policies """
        await cls.saveEventsAsyncRunner.close() 
[docs]    async def execute(self, events: list[Event], luna3Client: Client, lunaEventsUsage: bool) -> None:
        """
        Save events.
        Args:
            events: events
            luna3Client: client
            lunaEventsUsage: use or not luna events
        """
        eventsToSend = []
        for event in events:
            if not lunaEventsUsage or not self.storeEvent:
                continue
            if not self.isEventSatisfyFilters(event):
                continue
            event.eventUrl = f"{luna3Client.lunaEvents.baseUri}/events/{event.eventId}"
            eventsToSend.append(eventAsLuna3(event))
        if not eventsToSend:
            return
        async def saveEvents():
            reply = await luna3Client.lunaEvents.saveEvents(
                eventsToSend, waitEventsSaving=bool(self.waitSaving), raiseError=self.waitSaving
            )
            if not reply.success:
                logger.warning(
                    f"Failed save events to luna-event, receive response "
                    f"with status code {reply.statusCode}, body {reply.text}"
                )
        if self.waitSaving:
            await saveEvents()
        else:
            self.saveEventsAsyncRunner.runNoWait((saveEvents(),))  
[docs]class StoragePolicy(BaseSchema):
    """Storage policy schema"""
    # face sample storage policy
    faceSamplePolicy: FaceSamplePolicy = FaceSamplePolicy()
    # body sample storage policy
    bodySamplePolicy: BodySamplePolicy = BodySamplePolicy()
    # image origin storage policy
    imageOriginPolicy: ImageOriginPolicy = ImageOriginPolicy()
    # attribute storage policy
    attributePolicy: AttributeStorePolicy = Field(default_factory=lambda: AttributeStorePolicy())
    # face storage policy
    facePolicy: FaceStoragePolicy = FaceStoragePolicy()
    # event storage policy
    eventPolicy: EventStoragePolicy = EventStoragePolicy()
    # notification storage policy
    notificationPolicy: NotificationStoragePolicy = NotificationStoragePolicy()
[docs]    async def execute(
        self,
        sources: list[ProcessedImageData],
        events: list[Event],
        config: StorePolicyConfig,
        luna3Client: Client,
        redisContext: RedisContext,
    ) -> HandlersMonitoringData:
        """
        Execute storage policy - save objects.
        Args:
            sources: origin image sources
            events: events to process
            config: app config
            luna3Client: client
            redisContext: redis context
        Returns:
            monitoring data
        """
        async def _faceSample() -> None:
            with monitorTime(monitoringData.request, "face_sample_storage_policy_time"):
                await self.faceSamplePolicy.execute(events, sources, config.facesBucket, luna3Client)
        async def _bodySample() -> None:
            with monitorTime(monitoringData.request, "body_sample_storage_policy_time"):
                await self.bodySamplePolicy.execute(events, sources, config.bodiesBucket, luna3Client)
        async def _originImage() -> None:
            with monitorTime(monitoringData.request, "image_origin_storage_policy_time"):
                await self.imageOriginPolicy.execute(events, sources, config.originBucket, luna3Client)
        async def _attribute() -> None:
            with monitorTime(monitoringData.request, "face_attribute_storage_policy_time"):
                await self.attributePolicy.execute(events, luna3Client)
        async def _face() -> None:
            with monitorTime(monitoringData.request, "face_storage_policy_time"):
                await self.facePolicy.execute(events, luna3Client)
        async def _event() -> None:
            with monitorTime(monitoringData.request, "event_storage_policy_time"):
                await self.eventPolicy.execute(events, luna3Client, config.lunaEventsUsage)
        async def _notification() -> None:
            with monitorTime(monitoringData.request, "notification_storage_policy_time"):
                await self.notificationPolicy.execute(events, redisContext, config.lunaSenderUsage)
        monitoringData = HandlersMonitoringData()
        await asyncio.gather(_faceSample(), _bodySample())
        # save attribute and face only after executing previous policies (^^^ samples and images are updated here ^^^)
        await asyncio.gather(_originImage(), _attribute(), _face())
        # save events only after executing previous policies (^^^ events are updated here ^^^)
        await asyncio.gather(_event(), _notification())
        return monitoringData