Source code for blocks.hooks.visualize_thresholded_prediction

"""
Save images overlayed by their semi-transparent prediction masks for each of the given thresholds.
"""

import os
from typing import Tuple, Optional, Iterable

import numpy as np
import emloop as el

try:
    import cv2
except ImportError as ex:
    raise ImportError('This hook requires OpenCV.') from ex


[docs]class VisualizeThresholdedPrediction(el.AbstractHook): """ Save images overlayed by their semi-transparent prediction masks for each of the given thresholds. The entire batch is saved as a single image matrix where columns correspond to given thresholds and rows to images in batch. """
[docs] def __init__(self, output_dir: str, image_variable: str, probability_variable: str, n_batches: int=1, n_epochs: int=1, thresholds: Optional[Iterable[float]]=None, mask_opacity: float=0.5, color: Tuple[int,int,int]=(0, 255, 0), **kwargs): """ :param output_dir: output directory where masks will be saved :param image_variable: name of the variable representing the source image :param probability_variable: name of the variable representing the probability image :param n_batches: count of batches from which the masks will be saved :param n_epochs: count of epochs from which the masks will be saved :param thresholds: list of probability thresholds from which the masks will be saved :param mask_opacity: opacity of the mask used for composition with image :param color: color of the mask used for composition with image """ self._output_dir = output_dir self._image_variable = image_variable self._probability_variable = probability_variable self._n_batches = n_batches self._n_epochs = n_epochs self._thresholds = thresholds if thresholds is not None else np.arange(0.1, 1, 0.1) self._mask_opacity = mask_opacity self._color = color self._batch_count = 0 self._current_epoch = 0 super().__init__(**kwargs)
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch): """Save the masks for each example.""" super().after_batch(stream_name=stream_name, batch_data=batch_data) if self._current_epoch >= self._n_epochs: return if self._batch_count >= self._n_batches: return self._batch_count += 1 assert self._image_variable in batch_data assert self._probability_variable in batch_data assert len(batch_data[self._image_variable]) == len(batch_data[self._probability_variable]) col_imgs = [] for image, probability in zip(batch_data[self._image_variable], batch_data[self._probability_variable]): assert image.shape[:2] == probability.shape[:2] row_imgs = [] for threshold in self._thresholds: color_mask = image.copy() color_mask[probability > threshold] = self._color joined = cv2.addWeighted(image, self._mask_opacity, color_mask, 1. - self._mask_opacity, 0) row_imgs.append(joined) col_imgs.append(np.hstack(row_imgs)) result = np.vstack(col_imgs) save_dir = os.path.join(self._output_dir, stream_name, 'epoch_{}'.format(str(self._current_epoch))) os.makedirs(save_dir, exist_ok=True) cv2.imwrite(os.path.join(save_dir, '{}_{}.png'.format(stream_name, self._batch_count)), result) return result
[docs] def after_epoch(self, epoch_id: int, **kwargs): self._batch_count = 0 self._current_epoch = epoch_id + 1 super().after_epoch(epoch_id=epoch_id, **kwargs)