Source code for luna_api.app.handlers.ws_proxy_handler

"""Module contains proxy handler for websockets"""
import asyncio
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import Dict, Optional

import aiohttp
import ujson
from aiohttp import ClientConnectorError, ClientWebSocketResponse
from websockets import WebSocketCommonProtocol
from yarl import URL

from app.app import ApiRequest
from app.handlers.base_handler import (
    ATTRIBUTE_RELATIVE_URL,
    BODY_SAMPLE_RELATIVE_URL,
    EVENT_RELATIVE_URL,
    FACE_RELATIVE_URL,
    FACE_SAMPLE_RELATIVE_URL,
    BaseRequestHandler,
)
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.web.handlers import WSBaseHandler


[docs]class WSProxyHandler(BaseRequestHandler, WSBaseHandler): """ Handler for proxying ws connection Attributes: upstreamWs (Optional[ClientWebSocketResponse]): open websocket connection to sender """ def __init__(self, request: ApiRequest, ws: WebSocketCommonProtocol): super().__init__(request) self.upstreamWs: Optional[ClientWebSocketResponse] = None
[docs] def checkTokenPermissions(self) -> None: """ Description see :func:`~BaseRequestHandler.checkTokenPermissions`. """
def adoptEvent(self, sourceEvent: Dict) -> dict: event = deepcopy(sourceEvent) if event["url"] is not None: event["url"] = EVENT_RELATIVE_URL.format(apiVersion=self.app.ctx.apiVersion, eventId=event["event_id"]) if event["face"] is not None and event["face"]["url"]: event["face"]["url"] = FACE_RELATIVE_URL.format( apiVersion=self.app.ctx.apiVersion, faceId=event["face"]["face_id"] ) if event["face_attributes"] is not None and event["face_attributes"]["url"] is not None: event["face_attributes"]["url"] = ATTRIBUTE_RELATIVE_URL.format( apiVersion=self.app.ctx.apiVersion, attributeId=event["face_attributes"]["attribute_id"], ) for detection in event["detections"]: if detection["samples"]["face"] is not None and detection["samples"]["face"]["url"] is not None: detection["samples"]["face"]["url"] = FACE_SAMPLE_RELATIVE_URL.format( apiVersion=self.app.ctx.apiVersion, sampleId=detection["samples"]["face"]["sample_id"], ) if detection["samples"]["body"] is not None and detection["samples"]["body"]["url"] is not None: detection["samples"]["body"]["url"] = BODY_SAMPLE_RELATIVE_URL.format( apiVersion=self.app.ctx.apiVersion, sampleId=detection["samples"]["body"]["sample_id"], ) return event
[docs] async def get(self, ws: WebSocketCommonProtocol): """ WS proxy handler. Returns: ws response """ try: async for msg in self.upstreamWs: if msg.type == aiohttp.WSMsgType.TEXT: msgWithEvent = ujson.loads(msg.data) msgWithEvent["event"] = self.adoptEvent(msgWithEvent["event"]) await ws.send(ujson.dumps(msgWithEvent)) elif self.upstreamWs.closed: await ws.close(code=self.upstreamWs.close_code, reason=msg.extra) else: raise ValueError(f'unexpected message type: "{msg}"') except asyncio.CancelledError: self.logger.debug("Client has disconnected") except Exception: self.logger.exception()
[docs] @asynccontextmanager async def wsSession(self, request: ApiRequest): """ Web socket session. Validate input data and open connection to luna-sender before websocket handshake. """ if not request.credentials.accountId: raise VLException(Error.AccountRequired4Subscription, 403, isCriticalError=False) # todo at LUNA-5822 self.request.headers["Luna-Account-Id"] = request.credentials.accountId if not self.config.additionalServicesUsage.lunaSender: raise VLException(Error.LunaSenderIsDisabled, statusCode=403, isCriticalError=False) query = request.args.copy() query.pop("account_id", None) address = URL(f"ws://{URL(self.config.senderAddress.origin).raw_authority}") upstreamUrl = address / str(self.config.senderAddress.apiVersion) / "ws" % query async with aiohttp.ClientSession() as clientSession: try: async with clientSession.ws_connect(upstreamUrl, headers=request.headers) as upstreamWs: self.upstreamWs = upstreamWs yield except aiohttp.WSServerHandshakeError as e: raise VLException(Error.AiohttpWSServerHandshakeError.format(e), 400, False) except ClientConnectorError as e: self.logger.exception("failed to connect to sender") raise VLException(Error.AiohttpClientConnectionError.format(e), 500, True)