Source code for luna_handlers.sdk.sdk_loop.face_extractor
"""
Module contains a functional for batch estimation descriptors and basic attributes
"""
from itertools import chain
from typing import List, Tuple, Optional
from lunavl.sdk.estimators.face_estimators.basic_attributes import BasicAttributesEstimator
from lunavl.sdk.estimators.face_estimators.facewarper import FaceWarpedImage
from .base_extractor import BaseExtractorSubTask, BaseExtractorState, BaseExtractor
from .crutches_on_wheels.errors.errors import Error
from .sdk_task import SDKTask, FaceAttributes, tasksTimeMonitoring, FilteredEstimation
from .settings import FaceExtractorSettings
from .utils.recipes import grouper
[docs]class ExtractorState(BaseExtractorState):
"""
Extractor state, process local. Final class.
State contains:
- logger for worker
- instance of basic attributes estimator
- instance of FaceEngine
- maps a descriptor version to instance of extractor
"""
_basicAttributesEstimator: BasicAttributesEstimator
@property
def basicAttributesEstimator(self) -> BasicAttributesEstimator:
"""
Get basic attributes estimators.
Returns:
_basicAttributesEstimator
"""
return ExtractorState._basicAttributesEstimator
[docs] @classmethod
def initialize(cls, workerName: str, settings: FaceExtractorSettings) -> bool:
"""
Initialize state. Singleton.
Initialize FaceEngine, extractor and basic attributes.
Args:
workerName: worker name
settings: settings for worker
Returns:
True if it is first call of initialize (for process) otherwise False
"""
if not super().initialize(workerName, settings):
return False
cls._basicAttributesEstimator = cls._faceEngine.createBasicAttributesEstimator()
return True
[docs]class FaceExtractorSubTask(BaseExtractorSubTask):
"""
Face extractor sub task.
Attributes:
estimateBasicAttributes (bool): estimate or not basic attributes
"""
__slots__ = ("estimateBasicAttributes",)
def __init__(
self,
attributes: FaceAttributes,
warps: List[FaceWarpedImage],
taskId: int,
estimationId: Optional[str] = None,
estimateBasicAttributes: bool = False,
descriptorVersion: int = 0,
):
super().__init__(
attributes=attributes,
warps=warps,
taskId=taskId,
estimationId=estimationId,
descriptorVersion=descriptorVersion,
)
self.estimateBasicAttributes: bool = estimateBasicAttributes
[docs]class FaceExtractor(BaseExtractor[ExtractorState]):
"""
Face extractor
"""
# state class
_state = ExtractorState
[docs] def batchExtractBasicAttributes(self, subTasks: List[BaseExtractorSubTask]):
"""
Batch extract basic attributes for sub tasks with one warp
Batches are chunked to reduce FSDK memory consumption.
Args:
subTasks: sub tasks
"""
subTasksForEstimation = []
for subTask in subTasks:
if not subTask.estimateBasicAttributes:
continue
subTasksForEstimation.append(subTask)
if not subTasksForEstimation:
return
for subTaskChunk in grouper(subTasksForEstimation, self.state.settings.optimalBatchSize):
warps = [subTask.warps[0] for subTask in subTaskChunk]
basicAttributes, _ = self.state.basicAttributesEstimator.estimateBasicAttributesBatch(
warps, estimateAge=True, estimateGender=True, estimateEthnicity=True
)
for basicAttribute, subTask in zip(basicAttributes, subTaskChunk):
subTask.attributes.basicAttributes = basicAttribute.asDict()
[docs] def extractAggregateBasicAttribute(self, subTask: BaseExtractorSubTask):
"""
Batch extract descriptors for sub task with several warps
Args:
subTask: sub tasks
"""
_, aggregateAttributes = self.state.basicAttributesEstimator.estimateBasicAttributesBatch(
subTask.warps, estimateAge=True, estimateGender=True, estimateEthnicity=True, aggregate=True
)
subTask.attributes.basicAttributes = aggregateAttributes.asDict()
[docs] def batchExtractBasicAttributesWithAggregation(self, subTasks: List[BaseExtractorSubTask]):
"""
Batch extract basic attributes for sub tasks with several warps
Args:
subTasks: sub tasks
"""
for attribute in subTasks:
if not attribute.estimateBasicAttributes:
continue
self.extractAggregateBasicAttribute(attribute)
[docs]def getSubTasksForTasks(tasks: List[SDKTask]) -> Tuple[List[BaseExtractorSubTask], List[BaseExtractorSubTask]]:
"""
Separate tasks on sub tasks.
Args:
tasks: tasks
Returns:
tuple sub tasks without aggregation and with
"""
subTasksWithoutAggregation = []
subTasksWithAggregation = []
for task in tasks:
targets = task.toEstimation.faceEstimationTargets
if not task.aggregateAttributes or len(task.faceWarps) == 1:
for estimation in chain(*(image.estimations for image in task.images if not image.error)):
warp = estimation.face.warp
try:
subtask = FaceExtractorSubTask(
attributes=FaceAttributes(warps=[warp]),
warps=[FaceWarpedImage(warp.body)],
estimateBasicAttributes=bool(targets.estimateBasicAttributes),
descriptorVersion=targets.estimateFaceDescriptor,
taskId=task.taskId,
estimationId=estimation.id,
)
subTasksWithoutAggregation.append(subtask)
except ValueError as e:
warp.error = Error.BadWarpImage.format(e)
continue
else:
try:
subtask = FaceExtractorSubTask(
attributes=FaceAttributes(warps=task.faceWarps),
warps=[FaceWarpedImage(warp.body) for warp in task.faceWarps],
estimateBasicAttributes=bool(targets.estimateBasicAttributes),
descriptorVersion=targets.estimateFaceDescriptor,
taskId=task.taskId,
)
except ValueError as e:
task.error = Error.BadWarpImage.format(e)
else:
subTasksWithAggregation.append(subtask)
return subTasksWithoutAggregation, subTasksWithAggregation
[docs]def updateTaskWithFilteredResults(task: SDKTask, filtrationRes: dict, attribute: FaceAttributes):
"""
Update task with results filtered.
Args:
task: task
filtrationRes: filtration result
attribute: filtered attribute
"""
for image in task.images:
for index, estimation in enumerate(image.estimations):
if any((estimation.face.warp.sampleId == warp.sampleId for warp in attribute.warps)):
estimation.face.warp.isFiltered = True
estimation.face.filter = filtrationRes
estimation.face.extractedAttributes = attribute
task.filteredEstimations.append(FilteredEstimation(filename=image.filename, estimation=estimation.face))
image.estimations.pop(index)
[docs]def checkAttributesFiltered(task: SDKTask, attribute: FaceAttributes) -> bool:
"""
Check whether extracted attribute is filtered, and update task with filtered result if it is.
Attributes:
task: task
attribute: extracted attribute
Returns:
True if attribute is filtered, otherwise False
"""
if attribute.descriptor and task.filters.garbageScoreThreshold:
filtrationRes = task.filters.checkFilterByGS(attribute.descriptor["score"])
if filtrationRes["is_filtered"]:
attribute.filtered = True
updateTaskWithFilteredResults(task=task, filtrationRes=filtrationRes, attribute=attribute)
return True
return False
[docs]def extract(tasks: List[SDKTask]):
"""
Extract face attributes attributes.
Args:
tasks: tasks
Returns:
processed tasks
"""
state = ExtractorState()
state.logger.info(f"got {len(tasks)} tasks")
subTasksWithoutAggregation, subTasksWithAggregation = getSubTasksForTasks(tasks)
extractor = FaceExtractor()
if subTasksWithoutAggregation:
with tasksTimeMonitoring(fieldName="faceDescriptorExtractTime", tasks=tasks):
extractor.batchExtractDescriptors(subTasksWithoutAggregation)
with tasksTimeMonitoring(fieldName="basicAttributesExtractTime", tasks=tasks):
extractor.batchExtractBasicAttributes(subTasksWithoutAggregation)
if subTasksWithAggregation:
with tasksTimeMonitoring(fieldName="faceDescriptorExtractTime", tasks=tasks):
extractor.batchExtractDescriptorsWithAggregation(subTasksWithAggregation)
with tasksTimeMonitoring(fieldName="basicAttributesExtractTime", tasks=tasks):
extractor.batchExtractBasicAttributesWithAggregation(subTasksWithAggregation)
for task in tasks:
for subTask in subTasksWithoutAggregation + subTasksWithAggregation:
if task.taskId == subTask.taskId:
if subTask.error:
task.error = subTask.error
continue
if checkAttributesFiltered(task=task, attribute=subTask.attributes):
continue
if task.aggregateAttributes:
task.aggregatedEstimations.extraction.face = subTask.attributes
else:
task.updateEstimationsWithExtractionResult(subTask)
state.logger.info(f"performed {len(tasks)} tasks")
return tasks
[docs]def initWorker(settings: FaceExtractorSettings):
"""
Initialize extractor worker. Init logger, initialize FSDK, create basic attributes estimator and descriptor
extractors.
Args:
settings: detector settings
"""
ExtractorState.initialize("luna-handlers-f-extractor", settings=settings)
ExtractorState().logger.info("extractor worker is initialized")