Source code for blocks.hooks.visualize_masks
"""
Module with VisualizeMasks hook providing visualization of images and segmentation masks.
"""
from typing import Tuple
import cv2
import numpy as np
import emloop as el
from .save_images import SaveImages
[docs]class VisualizeMasks(SaveImages):
"""
Join and save images with the corresponding segmentation masks after each batch.
.. caution::
Saving all the images may require considerable amount of time. Use this hook only for debugging or
limit the number of saved images with ``image_count`` or ``batch_count`` parameters.
.. code-block:: yaml
:caption: Join and save image stored in variable ``imgs`` with mask stored in variable ``segments``.
hooks:
- blocks.hooks.VisualizeMasks:
variable: imgs
mask_variable: mask
"""
[docs] def __init__(self, output_dir: str, mask_variable: str='mask', mask_factor: float=255,
mask_opacity: float=0.3, color: Tuple[int, int, int]=(0, 255, 0), **kwargs):
"""
Create new VisualizeMasks hook.
:param mask_variable: name of the variable representing the mask of image
:param mask_factor: constant by which the mask is multiplied
:param mask_opacity: opacity of the mask used for composition with image
:param color: color of the mask used in the resulting image (BGR colorspace)
"""
super().__init__(output_dir, **kwargs)
self._mask_variable = mask_variable
self._mask_factor = mask_factor
self._mask_opacity = mask_opacity
self._color = color
@property
def image_suffix(self) -> str:
"""Image suffix in the form: ``<variable>_<mask_variable>``."""
return "{}_{}".format(super().image_suffix, self._mask_variable)
[docs] def process_img(self, img_i: int, batch_data: el.Batch) -> np.ndarray:
"""
Join the image with dyed mask.
Only the white pixels of the mask are dyed (which satisfies: mask_px * mask_factor = 255).
"""
image = (batch_data[self._variable][img_i] * self._factor).astype('uint8')
mask = (batch_data[self._mask_variable][img_i] * self._mask_factor).astype('uint8')
image = self.to_color(image)
color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
color_mask[np.nonzero(mask)] = self._color
assert color_mask.shape == image.shape, ('The shape of `{}`: {} and `{}`: {} '
'must be the same.'.format(self._variable, image.shape,
self._mask_variable, color_mask.shape))
joined = cv2.addWeighted(image, self._mask_opacity, color_mask, 1. - self._mask_opacity, 0)
return joined
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch):
"""Assert mask variable is in batch data."""
if self._mask_variable not in batch_data:
raise KeyError('Variable `{}` to be visualized was not found in the batch data (of stream `{}`). '
'Available variables are `{}`.'.format(self._mask_variable, stream_name,
batch_data.keys()))
super().after_batch(stream_name, batch_data)