import asyncio
from contextlib import asynccontextmanager
from typing import Callable, List
from websockets import ConnectionClosed, WebSocketCommonProtocol
from app.app import SenderRequest
from app.classes.ws_connection import Filters, WSConnection
from app.handlers.base_handler import BaseSenderRequestHandler
from configs.config import WSCODE_GOING_AWAY
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.web.handlers import WSBaseHandler
from crutches_on_wheels.cow.web.query_getters import (
apparentGenderGetter,
backpackStatesGetter,
clothingColorGetter,
emotionsGetter,
ethnicGroupGetter,
float01Getter,
headwearStatesGetter,
int01Getter,
listStringsGetter,
listUUIDsGetter,
livenessGetter,
masksGetter,
sleeveLengthGetter,
)
from redis_db import redis_context
[docs]class WSHandler(BaseSenderRequestHandler, WSBaseHandler):
"""
Handler for creating ws connection
Attributes:
filters: ws filters prepared before handshake
"""
def __init__(self, request, ws: WebSocketCommonProtocol):
super().__init__(request)
[docs] async def get(self, ws: WebSocketCommonProtocol):
"""
WS handshake.
"""
connection = WSConnection(ws, self.filters, self.requestId)
self.app.ctx.websockets[connection.subscriptionId] = connection
try:
async with self.redisContext.subscribe(connection, self.accountId, self.logger):
while True:
await connection.wsResponse.recv()
except ConnectionClosed as e:
self.logger.error(f"WS connection closed with code {e.code}: {e.reason}")
except asyncio.CancelledError:
self.logger.debug("Client disconnected session")
except redis_context.RedisDisconnected:
await connection.wsResponse.close(code=WSCODE_GOING_AWAY, reason="Server shutdown, try later")
except:
self.logger.exception()
[docs] @asynccontextmanager
async def wsSession(self, request: SenderRequest):
"""
WSBaseHandler `wsSession` abstract method implementation.
See `_websocket_handler` implementation:
http://git.visionlabs.ru/luna/crutches_on_wheels.cow.-/blob/platform_5/web/application.py
Validates input data before websocket handshake.w
"""
await self.prepareWSFilters()
yield
[docs] async def prepareWSFilters(self):
"""
WS request validation before handshake. Prepares filters from websocket request.
"""
def getFilter(name: str, validator: Callable):
"""
Gets filter from request
Args:
name: filter name
validator: validator
Returns:
valid filter value, or None if not set
Raises:
VLException(Error.BadQueryParams, 400): if format of value is invalid
"""
value = self.request.args.get(name)
if value is None:
return
try:
validValue = validator(value)
if isinstance(validValue, List):
return set(validValue)
return validValue
except ValueError:
raise VLException(Error.BadQueryParams.format(name), 400, isCriticalError=False)
self.filters = Filters(
handlers=getFilter("handler_ids", listUUIDsGetter),
matchingCandidatesLabels=getFilter("matching_candidates_labels", listStringsGetter),
objectSimilarityGte=getFilter("object_similarity__gte", float01Getter),
objectSimilarityLt=getFilter("object_similarity__lt", float01Getter),
ethnicGroups=getFilter("ethnic_groups", ethnicGroupGetter),
gender=getFilter("gender", int01Getter),
ageGte=getFilter("age__gte", int),
ageLt=getFilter("age__lt", int),
sources=getFilter("sources", listStringsGetter),
tags=getFilter("tags", listStringsGetter),
cities=getFilter("cities", listStringsGetter),
areas=getFilter("areas", listStringsGetter),
districts=getFilter("districts", listStringsGetter),
streets=getFilter("streets", listStringsGetter),
houseNumbers=getFilter("house_numbers", listStringsGetter),
liveness=getFilter("liveness", livenessGetter),
masks=getFilter("masks", masksGetter),
emotions=getFilter("emotions", emotionsGetter),
apparentGender=getFilter("apparent_gender", apparentGenderGetter),
apparentAgeGte=getFilter("apparent_age__gte", int),
apparentAgeLt=getFilter("apparent_age__lt", int),
headwearStates=getFilter("headwear_states", headwearStatesGetter),
backpackStates=getFilter("backpack_states", backpackStatesGetter),
sleeveLengths=getFilter("sleeve_lengths", sleeveLengthGetter),
upperClothingColors=getFilter("upper_clothing_colors", clothingColorGetter),
)