Source code for luna_handlers.sdk.sdk_loop.face_extractor

"""
Module contains a functional for batch estimation descriptors and basic attributes
"""
from itertools import chain
from typing import List, Tuple, Optional

from lunavl.sdk.estimators.face_estimators.basic_attributes import BasicAttributesEstimator
from lunavl.sdk.estimators.face_estimators.facewarper import FaceWarpedImage

from .base_extractor import BaseExtractorSubTask, BaseExtractorState, BaseExtractor
from .crutches_on_wheels.errors.errors import Error
from .sdk_task import SDKTask, FaceAttributes, tasksTimeMonitoring, FilteredEstimation
from .settings import FaceExtractorSettings
from .utils.recipes import grouper


[docs]class ExtractorState(BaseExtractorState): """ Extractor state, process local. Final class. State contains: - logger for worker - instance of basic attributes estimator - instance of FaceEngine - maps a descriptor version to instance of extractor """ _basicAttributesEstimator: BasicAttributesEstimator @property def basicAttributesEstimator(self) -> BasicAttributesEstimator: """ Get basic attributes estimators. Returns: _basicAttributesEstimator """ return ExtractorState._basicAttributesEstimator
[docs] @classmethod def initialize(cls, workerName: str, settings: FaceExtractorSettings) -> bool: """ Initialize state. Singleton. Initialize FaceEngine, extractor and basic attributes. Args: workerName: worker name settings: settings for worker Returns: True if it is first call of initialize (for process) otherwise False """ if not super().initialize(workerName, settings): return False cls._basicAttributesEstimator = cls._faceEngine.createBasicAttributesEstimator() return True
[docs]class FaceExtractorSubTask(BaseExtractorSubTask): """ Face extractor sub task. Attributes: estimateBasicAttributes (bool): estimate or not basic attributes """ __slots__ = ("estimateBasicAttributes",) def __init__( self, attributes: FaceAttributes, warps: List[FaceWarpedImage], taskId: int, estimationId: Optional[str] = None, estimateBasicAttributes: bool = False, descriptorVersion: int = 0, ): super().__init__( attributes=attributes, warps=warps, taskId=taskId, estimationId=estimationId, descriptorVersion=descriptorVersion, ) self.estimateBasicAttributes: bool = estimateBasicAttributes
[docs]class FaceExtractor(BaseExtractor[ExtractorState]): """ Face extractor """ # state class _state = ExtractorState
[docs] def batchExtractBasicAttributes(self, subTasks: List[BaseExtractorSubTask]): """ Batch extract basic attributes for sub tasks with one warp Batches are chunked to reduce FSDK memory consumption. Args: subTasks: sub tasks """ subTasksForEstimation = [] for subTask in subTasks: if not subTask.estimateBasicAttributes: continue subTasksForEstimation.append(subTask) if not subTasksForEstimation: return for subTaskChunk in grouper(subTasksForEstimation, self.state.settings.optimalBatchSize): warps = [subTask.warps[0] for subTask in subTaskChunk] basicAttributes, _ = self.state.basicAttributesEstimator.estimateBasicAttributesBatch( warps, estimateAge=True, estimateGender=True, estimateEthnicity=True ) for basicAttribute, subTask in zip(basicAttributes, subTaskChunk): subTask.attributes.basicAttributes = basicAttribute.asDict()
[docs] def extractAggregateBasicAttribute(self, subTask: BaseExtractorSubTask): """ Batch extract descriptors for sub task with several warps Args: subTask: sub tasks """ _, aggregateAttributes = self.state.basicAttributesEstimator.estimateBasicAttributesBatch( subTask.warps, estimateAge=True, estimateGender=True, estimateEthnicity=True, aggregate=True ) subTask.attributes.basicAttributes = aggregateAttributes.asDict()
[docs] def batchExtractBasicAttributesWithAggregation(self, subTasks: List[BaseExtractorSubTask]): """ Batch extract basic attributes for sub tasks with several warps Args: subTasks: sub tasks """ for attribute in subTasks: if not attribute.estimateBasicAttributes: continue self.extractAggregateBasicAttribute(attribute)
[docs]def getSubTasksForTasks(tasks: List[SDKTask]) -> Tuple[List[BaseExtractorSubTask], List[BaseExtractorSubTask]]: """ Separate tasks on sub tasks. Args: tasks: tasks Returns: tuple sub tasks without aggregation and with """ subTasksWithoutAggregation = [] subTasksWithAggregation = [] for task in tasks: targets = task.toEstimation.faceEstimationTargets if not task.aggregateAttributes or len(task.faceWarps) == 1: for estimation in chain(*(image.estimations for image in task.images if not image.error)): warp = estimation.face.warp try: subtask = FaceExtractorSubTask( attributes=FaceAttributes(warps=[warp]), warps=[FaceWarpedImage(warp.body)], estimateBasicAttributes=bool(targets.estimateBasicAttributes), descriptorVersion=targets.estimateFaceDescriptor, taskId=task.taskId, estimationId=estimation.id, ) subTasksWithoutAggregation.append(subtask) except ValueError as e: warp.error = Error.BadWarpImage.format(e) continue else: try: subtask = FaceExtractorSubTask( attributes=FaceAttributes(warps=task.faceWarps), warps=[FaceWarpedImage(warp.body) for warp in task.faceWarps], estimateBasicAttributes=bool(targets.estimateBasicAttributes), descriptorVersion=targets.estimateFaceDescriptor, taskId=task.taskId, ) except ValueError as e: task.error = Error.BadWarpImage.format(e) else: subTasksWithAggregation.append(subtask) return subTasksWithoutAggregation, subTasksWithAggregation
[docs]def updateTaskWithFilteredResults(task: SDKTask, filtrationRes: dict, attribute: FaceAttributes): """ Update task with results filtered. Args: task: task filtrationRes: filtration result attribute: filtered attribute """ for image in task.images: for index, estimation in enumerate(image.estimations): if any((estimation.face.warp.sampleId == warp.sampleId for warp in attribute.warps)): estimation.face.warp.isFiltered = True estimation.face.filter = filtrationRes estimation.face.extractedAttributes = attribute task.filteredEstimations.append(FilteredEstimation(filename=image.filename, estimation=estimation.face)) image.estimations.pop(index)
[docs]def checkAttributesFiltered(task: SDKTask, attribute: FaceAttributes) -> bool: """ Check whether extracted attribute is filtered, and update task with filtered result if it is. Attributes: task: task attribute: extracted attribute Returns: True if attribute is filtered, otherwise False """ if attribute.descriptor and task.filters.garbageScoreThreshold: filtrationRes = task.filters.checkFilterByGS(attribute.descriptor["score"]) if filtrationRes["is_filtered"]: attribute.filtered = True updateTaskWithFilteredResults(task=task, filtrationRes=filtrationRes, attribute=attribute) return True return False
[docs]def extract(tasks: List[SDKTask]): """ Extract face attributes attributes. Args: tasks: tasks Returns: processed tasks """ state = ExtractorState() state.logger.info(f"got {len(tasks)} tasks") subTasksWithoutAggregation, subTasksWithAggregation = getSubTasksForTasks(tasks) extractor = FaceExtractor() if subTasksWithoutAggregation: with tasksTimeMonitoring(fieldName="faceDescriptorExtractTime", tasks=tasks): extractor.batchExtractDescriptors(subTasksWithoutAggregation) with tasksTimeMonitoring(fieldName="basicAttributesExtractTime", tasks=tasks): extractor.batchExtractBasicAttributes(subTasksWithoutAggregation) if subTasksWithAggregation: with tasksTimeMonitoring(fieldName="faceDescriptorExtractTime", tasks=tasks): extractor.batchExtractDescriptorsWithAggregation(subTasksWithAggregation) with tasksTimeMonitoring(fieldName="basicAttributesExtractTime", tasks=tasks): extractor.batchExtractBasicAttributesWithAggregation(subTasksWithAggregation) for task in tasks: for subTask in subTasksWithoutAggregation + subTasksWithAggregation: if task.taskId == subTask.taskId: if subTask.error: task.error = subTask.error continue if checkAttributesFiltered(task=task, attribute=subTask.attributes): continue if task.aggregateAttributes: task.aggregatedEstimations.extraction.face = subTask.attributes else: task.updateEstimationsWithExtractionResult(subTask) state.logger.info(f"performed {len(tasks)} tasks") return tasks
[docs]def initWorker(settings: FaceExtractorSettings): """ Initialize extractor worker. Init logger, initialize FSDK, create basic attributes estimator and descriptor extractors. Args: settings: detector settings """ ExtractorState.initialize("luna-handlers-f-extractor", settings=settings) ExtractorState().logger.info("extractor worker is initialized")