"""
Module with VisualizeMasksGT hook which join image with dyed prediction layer.
"""
from typing import Tuple
import cv2
import numpy as np
import emloop as el
from .save_images import SaveImages
[docs]class VisualizeMasksGT(SaveImages):
"""
Join and save image with dyed segmentation mask prediction.
.. caution::
Saving all the images may require considerable amount of time. Use this hook only for debugging or
limit the number of saved images with ``image_count`` or ``batch_count`` parameters.
.. code-block:: yaml
:caption: Join image stored in variable ``imgs`` with dyed layer obtained from mask
stored in variable ``segments`` and predictions stored in variable ``results``.
hooks:
- blocks.hooks.VisualizeMasksGT:
variable: imgs
mask_variable: mask
predictions_variable: results
"""
[docs] def __init__(self, output_dir: str, mask_variable: str='mask', predictions_variable: str='predictions',
mask_opacity: float=0.8, mask_factor: float=255, predictions_factor: float=255,
color_true: Tuple[int, int, int]=(0, 255, 0), color_false_pos: Tuple[int, int, int]=(0, 0, 255),
color_false_neg: Tuple[int, int, int]=(255, 0, 0), **kwargs):
"""
Create new VisualizeMasksGT hook.
:param mask_variable: name of the variable representing the mask of image
:param predictions_variable: name of the variable representing the predictions (model output)
:param mask_opacity: opacity of the mask used for composition with image
:param mask_factor: the multiplier of the mask
:param predictions_factor: the multiplier of the predictions variable
:param color_true: color for the true positive and true negative pixels (BGR colorspace)
:param color_false_pos: color for the false positive pixels (BGR colorspace)
:param color_false_neg: color for the false negative pixels (BGR colorspace)
"""
super().__init__(output_dir, **kwargs)
self._mask_variable = mask_variable
self._predictions_variable = predictions_variable
self._mask_opacity = mask_opacity
self._mask_factor = mask_factor
self._predictions_factor = predictions_factor
self._color_true = color_true
self._color_false_pos = color_false_pos
self._color_false_neg = color_false_neg
@property
def image_suffix(self) -> str:
"""Image suffix in the form: ``<variable>_<mask_variable>_<predictions_variable>``."""
return "{}_{}_{}".format(super().image_suffix, self._mask_variable, self._predictions_variable)
[docs] def process_img(self, img_i: int, batch_data: el.Batch) -> np.ndarray:
"""
Join the image with the layer where pixels for true positive and true negative are dyed
to ``color_true``, pixels for false positive are dyed to ``color_false_pos`` and pixels for
false negative are dyed to ``color_false_neg``.
"""
image = (batch_data[self._variable][img_i] * self._factor).astype('uint8')
mask = (batch_data[self._mask_variable][img_i] * self._mask_factor).astype('uint8')
predict = (np.around(batch_data[self._predictions_variable][img_i]) * self._predictions_factor).astype('uint8')
assert mask.shape == predict.shape, ('The shape of `{}`: {} and `{}`: {} '
'must be the same.'.format(self._predictions_variable, predict.shape,
self._mask_variable, mask.shape))
image = self.to_color(image)
assert image.shape[:2] == predict.shape[:2], ('The first two ndims of `{}`: {} and `{}`: {} '
'must be the same.'.format(self._variable, image.shape,
self._mask_variable, mask.shape))
gt = np.zeros(image.shape, dtype='uint8')
gt[mask == predict] = self._color_true
gt[mask > predict] = self._color_false_neg
gt[mask < predict] = self._color_false_pos
joined = cv2.addWeighted(image, self._mask_opacity, gt, 1. - self._mask_opacity, 0)
return joined
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch):
"""Assert mask and prediction variables are in batch data."""
if self._mask_variable not in batch_data:
raise KeyError('Mask variable `{}` was not found in the batch data (of stream `{}`). '
'Available variables are `{}`.'.format(self._mask_variable, stream_name,
batch_data.keys()))
if self._predictions_variable not in batch_data:
raise KeyError('Predictions variable `{}` was not found in the batch data (of stream `{}`). '
'Available variables are `{}`.'.format(self._predictions_variable, stream_name,
batch_data.keys()))
super().after_batch(stream_name, batch_data)