Source code for luna_api.app.handlers.ws_proxy_handler
"""Module contains proxy handler for websockets"""
import asyncio
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import Dict, Optional
import aiohttp
import ujson
from aiohttp import ClientConnectorError, ClientWebSocketResponse
from websockets import WebSocketCommonProtocol
from yarl import URL
from app.app import ApiRequest
from app.handlers.base_handler import (
    ATTRIBUTE_RELATIVE_URL,
    BODY_SAMPLE_RELATIVE_URL,
    EVENT_RELATIVE_URL,
    FACE_RELATIVE_URL,
    FACE_SAMPLE_RELATIVE_URL,
    BaseRequestHandler,
)
from crutches_on_wheels.errors.errors import Error
from crutches_on_wheels.errors.exception import VLException
from crutches_on_wheels.web.handlers import WSBaseHandler
[docs]class WSProxyHandler(BaseRequestHandler, WSBaseHandler):
    """
    Handler for proxying ws connection
    Attributes:
        upstreamWs (Optional[ClientWebSocketResponse]): open websocket connection to sender
    """
    def __init__(self, request: ApiRequest, ws: WebSocketCommonProtocol):
        super().__init__(request)
        self.upstreamWs: Optional[ClientWebSocketResponse] = None
    def adoptEvent(self, sourceEvent: Dict) -> dict:
        event = deepcopy(sourceEvent)
        if event["url"] is not None:
            event["url"] = EVENT_RELATIVE_URL.format(apiVersion=self.app.ctx.apiVersion, eventId=event["event_id"])
        if event["face"] is not None and event["face"]["url"]:
            event["face"]["url"] = FACE_RELATIVE_URL.format(
                apiVersion=self.app.ctx.apiVersion, faceId=event["face"]["face_id"]
            )
        if event["face_attributes"] is not None and event["face_attributes"]["url"] is not None:
            event["face_attributes"]["url"] = ATTRIBUTE_RELATIVE_URL.format(
                apiVersion=self.app.ctx.apiVersion,
                attributeId=event["face_attributes"]["attribute_id"],
            )
        for detection in event["detections"]:
            if detection["samples"]["face"] is not None and detection["samples"]["face"]["url"] is not None:
                detection["samples"]["face"]["url"] = FACE_SAMPLE_RELATIVE_URL.format(
                    apiVersion=self.app.ctx.apiVersion,
                    sampleId=detection["samples"]["face"]["sample_id"],
                )
            if detection["samples"]["body"] is not None and detection["samples"]["body"]["url"] is not None:
                detection["samples"]["body"]["url"] = BODY_SAMPLE_RELATIVE_URL.format(
                    apiVersion=self.app.ctx.apiVersion,
                    sampleId=detection["samples"]["body"]["sample_id"],
                )
        return event
[docs]    async def get(self, ws: WebSocketCommonProtocol):
        """
        WS proxy handler.
        Returns:
            ws response
        """
        try:
            async for msg in self.upstreamWs:
                if msg.type == aiohttp.WSMsgType.TEXT:
                    msgWithEvent = ujson.loads(msg.data)
                    msgWithEvent["event"] = self.adoptEvent(msgWithEvent["event"])
                    await ws.send(ujson.dumps(msgWithEvent))
                elif self.upstreamWs.closed:
                    await ws.close(code=self.upstreamWs.close_code, reason=msg.extra)
                else:
                    raise ValueError(f'unexpected message type: "{msg}"')
        except asyncio.CancelledError:
            self.logger.debug("Client has disconnected")
        except Exception:
            self.logger.exception() 
[docs]    @asynccontextmanager
    async def wsSession(self, request: ApiRequest):
        """
        Web socket session.
        Validate input data and open connection to luna-sender before websocket handshake.
        """
        if not request.accountId:
            raise VLException(Error.AccountRequired4Subscription, 403, isCriticalError=False)
        if not self.config.additionalServicesUsage.lunaSender:
            raise VLException(Error.LunaSenderIsDisabled, statusCode=403, isCriticalError=False)
        query = request.args.copy()
        address = URL(f"ws://{URL(self.config.senderAddress.origin).raw_authority}")
        upstreamUrl = address / str(self.config.senderAddress.apiVersion) / "ws" % query
        async with aiohttp.ClientSession() as clientSession:
            try:
                async with clientSession.ws_connect(upstreamUrl, headers=request.headers) as upstreamWs:
                    self.upstreamWs = upstreamWs
                    yield
            except aiohttp.WSServerHandshakeError as e:
                raise VLException(Error.AiohttpWSServerHandshakeError.format(e), 400, False)
            except ClientConnectorError as e:
                self.logger.exception("failed to connect to sender")
                raise VLException(Error.AiohttpClientConnectionError.format(e), 500, True)