Source code for luna_sender.app.handlers.ws_handler

import asyncio
from contextlib import asynccontextmanager
from typing import Callable, List

from websockets import ConnectionClosed, WebSocketCommonProtocol

from app.app import SenderRequest
from app.classes.ws_connection import Filters, WSConnection
from app.handlers.base_handler import BaseSenderRequestHandler
from configs.config import WSCODE_GOING_AWAY
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.web.handlers import WSBaseHandler
from crutches_on_wheels.cow.web.query_getters import (
    apparentGenderGetter,
    backpackStatesGetter,
    clothingColorGetter,
    emotionsGetter,
    ethnicGroupGetter,
    float01Getter,
    headwearStatesGetter,
    int01Getter,
    listStringsGetter,
    listUUIDsGetter,
    livenessGetter,
    masksGetter,
    sleeveLengthGetter,
)
from redis_db import redis_context


[docs]class WSHandler(BaseSenderRequestHandler, WSBaseHandler): """ Handler for creating ws connection Attributes: filters: ws filters prepared before handshake """ def __init__(self, request, ws: WebSocketCommonProtocol): super().__init__(request)
[docs] async def get(self, ws: WebSocketCommonProtocol): """ WS handshake. """ connection = WSConnection(ws, self.filters, self.requestId) self.app.ctx.websockets[connection.subscriptionId] = connection try: async with self.redisContext.subscribe(connection, self.accountId, self.logger): while True: await connection.wsResponse.recv() except ConnectionClosed as e: self.logger.error(f"WS connection closed with code {e.code}: {e.reason}") except asyncio.CancelledError: self.logger.debug("Client disconnected session") except redis_context.RedisDisconnected: await connection.wsResponse.close(code=WSCODE_GOING_AWAY, reason="Server shutdown, try later") except: self.logger.exception()
[docs] @asynccontextmanager async def wsSession(self, request: SenderRequest): """ WSBaseHandler `wsSession` abstract method implementation. See `_websocket_handler` implementation: http://git.visionlabs.ru/luna/crutches_on_wheels.cow.-/blob/platform_5/web/application.py Validates input data before websocket handshake.w """ await self.prepareWSFilters() yield
[docs] async def prepareWSFilters(self): """ WS request validation before handshake. Prepares filters from websocket request. """ def getFilter(name: str, validator: Callable): """ Gets filter from request Args: name: filter name validator: validator Returns: valid filter value, or None if not set Raises: VLException(Error.BadQueryParams, 400): if format of value is invalid """ value = self.request.args.get(name) if value is None: return try: validValue = validator(value) if isinstance(validValue, List): return set(validValue) return validValue except ValueError: raise VLException(Error.BadQueryParams.format(name), 400, isCriticalError=False) self.filters = Filters( handlers=getFilter("handler_ids", listUUIDsGetter), matchingCandidatesLabels=getFilter("matching_candidates_labels", listStringsGetter), objectSimilarityGte=getFilter("object_similarity__gte", float01Getter), objectSimilarityLt=getFilter("object_similarity__lt", float01Getter), ethnicGroups=getFilter("ethnic_groups", ethnicGroupGetter), gender=getFilter("gender", int01Getter), ageGte=getFilter("age__gte", int), ageLt=getFilter("age__lt", int), sources=getFilter("sources", listStringsGetter), tags=getFilter("tags", listStringsGetter), cities=getFilter("cities", listStringsGetter), areas=getFilter("areas", listStringsGetter), districts=getFilter("districts", listStringsGetter), streets=getFilter("streets", listStringsGetter), houseNumbers=getFilter("house_numbers", listStringsGetter), liveness=getFilter("liveness", livenessGetter), masks=getFilter("masks", masksGetter), emotions=getFilter("emotions", emotionsGetter), apparentGender=getFilter("apparent_gender", apparentGenderGetter), apparentAgeGte=getFilter("apparent_age__gte", int), apparentAgeLt=getFilter("apparent_age__lt", int), headwearStates=getFilter("headwear_states", headwearStatesGetter), backpackStates=getFilter("backpack_states", backpackStatesGetter), sleeveLengths=getFilter("sleeve_lengths", sleeveLengthGetter), upperClothingColors=getFilter("upper_clothing_colors", clothingColorGetter), )