Source code for blocks.hooks.visualize_thresholded_prediction
"""
Save images overlayed by their semi-transparent prediction masks for each of the given thresholds.
"""
import os
from typing import Tuple, Optional, Iterable
import numpy as np
import emloop as el
try:
import cv2
except ImportError as ex:
raise ImportError('This hook requires OpenCV.') from ex
[docs]class VisualizeThresholdedPrediction(el.AbstractHook):
"""
Save images overlayed by their semi-transparent prediction masks for each of the given thresholds. The entire batch is saved as a single image matrix where
columns correspond to given thresholds and rows to images in batch.
"""
[docs] def __init__(self, output_dir: str, image_variable: str, probability_variable: str,
n_batches: int=1, n_epochs: int=1, thresholds: Optional[Iterable[float]]=None,
mask_opacity: float=0.5, color: Tuple[int,int,int]=(0, 255, 0), **kwargs):
"""
:param output_dir: output directory where masks will be saved
:param image_variable: name of the variable representing the source image
:param probability_variable: name of the variable representing the probability image
:param n_batches: count of batches from which the masks will be saved
:param n_epochs: count of epochs from which the masks will be saved
:param thresholds: list of probability thresholds from which the masks will be saved
:param mask_opacity: opacity of the mask used for composition with image
:param color: color of the mask used for composition with image
"""
self._output_dir = output_dir
self._image_variable = image_variable
self._probability_variable = probability_variable
self._n_batches = n_batches
self._n_epochs = n_epochs
self._thresholds = thresholds if thresholds is not None else np.arange(0.1, 1, 0.1)
self._mask_opacity = mask_opacity
self._color = color
self._batch_count = 0
self._current_epoch = 0
super().__init__(**kwargs)
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch):
"""Save the masks for each example."""
super().after_batch(stream_name=stream_name, batch_data=batch_data)
if self._current_epoch >= self._n_epochs:
return
if self._batch_count >= self._n_batches:
return
self._batch_count += 1
assert self._image_variable in batch_data
assert self._probability_variable in batch_data
assert len(batch_data[self._image_variable]) == len(batch_data[self._probability_variable])
col_imgs = []
for image, probability in zip(batch_data[self._image_variable], batch_data[self._probability_variable]):
assert image.shape[:2] == probability.shape[:2]
row_imgs = []
for threshold in self._thresholds:
color_mask = image.copy()
color_mask[probability > threshold] = self._color
joined = cv2.addWeighted(image, self._mask_opacity, color_mask, 1. - self._mask_opacity, 0)
row_imgs.append(joined)
col_imgs.append(np.hstack(row_imgs))
result = np.vstack(col_imgs)
save_dir = os.path.join(self._output_dir, stream_name, 'epoch_{}'.format(str(self._current_epoch)))
os.makedirs(save_dir, exist_ok=True)
cv2.imwrite(os.path.join(save_dir, '{}_{}.png'.format(stream_name, self._batch_count)), result)
return result
[docs] def after_epoch(self, epoch_id: int, **kwargs):
self._batch_count = 0
self._current_epoch = epoch_id + 1
super().after_epoch(epoch_id=epoch_id, **kwargs)