import contextlib
from asyncio import CancelledError
from dataclasses import dataclass
from datetime import datetime
from itertools import chain
from typing import Any, Callable, Iterable, Literal, Optional, Union
from uuid import uuid4
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from asyncpg import ForeignKeyViolationError, UndefinedFunctionError, UniqueViolationError
from sqlalchemy import Column, and_, asc, delete, desc, exists, func, insert, select, text, union_all, update
from sqlalchemy.dialects import oracle, postgresql
from sqlalchemy.exc import DatabaseError, IntegrityError
from sqlalchemy.orm import Query, aliased
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql.elements import BooleanClauseList, ColumnClause, TextClause, not_
from sqlalchemy.sql.operators import mod
from sqlalchemy.sql.selectable import CompoundSelect
from tzlocal import get_localzone_name
from vlutils.helpers import bytesToBase64, convertTimeToString
from vlutils.jobs.async_runner import AsyncRunner
from app.handlers.helpers import SearchFacesFilters, SearchListsFilters
from attributes_db.model import BasicAttributes, Descriptor, TemporaryAttributes
from configs.config import BACKGROUND_REMOVAL_BATCH_SIZE, DB_CONNECT_TIMEOUT
from configs.configs.configs.settings.classes import DBSetting
from crutches_on_wheels.cow.adaptors.connection import AbstractDBConnection
from crutches_on_wheels.cow.db.base_context import BaseDBContext
from crutches_on_wheels.cow.enums.attributes import TemporaryAttributeTargets as FaceAttributeTargets
from crutches_on_wheels.cow.errors.errors import Error
from crutches_on_wheels.cow.errors.exception import VLException
from crutches_on_wheels.cow.utils.db_functions import currentDBTimestamp, dbExceptionWrap as exceptionWrap
from crutches_on_wheels.cow.utils.functions import getPageCount
from crutches_on_wheels.cow.utils.log import Logger
from db.faces_db_tools.models import faces_models as models
from db.faces_db_tools.models.enums import AttributeSample, SampleType
from db.faces_db_tools.models.faces_models import DB_LIST_LIMIT, LINK_SEQUENCE
SAMPLE_POSTFIX = "_samples"
OBTAINING_METHOD_POSTFIX = "_obtaining_method"
VERSION_POSTFIX = "_version"
DESCRIPTOR = "descriptor"
MIN_GEN = 0 # minimum attribute generation
FACE_TO_LIST_TARGETS = frozenset(("face_id", "lists", "account_id"))
indexFaceName = None
indexDescriptorName = None
[docs]@dataclass
class Reference:
"""Reference class."""
descriptorVersion: int
descriptor: bytes = None
attributeId: str = None
faceId: str = None
# Dict with raw sqls for getting link/unlink keys
RAW_SQLS = {
"postgres": {
"link": """(SELECT listId, linkKey FROM (SELECT list_id AS listId, link_key AS linkKey FROM list_face WHERE
list_id = '{listId}' AND mod(link_key, 2) = {mod} ORDER BY link_key DESC) as f LIMIT 1)""",
"unlink": """(SELECT listId, unLinkKey FROM (SELECT list_id AS listId, unlink_key AS unLinkKey FROM
unlink_attributes_log WHERE list_id = '{listId}' AND mod(unlink_key, 2) = {mod} ORDER BY
unlink_key DESC) as f LIMIT 1)""",
},
"oracle": {
"link": """(SELECT listId, linkKey FROM (SELECT list_id AS listId, link_key AS linkKey FROM list_face WHERE
list_id = '{listId}' AND mod(link_key, 2) = {mod} ORDER BY link_key DESC) WHERE ROWNUM <= 1)""",
"unlink": """(SELECT listId, unLinkKey FROM (SELECT list_id AS listId, unlink_key AS unLinkKey FROM
unlink_attributes_log WHERE list_id = '{listId}' AND mod(unlink_key, 2) = {mod} ORDER BY
unlink_key DESC) WHERE ROWNUM <= 1)""",
},
}
[docs]def getCompiledQuery(query: Union[Query, CompoundSelect], dbType: str) -> str:
"""
Compile query to with literal_binds = True
Args:
query: query made by sqlalchemy or CompoundSelect as result of union_all from n queries operation
dbType: database type
Returns:
string with compiled query
"""
currentDialect = oracle if dbType == "oracle" else postgresql
kwargs = {"dialect": currentDialect.dialect(), "compile_kwargs": {"literal_binds": True}}
if isinstance(query, Query):
return str(query.statement.compile(**kwargs))
else:
return str(query.compile(**kwargs))
[docs]def makeOutputFaces(faces: list[dict[str, Any]], storageTime: str) -> list[dict[str, Any]]:
"""
Make result faces (from the database reply) proper for an user.
Args:
faces: face list
storageTime: storage time
Returns:
faces with changed fields
"""
processFunctions = {
"create_time": lambda t: convertTimeToString(t, storageTime == "UTC"),
"external_id": lambda s: s if s is not None else "",
"user_data": lambda s: s if s is not None else "",
"avatar": lambda s: s if s is not None else "",
}
for face in faces:
for field, processor in processFunctions.items():
if field in face:
face[field] = processor(face[field])
return faces
[docs]def makeOutputLists(lists: list[dict[str, Any]], storageTime: str) -> list[dict[str, Any]]:
"""
Make result of lists (from the database reply) proper for an user.
Args:
lists: list with lists
storageTime: storage time
Returns:
lists with changed fields
"""
processFunctions = {
"create_time": lambda t: convertTimeToString(t, storageTime == "UTC"),
"last_update_time": lambda t: convertTimeToString(t, storageTime == "UTC"),
"user_data": lambda s: s if s is not None else "",
}
result = [
{field: processFunctions[field](value) if field in processFunctions else value for field, value in row.items()}
for row in lists
]
return result
[docs]class DBContext(BaseDBContext):
"""
DB context
Attributes:
logger: request logger
"""
defaultDescriptorVersion: Optional[int] = None
# (AsyncRunner): background task executor
backgroundWorker: Optional[AsyncRunner] = None
# (AsyncIOScheduler): background scheduler
backgroundScheduler: Optional[AsyncIOScheduler] = None
# (set): list ids set for deferred updating a last_update_time of list
_listsForUpdate: set = set()
# (Logger): background task logger
backgroundLogger: Optional[Logger] = None
@classmethod
async def _deferredUpdateListLastUpdateTime(cls):
"""
Deferred update list last update time. Lists for updating are cached in the '_listsForUpdate'
"""
if not cls._listsForUpdate:
return
cls.backgroundLogger.debug(f"start to update {len(cls._listsForUpdate)} lists last update time")
context = cls(cls.backgroundLogger)
for listId in list(cls._listsForUpdate):
try:
# batch update provokes deadlocks (for several luna-faces instances)
await context.updateListLastUpdateTime(listId)
cls._listsForUpdate.remove(listId)
except Exception:
cls.backgroundLogger.exception()
else:
cls.backgroundLogger.debug(f"updated last update time for {listId}")
@classmethod
def _addListsToDeferrerUpdateListLastUpdateTime(cls, listIds: Iterable[str]):
"""
Add lists to deferrer update list last update time
Args:
listIds: list ids
"""
cls._listsForUpdate.update(listIds)
[docs] @classmethod
async def initDBContext(cls, dbSettings: DBSetting, storageTime: str, defaultDescriptorVersion: int) -> None:
"""
Initialize context
Args:
dbSettings: database settings
storageTime: storage time
defaultDescriptorVersion: default descriptor version
"""
await super().initDBContext(dbSettings=dbSettings, storageTime=storageTime, connectTimeout=DB_CONNECT_TIMEOUT)
cls.backgroundLogger = Logger(template="backgroundDatabase")
cls.defaultDescriptorVersion = defaultDescriptorVersion
cls.backgroundScheduler = AsyncIOScheduler(timezone=get_localzone_name())
cls.backgroundWorker = AsyncRunner(max(cls.adaptor.sessionCount // 2, 1), closeTimeout=5)
cls.backgroundScheduler.add_job(cls._deferredUpdateListLastUpdateTime, trigger="interval", seconds=1)
cls.backgroundScheduler.start()
if dbSettings.type == "oracle":
await cls.setOracleIndexNames(Logger("initialize faces db"))
@property
def currentDBTimestamp(self) -> TextClause:
"""
Get current db timestamp function
Returns:
function for computation current db timestamp
"""
return currentDBTimestamp(self.dbConfig.type)
[docs] @classmethod
async def closeDBContext(cls):
"""
Close all database contexts
"""
try:
await cls.backgroundWorker.close()
except CancelledError:
pass
await cls._deferredUpdateListLastUpdateTime()
if cls.adaptor:
await cls.adaptor.close()
if cls.backgroundScheduler:
cls.backgroundScheduler.shutdown()
[docs] @staticmethod
async def setOracleIndexNames(logger):
"""
Set index names for some requests with 'use-force-index'
"""
async with DBContext.adaptor.connection(logger) as connection:
indexNames = await connection.fetchall(
"select index_name from all_indexes where table_name in ('FACE', 'DESCRIPTOR') and "
"index_name like 'SYS_%' order by table_name"
)
global indexDescriptorName, indexFaceName
indexDescriptorName, indexFaceName = [indexName[0] for indexName in indexNames]
[docs] async def ping(self, pingCount: int) -> bool:
"""
Ping database. Execute 'SELECT 1' from one to 'pingCount' times.
Args:
pingCount: ping count
Returns:
True - if any request success execute otherwise False
"""
async def ping():
async with DBContext.adaptor.connection(self.logger) as connection:
st = select([1])
await connection.execute(st)
for i in range(pingCount):
try:
await ping()
except Exception:
self.logger.exception(f"failed db ping, try {i + 1}")
else:
return True
return False
[docs] @exceptionWrap
async def checkDbMatchFunctionExistence(
self, saFuncCall: ColumnClause, comparator: Callable[[Any], bool], *, name: str
) -> bool:
"""
Try call custom function that should exist in database. Assert expected result.
Args:
saFuncCall: constructed sqlalchemy function object
comparator: function for checking `saFuncCall` result
name: function name for logging
Returns:
True - if request success execute otherwise False
"""
self.logger.debug(f"Check database function '{name}' correctness: begin")
try:
async with DBContext.adaptor.connection(self.logger) as connection:
dbRes = await connection.fetchall(select([saFuncCall]))
except (DatabaseError, UndefinedFunctionError):
self.logger.debug(f"Database function '{name}' is not implemented.")
return False
if comparator(dbRes):
self.logger.debug(f"Check database function '{name}' correctness: success")
return True
self.logger.debug(f"Check database function '{name}' correctness: function returns incorrect result '{dbRes}'.")
return False
[docs] async def insertFaceAttributeData(
self, connection: AbstractDBConnection, faceId: str, attribute: TemporaryAttributes
) -> None:
"""
Insert face attribute data.
Args:
connection: db connection
faceId: face id
attribute: temporary attribute container
"""
insertToAttributesValue = {
"face_id": faceId,
"descriptor_samples_generation": 0,
"create_time": attribute.createTime.replace(tzinfo=None),
}
if attribute.basicAttributes:
insertToAttributesValue["age"] = attribute.basicAttributes.age
insertToAttributesValue["gender"] = attribute.basicAttributes.gender
insertToAttributesValue["ethnicity"] = attribute.basicAttributes.ethnicity
insertToAttributesValue["gender_obtaining_method"] = 1
insertToAttributesValue["gender_version"] = 1
insertToAttributesValue["age_obtaining_method"] = 1
insertToAttributesValue["age_version"] = 1
insertToAttributesValue["ethnicity_obtaining_method"] = 1
insertToAttributesValue["ethnicity_version"] = 1
insertToSamplesValues = [
{"face_id": faceId, "sample_id": sampleId, "type": SampleType.basic_attributes.value}
for sampleId in attribute.basicAttributesSamples
]
insertToSamplesValues += [
{"face_id": faceId, "sample_id": sampleId, "type": SampleType.face_descriptor.value}
for sampleId in attribute.descriptorSamples
]
insertDescriptorValues = [
{
"descriptor_version": descriptor.version,
"face_id": faceId,
"descriptor": descriptor.descriptor,
"descriptor_obtaining_method": 1,
"descriptor_generation": 0,
}
for descriptor in attribute.descriptors
]
await connection.execute(insert(models.Attribute).values(insertToAttributesValue))
if self.dbConfig.type == "oracle":
for value in insertToSamplesValues:
await connection.execute(insert(models.Sample).values(value))
for value in insertDescriptorValues:
await connection.execute(insert(models.Descriptor).values(value))
else:
if insertToSamplesValues:
await connection.execute(insert(models.Sample).values(insertToSamplesValues))
if insertDescriptorValues:
await connection.execute(insert(models.Descriptor).values(insertDescriptorValues))
[docs] async def createFace(
self,
externalFaceId: Optional[str] = None,
listIds: Optional[set[str]] = None,
attribute: Optional[TemporaryAttributes] = None,
**kwargs,
) -> str:
"""
Create face.
Args:
externalFaceId: external faceId
listIds: luna lists
attribute: temporary attribute container
Keyword Args:
event_id: reference to event which created face
user_data: face information
account_id: id of account, required
external_id: external id of the face, if it has its own mapping in external system
avatar: image url that represents the face
Returns:
faceId: face id in uuid4 format
"""
kwargs["face_id"] = faceId = str(uuid4()) if externalFaceId is None else externalFaceId
accountId = kwargs["account_id"]
async def processFaceListsAndAttributes(openConnection):
if attribute is not None:
await self.insertFaceAttributeData(openConnection, faceId=faceId, attribute=attribute)
if listIds:
selectLists = select([models.List.list_id, text(f"'{faceId}'"), LINK_SEQUENCE.next_value()]).where(
and_(models.List.account_id == accountId, models.List.list_id.in_(listIds))
)
insertFaceListSt = insert(models.ListFace).from_select(
[models.ListFace.list_id, models.ListFace.face_id, models.ListFace.link_key], selectLists
)
listCount = await openConnection.execute(insertFaceListSt)
if len(listIds) != listCount:
# stupid oracle return only one updated row
selectListSt = select([models.List.list_id]).where(
and_(models.List.list_id.in_(listIds), models.List.account_id == accountId)
)
updatedLists = await openConnection.fetchall(selectListSt)
updatedLists = set(chain(*updatedLists))
nonExistListId = next(iter(listIds - updatedLists))
raise VLException(Error.ListsNotFound.format(nonExistListId), 400, False)
return faceId
async with DBContext.adaptor.connection(self.logger) as connection:
deleteSt = delete(models.Face).where(models.Face.face_id == faceId)
insertSt = insert(models.Face).values(**kwargs)
# fast way creating face
try:
await connection.execute(insertSt)
except (IntegrityError, UniqueViolationError):
# face with same id already exists, need remove it, go to slow way
pass
else:
await processFaceListsAndAttributes(connection)
if listIds:
self._addListsToDeferrerUpdateListLastUpdateTime(listIds)
return faceId
# slow way
async with DBContext.adaptor.connection(self.logger) as connection:
await connection.execute(deleteSt)
await connection.execute(insertSt)
await processFaceListsAndAttributes(connection)
if listIds:
self._addListsToDeferrerUpdateListLastUpdateTime(listIds)
return faceId
[docs] @exceptionWrap
async def updateFace(self, faceId: str, accountId: Optional[str] = None, **kwargs) -> int:
"""
Update face.
Args:
faceId: face id
accountId: account id
Keyword Args:
user_data: face information
event_id: reference to event which created face
external_id: external id of the face, if it has its own mapping in external system
avatar: image url that represents the face
Returns:
updated face count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
updateFaceSt = (
update(models.Face)
.where(and_(models.Face.face_id == faceId, models.Face.account_id @ accountId))
.values(**kwargs)
)
return await connection.execute(updateFaceSt)
[docs] @exceptionWrap
async def linkFacesToList(self, listId: str, faces: list[str]) -> tuple[list[str], list[str]]:
"""
Attach faces to list.
Args:
listId: list id
faces: face ids
Raises:
IntegrityError
Exception
Returns:
list of success link faces and failed link faces.
"""
success = []
error = []
async def linkFace(faceId, conn):
try:
updateFaceListSt = insert(models.ListFace).values(
list_id=listId, face_id=faceId, link_key=LINK_SEQUENCE.next_value()
)
await conn.execute(updateFaceListSt)
except (IntegrityError, UniqueViolationError, ForeignKeyViolationError):
error.append(faceId)
return False
else:
success.append(faceId)
return True
async with DBContext.adaptor.connection(self.logger) as connection:
updateFaceSt = (
update(models.Face)
.where(models.Face.face_id.in_(faces))
.values(last_update_time=self.currentDBTimestamp)
)
await connection.execute(updateFaceSt)
self._addListsToDeferrerUpdateListLastUpdateTime((listId,))
for face in faces:
async with DBContext.adaptor.connection(self.logger) as connection:
await linkFace(face, connection)
return success, error
[docs] @exceptionWrap
async def unlinkFacesFromList(self, listId: str, faces: list[str], accountId: Optional[str] = None):
"""
Unlink faces from list.
Args:
listId: list id
faces: face ids
accountId: account id
"""
async def updateHistoryLog():
querySelect = select([models.ListFace.list_id, models.Face.face_id, models.ListFace.link_key]).where(
and_(
models.Face.face_id.in_(faces),
models.Face.face_id == models.ListFace.face_id,
models.ListFace.list_id == listId,
)
)
insertSt = insert(models.UnlinkAttributesLog).from_select(
[
models.UnlinkAttributesLog.list_id,
models.UnlinkAttributesLog.face_id,
models.UnlinkAttributesLog.link_key,
],
querySelect,
)
await connection.execute(insertSt)
async def unlinkFaces():
deleteFacesSt = delete(models.ListFace).where(
and_(
models.ListFace.face_id.in_(faces),
models.ListFace.list_id == listId,
models.List.account_id @ accountId,
)
)
await connection.execute(deleteFacesSt)
updateFaceSt = (
update(models.Face)
.where(models.Face.face_id.in_(faces))
.values(last_update_time=self.currentDBTimestamp)
)
await connection.execute(updateFaceSt)
async with DBContext.adaptor.connection(self.logger) as connection:
await self.blockFaces(connection, [models.Face.face_id.in_(faces)])
await updateHistoryLog()
await unlinkFaces()
self._addListsToDeferrerUpdateListLastUpdateTime((listId,))
[docs] @exceptionWrap
async def createList(self, account_id: str, listId: str, user_data: Optional[str] = "") -> None:
"""
Create list for account.
Args:
account_id: account id
user_data: user data
listId: list id
"""
async with DBContext.adaptor.connection(self.logger) as connection:
insertSt = insert(models.List).values(list_id=listId, account_id=account_id, user_data=user_data)
try:
await connection.execute(insertSt)
except (IntegrityError, UniqueViolationError):
raise VLException(Error.ListAlreadyExist.format(listId), 409, isCriticalError=False)
[docs] @staticmethod
def prepareSearchFacesFilters(
filters: SearchFacesFilters,
modelsFace: Optional[AliasedClass] = None,
) -> list[BooleanClauseList]:
"""
Prepare search query filters on "Face" models.
Args:
filters: query filters
modelsFace: database table model for faces
Returns:
filters for Face model
"""
modelsFace = modelsFace or models.Face
return [
modelsFace.account_id == filters.account_id if filters.account_id is not None else True,
modelsFace.user_data.like("%{}%".format(filters.user_data)) if filters.user_data is not None else True,
modelsFace.event_id == filters.event_id if filters.event_id is not None else True,
modelsFace.create_time < filters.create_time__lt if filters.create_time__lt is not None else True,
modelsFace.create_time >= filters.create_time__gte if filters.create_time__gte is not None else True,
modelsFace.face_id < filters.face_id__lt if filters.face_id__lt is not None else True,
modelsFace.face_id >= filters.face_id__gte if filters.face_id__gte is not None else True,
modelsFace.external_id.in_(filters.external_ids) if filters.external_ids is not None else True,
modelsFace.face_id.in_(filters.face_ids) if filters.face_ids is not None else True,
]
[docs] @staticmethod
def setOffsetAndOrderToQuery(
query: Query,
order: Optional[Literal["desc", "asc"]] = "desc",
orderColumn: Optional[Column] = None,
page: Optional[int] = 1,
pageSize: Optional[int] = 100,
) -> Query:
"""
Apply `OFFSET` and `ORDER BY` criterion to the query.
Args:
query: query object
order: result sort order (ask or desc)
orderColumn: result sort column
page: pagination page value
pageSize: pagination page size value, set -1 to get all faces
Returns:
updated query
"""
if orderColumn is not None:
query = query.order_by((asc if order == "asc" else desc)(orderColumn))
if pageSize != -1:
query = query.offset((page - 1) * pageSize).limit(pageSize)
return query
[docs] @staticmethod
def prepareListQueryFilters(filters: SearchListsFilters) -> BooleanClauseList:
"""
Prepare search query filters on "List" models.
Args:
filters: query filters
Returns:
filters for List select query
"""
sqlFilters = [
models.List.account_id @ filters.account_id,
models.List.user_data.like("%{}%".format(filters.user_data)) if filters.user_data is not None else True,
models.List.user_data @ filters.user_data__eq,
models.List.create_time < filters.create_time__lt if filters.create_time__lt is not None else True,
models.List.create_time >= filters.create_time__gte if filters.create_time__gte is not None else True,
models.List.last_update_time < filters.last_update_time__lt
if filters.last_update_time__lt is not None
else True,
models.List.last_update_time >= filters.last_update_time__gte
if filters.last_update_time__gte is not None
else True,
models.List.list_id < filters.list_id__lt if filters.list_id__lt is not None else True,
models.List.list_id >= filters.list_id__gte if filters.list_id__gte is not None else True,
models.List.list_id.in_(filters.list_ids) if filters.list_ids is not None else True,
]
return and_(*sqlFilters)
[docs] @exceptionWrap
async def executeSearchFaces(
self,
filters: SearchFacesFilters,
targets: Optional[list[str]],
page: Optional[int] = 1,
pageSize: Optional[int] = 100,
) -> list[dict[str, str]]:
"""
Get faces searched by filters
Args:
filters: raw search request filters
targets: target Face columns' names to get info on
page: page
pageSize: page size
Returns:
faces list
"""
queryColumns = [getattr(models.Face, target) for target in targets if target != "lists"]
orderColumn = (
models.Face.face_id if any((filters.face_id__gte, filters.face_id__lt)) else models.Face.create_time
)
if isNeedReturnLists := "lists" in targets:
# needed to get lists by face IDs
if "face_id" not in targets:
queryColumns.append(models.Face.face_id)
query = Query(queryColumns).filter(
and_(
models.ListFace.face_id == models.Face.face_id if filters.list_id is not None else True,
models.ListFace.list_id == filters.list_id if filters.list_id is not None else True,
*self.prepareSearchFacesFilters(filters),
)
)
query = self.setOffsetAndOrderToQuery(query=query, orderColumn=orderColumn, page=page, pageSize=pageSize)
async with DBContext.adaptor.connection(self.logger) as connection:
faceRows = await connection.fetchall(query.statement)
if isNeedReturnLists:
mapFaceToList = {}
faceIds = [row["face_id"] for row in faceRows]
if faceIds:
listQuery = Query([models.ListFace.list_id, models.ListFace.face_id]).filter(
models.ListFace.face_id.in_(faceIds)
)
mapListFace = await connection.fetchall(listQuery.statement)
for lunaList, face in mapListFace:
mapFaceToList.setdefault(face, []).append(lunaList)
faces = [
{
target: mapFaceToList.get(row["face_id"], []) if target == "lists" else row[target]
for target in targets
}
for row in faceRows
]
else:
faces = [{target: row[target] for target in targets} for row in faceRows]
return makeOutputFaces(faces, self.storageTime)
[docs] @exceptionWrap
async def executeSearchFacesFilteredByList(
self, filters: SearchFacesFilters, targets: list[str], page: Optional[int] = 1, pageSize: Optional[int] = 100
) -> list[dict[str, str]]:
"""
Get faces searched by filters associated with face-list model
Args:
filters: raw search request filters
targets: target Face columns' names to get info on
page: page
pageSize: page size
Returns:
faces list
"""
# available target columns
columnsMap = {
"lists": models.ListFace.list_id,
"face_id": models.ListFace.face_id,
"account_id": models.List.account_id,
}
queryFilters = and_(
models.List.account_id == filters.account_id if filters.account_id is not None else True,
models.List.list_id == filters.list_id if filters.list_id is not None else True,
models.ListFace.list_id == models.List.list_id if filters.list_id is not None else True,
models.ListFace.face_id < filters.face_id__lt if filters.face_id__lt is not None else True,
models.ListFace.face_id >= filters.face_id__gte if filters.face_id__gte is not None else True,
models.ListFace.face_id.in_(filters.face_ids) if filters.face_ids is not None else True,
)
query = select([columnsMap[target] for target in targets]).where(queryFilters)
query = self.setOffsetAndOrderToQuery(
query=query, orderColumn=models.ListFace.face_id, page=page, pageSize=pageSize
)
async with DBContext.adaptor.connection(self.logger) as connection:
faceRows = await connection.fetchall(query)
# `one to many` relation, one `list_id` for lists
return [
{target: [row["list_id"]] if target == "lists" else row[target] for target in targets} for row in faceRows
]
[docs] @exceptionWrap
async def getFaces(
self,
filters: SearchFacesFilters,
targets: Optional[list[str]] = None,
page: Optional[int] = 1,
pageSize: Optional[int] = 100,
) -> list[dict[str, str]]:
"""
Get faces.
Args:
filters: raw search request filters
targets: target Face columns' names to get info on
page: page
pageSize: page size
Returns:
list of faces
"""
if filters.listFaceFiltersAreUsed() and any((filters.face_id__gte, filters.face_id__lt)):
# getting faces from a list-face model only when appropriate filters and targets are available
# more profitable with a large count of faces in the list
if not set(targets).difference(FACE_TO_LIST_TARGETS):
return await self.executeSearchFacesFilteredByList(filters, targets, page=page, pageSize=pageSize)
return await self.executeSearchFaces(filters, targets, page=page, pageSize=pageSize)
[docs] @exceptionWrap
async def getNonexistentFaceId(
self,
requiredFaceIds: set[str],
accountId: Optional[str] = None,
dbConnection: Optional[AbstractDBConnection] = None,
) -> Union[str, None]:
"""
Get one of requiredFaceIds that not exists.
Args:
requiredFaceIds: set of required face ids
accountId: account id
dbConnection: current connection. Needed not to get new connection from pool
or to do some stuff in scope of the transaction.
Returns:
Non existing face id
"""
async with contextlib.AsyncExitStack() as stack:
connection = (
await stack.enter_async_context(DBContext.adaptor.connection(self.logger))
if dbConnection is None
else dbConnection
)
query = Query([models.Face.face_id]).filter(
models.Face.account_id == accountId if accountId is not None else True,
models.Face.face_id.in_(requiredFaceIds),
)
existFaceIds = set(chain(*await connection.fetchall(query.statement)))
nonexistentFaceIds = requiredFaceIds - existFaceIds
return next(iter(nonexistentFaceIds)) if nonexistentFaceIds else None
[docs] @exceptionWrap
async def getFacesCountByList(self, filters: SearchFacesFilters) -> int:
"""
Get count of faces filtered by list id.
Args:
filters: raw search request filters
Returns:
Number of faces
"""
async with DBContext.adaptor.connection(self.logger) as connection:
if filters.account_id is not None:
# checking the list by account before getting the faces count is more profitable
listCountSt = select([func.count(models.List.list_id)]).where(
and_(models.List.account_id == filters.account_id, models.List.list_id == filters.list_id)
)
if not await connection.scalar(listCountSt):
return 0
selectSt = select([func.count(models.ListFace.face_id)]).where(
and_(
models.ListFace.list_id == filters.list_id,
models.ListFace.face_id < filters.face_id__lt if filters.face_id__lt is not None else True,
models.ListFace.face_id >= filters.face_id__gte if filters.face_id__gte is not None else True,
models.ListFace.face_id.in_(filters.face_ids) if filters.face_ids is not None else True,
)
)
return await connection.scalar(selectSt)
[docs] @exceptionWrap
async def getFacesCount(self, filters: SearchFacesFilters) -> int:
"""
Get count of faces.
Args:
filters: raw search request filters
Returns:
Number of faces
"""
if filters.listFaceFiltersAreUsed():
return await self.getFacesCountByList(filters)
else:
selectSt = select([func.count(models.Face.face_id)]).where(
and_(
models.ListFace.face_id == models.Face.face_id if filters.list_id is not None else True,
models.ListFace.list_id == filters.list_id if filters.list_id is not None else True,
*self.prepareSearchFacesFilters(filters),
)
)
async with DBContext.adaptor.connection(self.logger) as connection:
return await connection.scalar(selectSt)
[docs] @exceptionWrap
async def getFacesAttributesCount(self, accountId: Optional[str] = None) -> int:
"""
Return face attribute count
Args:
accountId: account id
Returns:
face attribute count
"""
if accountId:
filters = and_(models.Face.account_id == accountId, models.Attribute.face_id == models.Face.face_id)
else:
filters = and_()
selectSt = select([func.count(models.Attribute.face_id)]).where(filters)
async with DBContext.adaptor.connection(self.logger) as connection:
faceAttributeCount = await connection.scalar(selectSt)
return faceAttributeCount
[docs] @exceptionWrap
async def getLists(
self, filters: SearchListsFilters, page: Optional[int] = 1, pageSize: Optional[int] = 100
) -> list[dict[str, str]]:
"""
Get list
Args:
filters: raw search request filters
page: page
pageSize: page size
Returns:
List with dicts
"""
queryFilters = self.prepareListQueryFilters(filters)
orderColumn = (
models.List.list_id
if filters.list_id__gte is not None or filters.list_id__lt is not None
else models.List.create_time
)
async with DBContext.adaptor.connection(self.logger) as connection:
query = (
Query(models.List)
.filter(queryFilters)
.order_by(orderColumn.desc())
.offset((page - 1) * pageSize)
.limit(pageSize)
)
listRows = await connection.fetchall(query.statement)
return makeOutputLists(listRows, storageTime=self.storageTime)
[docs] @exceptionWrap
async def getNonexistentListId(
self, requiredListIds: set[str], accountId: Optional[str] = None
) -> Union[str, None]:
"""
Get one of requiredListIds that not exists.
Args:
requiredListIds: set of required face ids
accountId: account id
Returns:
Non existsing list id
"""
async with DBContext.adaptor.connection(self.logger) as connection:
query = Query(models.List.list_id).filter(
models.List.account_id @ accountId, models.List.list_id.in_(requiredListIds)
)
res = await connection.fetchall(query.statement)
existListIds = set(chain(*res))
nonexistentListIds = requiredListIds - existListIds
return next(iter(nonexistentListIds)) if nonexistentListIds else None
[docs] @exceptionWrap
async def getListsCount(self, filters: SearchListsFilters) -> int:
"""
Count lists
Args:
filters: raw search request filters
Returns:
Count of lists
"""
queryFilters = self.prepareListQueryFilters(filters)
async with DBContext.adaptor.connection(self.logger) as connection:
query = Query([func.count(models.List.list_id).label("count")])
query = query.filter(queryFilters)
listCount = await connection.scalar(query.statement)
return listCount
[docs] @exceptionWrap
async def getListsAndKeysFromMV(self, listIds: set[str], useParity: int) -> list[dict[str, Any]]:
"""
Get lists with last link and unlink keys from materialized views
Args:
listIds: list ids
useParity: if 1 - get even and odd max link/unlink keys else max link/unlink keys
Returns:
List with list ids and its link/unlink keys
"""
listIdsList = list(listIds)
listIdsBatches = [listIdsList[idx : idx + DB_LIST_LIMIT] for idx in range(0, len(listIdsList), DB_LIST_LIMIT)]
if useParity:
async with DBContext.adaptor.connection(self.logger) as connection:
selectQueryList = [
Query(
[
models.List.list_id,
models.MV_LINK_0.link_key.label("link_key_even"),
models.MV_LINK_1.link_key.label("link_key_odd"),
models.MV_UNLINK_0.unlink_key.label("unlink_key_even"),
models.MV_UNLINK_1.unlink_key.label("unlink_key_odd"),
]
)
.join(models.MV_LINK_0, models.MV_LINK_0.list_id == models.List.list_id, full=True)
.join(models.MV_LINK_1, models.MV_LINK_1.list_id == models.List.list_id, full=True)
.join(models.MV_UNLINK_0, models.MV_UNLINK_1.list_id == models.List.list_id, full=True)
.join(models.MV_UNLINK_1, models.MV_UNLINK_1.list_id == models.List.list_id, full=True)
.filter(models.List.list_id.in_(listIdsBatch))
for listIdsBatch in listIdsBatches
]
selectResult = []
for selectQuery in selectQueryList:
selectResult.extend(await connection.fetchall(selectQuery.statement))
result = [
{
"list_id": listId,
"link_key_even": linkKeyEven,
"link_key_odd": linkKeyOdd,
"unlink_key_even": unlinkKeyEven,
"unlink_key_odd": unlinkKeyOdd,
}
for listId, linkKeyEven, linkKeyOdd, unlinkKeyEven, unlinkKeyOdd in selectResult
]
else:
async with DBContext.adaptor.connection(self.logger) as connection:
selectQueryList = [
Query([models.List.list_id, models.MV_LINK.link_key, models.MV_UNLINK.unlink_key])
.join(models.MV_LINK, models.MV_LINK.list_id == models.List.list_id, full=True)
.join(models.MV_UNLINK, models.MV_UNLINK.list_id == models.List.list_id, full=True)
.filter(models.List.list_id.in_(listIdsBatch))
for listIdsBatch in listIdsBatches
]
selectResult = []
for selectQuery in selectQueryList:
selectResult.extend(await connection.fetchall(selectQuery.statement))
result = [
{"list_id": listId, "link_key": linkKey, "unlink_key": unlinkKey}
for listId, linkKey, unlinkKey in selectResult
]
return result
[docs] @exceptionWrap
async def getListsWithKeys(self, listIds: set[str], useParity: int) -> list[dict[str, Any]]:
"""
Get lists with last link and unlink keys from usual tables
Args:
listIds: list ids
useParity: if 1 - get even and odd max link/unlink keys, if 0 - max link/unlink keys
Returns:
list with list ids and its link/unlink keys
"""
listIdsList = list(listIds)
listIdsBatches = [listIdsList[idx : idx + DB_LIST_LIMIT] for idx in range(0, len(listIdsList), DB_LIST_LIMIT)]
async def getKeysFromDb(
searchLists: list[str], dbConnection, keyType: str, parity: Optional[str] = None
) -> dict:
"""
Get link/unlink keys from db
Args:
searchLists: list of searching lists
dbConnection: database connection
keyType: link or unlink
parity: none, odd or even
Returns:
dict with listIds and keys
"""
if parity is not None:
rawSql = " UNION ALL ".join(
(
RAW_SQLS[self.dbConfig.type][keyType].format(listId=listId, mod=1 if parity == "odd" else 0)
for listId in searchLists
)
)
res = await dbConnection.fetchall(rawSql)
return dict(res)
else:
queries = []
for listId in searchLists:
if keyType == "link":
selectParams = [models.ListFace.list_id, models.ListFace.link_key]
basicFilters = models.ListFace.list_id == listId
key = models.ListFace.link_key
else:
selectParams = [models.UnlinkAttributesLog.list_id, models.UnlinkAttributesLog.unlink_key]
basicFilters = models.UnlinkAttributesLog.list_id == listId
key = models.UnlinkAttributesLog.unlink_key
query = Query(selectParams).filter(basicFilters).order_by(key.desc()).limit(1)
queries.append(query)
linkKeyAndLists = await dbConnection.fetchall(union_all(*queries))
return dict(linkKeyAndLists)
async with DBContext.adaptor.connection(self.logger) as connection:
existsListSts = [
Query(models.List.list_id).filter(models.List.list_id.in_(listIdsBatch)).statement
for listIdsBatch in listIdsBatches
]
res = []
for existListSt in existsListSts:
res.extend(await connection.fetchall(existListSt))
existsLists = [listId[0] for listId in res]
existsListsBatches = [
existsLists[idx : idx + DB_LIST_LIMIT] for idx in range(0, len(existsLists), DB_LIST_LIMIT)
]
if not existsListsBatches:
return []
result = dict()
for existsListsBatch in existsListsBatches:
if useParity:
mapListIdMaxEvenLinkKey = await getKeysFromDb(existsListsBatch, connection, "link", "odd")
mapListIdMaxEvenUnlinkKey = await getKeysFromDb(existsListsBatch, connection, "unlink", "odd")
mapListIdMaxOddLinkKey = await getKeysFromDb(existsListsBatch, connection, "link", "even")
mapListIdMaxOddUnlinkKey = await getKeysFromDb(existsListsBatch, connection, "unlink", "even")
result.update(
dict(
zip(
existsListsBatch,
[
{
"link_key_even": mapListIdMaxOddLinkKey.get(listId, None),
"unlink_key_even": mapListIdMaxOddUnlinkKey.get(listId, None),
"link_key_odd": mapListIdMaxEvenLinkKey.get(listId, None),
"unlink_key_odd": mapListIdMaxEvenUnlinkKey.get(listId, None),
}
for listId in existsListsBatch
],
)
)
)
else:
mapListIdMaxLinkKey = await getKeysFromDb(existsListsBatch, connection, "link")
mapListIdMaxUnlinkKey = await getKeysFromDb(existsListsBatch, connection, "unlink")
result.update(
dict(
zip(
existsListsBatch,
[
{
"link_key": mapListIdMaxLinkKey.get(listId, None),
"unlink_key": mapListIdMaxUnlinkKey.get(listId, None),
}
for listId in existsListsBatch
],
)
)
)
return [{"list_id": listId, **listData} for listId, listData in result.items()]
[docs] @exceptionWrap
async def deleteFaces(self, faces: list[str], accountId: Optional[str] = None) -> int:
"""
Remove faces.
Args:
accountId: faces account id
faces: faces ids
Returns:
removed face count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = [models.Face.face_id.in_(faces), models.Face.account_id @ accountId]
async def updateHistoryLog(facesForRemove):
linkedLists = await self._moveFacesLinksToLog(facesForRemove, connection)
self.logger.debug(f"update history log")
return linkedLists
self.logger.debug(f"start delete {len(faces)} faces")
_, blockedFaces = await self.blockFaces(connection, filters=filters, returnBlockedFaces=True)
updatedLists = await updateHistoryLog(blockedFaces)
deleteFacesSt = delete(models.Face).where(and_(*filters))
removedFaceCount = await connection.execute(deleteFacesSt)
self.logger.debug(f"delete {removedFaceCount} faces")
self._addListsToDeferrerUpdateListLastUpdateTime(updatedLists)
return removedFaceCount
@exceptionWrap
async def _deleteLists(self, lists: list[str], accountId: Optional[str] = None) -> int:
"""
Remove lists.
Args:
accountId: lists account id
lists: lists
Returns:
removed list count
Warnings:
trigger `trg_lists_deletion_log` will insert a data to the table `ListsDeletionLog`
"""
async with DBContext.adaptor.connection(self.logger) as connection:
deleteListsSt = delete(models.List).where(
and_(models.List.list_id.in_(lists), models.List.account_id @ accountId)
)
return await connection.execute(deleteListsSt)
[docs] @exceptionWrap
async def deleteLists(self, lists: list[str], accountId: Optional[str] = None, withFaces: bool = False) -> int:
"""
Remove lists.
Args:
accountId: lists account id
lists: lists
withFaces: remove lists with all faces which is contained in these lists
Returns:
removed list count
"""
if not withFaces:
return await self._deleteLists(lists, accountId)
listsFilters = [models.List.list_id.in_(lists), models.List.account_id @ accountId]
async with DBContext.adaptor.connection(self.logger) as connection:
query = Query([models.List.list_id]).filter(and_(*listsFilters)).with_for_update()
blockListsRes = await connection.fetchall(query.statement)
if not blockListsRes:
return 0
listsToRealDeletion = [blockedList[0] for blockedList in blockListsRes]
query = select([models.ListFace.face_id]).where(
and_(
models.List.list_id.in_(listsToRealDeletion),
models.List.list_id == models.ListFace.list_id,
models.List.account_id @ accountId,
)
)
faceIdsToDeletion = [row[0] for row in await connection.fetchall(query)]
deleteListsSt = delete(models.List).where(and_(*listsFilters))
removedListCount = await connection.execute(deleteListsSt)
pageSize = BACKGROUND_REMOVAL_BATCH_SIZE
pageCount = getPageCount(len(faceIdsToDeletion), pageSize)
coros = [
self.deleteFaces(faceIdsToDeletion[page * pageSize : (page + 1) * pageSize]) for page in range(pageCount)
]
DBContext.backgroundWorker.runNoWait(coros)
return removedListCount
[docs] @exceptionWrap
async def updateListUserData(self, listId: str, userData: str, accountId: Optional[str] = None) -> int:
"""
Update user data of list
Args:
listId: list id
userData: user data
accountId : account id
Returns:
updated list count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
updateListSt = (
update(models.List)
.where(and_(models.List.list_id == listId, models.List.account_id @ accountId))
.values(user_data=userData)
)
return await connection.execute(updateListSt)
[docs] @exceptionWrap
async def isFacesExist(self, faceIds: list[str], accountId: Optional[str] = None) -> bool:
"""
Checking to exist faces or not.
Args:
accountId: account id
faceIds: face ids
Returns:
True if all faces exist else false
"""
async with DBContext.adaptor.connection(self.logger) as connection:
query = Query([func.count(models.Face.face_id).label("count")])
query = query.filter(models.Face.face_id.in_(faceIds), models.Face.account_id @ accountId)
res = await connection.fetchone(query.statement)
faceCount = res[0]
return True if faceCount == len(faceIds) else False
[docs] @exceptionWrap
async def isListsExist(self, listIds: set[str], accountId: Optional[str] = None) -> bool:
"""
Checking to exist lists or not.
Args:
accountId: account id
listIds: list ids
Returns:
True if all lists exist else false
"""
async with DBContext.adaptor.connection(self.logger) as connection:
query = Query([func.count(models.List.list_id).label("count")])
query = query.filter(models.List.list_id.in_(listIds), models.List.account_id @ accountId)
res = await connection.fetchone(query.statement)
listCount = res[0]
return True if listCount == len(listIds) else False
[docs] @exceptionWrap
async def getListPlusDelta(
self,
listId,
linkKeyGte: Optional[int] = None,
linkKeyLt: Optional[int] = None,
limit: int = 10000,
parity: Optional[int] = None,
) -> list[dict]:
"""
Get attach attributes to lists
Args:
listId: list id
linkKeyLt: upper bound of link key value
linkKeyGte: lower bound of link key value
limit: limit
parity: 0 for odd or 1 for even link keys to search for
Returns:
List of dicts with following keys: "face_id", "link_key"
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = and_(
models.ListFace.link_key < linkKeyLt if linkKeyLt is not None else True,
models.ListFace.link_key >= linkKeyGte if linkKeyGte is not None else True,
models.ListFace.list_id == listId,
models.Attribute.face_id == models.ListFace.face_id,
mod(models.ListFace.link_key, 2) == parity if parity is not None else True,
)
query = Query([models.Attribute.face_id, models.ListFace.link_key])
query = query.filter(filters).order_by(models.ListFace.link_key.asc()).limit(limit)
attributesRows = await connection.fetchall(query.statement)
# used by matcher
res = [dict(zip(("attribute_id", "link_key"), attributesRow)) for attributesRow in attributesRows]
return res
[docs] @exceptionWrap
async def getListMinusDelta(
self,
listId,
unlinkKeyGte: Optional[int] = None,
unlinkKeyLt: Optional[int] = None,
limit: int = 10000,
parity: Optional[int] = None,
) -> list[dict]:
"""
Get history of detach attribute to lists
Args:
listId: list id
unlinkKeyLt: upper bound of unlink key value
unlinkKeyGte: lower bound of unlink key value
limit: limit
parity: 0 for odd or 1 for even link keys to search for
Returns:
List of dicts with following keys: "face_id", "link_key", "unlink_key"
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = and_(
models.UnlinkAttributesLog.unlink_key < unlinkKeyLt if unlinkKeyLt is not None else True,
models.UnlinkAttributesLog.unlink_key >= unlinkKeyGte if unlinkKeyGte is not None else True,
models.UnlinkAttributesLog.list_id == listId,
mod(models.UnlinkAttributesLog.unlink_key, 2) == parity if parity is not None else True,
models.UnlinkAttributesLog.face_id != None,
)
query = Query(
[
models.UnlinkAttributesLog.face_id,
models.UnlinkAttributesLog.link_key,
models.UnlinkAttributesLog.unlink_key,
]
)
query = query.filter(filters).order_by(models.UnlinkAttributesLog.unlink_key.asc()).limit(limit)
attributesRows = await connection.fetchall(query.statement)
res = [
dict(zip(("attributes_id", "link_key", "unlink_key"), attributesRow))
for attributesRow in attributesRows
]
return res
[docs] @exceptionWrap
async def cleanLog(self, updateTimeLt: Optional[datetime] = None) -> None:
"""
Remove notes from unlink tables.
Args:
updateTimeLt: lower bound of update time
"""
async with DBContext.adaptor.connection(self.logger) as connection:
cleanLogSt = delete(models.UnlinkAttributesLog).where(
and_(models.UnlinkAttributesLog.update_time < updateTimeLt if updateTimeLt is not None else True)
)
await connection.execute(cleanLogSt)
[docs] @exceptionWrap
async def getFacesAttributes(
self,
faceIds: list[str],
retrieveAttrs: Iterable[str],
descriptorVersion: int,
accountId: Optional[str],
getDescriptorAsBytes: bool = False,
) -> list[dict[str, Any]]:
"""
Retrieve attributes
Args:
faceIds: face ids
retrieveAttrs: tuple of attributes what need retrieve
descriptorVersion: requested descriptor version
accountId: account id
getDescriptorAsBytes: whether to get descriptor as bytes otherwise as base64
Returns:
face attributes list with items with the following properties:
face_id: face id
attributes: target-specific face attributes dict
If some attribute data is not found it will be filled with 'empty' value:
None for attributes, [] for samples
"""
faceAttributes = {}
async with DBContext.adaptor.connection(self.logger) as connection:
async def getAttributeModelData(
faceIds_: list[str], accountId_: Optional[str], retrieveAttrs_: Iterable[str], descriptorVersion_: int
) -> Iterable[dict]:
"""Get attribute data from Attribute and Descriptor tables."""
filters = [models.Face.face_id.in_(faceIds_), models.Face.account_id @ accountId_]
selectQuery = Query([models.Face.face_id])
if (
FaceAttributeTargets.createTime.value in retrieveAttrs_
or FaceAttributeTargets.basicAttributes.value in retrieveAttrs_
):
if FaceAttributeTargets.createTime.value in retrieveAttrs_:
selectQuery = selectQuery.add_column(models.Attribute.create_time)
if FaceAttributeTargets.basicAttributes.value in retrieveAttrs_:
selectQuery = selectQuery.add_columns(
models.Attribute.age, models.Attribute.gender, models.Attribute.ethnicity
)
selectQuery = selectQuery.outerjoin(
models.Attribute, models.Attribute.face_id == models.Face.face_id
)
if FaceAttributeTargets.faceDescriptor.value in retrieveAttrs:
selectQuery = selectQuery.add_column(models.Descriptor.descriptor)
selectQuery = selectQuery.outerjoin(
models.Descriptor,
and_(
models.Descriptor.face_id == models.Face.face_id,
models.Descriptor.descriptor_version == descriptorVersion_,
),
)
selectQuery = selectQuery.filter(and_(*filters))
attributesData = await connection.fetchall(selectQuery.statement)
responseFields = [str(field) for field in selectQuery.statement.columns]
attributes_ = [{name: data for name, data in zip(responseFields, row)} for row in attributesData]
return attributes_
async def getSampleModelData(faceIds_: list[str], accountId_: Optional[str]) -> Iterable[dict]:
"""Get attribute data from Sample table."""
filters = [models.Face.face_id.in_(faceIds_), models.Face.account_id @ accountId_]
selectQuery = (
Query([models.Face.face_id, models.Sample.sample_id, models.Sample.type])
.outerjoin(models.Sample, models.Sample.face_id == models.Face.face_id)
.filter(and_(*filters))
)
samplesData = await connection.fetchall(selectQuery.statement)
responseFields = [str(field) for field in selectQuery.statement.columns]
samples_ = [{name: data for name, data in zip(responseFields, row)} for row in samplesData]
return samples_
def updateFaceAttributesWithAttributeData(retrieveAttrs_: Iterable[str], attributes_: Iterable[dict]):
"""Update faceAttributes result dict with attribute data."""
for attribute in attributes_:
attributeFaceId = attribute.pop("face_id")
faceAttributes[attributeFaceId] = {}
if FaceAttributeTargets.createTime.value in retrieveAttrs_:
createTime = attribute.pop(FaceAttributeTargets.createTime.value)
if createTime is not None:
createTime = convertTimeToString(createTime, self.storageTime == "UTC")
faceAttributes[attributeFaceId][FaceAttributeTargets.createTime.value] = createTime
if FaceAttributeTargets.faceDescriptor.value in retrieveAttrs_:
descriptor = attribute.pop("descriptor")
descriptorData = (
{
"descriptor": descriptor if getDescriptorAsBytes else bytesToBase64(descriptor),
"descriptor_version": descriptorVersion,
}
if descriptor is not None
else None
)
faceAttributes[attributeFaceId][FaceAttributeTargets.faceDescriptor.value] = descriptorData
if FaceAttributeTargets.basicAttributes.value in retrieveAttrs_:
basicAttributes = attribute if any(attr is not None for attr in attribute.values()) else None
faceAttributes[attributeFaceId][FaceAttributeTargets.basicAttributes.value] = basicAttributes
def updateFaceAttributesWithSampleData(retrieveAttrs_: Iterable[str], samples_: Iterable[dict]):
"""Update faceAttributes result dict with sample data."""
for sample in samples_:
sampleType = sample["type"]
if sampleType is None:
for attr in (SampleType.face_descriptor, SampleType.basic_attributes):
attrType = f"{attr.name}{SAMPLE_POSTFIX}"
if attrType in retrieveAttrs_:
faceAttributes.setdefault(sample["face_id"], {})[attrType] = []
else:
attrType = f"{SampleType(sampleType).name}{SAMPLE_POSTFIX}"
if attrType in retrieveAttrs_:
faceAttributes.setdefault(sample["face_id"], {}).setdefault(attrType, []).append(
sample["sample_id"]
)
if {
FaceAttributeTargets.createTime.value,
FaceAttributeTargets.basicAttributes.value,
FaceAttributeTargets.faceDescriptor.value,
} | set(retrieveAttrs):
attributes = await getAttributeModelData(
faceIds_=faceIds,
accountId_=accountId,
retrieveAttrs_=retrieveAttrs,
descriptorVersion_=descriptorVersion,
)
updateFaceAttributesWithAttributeData(retrieveAttrs_=retrieveAttrs, attributes_=attributes)
if (
FaceAttributeTargets.basicAttributesSamples.value in retrieveAttrs
or FaceAttributeTargets.faceDescriptorSamples.value in retrieveAttrs
):
samples = await getSampleModelData(faceIds_=faceIds, accountId_=accountId)
updateFaceAttributesWithSampleData(retrieveAttrs_=retrieveAttrs, samples_=samples)
return [{"face_id": faceId, "attributes": attribute} for faceId, attribute in faceAttributes.items()]
[docs] @exceptionWrap
async def getFacesAttributeSamples(self, faceId: list[str], accountId: Optional[str]) -> list[str]:
"""
Get all the attribute samples of the specified face
Args:
faceId: face id
accountId: account id
Returns:
face attribute samples list
Raises:
VLException(Error.FaceNotFound.format(faceId), 404, isCriticalError=False) if face not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([models.Sample.sample_id]).where(
and_(
models.Sample.face_id == faceId,
models.Face.account_id @ accountId,
models.Sample.face_id == models.Face.face_id,
)
)
sampleIds = list(set(chain(*await connection.fetchall(selectSt))))
if not sampleIds:
# Check the face existence
selectSt = select([models.Face.face_id]).where(
and_(models.Face.face_id == faceId, models.Face.account_id @ accountId)
)
if await connection.scalar(selectSt) is None:
raise VLException(Error.FaceNotFound.format(faceId), 404, isCriticalError=False)
return sampleIds
[docs] @exceptionWrap
async def putFaceAttributes(self, faceId: str, attribute: TemporaryAttributes, accountId: Optional[str]) -> None:
"""
Create attributes with exception wrap
Args:
faceId: face id
attribute: temporary attribute container
accountId: account id
Raises:
VLException(Error.FaceNotFound.format(faceId), 404, isCriticalError=False) if face not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
async def updateHistoryLog():
linkedLists = await self._moveFacesLinksToLog((faceId,), connection)
updateFaceListSt = (
update(models.ListFace)
.where(models.ListFace.face_id == faceId)
.values(last_update_time=self.currentDBTimestamp, link_key=LINK_SEQUENCE.next_value())
)
await connection.execute(updateFaceListSt)
return linkedLists
# Do select for update
lockFaceCount, _ = await self.blockFaces(
connection, filters=[models.Face.face_id == faceId, models.Face.account_id @ accountId]
)
if lockFaceCount == 0:
raise VLException(Error.FaceNotFound.format(faceId), 404, isCriticalError=False)
linkedLists = await updateHistoryLog()
await connection.execute(delete(models.Attribute).where(models.Attribute.face_id == faceId))
await self.insertFaceAttributeData(connection, faceId=faceId, attribute=attribute)
self._addListsToDeferrerUpdateListLastUpdateTime(linkedLists)
@exceptionWrap
async def _updateAttributeDescriptor(
self, connection: AbstractDBConnection, faceId: str, descriptors: list[Descriptor], generation: int
) -> None:
"""
Partial update attribute
Args:
connection: connection
faceId: face id
descriptors: list of descriptor containers for patch
generation: descriptor generation
"""
descriptorData = [
{
"descriptor_version": descriptor.version,
"descriptor": descriptor.descriptor,
"descriptor_obtaining_method": 1,
}
for descriptor in descriptors
]
for descriptorDataItem in descriptorData:
values = {k: v for k, v in descriptorDataItem.items() if k != "descriptor_version"}
updateSt = (
update(models.Descriptor)
.where(
and_(
models.Descriptor.face_id == faceId,
models.Descriptor.descriptor_version == descriptorDataItem["descriptor_version"],
)
)
.values(**values)
)
if not await connection.execute(updateSt):
# insert full values if cannot update
insertDescriptorSt = insert(models.Descriptor).values(
**descriptorDataItem, face_id=faceId, descriptor_generation=generation
)
await connection.execute(insertDescriptorSt)
[docs] @exceptionWrap
async def updateFaceAttributes(
self,
faceId: str,
attribute: TemporaryAttributes,
accountId: Optional[str] = None,
forceUpdate: Optional[bool] = False,
) -> None:
"""
Update attribute
Args:
faceId: updated attribute face id
attribute: temporary attributes container
accountId: account id
forceUpdate: whether not to compare existing samples and new ones
Raises:
VLException(Error.FaceSampleConflict.format(faceId), 400, False) if samples conflict
VLException(Error.AttributesForUpdateNotFound.format(faceId), 400, False) if attribute for update not found
VLException(Error.FaceNotFound.format(faceId), 404, False) if face not found
"""
async with DBContext.adaptor.connection(self.logger) as connection:
async def checkSamples(
faceId_: str, filteredSamples_: dict[int, set[str]], updDescriptorVersions_: list[int]
) -> None:
"""
Check conformity of old and new sample ids.
Args:
faceId_: face id
filteredSamples_: {<sample type>: <samples>} map
updDescriptorVersions_: versions of updating descriptors
Raises:
VLException(Error.FaceSampleConflict.format(attributeId), 400, False) if old samples do not
match specified samples in attribute
"""
isNeedToCheckSamples = not forceUpdate
for sampleTypeToCheck in SampleType:
sampleTypeToCheck = sampleTypeToCheck.value
if sampleTypeToCheck not in filteredSamples_:
continue
newSamplesSet = filteredSamples_[sampleTypeToCheck]
selectSt = select([models.Sample.sample_id]).where(
and_(models.Sample.face_id == faceId_, models.Sample.type == sampleTypeToCheck)
)
if sampleTypeToCheck == SampleType.face_descriptor.value:
# (if descriptor samples seems to be updated and descriptors not fully updated)
dbReply = await connection.fetchall(selectSt)
existingSampleIds = set(chain(*dbReply))
if existingSampleIds and existingSampleIds != newSamplesSet:
# mb replace all descriptors
selectSt = select([models.Descriptor.descriptor_version]).where(
and_(
models.Descriptor.face_id == faceId_,
not_(models.Descriptor.descriptor_version.in_(updDescriptorVersions_)),
)
)
notUpdatedDescriptorVersions = await connection.fetchall(selectSt)
if notUpdatedDescriptorVersions:
raise VLException(Error.FaceSampleConflict.format(faceId_), 400, False)
else:
# basic attributes samples
if not isNeedToCheckSamples:
continue
dbReply = await connection.fetchall(selectSt)
existingSampleIds = set(chain(*dbReply))
if existingSampleIds and existingSampleIds != newSamplesSet:
raise VLException(Error.FaceSampleConflict.format(faceId_), 400, False)
async def updateAttributeModel(
faceId_: str,
accountId_: Optional[str],
basicAttributes: Optional[BasicAttributes],
isNeedToUpdGeneration: bool,
) -> int:
"""
Update attribute model.
Increase current descriptor's generation if it needs (depends on descriptor samples in attributes)
Args:
faceId_: face id
accountId_: account id
basicAttributes: basic attributes container
isNeedToUpdGeneration: whether to increase descriptor sample generation or not
Returns:
descriptor generation
"""
if accountId is None:
filters = and_(models.Attribute.face_id == faceId_)
else:
filters = exists([models.Attribute.face_id]).where(
and_(
models.Face.face_id == faceId_,
models.Face.account_id == accountId_,
models.Attribute.face_id == models.Face.face_id,
)
)
values = {
"descriptor_samples_generation": models.Attribute.descriptor_samples_generation
+ int(isNeedToUpdGeneration)
}
if basicAttributes is not None:
values.update(
{
"age": basicAttributes.age,
"gender": basicAttributes.gender,
"ethnicity": basicAttributes.ethnicity,
}
)
updateQuery = (
update(models.Attribute)
.where(filters)
.values(**values)
.returning(models.Attribute.descriptor_samples_generation)
)
generation = await connection.scalar(updateQuery)
if generation is None:
raise VLException(Error.AttributesForUpdateNotFound.format(faceId_), 400, isCriticalError=False)
return generation
async def replaceSamples(faceId_: str, filteredSamples_: dict[int, set[str]]) -> None:
"""
Replace samples
Args:
faceId_: face id
filteredSamples_: {<sample type>: <samples>} map
"""
for sampleType, samples in filteredSamples_.items():
deleteSt = delete(models.Sample).where(
and_(models.Sample.type == sampleType, models.Sample.face_id == faceId_)
)
await connection.execute(deleteSt)
for sampleType, sampleIds in filteredSamples_.items():
for sampleId in sampleIds:
query = insert(models.Sample).values(face_id=faceId_, sample_id=sampleId, type=sampleType)
await connection.execute(query)
async def recreateLinks(faceId_: str, updDescriptorVersions_: list[int]) -> list[str]:
"""
Recreate face-list links if the descriptor of the current version was changed.
Needed for matcher if a descriptor was changed.
Args:
faceId_: face id
updDescriptorVersions_: versions of updating descriptors
"""
needToUpdate = self.defaultDescriptorVersion in updDescriptorVersions_
if not needToUpdate:
return []
updateLinkKeysSt = (
update(models.ListFace)
.where(models.ListFace.face_id.in_([faceId_]))
.values(link_key=LINK_SEQUENCE.next_value())
)
linkedLists = await self._moveFacesLinksToLog((faceId,), connection)
await connection.execute(updateLinkKeysSt)
return linkedLists
# Do select for update
lockFaceCount, _ = await self.blockFaces(
connection, filters=[models.Face.face_id == faceId, models.Face.account_id @ accountId]
)
if lockFaceCount == 0:
raise VLException(Error.FaceNotFound.format(faceId), 404, isCriticalError=False)
# Prepare map: {<sample type>: <samples>}
filteredSamples = {}
if len(attribute.basicAttributesSamples) > 0:
filteredSamples[SampleType[AttributeSample.basic_attributes.value].value] = set(
attribute.basicAttributesSamples
)
if len(attribute.descriptorSamples) > 0:
filteredSamples[SampleType[AttributeSample.face_descriptors.value].value] = set(
attribute.descriptorSamples
)
# Check no sample conflict
updDescriptorVersions = [descr.version for descr in attribute.descriptors]
await checkSamples(
faceId_=faceId, filteredSamples_=filteredSamples, updDescriptorVersions_=updDescriptorVersions
)
# Update attribute data
isNeedToUpdGeneration = filteredSamples.get(SampleType.face_descriptor.value) is not None
generation = await updateAttributeModel(
faceId_=faceId,
accountId_=accountId,
basicAttributes=attribute.basicAttributes,
isNeedToUpdGeneration=isNeedToUpdGeneration,
)
if forceUpdate:
await replaceSamples(faceId_=faceId, filteredSamples_=filteredSamples)
await self._updateAttributeDescriptor(
connection=connection, faceId=faceId, descriptors=attribute.descriptors, generation=generation
)
listsForUpdate = await recreateLinks(faceId_=faceId, updDescriptorVersions_=updDescriptorVersions)
self._addListsToDeferrerUpdateListLastUpdateTime(listsForUpdate)
[docs] @exceptionWrap
async def deleteFaceAttributes(self, faceId: str, accountId: Optional[str] = None) -> None:
"""
Delete face attributes.
Args:
faceId: face id to remove its attributes
accountId: account id
"""
async with DBContext.adaptor.connection(self.logger) as connection:
await self.blockFaces(connection, [models.Face.face_id == faceId])
faceCheckFilter = and_(models.Face.face_id == faceId, models.Face.account_id @ accountId)
updateFaceSt = update(models.Face).where(faceCheckFilter).values(last_update_time=self.currentDBTimestamp)
faceCount = await connection.execute(updateFaceSt)
if not faceCount:
raise VLException(Error.FaceNotFound.format(faceId), 404, isCriticalError=False)
updateListSt = (
update(models.List)
.where(
exists(
select([models.List.list_id]).where(
and_(
models.List.list_id == models.ListFace.list_id,
models.ListFace.face_id == models.Face.face_id,
faceCheckFilter,
)
)
)
)
.values(last_update_time=self.currentDBTimestamp)
)
await connection.execute(updateListSt)
deleteAttributeSt = delete(models.Attribute).where(
exists(
select([models.Attribute.face_id]).where(
and_(faceCheckFilter, models.Attribute.face_id == models.Face.face_id)
)
)
)
await connection.execute(deleteAttributeSt)
[docs] @exceptionWrap
async def getDescriptorsBatchByFaceIds(
self,
facesIds: set[str],
accountId: Optional[str] = None,
checkObjectExistence: bool = True,
descriptorVersion: Optional[int] = None,
receiveExternalId: Optional[bool] = None,
) -> dict[str, Union[list[bytes], int, list[str]]]:
"""
Get descriptors batch by faces Ids.
! Splitting into batches is used due to enormous queries (25Mb for 100k faces) that fails Oracle DB. (LUNA-3734)
Args:
facesIds: faces ids the descriptors were attached to
accountId: account id of the faces
checkObjectExistence: check faces existence or not
descriptorVersion: requested descriptor version
receiveExternalId: receive or not external ids
Returns:
Dict:
descriptors: list[bytes]
descriptorVersion: int
faceUuids: list[str]
notFoundUuids: list[str]
notExtractedUuids: list[str]
externalIds: list[str]
Raises:
VLException(Error.FacesNotFound)
"""
if descriptorVersion is None:
descriptorVersion = self.defaultDescriptorVersion
selectFields = [models.Descriptor.descriptor, models.Face.face_id]
if receiveExternalId:
selectFields.append(models.Face.external_id)
# 1k batching splitting
facesIdsList = list(facesIds)
faceIdsBatches = [facesIdsList[idx : idx + DB_LIST_LIMIT] for idx in range(0, len(facesIdsList), DB_LIST_LIMIT)]
queries = [
select(selectFields).where(
and_(
models.Descriptor.face_id.in_(faceIds_),
models.Descriptor.descriptor_version == descriptorVersion,
# if need filtration by account
models.Face.face_id == models.Descriptor.face_id,
models.Face.account_id @ accountId,
)
)
for faceIds_ in faceIdsBatches
]
getNotExtractedQuery = lambda existFaceIds: Query([models.Face.face_id]).filter(
and_(
models.Face.face_id.in_(list(set(facesIds).difference(set(existFaceIds)))),
models.Face.account_id @ accountId,
~exists(
select([models.Face.face_id]).where(
and_(
models.Face.face_id == models.Descriptor.face_id,
models.Descriptor.descriptor_version == descriptorVersion,
)
)
),
)
)
async with DBContext.adaptor.connection(self.logger) as connection:
unitedResults = []
for query in queries:
unitedResults.extend(await connection.fetchall(query))
# 1k batching join
if receiveExternalId:
descriptors, replyFaceIds, externalIds = tuple(zip(*unitedResults)) or ([], [], [])
externalIds = list(map(lambda x: x or "", externalIds))
else:
descriptors, replyFaceIds = tuple(zip(*unitedResults)) or ([], [])
externalIds = None
notFoundUuids = []
if len(replyFaceIds) != len(facesIds):
notExtractedFaceIds = [
faceId[0] for faceId in (await connection.fetchall(getNotExtractedQuery(replyFaceIds).statement))
]
if checkObjectExistence:
firstNotFoundFace = await self.getNonexistentFaceId(facesIds, accountId, connection)
if firstNotFoundFace:
raise VLException(Error.FacesNotFound.format(firstNotFoundFace), 400, isCriticalError=False)
else:
notFoundUuids = list(facesIds.difference(notExtractedFaceIds).difference(replyFaceIds))
else:
notExtractedFaceIds = []
result = dict(
descriptors=list(map(bytes, descriptors)),
uuids=replyFaceIds,
descriptorVersion=descriptorVersion,
notFoundUuids=notFoundUuids,
notExtractedUuids=notExtractedFaceIds,
externalIds=externalIds,
)
return result
[docs] @exceptionWrap
async def getListFacesDescriptorsBatch(
self,
listId: str,
linkKeyGte: int,
limit: int,
descriptorVersion: Optional[int] = None,
parity: Optional[int] = None,
receiveExternalId: Optional[bool] = None,
) -> dict[str, Union[list[bytes], int, list[int], list[str]]]:
"""
Get descriptors batch
Args:
listId: list id the descriptors were attached to
linkKeyGte: the lower including boundary
limit: descriptors count to return
descriptorVersion: descriptor version
parity: 0 for odd or 1 for even link keys to search for
receiveExternalId: receive or not external ids
Returns:
Dict:
descriptors: list[bytes]
descriptorVersion: int
faceUuids: list[str]
linkKeys: list[int]
faceIds: list[str]
externalIds: list[str]
"""
if descriptorVersion is None:
descriptorVersion = self.defaultDescriptorVersion
selectFields = [models.Descriptor.descriptor, models.ListFace.link_key, models.Descriptor.face_id]
if receiveExternalId:
selectFields.append(models.Face.external_id)
query = Query(selectFields).filter(
models.Descriptor.face_id == models.Face.face_id if receiveExternalId else True,
models.Descriptor.face_id == models.ListFace.face_id,
models.Descriptor.descriptor_version == descriptorVersion,
models.ListFace.list_id == listId,
models.ListFace.link_key >= linkKeyGte,
mod(models.ListFace.link_key, 2) == parity if parity is not None else True,
)
if parity is not None:
query = query.order_by(
models.ListFace.list_id, mod(models.ListFace.link_key, 2), models.ListFace.link_key.asc()
)
else:
query = query.order_by(models.ListFace.link_key.asc())
query = query.limit(limit)
async with DBContext.adaptor.connection(self.logger) as connection:
if self.dbConfig.type == "oracle":
global indexFaceName, indexDescriptorName
if not (indexFaceName is indexDescriptorName is None):
query = query.with_hint(
models.ListFace,
f"INDEX(LIST_FACE LINK_KEY_FUNC_INDEX) INDEX(FACE {indexFaceName}) "
f"INDEX(DESCRIPTOR {indexDescriptorName})",
)
else:
self.logger.warning("Database indexes may not using")
if self.dbConfig.type == "oracle":
queryCursor = await connection.fetchall(getCompiledQuery(query, self.dbConfig.type))
else:
# for postgres compiling wrong request in mod(models.ListFace.link_key, 2) == parity
queryCursor = await connection.fetchall(query.statement)
if receiveExternalId:
descriptors, linkKeys, faceIds, externalIds = tuple(zip(*queryCursor)) or ([], [], [], [])
externalIds = list(map(lambda x: x or "", externalIds))
else:
descriptors, linkKeys, faceIds = tuple(zip(*queryCursor)) or ([], [], [])
externalIds = None
result = dict(
descriptors=descriptors,
descriptorVersion=descriptorVersion,
linkKeys=linkKeys,
uuids=faceIds,
externalIds=externalIds,
)
return result
def _prepareMissingDescriptorsFilters(
self,
missingVersion,
accountId: Optional[str] = None,
faceIds: Optional[list[str]] = None,
faceIdGte: Optional[str] = None,
faceIdLt: Optional[str] = None,
) -> and_:
"""
Prepare filters for missing descriptors query.
Args:
missingVersion: missing descriptor version
accountId: account id of the attributes
faceIds: list of face ids
faceIdGte: lower face id including boundary
faceIdLt: upper face id excluding boundary
Returns:
prepared sa.and_() with filters
"""
nestedDescriptor = aliased(models.Descriptor) # table from the nested query need to be aliased
return and_(
# whether we need to reextract
models.Attribute.face_id == models.Descriptor.face_id,
models.Descriptor.descriptor_version == self.defaultDescriptorVersion,
# filters
models.Attribute.face_id == models.Face.face_id if accountId is not None else True,
models.Face.account_id @ accountId,
models.Attribute.face_id >= faceIdGte if faceIdGte is not None else True,
models.Attribute.face_id < faceIdLt if faceIdLt is not None else True,
models.Attribute.face_id.in_(faceIds) if faceIds is not None else True,
# whether descriptor was not reextracted
~exists(
select([nestedDescriptor.descriptor_version]).where(
and_(
nestedDescriptor.face_id == models.Attribute.face_id,
nestedDescriptor.descriptor_version == missingVersion,
nestedDescriptor.descriptor_generation == models.Attribute.descriptor_samples_generation,
)
)
),
)
[docs] @exceptionWrap
async def getMissingDescriptors(
self,
missingVersion: int,
accountId: Optional[str] = None,
faceIds: Optional[list[str]] = None,
faceIdGte: Optional[str] = None,
faceIdLt: Optional[str] = None,
limit: int = 1000,
) -> list[dict[str, Union[str, list[str]]]]:
"""
Get missing descriptors of 'missingVersion' version:
get attributes not having descriptor
get samples
Args:
missingVersion: missing descriptor version
accountId: account id of the attributes
faceIds: list of attribute ids
faceIdGte: lower face id including boundary
faceIdLt: upper face id excluding boundary
limit: maximum attributes amount to return
Returns:
data array in the format:
{"face_id": "<face_id>", "generation": <generation>, "samples": ["<sample_id>"]}
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = (
select([models.Attribute.face_id])
.where(self._prepareMissingDescriptorsFilters(missingVersion, accountId, faceIds, faceIdGte, faceIdLt))
.order_by(models.Attribute.face_id.asc())
.limit(limit)
)
facesRes = await connection.fetchall(selectSt)
# {face_id: {face_id: id, samples:[]}}
result = {faceId: {"face_id": faceId, "samples": []} for (faceId,) in facesRes}
samplesSt = select([models.Sample.face_id, models.Sample.sample_id]).where(
and_(models.Sample.face_id.in_(result.keys()), models.Sample.type == SampleType.face_descriptor.value)
)
samplesRes = await connection.fetchall(samplesSt)
for (faceId, sampleId) in samplesRes:
result[faceId]["samples"].append(sampleId)
return list(result.values())
[docs] @exceptionWrap
async def getMissingDescriptorsCount(
self,
missingVersion: int,
accountId: Optional[str] = None,
faceIds: Optional[list[str]] = None,
faceIdGte: Optional[str] = None,
faceIdLt: Optional[str] = None,
) -> int:
"""
Get count of missing descriptors of 'missingVersion' version.
Args:
missingVersion: missing descriptor version
accountId: account id of the attributes
faceIds: list of face ids
faceIdGte: lower attribute id including boundary
faceIdLt: upper attribute id excluding boundary
Returns:
missing descriptors count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([func.count(models.Attribute.face_id)]).where(
self._prepareMissingDescriptorsFilters(missingVersion, accountId, faceIds, faceIdGte, faceIdLt)
)
count = await connection.scalar(selectSt)
return count
[docs] @exceptionWrap
async def getDescriptorsCount(self, accountId: Optional[str] = None) -> list[dict[str, int]]:
"""
Get face descriptors count with filters.
Args:
accountId: account id
Returns:
list of dict. Keys of each dict are "descriptor_version" and "descriptor_count"
"""
async with DBContext.adaptor.connection(self.logger) as connection:
filters = and_(models.Face.account_id @ accountId, models.Face.face_id == models.Descriptor.face_id)
selectSt = (
select([models.Descriptor.descriptor_version, func.count(models.Descriptor.face_id)])
.where(filters)
.group_by(models.Descriptor.descriptor_version)
)
resRows = await connection.fetchall(selectSt)
return [{"descriptor_version": version, "descriptor_count": count} for version, count in resRows]
def _prepareMissingBasicAttrsFilters(
self,
accountId: Optional[str] = None,
faceIds: Optional[list[str]] = None,
faceIdGte: Optional[str] = None,
faceIdLt: Optional[str] = None,
) -> and_:
"""
Prepare filters for missing basic attributes query.
Args:
accountId: account id of the attributes
faceIds: list of face ids
faceIdGte: lower face id including boundary
faceIdLt: upper face id excluding boundary
Returns:
prepared sa.and_() with filters
"""
return and_(
# whether we need to reextract
models.Attribute.ethnicity.is_(None),
# filters
models.Attribute.face_id == models.Face.face_id if accountId is not None else True,
models.Face.account_id @ accountId,
models.Attribute.face_id >= faceIdGte if faceIdGte is not None else True,
models.Attribute.face_id < faceIdLt if faceIdLt is not None else True,
models.Attribute.face_id.in_(faceIds) if faceIds is not None else True,
)
[docs] @exceptionWrap
async def getMissingBasicAttrs(
self,
accountId: Optional[str] = None,
faceIds: Optional[list[str]] = None,
faceIdGte: Optional[str] = None,
faceIdLt: Optional[str] = None,
limit: int = 1000,
) -> list[dict[str, Union[str, list[str]]]]:
"""
Get missing basic attributes:
get attributes without basic attributes
get samples
Args:
accountId: account id of the attributes
faceIds: list of attribute ids
faceIdGte: lower face id including boundary
faceIdLt: upper face id excluding boundary
limit: maximum attributes amount to return
Returns:
data array in the format:
{"face_id": "<face_id>", "samples": ["<sample_id>"]}
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = (
select([models.Attribute.face_id])
.where(self._prepareMissingBasicAttrsFilters(accountId, faceIds, faceIdGte, faceIdLt))
.order_by(models.Attribute.face_id.asc())
.limit(limit)
)
facesRes = await connection.fetchall(selectSt)
# {face_id: {face_id: id, samples:[]}}
result = {face[0]: {"face_id": face[0], "samples": []} for face in facesRes}
samplesSt = select([models.Sample.face_id, models.Sample.sample_id]).where(
and_(models.Sample.face_id.in_(result.keys()), models.Sample.type == SampleType.face_descriptor.value)
)
samplesRes = await connection.fetchall(samplesSt)
for sample in samplesRes:
result[sample[0]]["samples"].append(sample[1])
return list(result.values())
[docs] @exceptionWrap
async def getMissingBasicAttrsCount(
self,
accountId: Optional[str] = None,
faceIds: Optional[list[str]] = None,
faceIdGte: Optional[str] = None,
faceIdLt: Optional[str] = None,
) -> int:
"""
Get count of missing basic attributes.
Args:
accountId: account id of the attributes
faceIds: list of face ids
faceIdGte: lower attribute id including boundary
faceIdLt: upper attribute id excluding boundary
Returns:
missing basic attributes count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([func.count(models.Attribute.face_id)]).where(
self._prepareMissingBasicAttrsFilters(accountId, faceIds, faceIdGte, faceIdLt)
)
count = await connection.scalar(selectSt)
return count
[docs] @exceptionWrap
async def getBasicAttrsCount(self, accountId: Optional[str] = None) -> int:
"""
Get count of basic attributes.
Args:
accountId: account id of the attributes
Returns:
basic attributes count
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = select([func.count(models.Attribute.face_id)]).where(
and_(
models.Attribute.ethnicity.isnot(None),
models.Face.account_id @ accountId,
models.Attribute.face_id == models.Face.face_id if accountId is not None else True,
)
)
count = await connection.scalar(selectSt)
return count
@staticmethod
def _prepareDescriptorsFilters(
descriptorVersion: int, faceIdGte: Optional[str] = None, faceIdLt: Optional[str] = None
) -> BooleanClauseList:
"""
Prepare filters for descriptors query.
Args:
descriptorVersion:descriptor version
faceIdGte: lower face id including boundary
faceIdLt: upper face id excluding boundary
Returns:
prepared sa.and_() with filters
"""
return and_(
models.Descriptor.descriptor_version == descriptorVersion,
models.Descriptor.face_id >= faceIdGte if faceIdGte is not None else True,
models.Descriptor.face_id < faceIdLt if faceIdLt is not None else True,
)
[docs] @exceptionWrap
async def deleteDescriptors(
self, descriptorVersion: int, faceIdGte: Optional[str] = None, faceIdLt: Optional[str] = None, limit: int = 1000
) -> list[dict[str, Union[str, int, list[str]]]]:
"""
Delete descriptors by version:
delete descriptors of preassigned version
Args:
descriptorVersion: descriptor version
faceIdGte: lower face id including boundary
faceIdLt: upper face id excluding boundary
limit: maximum attributes amount to delete
Returns:
list of face ids id for deleted descriptors of preassigned version
"""
async with DBContext.adaptor.connection(self.logger) as connection:
selectSt = (
select([models.Descriptor.face_id])
.where(self._prepareDescriptorsFilters(descriptorVersion, faceIdGte, faceIdLt))
.order_by(models.Descriptor.face_id.asc())
.limit(limit)
)
faces = [row["face_id"] for row in await connection.fetchall(selectSt)]
deleteSt = delete(models.Descriptor).where(
and_(models.Descriptor.face_id.in_(faces), models.Descriptor.descriptor_version == descriptorVersion)
)
await connection.execute(deleteSt)
return faces
[docs] async def blockFaces(
self, connection, filters: list[ColumnClause], returnBlockedFaces: bool = False
) -> tuple[int, Union[tuple[str, ...], None]]:
"""
Block some faces rows in the db, mechanics "select for update".
Returns:
number of blocked faces and blocked faces or number of blocked faces and None, depends on the argument
"""
query = select([models.Face.face_id]).where(and_(*filters)).with_for_update()
if self.dbConfig.type == "oracle" or returnBlockedFaces:
_blockedFaces = await connection.fetchall(query)
blockedFaces = tuple(face[0] for face in _blockedFaces)
blockedFaceCount = len(blockedFaces)
else:
blockedFaces = None
count = select([func.count()], from_obj=aliased(query))
blockedFaceCount = await connection.scalar(count)
self.logger.debug(f"lock {blockedFaceCount} faces")
if returnBlockedFaces:
return blockedFaceCount, blockedFaces
else:
return blockedFaceCount, None
[docs] async def updateListLastUpdateTime(self, listId: str):
"""
Update last update time of lists.
Args:
listId: list id
"""
async with DBContext.adaptor.connection(self.logger) as connection:
st = (
update(models.List)
.where(models.List.list_id == listId)
.values(last_update_time=self.currentDBTimestamp)
)
await connection.execute(st)
async def _moveFacesLinksToLog(
self, faceIds: Union[list[str], tuple[str, ...]], connection: AbstractDBConnection
) -> list[str]:
"""
Move face to list links to unlink attributes log
Args:
faceIds: face ids
connection: open connection
Returns:
linked to faces list ids
"""
query = select([models.ListFace.list_id, models.ListFace.face_id, models.ListFace.link_key]).where(
models.ListFace.face_id.in_(faceIds)
)
insertSt = insert(models.UnlinkAttributesLog).from_select(
[
models.UnlinkAttributesLog.list_id,
models.UnlinkAttributesLog.face_id,
models.UnlinkAttributesLog.link_key,
],
query,
)
if self.dbConfig.type == "oracle":
listCount = await connection.execute(insertSt)
if listCount:
query = select([models.ListFace.list_id]).where(models.ListFace.face_id.in_(faceIds))
_linkedLists = await connection.fetchall(query)
else:
_linkedLists = []
else:
st = insertSt.returning(models.UnlinkAttributesLog.list_id)
_linkedLists = await connection.fetchall(st)
linkedLists = [listId[0] for listId in _linkedLists]
return linkedLists
[docs] async def getListDeletions(
self,
deletionTimeLt: Optional[datetime] = None,
deletionTimeGte: Optional[datetime] = None,
page: Optional[int] = 1,
pageSize: Optional[int] = 100,
):
"""
Get lists deletion logs
Args:
deletionTimeLt: upper bound of list deletion time
deletionTimeGte: lower bound of list deletion time
page: page
pageSize: page size
Warnings:
trigger `trg_lists_deletion_log` inserts a data for table `ListsDeletionLog`
Returns:
list of deletions in the reverse order of deletion of lists
"""
query = (
select(
[
models.ListsDeletionLog.list_id,
models.ListsDeletionLog.account_id,
models.ListsDeletionLog.create_time,
models.ListsDeletionLog.deletion_id,
models.ListsDeletionLog.deletion_time,
]
)
.where(
and_(
models.ListsDeletionLog.deletion_time >= deletionTimeGte if deletionTimeGte else True,
models.ListsDeletionLog.deletion_time < deletionTimeLt if deletionTimeLt else True,
)
)
.order_by(models.ListsDeletionLog.deletion_time.desc())
.offset((page - 1) * pageSize)
.limit(pageSize)
.offset((page - 1) * pageSize)
.limit(pageSize)
)
async with DBContext.adaptor.connection(self.logger) as connection:
records = await connection.fetchall(query)
listsDeletions = []
for row in records:
removedList = dict(row)
removedList["create_time"] = convertTimeToString(removedList["create_time"], self.storageTime == "UTC")
removedList["deletion_time"] = convertTimeToString(
removedList["deletion_time"], self.storageTime == "UTC"
)
listsDeletions.append(removedList)
return listsDeletions
[docs] async def cleanListsDeletionLog(self, deletionTimeLt: datetime) -> int:
"""
Clear lists deletion log
Args:
deletionTimeLt: upper bound of list deletion time
Returns:
count of removed rows
"""
query = delete(models.ListsDeletionLog).where(models.ListsDeletionLog.deletion_time < deletionTimeLt)
async with DBContext.adaptor.connection(self.logger) as connection:
count = await connection.execute(query)
return count