"""
Module contains agent-example websocket provider.

The tasks of this module:

    - to provide possibility of comparison streams, its analytics and websocket subscribers

    - to provide possibilities of stream addition and removing

    - to provide possibilities of websocket subscription addition and removing

    - to execute sending required messages for all required websocket subscriptions

"""

import asyncio


class WSStreamMeta:
    """WS stream meta info"""

    __slots__ = ("accountId", "analyticIndexes")

    def __init__(self, analyticIndexes: list[int], accountId: str | None = None):
        self.accountId = accountId
        self.analyticIndexes = analyticIndexes


class WSProvider:
    """
    WS data provider which is intended to provide events from one or several streams to one or several subscribers
    Also it need check stream account id and analytics index affiliation before add subscribed (see `contains` method)
    """

    # stream id to connections map
    _websockets: dict[str, dict[int, set["WSConnection"]]]
    # stream id to account id map
    _streamsMeta: dict[str, WSStreamMeta]

    def __init__(self):
        self._websockets = {}
        self._streamsMeta = {}

    def contains(self, streamId: str, analyticIndex: int, accountId: str | None = None):
        """Check whether ws provider has knowledge about stream with specified id/account id"""
        if streamId not in self._streamsMeta:
            return False
        match = True
        if accountId is not None:
            match &= self._streamsMeta[streamId].accountId == accountId
        match &= analyticIndex in self._streamsMeta[streamId].analyticIndexes
        return match

    def wsBroadcast(
        self,
        stream_id: str,
        streamError: str | None = None,
        analyticsError: dict | None = None,
        analyticsResults: dict | list[dict] | None = None,
    ) -> None:
        """Broadcast message to all subscribers"""
        item2Send: dict
        if streamError is not None:
            item2Send = {"stream_status": "error", "error": streamError}
        elif analyticsError is not None:
            item2Send = {
                "stream_status": "in_progress",
                "error": None,
                "analytics_results": analyticsError,
            }
        elif analyticsResults is not None:
            item2Send = {
                "stream_status": "in_progress",
                "error": None,
                "analytics_results": analyticsResults,
            }
        else:
            raise NotImplementedError

        for analyticIdx, connections in self._websockets[stream_id].items():
            for connection in connections:
                asyncio.create_task(connection.send_json(item2Send))

    def addStream(self, stream: dict):
        """Add stream"""
        analyticIndexes = [row["idx"] for row in stream["analytics"]]
        streamId = stream["stream_id"]
        if streamId not in self._websockets:
            self._websockets[streamId] = {}
        for analyticIdx in analyticIndexes:
            if analyticIdx not in self._websockets[streamId]:
                self._websockets[streamId][analyticIdx] = set()
        if streamId not in self._streamsMeta:
            self._streamsMeta[streamId] = WSStreamMeta(accountId=stream["account_id"], analyticIndexes=analyticIndexes)

    def removeStream(self, streamId: str):
        """Remove stream"""
        if streamId in self._websockets:
            for analyticIdx, connections in self._websockets[streamId].items():
                for connection in connections:
                    asyncio.create_task(connection.send_json({"stream_status": "error", "error": "processing stopped"}))
        self._websockets.pop(streamId, None)
        self._streamsMeta.pop(streamId, None)

    def finishStream(self, streamId: str):
        """Remove stream from pool"""
        if streamId in self._websockets:
            for analyticIdx, connections in self._websockets[streamId].items():
                for connection in connections:
                    asyncio.create_task(connection.send_json({"stream_status": "finished", "error": None}))
        self._websockets.pop(streamId, None)
        self._streamsMeta.pop(streamId, None)

    def addSubscriber(self, streamId: str, analyticIndex: int, wsConnection):
        """Add subscriber for stream"""
        if streamId not in self._websockets:
            self._websockets[streamId] = {}
        if analyticIndex not in self._websockets[streamId]:
            self._websockets[streamId][analyticIndex] = set()
        self._websockets[streamId][analyticIndex].add(wsConnection)

    def removeSubscriber(self, streamId: str, analyticIndex: int, wsConnection):
        """Remove subscriber for stream"""
        if streamId in self._websockets:
            if analyticIndex in self._websockets[streamId]:
                self._websockets[streamId][analyticIndex].remove(wsConnection)

    async def close(self):
        coros = []
        for setOfAnalytics in self._websockets.values():
            for wsSet in setOfAnalytics.values():
                for ws in wsSet:
                    coros.append(ws.close(code=1001, reason="Server shutdown"))
        self._websockets = {}
        await asyncio.gather(*coros)
