Source code for blocks.hooks.visualize_masks_gt

"""
Module with VisualizeMasksGT hook which join image with dyed prediction layer.
"""

from typing import Tuple

import cv2
import numpy as np
import emloop as el

from .save_images import SaveImages


[docs]class VisualizeMasksGT(SaveImages): """ Join and save image with dyed segmentation mask prediction. .. caution:: Saving all the images may require considerable amount of time. Use this hook only for debugging or limit the number of saved images with ``image_count`` or ``batch_count`` parameters. .. code-block:: yaml :caption: Join image stored in variable ``imgs`` with dyed layer obtained from mask stored in variable ``segments`` and predictions stored in variable ``results``. hooks: - blocks.hooks.VisualizeMasksGT: variable: imgs mask_variable: mask predictions_variable: results """
[docs] def __init__(self, output_dir: str, mask_variable: str='mask', predictions_variable: str='predictions', mask_opacity: float=0.8, mask_factor: float=255, predictions_factor: float=255, color_true: Tuple[int, int, int]=(0, 255, 0), color_false_pos: Tuple[int, int, int]=(0, 0, 255), color_false_neg: Tuple[int, int, int]=(255, 0, 0), **kwargs): """ Create new VisualizeMasksGT hook. :param mask_variable: name of the variable representing the mask of image :param predictions_variable: name of the variable representing the predictions (model output) :param mask_opacity: opacity of the mask used for composition with image :param mask_factor: the multiplier of the mask :param predictions_factor: the multiplier of the predictions variable :param color_true: color for the true positive and true negative pixels (BGR colorspace) :param color_false_pos: color for the false positive pixels (BGR colorspace) :param color_false_neg: color for the false negative pixels (BGR colorspace) """ super().__init__(output_dir, **kwargs) self._mask_variable = mask_variable self._predictions_variable = predictions_variable self._mask_opacity = mask_opacity self._mask_factor = mask_factor self._predictions_factor = predictions_factor self._color_true = color_true self._color_false_pos = color_false_pos self._color_false_neg = color_false_neg
@property def image_suffix(self) -> str: """Image suffix in the form: ``<variable>_<mask_variable>_<predictions_variable>``.""" return "{}_{}_{}".format(super().image_suffix, self._mask_variable, self._predictions_variable)
[docs] def process_img(self, img_i: int, batch_data: el.Batch) -> np.ndarray: """ Join the image with the layer where pixels for true positive and true negative are dyed to ``color_true``, pixels for false positive are dyed to ``color_false_pos`` and pixels for false negative are dyed to ``color_false_neg``. """ image = (batch_data[self._variable][img_i] * self._factor).astype('uint8') mask = (batch_data[self._mask_variable][img_i] * self._mask_factor).astype('uint8') predict = (np.around(batch_data[self._predictions_variable][img_i]) * self._predictions_factor).astype('uint8') assert mask.shape == predict.shape, ('The shape of `{}`: {} and `{}`: {} ' 'must be the same.'.format(self._predictions_variable, predict.shape, self._mask_variable, mask.shape)) image = self.to_color(image) assert image.shape[:2] == predict.shape[:2], ('The first two ndims of `{}`: {} and `{}`: {} ' 'must be the same.'.format(self._variable, image.shape, self._mask_variable, mask.shape)) gt = np.zeros(image.shape, dtype='uint8') gt[mask == predict] = self._color_true gt[mask > predict] = self._color_false_neg gt[mask < predict] = self._color_false_pos joined = cv2.addWeighted(image, self._mask_opacity, gt, 1. - self._mask_opacity, 0) return joined
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch): """Assert mask and prediction variables are in batch data.""" if self._mask_variable not in batch_data: raise KeyError('Mask variable `{}` was not found in the batch data (of stream `{}`). ' 'Available variables are `{}`.'.format(self._mask_variable, stream_name, batch_data.keys())) if self._predictions_variable not in batch_data: raise KeyError('Predictions variable `{}` was not found in the batch data (of stream `{}`). ' 'Available variables are `{}`.'.format(self._predictions_variable, stream_name, batch_data.keys())) super().after_batch(stream_name, batch_data)