import contextlib
import uuid
from collections import defaultdict
from datetime import datetime, timedelta
from itertools import chain
from typing import Any, Awaitable, Iterable, List, Literal, Optional, Tuple, TypeVar, Union
import shapely.geometry
import ujson
import ujson as json
from asyncpg import UniqueViolationError
from geo_utils.types import Geometry
from sqlalchemy import Interval, and_, asc, delete, desc, func, insert, or_, select, text, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Query
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.expression import literal, union_all
from app.handlers.classes.enums import AutoRestartStatus, LogTarget, StreamStatus
from app.handlers.classes.filters import StreamSearchFilters
from configs.config import BG_BATCH_LIMIT
from crutches_on_wheels.cow.adaptors.connection import AbstractDBConnection
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.enums.databases import Databases
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils import mixins
from crutches_on_wheels.cow.utils.check_connection import checkConnectionToDB
from crutches_on_wheels.cow.utils.db_functions import currentDBTimestamp
from crutches_on_wheels.cow.utils.healthcheck import checkSql, checkSqlMigration
from crutches_on_wheels.cow.utils.log import Logger, logger
from db.exceptions import ContextException, exceptionWrap
from db.helpers import StreamInfo
from db.loaders import makeOutputGroups, makeOutputLogs, prepareStreamV1, prepareStreamV2
from db.streams_db_tools.models import streams_db_models as models
from db.streams_db_tools.models.config import DBConfig
from schemas.groups import NewGroupModel, PatchGroupModel
from schemas.stream import DeleteStreamsFilters, NewStreamModel, StreamPatchingDataModel
from schemas.stream_v2 import NewStreamModelV2
from schemas.streams_feedback import OneStreamFeedback
# prevents sphinx warning
T_SDO_GEOMETRY = TypeVar("T_SDO_GEOMETRY")
# fatal stream errors which prevent stream autorestart
FATAL_ERRORS = ["Failed to authorize in Luna Platform"]
[docs]
class DBContext(BaseDBContext, mixins.Initializable):
"""
DB context
"""
#: sdo geometry object for geo_position | required for oracle db only
_sdoGeometryType = None
#: sdo geometry object for geo_position | required for oracle db only
_sdoGeometryPoint = None
def __init__(self, logger: Logger, dbSettings: "DBSetting" = None, storageTime: str = None):
super().__init__(logger)
self.dbSettings = dbSettings
self.storageTime = storageTime
[docs]
async def initialize(self):
"""Initialize context"""
if self.dbSettings and self.storageTime:
DBConfig.initialize(self.dbSettings.type)
await DBContext.initDBContext(
dbSettings=self.dbSettings,
storageTime=self.storageTime,
)
if DBConfig.dbType == Databases.ORACLE.value:
await DBContext.initSDO()
[docs]
async def probe(self) -> bool:
"""
Ensure provided config is valid. Create new connection. Can be used without initialization.
Support mixin.Initializable protocol
"""
return await checkConnectionToDB(dbSetting=self.dbSettings, postfix="events", asyncCheck=True)
[docs]
async def close(cls):
"""Close context"""
await DBContext.closeDBContext()
[docs]
def getRuntimeChecks(
self, includeLunaServices: bool = False
) -> List[Tuple[str, Awaitable]]: # pylint: disable-msg=W0613
"""
Returns configured system checks, pairs of (name, coroutine).
Args:
includeLunaServices: A bool, whether to return checks for luna services.
"""
checks = [
("db", checkSql(self.adaptor)),
("db_migration", checkSqlMigration(self.adaptor)),
]
return checks
[docs]
@classmethod
async def initSDO(cls) -> None:
"""
Initialize `SDO_GEOMETRY` to improve insertion speed | Required only for oracle database
See cx_oracle `samples/insert_geometry.py` for details
"""
async with DBContext.adaptor.connection(logger) as connection:
cls._sdoGeometryType = connection.connection.connection.connection.gettype("MDSYS.SDO_GEOMETRY")
cls._sdoGeometryPoint = connection.connection.connection.connection.gettype("MDSYS.SDO_POINT_TYPE")
[docs]
def splitList(self, values: list, divider: int = 1) -> list[list]:
"""
Split list into several lists taking into account max length of sql queries
Args:
values: list of values
divider: additional divider for list split (applicable for many args usage per request)
Returns:
lists
"""
if not values:
return []
maxLength = 1000 if self.dbType == Databases.ORACLE.value else 32765
maxLength = int(maxLength / divider)
return [values[x : x + maxLength] for x in range(0, len(values), maxLength)]
[docs]
@staticmethod
def prepareSearchQueryFilters(filters: StreamSearchFilters) -> BooleanClauseList:
"""
Prepare search query filters on "Stream" model.
Args:
filters: query filters
Returns:
filters for select query
"""
sqlFilters = [
models.Stream.account_id == filters.accountId if filters.accountId is not None else True,
models.Stream.stream_id.in_(filters.streamIds) if filters.streamIds is not None else True,
models.Stream.name.in_(filters.streamNames) if filters.streamNames is not None else True,
models.Stream.type.in_(filters.streamTypes) if filters.streamTypes is not None else True,
models.Stream.reference == filters.reference if filters.reference is not None else True,
(
models.Stream.status.in_([StreamStatus[x].value for x in filters.statuses])
if filters.statuses is not None
else True
),
models.Stream.create_time >= filters.createTimeGte if filters.createTimeGte else True,
models.Stream.create_time < filters.createTimeLt if filters.createTimeLt else True,
models.Stream.stream_id >= filters.streamIdGte if filters.streamIdGte is not None else True,
models.Stream.stream_id < filters.streamIdLt if filters.streamIdLt is not None else True,
]
return and_(*sqlFilters)
[docs]
def getGeoPosition(self, geoPosition: dict) -> Union[Geometry, T_SDO_GEOMETRY]:
"""
Get geo position prepared for insertion to database
Args:
geoPosition: raw dict with latitude and longitude
Returns:
`Geometry` object for postgres database or `SDO_GEOMETRY` object for oracle database
"""
if self.dbType == Databases.ORACLE.value:
sdo = self._sdoGeometryType.newobject()
sdoPoint = self._sdoGeometryPoint.newobject()
sdoPoint.X = geoPosition["longitude"]
sdoPoint.Y = geoPosition["latitude"]
sdo.SDO_POINT = sdoPoint
return sdo
return Geometry(shapely.geometry.Point(geoPosition["longitude"], geoPosition["latitude"]))
[docs]
@exceptionWrap
async def createStream(
self,
stream: NewStreamModel,
streamId: Optional[str] = None,
version: int = 1,
dbConnection: Optional[AbstractDBConnection] = None,
) -> str:
"""
Save stream.
Args:
stream: stream to save
streamId: optional id to put stream by
version: optional stream version
dbConnection: stream database connection, if need to do something within the transaction
Returns:
stream id in uuid4 format
"""
async with contextlib.AsyncExitStack() as stack:
connection = (
await stack.enter_async_context(DBContext.adaptor.connection(self.logger))
if dbConnection is None
else dbConnection
)
insertStreamSt = (
insert(models.Stream)
.values(
stream_id=streamId or str(uuid.uuid4()),
account_id=stream.accountId,
name=stream.name,
description=stream.description,
type=stream.data.type,
mask=stream.data.mask if stream.data.type == "images" else None,
endless=stream.data.endless if stream.data.type in ("tcp", "udp") else None,
reference=stream.data.reference,
roi=json.dumps(stream.data.roi) if isinstance(stream.data.roi, list) else stream.data.roi.json(),
droi=(
json.dumps(stream.data.droi) if isinstance(stream.data.droi, list) else stream.data.droi.json()
),
rotation=stream.data.rotation,
preferred_program_stream_frame_width=stream.data.preferredProgramStreamFrameWidth,
status=StreamStatus[stream.status].value,
version=version,
api_version=1,
)
.returning(models.Stream.stream_id)
)
streamId = await connection.scalar(insertStreamSt)
if stream.groupName is not None or stream.groupId is not None:
groupSeqId, groupName = await self._getGroupIdAndName(
connection, stream.groupName, stream.groupId, accountId=stream.accountId, block=True
)
insertStreamGroupSt = insert(models.GroupStream).values(
id=groupSeqId, group_name=groupName, stream_id=streamId
)
await connection.execute(insertStreamGroupSt)
insertHandlertSt = insert(models.Handler).values(
stream_id=streamId,
event_handler=stream.eventHandler.json(by_alias=True, exclude_none=True),
policies=stream.policies.json(by_alias=True, exclude_none=True),
)
await connection.execute(insertHandlertSt)
if locationData := stream.location.dict(by_alias=True, exclude_none=True):
if rawGeoPosition := locationData.get("geo_position"):
locationData["geo_position"] = self.getGeoPosition(rawGeoPosition)
insertLocationSt = insert(models.Location).values(stream_id=streamId, **locationData)
await connection.execute(insertLocationSt)
insertRestartSt = insert(models.Restart).values(
stream_id=streamId,
status=AutoRestartStatus(stream.autorestart.restart).value,
**stream.autorestart.dict(by_alias=True),
)
await connection.execute(insertRestartSt)
insertLogSt = insert(models.Log).values(
stream_id=streamId, status=StreamStatus[stream.status].value, stream_version=version
)
await connection.execute(insertLogSt)
return streamId
[docs]
@exceptionWrap
async def createStreamV2(
self,
stream: NewStreamModelV2,
streamId: Optional[str] = None,
version: int = 1,
dbConnection: Optional[AbstractDBConnection] = None,
) -> str:
"""
Save stream.
Args:
stream: stream to save in V2 format
streamId: optional id to put stream by
version: optional stream version
dbConnection: stream database connection, if need to do something within the transaction
Returns:
stream id in uuid4 format
"""
async with contextlib.AsyncExitStack() as stack:
connection = (
await stack.enter_async_context(DBContext.adaptor.connection(self.logger))
if dbConnection is None
else dbConnection
)
insertStreamSt = (
insert(models.Stream)
.values(
stream_id=streamId or str(uuid.uuid4()),
account_id=stream.accountId,
name=stream.name,
description=stream.description,
type=stream.data.type,
mask=stream.data.mask if stream.data.type == "images" else None,
endless=stream.data.endless if stream.data.type in ("tcp", "udp") else None,
reference=stream.data.reference,
roi=(
json.dumps(stream.data.roi)
if stream.data.roi and isinstance(stream.data.roi, list)
else stream.data.roi.json()
),
droi=(
json.dumps(stream.data.analytics[0].droi)
if isinstance(stream.data.analytics[0].droi, list)
else stream.data.analytics[0].droi.json()
),
rotation=stream.data.rotation,
status=StreamStatus[stream.status].value,
version=version,
preferred_program_stream_frame_width=stream.data.preferredProgramStreamFrameWidth,
api_version=2,
)
.returning(models.Stream.stream_id)
)
streamId = await connection.scalar(insertStreamSt)
if stream.groupName is not None or stream.groupId is not None:
groupSeqId, groupName = await self._getGroupIdAndName(
connection, stream.groupName, stream.groupId, accountId=stream.accountId, block=True
)
insertStreamGroupSt = insert(models.GroupStream).values(
id=groupSeqId, group_name=groupName, stream_id=streamId
)
await connection.execute(insertStreamGroupSt)
combinedAnalytics = {
"frame_processing_mode": stream.data.frameProcessingMode,
"ffmpeg_threads_number": stream.data.ffmpegThreadsNumber,
"real_time_mode_fps": stream.data.realTimeModeFps,
"analytics": [analytic.asDict() for analytic in stream.data.analytics],
}
insertHandlerSt = insert(models.Handler).values(
stream_id=streamId,
event_handler=stream.data.analytics[0].eventHandler.json(),
policies=json.dumps(combinedAnalytics),
)
await connection.execute(insertHandlerSt)
if locationData := stream.location.dict(by_alias=True, exclude_none=True):
if rawGeoPosition := locationData.get("geo_position"):
locationData["geo_position"] = self.getGeoPosition(rawGeoPosition)
insertLocationSt = insert(models.Location).values(stream_id=streamId, **locationData)
await connection.execute(insertLocationSt)
insertRestartSt = insert(models.Restart).values(
stream_id=streamId,
status=AutoRestartStatus(stream.autorestart.restart).value,
**stream.autorestart.dict(by_alias=True),
)
await connection.execute(insertRestartSt)
insertLogSt = insert(models.Log).values(
stream_id=streamId, status=StreamStatus[stream.status].value, stream_version=version
)
await connection.execute(insertLogSt)
return streamId
[docs]
@exceptionWrap
async def blockStreams(self, connection, streamIds: list[str], accountId: Optional[str] = None) -> list[str]:
"""
Block streams for update.
Args:
connection: current connection
streamIds: stream ids
accountId: account id
Returns:
ids of streams blocked
"""
blockSt = (
select([models.Stream.stream_id])
.where(
and_(
models.Stream.stream_id.in_(streamIds),
models.Stream.account_id == accountId if accountId is not None else True,
)
)
.order_by(models.Stream.stream_id)
.with_for_update()
)
return await connection.fetchall(blockSt)
[docs]
@exceptionWrap
async def checkStreamAllowedForUpdate(self, connection, streamId: str, newStatus: Optional[str] = None):
"""
Check stream is allowed for update with new status
Args:
connection: connection to database
streamId: stream id
newStatus: new status for stream
Raises:
VLException(Error.UnableToStopProcessing.format(streamId), 400, False) if unable to stop stream
VLException(Error.UnableToCancelProcessing.format(streamId), 400, False) if unable to cancel stream
"""
if not newStatus:
return
selectSt = select([models.Stream.status, models.Stream.type]).where(models.Stream.stream_id == streamId)
dbReply = await connection.fetchone(selectSt)
streamStatus, streamType = dbReply["status"], dbReply["type"]
if all(
(
StreamStatus[newStatus] == StreamStatus.pause,
StreamStatus(streamStatus) == StreamStatus.in_progress,
streamType == "videofile",
)
):
raise VLException(Error.UnableToStopProcessing.format(streamId), 400, False)
if all(
(
StreamStatus[newStatus] == StreamStatus.cancel,
StreamStatus(streamStatus) in (StreamStatus.done, StreamStatus.failure),
)
):
raise VLException(Error.UnableToCancelProcessing.format(streamId), 400, False)
[docs]
@exceptionWrap
async def updateStream(self, streamId: str, streamData: StreamPatchingDataModel) -> bool:
"""
Update stream.
Args:
streamId: stream id
streamData: stream data to update
Returns:
True if stream was patched
"""
async with DBContext.adaptor.connection(self.logger) as connection:
if await self.blockStreams(connection, streamIds=[streamId]):
await self.checkStreamAllowedForUpdate(connection, streamId=streamId, newStatus=streamData.status)
values = {}
if streamData.description is not None:
values["description"] = streamData.description
if streamData.status is not None:
values["status"] = StreamStatus[streamData.status].value
values["status_last_update_time"] = self.currentDBTimestamp
updateStreamSt = update(models.Stream).where(models.Stream.stream_id == streamId).values(**values)
await connection.execute(updateStreamSt)
if streamData.status:
insertLogSt = insert(models.Log).from_select(
("stream_id", "status", "stream_version"),
select([models.Stream.stream_id, models.Stream.status, models.Stream.version]).where(
models.Stream.stream_id == streamId
),
)
await connection.execute(insertLogSt)
return True
return False
[docs]
@exceptionWrap
async def putStream(self, streamId: str, stream: NewStreamModel | NewStreamModelV2, streamFormat: int = 1) -> int:
"""
Put stream.
There are 3 cases:
1) Replacing existing stream
2) Creating new stream
3) Recreating new stream
Args:
streamId: id to put stream by
stream: stream to put
streamFormat: stream data format version
Raises:
VLException(Error.StreamNotFound.format(streamId), 404, False) if the stream was deleted
Returns:
stream current version
"""
async with DBContext.adaptor.connection(self.logger) as connection:
if await self.blockStreams(connection, streamIds=[streamId], accountId=stream.accountId):
streamVersion = await connection.scalar(
select([models.Stream.version]).where(models.Stream.stream_id == streamId)
)
# we had stream with given id in the past or having now
newStreamVersion = streamVersion + 1
await connection.execute(delete(models.Stream).where(models.Stream.stream_id == streamId))
# restarting existing stream
insertLogSt = insert(models.Log).values(
stream_id=streamId, status=StreamStatus.restart.value, stream_version=newStreamVersion
)
await connection.execute(insertLogSt)
createMethod = self.createStream
if streamFormat == 2:
createMethod = self.createStreamV2
await createMethod(stream=stream, streamId=streamId, version=newStreamVersion, dbConnection=connection)
return newStreamVersion
# stream has been deleted or not created yet
raise VLException(Error.StreamNotFound.format(streamId), 404, False)
@exceptionWrap
async def _getStreams(
self,
filters: StreamSearchFilters,
page: Optional[int] = None,
pageSize: Optional[int] = None,
):
"""
Get streams.
Args:
filters: stream search filters
page: page number
pageSize: page size
Returns:
list of streams
"""
async with DBContext.adaptor.connection(self.logger) as connection:
async def getStreamLogMap(streamIds_: Iterable[str]) -> dict[str, dict]:
"""Get log rows by stream ids."""
if not streamIds_:
return {}
targets = [models.Log.stream_id, models.Log.error, models.Log.video_info, models.Log.preview]
maxIdsSts = (
select([func.max(models.Log.id)]).where(models.Log.stream_id == streamId_)
for streamId_ in streamIds_
)
selectSt = select(targets).where(models.Log.id.in_(union_all(*maxIdsSts)))
dbReply = await connection.fetchall(selectSt)
return {
streamId_: {
"last_error": error,
"video_info": ujson.loads(videoInfo) if videoInfo is not None else None,
"preview": ujson.loads(preview) if preview is not None else None,
}
for streamId_, error, videoInfo, preview in dbReply
}
if DBContext.adaptor.dbType == Databases.POSTGRES.value:
await connection.execute(text("SET LOCAL random_page_cost=1;"))
geoPositionTarget = func.ST_AsText(models.Location.geo_position).label("geo_position")
else:
geoPositionTarget = models.Location.geo_position
targets = [
models.Stream.stream_id,
models.Stream.account_id,
models.Stream.name,
models.Stream.description,
models.Stream.type,
models.Stream.reference,
models.Stream.mask,
models.Stream.endless,
models.Stream.roi,
models.Stream.droi,
models.Stream.rotation,
models.Stream.preferred_program_stream_frame_width,
models.Stream.status,
models.Stream.version,
models.Stream.create_time,
models.Stream.api_version,
models.Handler.event_handler,
models.Handler.policies,
models.Location.city,
models.Location.area,
models.Location.district,
models.Location.street,
models.Location.house_number,
geoPositionTarget,
models.Restart.restart,
models.Restart.attempt_count,
models.Restart.delay,
models.Restart.current_attempt,
models.Restart.last_attempt_time,
models.Restart.status.label("autorestart_status"),
]
query = (Query(targets).join(models.Restart).join(models.Handler).outerjoin(models.Location)).filter(
self.prepareSearchQueryFilters(filters)
)
if filters.streamIdGte or filters.streamIdLt:
order = asc if filters.order == "asc" else desc
query = query.order_by(order(models.Stream.stream_id))
elif filters.order is not None:
order = asc if filters.order == "asc" else desc
query = query.order_by(order(models.Stream.create_time))
else:
query = query.order_by(models.Stream.id)
if filters.groupName:
# Do nested select instead of join, so offset and limit will work properly
streamsInGroup = select([models.GroupStream.stream_id]).where(
models.GroupStream.group_name == filters.groupName
)
query = query.filter(models.Stream.stream_id.in_(streamsInGroup))
if pageSize is not None:
query = query.offset((page - 1) * pageSize).limit(pageSize)
streamRows = await connection.fetchall(query.statement)
streamIds = [row["stream_id"] for row in streamRows]
streamLogMap = await getStreamLogMap(streamIds)
# Fetch stream groups
streamGroupMapSt = (
Query([models.GroupStream.stream_id, models.GroupStream.group_name])
.filter(models.GroupStream.stream_id.in_(streamIds))
.statement
)
streamGroupMap = defaultdict(list)
for streamId, groupName in await connection.fetchall(streamGroupMapSt):
streamGroupMap[streamId].append(groupName)
return streamRows, streamLogMap, streamGroupMap
@staticmethod
def _getStreamVersion(streamRow) -> int:
"""
Get stream version. Parameter preferred_program_stream_frame_width is only present in V1 version
Returns:
stream version
"""
return streamRow["api_version"]
[docs]
async def getStreams(
self,
filters: StreamSearchFilters,
page: Optional[int] = None,
pageSize: Optional[int] = None,
targetFormat: int = 1,
) -> list[dict[str, Any]]:
"""
Get streams.
Args:
filters: stream search filters
page: page number
pageSize: page size
targetFormat: stream target format
Returns:
list of streams
"""
streamRows, streamLogMap, streamGroupMap = await self._getStreams(filters=filters, pageSize=pageSize, page=page)
streams = []
for streamRow in streamRows:
streamVersion = self._getStreamVersion(streamRow)
if streamVersion == 1:
stream = prepareStreamV1(streamRow=streamRow, targetFormat=targetFormat, storageTime=self.storageTime)
else:
stream = prepareStreamV2(streamRow=streamRow, targetFormat=targetFormat, storageTime=self.storageTime)
stream.update(streamLogMap.get(stream["stream_id"], {"last_error": None, "video_info": None}))
stream["groups"] = streamGroupMap.get(stream["stream_id"], [])
streams.append(stream)
return streams
[docs]
@exceptionWrap
async def countStreams(self, filters: StreamSearchFilters) -> int:
"""
Count streams according to filters.
Args:
filters: stream search filters
Returns:
streams count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([func.count(models.Stream.id)]).where(self.prepareSearchQueryFilters(filters))
streamCount = await connection.scalar(selectSt)
return streamCount
def _deleteFiltersClauses(self, filters: DeleteStreamsFilters) -> BooleanClauseList:
"""Yields sa filters for given DeleteStreamsFilter obj."""
if filters.streamIds:
yield models.Stream.stream_id.in_(filters.streamIds)
if filters.names:
yield models.Stream.name.in_(filters.names)
@exceptionWrap
async def _deleteStreams(self, where: BooleanClauseList) -> int:
"""
Delete single item while maintaining log records.
Args:
streamId: stream id
Returns:
Number or removed streams.
"""
sel = select((models.Stream.stream_id, StreamStatus.deleted.value, models.Stream.version + 1)).where(where)
insertLogSt = insert(models.Log).from_select(
(models.Log.stream_id, models.Log.status, models.Log.stream_version), sel
)
removeSt = delete(models.Stream).where(where)
async with DBContext.adaptor.connection(self.logger) as connection:
await connection.execute(insertLogSt)
removedCount = await connection.execute(removeSt)
self.logger.debug(f"streams have been deleted: {removedCount}")
return removedCount
[docs]
async def deleteStreamsByDeleteFilters(self, filters: DeleteStreamsFilters, accountId: Optional[str] = None):
"""
Delete streams by filters (OR).
Args:
filters: streams delete filters
accountId: account id
"""
clauses = tuple(self._deleteFiltersClauses(filters))
if not clauses:
return
where = and_(models.Stream.account_id == accountId if accountId is not None else True, or_(*clauses))
await self._deleteStreams(where)
[docs]
async def removeStreams(self, streamIds: list[str], accountId: Optional[str] = None) -> int:
"""
Remove streams by streamIds.
Args:
streamIds: stream ids
accountId: account id
Returns:
removed streams count
"""
where = and_(
models.Stream.stream_id.in_(streamIds),
models.Stream.account_id == accountId if accountId is not None else True,
)
return await self._deleteStreams(where)
@staticmethod
def _selectFromQueue(
limit: int,
streamIds: Optional[list[str]] = None,
streamNames: Optional[list[str]] = None,
groupIds: Optional[list[str]] = None,
groupNames: Optional[list[str]] = None,
) -> select:
"""
Prepare "select from queue" statement.
Args:
streamIds: stream ids filter
streamNames: stream names filter
limit: max amount to return
groupNames: group names filter
groupIds: groupd ids filter
Returns:
prepared select
"""
streamQueueAlias = models.StreamQueueView.alias()
return (
select([streamQueueAlias.c.stream_id])
.where(
and_(
streamQueueAlias.c.stream_id.in_(streamIds) if streamIds is not None else True,
streamQueueAlias.c.name.in_(streamNames) if streamNames is not None else True,
models.GroupStream.group_name.in_(groupNames) if groupNames is not None else True,
(
and_(
models.Group.group_id.in_(groupIds),
models.Group.id == models.GroupStream.id,
)
if groupIds is not None
else True
),
(
streamQueueAlias.c.stream_id == models.GroupStream.stream_id
if any((groupIds, groupNames))
else True
),
)
)
.limit(limit)
)
[docs]
@exceptionWrap
async def getStreamsQueue(
self,
limit: int,
streamIds: Optional[list[str]] = None,
streamNames: Optional[list[str]] = None,
groupIds: Optional[list[str]] = None,
groupNames: Optional[list[str]] = None,
targetFormat: int = 1,
) -> list[dict[str, Any]]:
"""
Get streams from the queue.
Args:
limit: the maximum amount to return
streamIds: list of stream ids
streamNames: list of stream names
groupIds: list of group ids
groupNames: list of group names
targetFormat: stream target format
Returns:
streams as dicts
"""
streamIdsSt = self._selectFromQueue(
limit=limit, streamIds=streamIds, streamNames=streamNames, groupIds=groupIds, groupNames=groupNames
)
async with DBContext.adaptor.connection(self.logger) as connection:
res = await connection.fetchall(streamIdsSt)
foundStreamIds = list(chain(*res))
result = await self.getStreams(StreamSearchFilters(streamIds=foundStreamIds), targetFormat=targetFormat)
result.sort(key=lambda stream: foundStreamIds.index(stream["stream_id"]))
return result
[docs]
@exceptionWrap
async def pullStreamsQueue(
self,
limit: int,
streamIds: Optional[list[str]] = None,
streamNames: Optional[list[str]] = None,
groupIds: Optional[list[str]] = None,
groupNames: Optional[list[str]] = None,
targetFormat: int = 1,
) -> list[dict[str, Any]]:
"""
Pull streams from the queue.
Args:
limit: the maximum amount to return
streamIds: list of stream id
streamNames: list of stream names
groupNames: list of group names
groupIds: list of group ids
targetFormat: stream target format - 1 or 2
Returns:
streams as dicts
"""
baseStreamIdsSt = self._selectFromQueue(
limit=limit, streamIds=streamIds, streamNames=streamNames, groupIds=groupIds, groupNames=groupNames
)
if DBContext.adaptor.dbType == Databases.ORACLE.value:
baseStreamIdsSt = baseStreamIdsSt.limit(1)
streamIdSt = (
update(models.StreamQueueView)
.where(
and_(
models.StreamQueueView.c.stream_id.in_(baseStreamIdsSt),
models.StreamQueueView.c.status == StreamStatus.pending.value,
)
)
.values(status=StreamStatus.in_progress.value, status_last_update_time=self.currentDBTimestamp)
.returning(models.StreamQueueView.c.stream_id)
)
streamIdsFromDb = []
for _ in range(limit):
async with DBContext.adaptor.connection(self.logger) as connection:
# TODO use one statement with the nested result structure (when oracle RETURNING added) LUNA-5313
streamId = await connection.scalar(streamIdSt)
if streamId is None:
break
insertLogSt = insert(models.Log).from_select(
("stream_id", "status", "stream_version"),
select([models.Stream.stream_id, StreamStatus.in_progress.value, models.Stream.version]).where(
models.Stream.stream_id == streamId
),
)
await connection.execute(insertLogSt)
streamIdsFromDb.append(streamId)
else:
streamIdsSt = (
update(models.StreamQueueView)
.where(
and_(
models.StreamQueueView.c.stream_id.in_(baseStreamIdsSt),
models.StreamQueueView.c.status == StreamStatus.pending.value,
)
)
.values(status=StreamStatus.in_progress.value, status_last_update_time=self.currentDBTimestamp)
.returning(models.StreamQueueView.c.stream_id)
)
async with DBContext.adaptor.connection(self.logger) as connection:
res = await connection.fetchall(streamIdsSt)
streamIdsFromDb = list(chain(*res))
insertLogSt = insert(models.Log).from_select(
("stream_id", "status", "stream_version"),
select([models.Stream.stream_id, StreamStatus.in_progress.value, models.Stream.version]).where(
models.Stream.stream_id.in_(streamIdsFromDb)
),
)
await connection.execute(insertLogSt)
result = await self.getStreams(StreamSearchFilters(streamIds=streamIdsFromDb), targetFormat=targetFormat)
result.sort(key=lambda stream: streamIdsFromDb.index(stream["stream_id"]))
return result
[docs]
@staticmethod
async def denyStreamAutoRestart(connection: AbstractDBConnection, streamId: str):
"""
Deny stream auto restart
Args:
connection: connection
streamId: stream id
"""
updateSt = (
update(models.Restart)
.where(models.Restart.stream_id == streamId)
.values(status=AutoRestartStatus.denied.value)
)
await connection.execute(updateSt)
[docs]
@exceptionWrap
async def saveFeedback(self, streams: list[OneStreamFeedback]) -> dict[str, list[StreamInfo]]:
"""
Save feedback, return stream info.
Args:
streams: list of stream feedbacks
Returns:
two lists of db stream info objects: {"modified": [{<stream info>}], "unmodified": [{<stream info>}]}
"""
modified, unmodified = [], []
async with DBContext.adaptor.connection(self.logger) as connection:
# select streams and lock them
selectSt = (
select([models.Stream.stream_id, models.Stream.status, models.Stream.version])
.where(models.Stream.stream_id.in_(stream.streamId for stream in streams))
.order_by(models.Stream.stream_id)
.with_for_update()
)
streamsFromDb = {
streamId: StreamInfo(streamId=streamId, status=StreamStatus(status), version=version)
for (streamId, status, version) in await connection.fetchall(selectSt)
}
# todo RETURNING in oracle???
for stream in streams:
streamInDB = streamsFromDb.get(stream.streamId)
if streamInDB is None:
# stream not found
streamInDB = StreamInfo(streamId=stream.streamId, status=StreamStatus.not_found, version=0)
else: # stream found
# update stream state
newStatus = StreamStatus[stream.status]
if (streamInDB.status == StreamStatus.in_progress and streamInDB.status != newStatus) or (
newStatus == StreamStatus.in_progress and streamInDB.status == StreamStatus.pending
):
updateSt = (
update(models.Stream)
.where(models.Stream.stream_id == stream.streamId)
.values(status=newStatus.value, status_last_update_time=self.currentDBTimestamp)
)
await connection.execute(updateSt)
streamInDB.status = newStatus
# insert log
videoInfo = None
if stream.videoInfo:
videoInfo = stream.videoInfo.json(by_alias=True, exclude_defaults=True)
preview = None
if stream.preview:
preview = stream.preview.json(by_alias=True, exclude_defaults=True)
insertSt = insert(models.Log).values(
stream_id=stream.streamId,
error=stream.error,
status=StreamStatus[stream.status].value,
video_info=videoInfo,
stream_version=streamInDB.version,
preview=preview,
)
await connection.execute(insertSt)
if stream.error in FATAL_ERRORS:
await self.denyStreamAutoRestart(connection, stream.streamId)
if (streamInDB.status.name, streamInDB.version) != (stream.status, stream.version):
modified.append(streamInDB)
else:
unmodified.append(streamInDB)
await connection.execute(
update(models.Stream)
.where(models.Stream.stream_id.in_((streamInfo.streamId for streamInfo in unmodified)))
.values(last_feedback_time=self.currentDBTimestamp)
)
return {"modified": modified, "unmodified": unmodified}
async def _insertStreamLogs(
self,
connection,
streamIds: list[str],
streamLogStatuses: list[StreamStatus],
streamToVersionMap: dict[str, int],
) -> None:
"""
Insert stream log with each status for each stream
Args:
connection: database connection
streamIds: stream ids
streamLogStatuses: log statuses
streamToVersionMap: stream to version map
"""
if self.dbType == Databases.ORACLE.value:
logInsertSts = [
insert(models.Log).values(
stream_id=streamId, status=status.value, stream_version=streamToVersionMap[streamId]
)
for status in streamLogStatuses
for streamId in streamIds
]
else:
logInsertSts = [
insert(models.Log).values(
[
{
"id": select([text(f"nextval('{models.logIdSeq.name}')")]),
"stream_id": streamId,
"status": status.value,
"stream_version": streamToVersionMap[streamId],
}
for status in streamLogStatuses
for streamId in streamIds
]
)
]
for insertSt in logInsertSts:
await connection.execute(insertSt)
[docs]
@exceptionWrap
async def downgradeStreamStatus(self, streamInactiveTime) -> int:
"""
Downgrade stream status.
Make stream.status `in_progress`->`pending` for all long-time-ago updated streams.
Args:
streamInactiveTime: period without feedback
Returns:
amount of the downgraded streams
"""
def tooLongAgo(timeColumn):
"""
Function defines "too long ago" criteria.
Args:
timeColumn: column to check
Returns:
prepared clause
"""
return timeColumn < self.currentDBTimestamp - func.cast(timedelta(seconds=streamInactiveTime), Interval)
streamsWithFeedbackSt = (
select([models.Stream.stream_id, models.Stream.version])
.where(
and_(
models.Stream.status == StreamStatus.in_progress.value,
tooLongAgo(models.Stream.last_feedback_time),
tooLongAgo(models.Stream.status_last_update_time),
)
)
.limit(BG_BATCH_LIMIT * 0.5)
.with_for_update()
)
streamsWithoutFeedbackSt = (
select([models.Stream.stream_id, models.Stream.version])
.where(
and_(
models.Stream.status == StreamStatus.in_progress.value,
models.Stream.last_feedback_time.is_(None),
tooLongAgo(models.Stream.status_last_update_time),
)
)
.limit(BG_BATCH_LIMIT * 0.5)
.with_for_update()
)
if self.dbType == Databases.POSTGRES.value:
streamsWithFeedbackSt = streamsWithFeedbackSt.order_by(models.Stream.stream_id)
streamsWithoutFeedbackSt = streamsWithoutFeedbackSt.order_by(models.Stream.stream_id)
async with DBContext.adaptor.connection(self.logger) as connection:
streamsWithFeedback = await connection.fetchall(streamsWithFeedbackSt)
streamsWithoutFeedback = await connection.fetchall(streamsWithoutFeedbackSt)
streamToVersionMap = {
streamId: version for (streamId, version) in chain(streamsWithFeedback, streamsWithoutFeedback)
}
streamIds = [streamId for (streamId, version) in chain(streamsWithFeedback, streamsWithoutFeedback)]
if not streamIds:
return 0
self.logger.warning(f"Downgrading {len(streamsWithoutFeedback)} streams without any feedback")
self.logger.debug(f"Downgrading streams without any feedback: {streamsWithoutFeedback}")
self.logger.warning(f"Downgrading {len(streamsWithFeedback)} streams with outdated feedback")
self.logger.debug(f"Downgrading streams with outdated feedback: {streamsWithFeedback}")
count = 0
for streamIdBatch in self.splitList(streamIds):
count += await connection.execute(
update(models.Stream)
.where(models.Stream.stream_id.in_(streamIdBatch))
.values(status=StreamStatus.pending.value, status_last_update_time=self.currentDBTimestamp)
)
streamLogStatuses = [StreamStatus.handler_lost, StreamStatus.restart, StreamStatus.pending]
for streamsLogBatch in self.splitList(streamIdBatch, divider=9):
await self._insertStreamLogs(
connection=connection,
streamIds=streamsLogBatch,
streamLogStatuses=streamLogStatuses,
streamToVersionMap=streamToVersionMap,
)
return count
[docs]
@exceptionWrap
async def autorestartStreams(self) -> int:
"""
Autorestart streams.
Make stream.status `failure`->`pending` for all failed streams that allow autorestart taking restart
delay into account
Make stream.restart.current_attempt to 0 (means autorestart has been completed) for all streams that have
status different from `failure` after `delay` seconds and `current_attempt` >= 1
Args:
lock: active db connection
Returns:
amount of the restarted streams
"""
if self.dbType == Databases.POSTGRES.value:
attemptTimeCondition = (
text(
f"EXTRACT (EPOCH FROM {self.currentDBTimestamp.text} - "
f"{models.Restart.__tablename__}.{models.Restart.last_attempt_time.property.key})"
)
>= models.Restart.delay
)
else:
attemptTimeCondition = (
text(
f"get_datediff_as_epoch({models.Restart.__tablename__}."
f"{models.Restart.last_attempt_time.property.key})"
)
>= models.Restart.delay
)
generalConditions = and_(
models.Restart.restart == 1,
models.Stream.stream_id == models.Restart.stream_id,
)
streamsRequiredRestartResetSt = (
select([models.Restart.stream_id])
.where(
and_(
generalConditions,
models.Stream.status != StreamStatus.failure.value,
models.Restart.status == AutoRestartStatus.in_progress.value,
attemptTimeCondition,
)
)
.limit(BG_BATCH_LIMIT)
.with_for_update()
)
streamsRequiredFailSt = (
select([models.Restart.stream_id])
.where(
and_(
generalConditions,
models.Stream.status == StreamStatus.failure.value,
models.Restart.status == AutoRestartStatus.in_progress.value,
models.Restart.current_attempt == models.Restart.attempt_count,
)
)
.limit(BG_BATCH_LIMIT)
)
streamsRequiredRestartSt = (
select([models.Restart.stream_id, models.Stream.version])
.where(
and_(
generalConditions,
models.Stream.status == StreamStatus.failure.value,
models.Restart.status.in_([AutoRestartStatus.enabled.value, AutoRestartStatus.in_progress.value]),
or_(
models.Restart.current_attempt != models.Restart.attempt_count,
models.Restart.current_attempt == None,
),
or_(attemptTimeCondition, models.Restart.last_attempt_time == None),
)
)
.limit(BG_BATCH_LIMIT)
.with_for_update()
)
async with DBContext.adaptor.connection(self.logger) as connection:
streamsRequiredRestartReset = self.splitList(
[row[0] for row in await connection.fetchall(streamsRequiredRestartResetSt)]
)
streamsRequiredFail = self.splitList([row[0] for row in await connection.fetchall(streamsRequiredFailSt)])
streamsRequiredRestart = self.splitList(await connection.fetchall(streamsRequiredRestartSt))
for requiredRestartBatch in streamsRequiredRestartReset:
await connection.execute(
update(models.Restart)
.where(models.Restart.stream_id.in_(requiredRestartBatch))
.values(current_attempt=0, status=AutoRestartStatus.enabled.value)
)
for requiredFailBatch in streamsRequiredFail:
await connection.execute(
update(models.Restart)
.where(models.Restart.stream_id.in_(requiredFailBatch))
.values(status=AutoRestartStatus.failed.value)
)
count = 0
for requiredRestartBatch in streamsRequiredRestart:
streamIdToVersionMap = {streamId: version for (streamId, version) in requiredRestartBatch}
streamIds = list(streamIdToVersionMap.keys()) # make ide silent
self.logger.warning(f"Restarting {len(streamIds)} streams")
self.logger.debug(f"Restarting stream ids: {streamIds}")
updateStreamSt = (
update(models.Stream)
.where(models.Stream.stream_id.in_(streamIds))
.values(status=StreamStatus.pending.value)
)
await connection.execute(updateStreamSt)
updateRestartSt = (
update(models.Restart)
.where(models.Restart.stream_id.in_(streamIds))
.values(
current_attempt=models.Restart.current_attempt + 1,
last_attempt_time=self.currentDBTimestamp,
status=AutoRestartStatus.in_progress.value,
)
)
await connection.execute(updateRestartSt)
for streamsLogBatch in self.splitList(streamIds, divider=6):
await self._insertStreamLogs(
connection=connection,
streamIds=streamsLogBatch,
streamLogStatuses=[StreamStatus.restart, StreamStatus.pending],
streamToVersionMap=streamIdToVersionMap,
)
count = len(streamIds)
return count
[docs]
@exceptionWrap
async def getStreamsLogs(
self,
page: int,
pageSize: int,
targets: Optional[list[str]] = None,
accountId: Optional[str] = None,
streamIds: Optional[list[int]] = None,
statuses: Optional[list[str]] = None,
logTimeLt: Optional[datetime] = None,
logTimeGte: Optional[datetime] = None,
) -> list[dict[str, Any]]:
"""
Get streams logs by filters
Args:
page: page
pageSize: page size
targets: targets
accountId: account id
streamIds: stream ids
statuses: stream statuses
logTimeLt: upper excluding bound of stream log creation time
logTimeGte: lower including bound of stream log creation time
"""
preparedTargets = list(LogTarget) if targets is None else [LogTarget[target] for target in targets]
selectSt = (
select([getattr(models.Log, target.value) for target in preparedTargets])
.where(
and_(
models.Stream.stream_id == models.Log.stream_id if accountId is not None else True,
models.Stream.account_id == accountId if accountId is not None else True,
models.Log.stream_id.in_(streamIds) if streamIds is not None else True,
(
models.Log.status.in_([StreamStatus[status].value for status in statuses])
if statuses is not None
else True
),
models.Log.time >= logTimeGte if logTimeGte else True,
models.Log.time < logTimeLt if logTimeLt else True,
)
)
.order_by(models.Log.id.desc())
.offset((page - 1) * pageSize)
.limit(pageSize)
)
async with DBContext.adaptor.connection(self.logger) as connection:
dbResult = await connection.fetchall(selectSt)
return makeOutputLogs(dbResult, preparedTargets, self.storageTime)
[docs]
@exceptionWrap
async def deleteStreamsLogs(self, logTimeLt: datetime | None = None) -> int:
"""
Delete streams logs except the last one
Args:
:logTimeLt: upper excluding bound of stream log creation time
Return:
deleted logs count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
maxIdsSts = select([func.max(models.Log.id)]).group_by(models.Log.stream_id)
deleteSt = delete(models.Log).where(
and_(models.Log.time < logTimeLt if logTimeLt else True, models.Log.id.notin_(maxIdsSts))
)
return await connection.execute(deleteSt)
[docs]
@exceptionWrap
async def createGroup(self, group: NewGroupModel) -> str:
"""
Create group.
Args:
group: group to create
Returns:
unique group id
"""
groupId = str(uuid.uuid4())
async with DBContext.adaptor.connection(self.logger) as connection:
try:
insertGroupSt = insert(models.Group).values(
group_id=groupId,
group_name=group.groupName,
account_id=group.accountId,
description=group.description,
)
await connection.execute(insertGroupSt)
except (IntegrityError, UniqueViolationError):
raise VLException(Error.StreamGroupAlreadyExist.format(group.groupName), 409, False)
return groupId
[docs]
@exceptionWrap
async def getGroups(
self,
groupNames: Optional[list[str]] = None,
groupIds: Optional[list[str]] = None,
accountId: Optional[str] = None,
page: int = 1,
pageSize: Optional[int] = None,
) -> list[dict[str, Any]]:
"""
Get groups.
Args:
groupNames: group names
groupIds: group ids
accountId: account id
page: page number
pageSize: page size
Returns:
list of groups
"""
selectGroupSt = (
select(
[
models.Group.group_id,
models.Group.group_name,
models.Group.account_id,
models.Group.description,
models.Group.create_time,
]
)
.where(
and_(
models.Group.group_name.in_(groupNames) if groupNames is not None else True,
models.Group.group_id.in_(groupIds) if groupIds is not None else True,
models.Group.account_id == accountId if accountId is not None else True,
)
)
.order_by(models.Group.id)
)
if pageSize is not None:
selectGroupSt = selectGroupSt.offset((page - 1) * pageSize).limit(pageSize)
async with DBContext.adaptor.connection(self.logger) as connection:
groupRows = await connection.fetchall(selectGroupSt)
return makeOutputGroups(groupRows, self.storageTime)
[docs]
@exceptionWrap
async def updateGroup(self, groupId: str, group: PatchGroupModel, accountId: Optional[str] = None) -> int:
"""
Update group.
Args:
groupId: group id
group: group to update
accountId: account id
Returns:
updated group count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
updateGroupSt = (
update(models.Group)
.where(
and_(
models.Group.account_id == accountId if accountId is not None else True,
models.Group.group_id == groupId,
)
)
.values(description=group.description)
)
return await connection.execute(updateGroupSt)
[docs]
@exceptionWrap
async def removeGroup(self, groupId: str, accountId: Optional[str] = None) -> int:
"""
Remove group.
Args:
groupId: group id
accountId: account id
Returns:
deleted group count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
deleteGroupSt = delete(models.Group).where(
and_(
models.Group.account_id == accountId if accountId is not None else True,
models.Group.group_id == groupId,
)
)
return await connection.execute(deleteGroupSt)
[docs]
@exceptionWrap
async def countGroups(
self,
groupNames: Optional[list[str]] = None,
groupIds: Optional[list[str]] = None,
accountId: Optional[str] = None,
) -> int:
"""
Get count of groups.
Args:
groupNames: group names
groupIds: group ids
accountId: account id
Returns:
count of groups
"""
filters = and_(
models.Group.group_name.in_(groupNames) if groupNames is not None else True,
models.Group.group_id.in_(groupIds) if groupIds is not None else True,
models.Group.account_id == accountId if accountId is not None else True,
)
selectGroupSt = select([func.count(models.Group.id)]).where(filters)
async with DBContext.adaptor.connection(self.logger) as connection:
return await connection.scalar(selectGroupSt)
@exceptionWrap
async def _getGroupIdAndName(
self,
connection: AbstractDBConnection,
groupName: Optional[str] = None,
groupId: Optional[str] = None,
accountId: Optional[str] = None,
block: bool = False,
) -> tuple[int, str]:
"""
Get existing sequence id and group name by filters.
Args:
connection: current connection with started transaction
accountId: accountId
groupName: group name
groupId: group id
block: whether to perform pessimistic locking of the group
Raises:
ContextException(Error.StreamGroupNameNotFound.format(groupName)) if group name doesnt exist
ContextException(Error.StreamGroupNotFound.format(groupId)) if group id doesnt exist
Returns:
seq id and name of the group
"""
if groupId is None:
if groupName is None or (groupName and accountId is None):
raise RuntimeError("Incorrect arguments passed")
query = select([models.Group.id, models.Group.group_name]).where(
and_(
models.Group.group_name == groupName if groupName is not None else True,
models.Group.group_id == groupId if groupId is not None else True,
models.Group.account_id == accountId if accountId is not None else True,
)
)
if block:
query = query.with_for_update()
res = await connection.fetchone(query)
if not res:
error = (
Error.StreamGroupNotFound.format(groupId)
if groupId is not None
else Error.StreamGroupNameNotFound.format(groupName)
)
raise ContextException(error)
return res
[docs]
@exceptionWrap
async def linkStreamsToGroup(self, streamIds: list[str], groupId: str, accountId: Optional[str] = None):
"""
Attach streamIds to group.
Args:
streamIds: stream ids
groupId: group id
accountId: account id
Raises:
ContextException(Error.StreamNotFound.format(streamId)) if stream doesnt exist
"""
async with DBContext.adaptor.connection(self.logger) as conn:
groupSeqId, groupName = await self._getGroupIdAndName(
conn, groupId=groupId, accountId=accountId, block=True
)
# Even that below query guarantees that no relations between
# unexistent streamIds and group will be made, we want to inform
# user about those streamIds.
lockedStreams = {x[0] for x in await self.blockStreams(conn, streamIds, accountId)}
for streamId in streamIds:
if streamId not in lockedStreams:
raise ContextException(Error.StreamNotFound.format(streamId))
# Database transaction management doesn't help to perform bulk inserts.
# Provide idempotency by nested query that excludes streamIds already linked
# with specified group.
insertSt = insert(models.GroupStream).from_select(
("id", "group_name", "stream_id"),
select([literal(groupSeqId), literal(groupName), models.Stream.stream_id]).where(
and_(
models.Stream.stream_id.in_(streamIds),
models.Stream.stream_id.notin_(
select([models.GroupStream.stream_id]).where(
and_(models.GroupStream.stream_id.in_(streamIds), models.GroupStream.id == groupSeqId)
)
),
)
),
)
await conn.execute(insertSt)
[docs]
@exceptionWrap
async def unlinkStreamsFromGroup(self, streamIds: list[str], groupId: str, accountId: Optional[str] = None):
"""
Unlink streamIds from group.
Args:
streamIds: stream ids
groupId: group id
accountId: account id
"""
async with DBContext.adaptor.connection(self.logger) as conn:
# Fetch groupId ahead because oracle XE doesnt support multiple-table DELETE statement
groupSeqId, _ = await self._getGroupIdAndName(conn, groupId=groupId, accountId=accountId)
deleteSt = delete(models.GroupStream).where(
and_(models.GroupStream.stream_id.in_(streamIds), models.GroupStream.id == groupSeqId)
)
await conn.execute(deleteSt)
[docs]
@exceptionWrap
async def getPreview(
self, streamId: str, previewType: Literal["last_frame", "live"], accountId: Optional[str] = None
) -> str:
"""
Get preview url
Args:
streamId: stream id
previewType: preview type - frame or live
accountId: account id
Returns:
preview url
Raises:
VLException(Error.PreviewNotFound.format(previewType, streamId), 404, False) if preview not found
VLException(Error.StreamNotFound.format(streamId), 404, False) if stream not found
"""
filters = (
and_(models.Stream.account_id == accountId, models.Log.stream_id == models.Stream.stream_id)
if accountId is not None
else True
)
async with DBContext.adaptor.connection(self.logger) as conn:
selectSt = (
select([models.Log.preview])
.where(and_(models.Log.stream_id == streamId, filters))
.order_by(models.Log.id.desc())
.limit(1)
)
dbReply = await conn.fetchone(selectSt)
if dbReply is None:
streamExistsSt = select([func.count(models.Stream.stream_id)]).where(
and_(models.Stream.stream_id == streamId, filters)
)
isStreamExists = await conn.scalar(streamExistsSt)
if isStreamExists:
raise VLException(Error.PreviewNotFound.format(previewType, streamId), 404, False)
raise VLException(Error.StreamNotFound.format(streamId), 404, False)
(previewStr,) = dbReply
if previewStr is None or (url := ujson.loads(previewStr).get(previewType, {}).get("url")) is None:
raise VLException(Error.PreviewNotFound.format(previewType, streamId), 404, False)
return url
[docs]
class SingleProcessLock:
"""
Single-process lock realisation.
Attributes:
logger: logger
name: lock name (SingleProcessLock.name)
context: database context
currentDBTimestamp: sql function for current database calculating
"""
def __init__(self, logger: Logger, name: str):
self.logger = logger
self.name = name
self.context = DBContext(self.logger)
self.currentDBTimestamp = currentDBTimestamp(DBContext.adaptor.dbtype)
[docs]
@exceptionWrap
async def heartbeat(self, name: str, ttl: int, sessionId: str) -> bool:
"""
Master heartbeat.
Args:
sessionId: current lock session
name: lock name
ttl: lock ttl
Returns:
true if lock has same session id otherwise false
"""
expiration = self.currentDBTimestamp + func.cast(timedelta(seconds=ttl), Interval)
updateSt = (
update(models.SingleProcessLock)
.where(and_(models.SingleProcessLock.name == name, models.SingleProcessLock.session_id == sessionId))
.values(expiration=expiration, checkpoint=self.currentDBTimestamp)
)
async with DBContext.adaptor.connection(self.logger) as connection:
count = await connection.execute(updateSt)
return bool(count)
[docs]
@exceptionWrap
async def lockProcess(self, name: str, ttl: int, sessionId: str) -> bool:
"""
Try to get a lock.
Lock can be to get if:
1) if expiration is null
2) if expiration in the past
Args:
name: lock name
ttl: lock ttl
sessionId: new lock session id
Returns:
True if lock is free and we get it (set up expiration time) otherwise False
"""
expiration = self.currentDBTimestamp + func.cast(timedelta(seconds=ttl), Interval)
updateSt = (
update(models.SingleProcessLock)
.where(
and_(
or_(
models.SingleProcessLock.expiration < self.currentDBTimestamp,
models.SingleProcessLock.expiration == None,
),
models.SingleProcessLock.name == name,
)
)
.values(expiration=expiration, checkpoint=self.currentDBTimestamp, session_id=sessionId)
)
async with DBContext.adaptor.connection(self.logger) as connection:
count = await connection.execute(updateSt)
return bool(count)