Source code for blocks.hooks.visualize_masks

"""
Module with VisualizeMasks hook providing visualization of images and segmentation masks.
"""

from typing import Tuple

import cv2
import numpy as np
import emloop as el

from .save_images import SaveImages


[docs]class VisualizeMasks(SaveImages): """ Join and save images with the corresponding segmentation masks after each batch. .. caution:: Saving all the images may require considerable amount of time. Use this hook only for debugging or limit the number of saved images with ``image_count`` or ``batch_count`` parameters. .. code-block:: yaml :caption: Join and save image stored in variable ``imgs`` with mask stored in variable ``segments``. hooks: - blocks.hooks.VisualizeMasks: variable: imgs mask_variable: mask """
[docs] def __init__(self, output_dir: str, mask_variable: str='mask', mask_factor: float=255, mask_opacity: float=0.3, color: Tuple[int, int, int]=(0, 255, 0), **kwargs): """ Create new VisualizeMasks hook. :param mask_variable: name of the variable representing the mask of image :param mask_factor: constant by which the mask is multiplied :param mask_opacity: opacity of the mask used for composition with image :param color: color of the mask used in the resulting image (BGR colorspace) """ super().__init__(output_dir, **kwargs) self._mask_variable = mask_variable self._mask_factor = mask_factor self._mask_opacity = mask_opacity self._color = color
@property def image_suffix(self) -> str: """Image suffix in the form: ``<variable>_<mask_variable>``.""" return "{}_{}".format(super().image_suffix, self._mask_variable)
[docs] def process_img(self, img_i: int, batch_data: el.Batch) -> np.ndarray: """ Join the image with dyed mask. Only the white pixels of the mask are dyed (which satisfies: mask_px * mask_factor = 255). """ image = (batch_data[self._variable][img_i] * self._factor).astype('uint8') mask = (batch_data[self._mask_variable][img_i] * self._mask_factor).astype('uint8') image = self.to_color(image) color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) color_mask[np.nonzero(mask)] = self._color assert color_mask.shape == image.shape, ('The shape of `{}`: {} and `{}`: {} ' 'must be the same.'.format(self._variable, image.shape, self._mask_variable, color_mask.shape)) joined = cv2.addWeighted(image, self._mask_opacity, color_mask, 1. - self._mask_opacity, 0) return joined
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch): """Assert mask variable is in batch data.""" if self._mask_variable not in batch_data: raise KeyError('Variable `{}` to be visualized was not found in the batch data (of stream `{}`). ' 'Available variables are `{}`.'.format(self._mask_variable, stream_name, batch_data.keys())) super().after_batch(stream_name, batch_data)