Source code for blocks.hooks.save_images
"""
Module with SaveImages hook which saves images.
"""
import os
import collections
from typing import Iterable
import logging
import cv2
import numpy as np
import emloop as el
[docs]class SaveImages(el.AbstractHook):
"""
Save images of the provided streams.
This hook can be used as a base class for the hooks, which need systematically save their results.
For this purpose, ``process_img`` method and ``image_suffix`` property should be overridden.
.. caution::
Saving all the images may require considerable amount of time. Use this hook only for debugging or
limit the number of saved images with ``image_count`` or ``batch_count`` parameters.
.. code-block:: yaml
:caption: Save images from test and valid streams, stored in variable x.
hooks:
- blocks.hooks.SaveImages:
variable: x
streams: [test, valid]
.. code-block:: yaml
:caption: Save the first two images from the first ten batches (from the train stream).
hooks:
- blocks.hooks.SaveImages:
img_count: 2
batch_count: 10
"""
_ROOT_DIR = 'visual'
[docs] def __init__(self, output_dir: str, streams: Iterable[str]=['train'], variable: str='images',
id_variable: str='ids', out_format: str='png', factor: float=1, img_count: int=None,
batch_count: int=None, **kwargs):
"""
:param output_dir: output directory where images will be saved
:param streams: list of stream names to be dumped
:param variable: name of the variable representing the source image
:param id_variable: name of the variable which represents a unique example id
:param out_format: extension of the saved image
:param factor: a constant by which the image is multiplied
:param img_count: count of images which will be saved from each batch (first ``img_count`` images will be saved)
:param batch_count: count of batches from which the images will be saved (first ``batch_count`` will be processed)
"""
super().__init__(**kwargs)
self._output_dir = output_dir
self._streams = streams
self._variable = variable
self._id_variable = id_variable
self._out_format = out_format
self._factor = factor
self._img_count = img_count
self._batch_count = batch_count
self._current_epoch_id = '_'
self._reset()
@property
def image_suffix(self) -> str:
"""The suffix of the saved image, used to distinguish between images from different hooks."""
return self._variable
[docs] def process_img(self, img_i: int, batch_data: el.Batch) -> np.ndarray:
"""
This method is called in ``after_batch`` method and its purpose is to prepare image for save.
If convenient, this method can be overridden in a subclass.
"""
return (batch_data[self._variable][img_i] * self._factor).astype(np.uint8)
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch):
"""
Save images in provided streams from selected variable. The amount of batches and images to be processed is
possible to control by ``batch_count`` and ``img_count`` parameters.
"""
if stream_name in self._streams:
# assert variables in batch data
if self._id_variable not in batch_data:
raise KeyError('Variable `{}` to be used as a unique id was not found in the batch data '
'(of stream `{}`). Available variables are `{}`.'.format(self._id_variable, stream_name,
batch_data.keys()))
if self._variable not in batch_data:
raise KeyError('Variable `{}` to be saved was not found in the batch data (of stream `{}`). '
'Available variables are `{}`.'.format(self._variable, stream_name,
batch_data.keys()))
self._batch_done[stream_name] += 1
if self._batch_count and self._batch_done[stream_name] > self._batch_count:
return
stream_out_dir = os.path.join(self._output_dir, self._ROOT_DIR,
'epoch_{}'.format(self._current_epoch_id), stream_name)
os.makedirs(stream_out_dir, exist_ok=True)
for i, ex_id in enumerate(batch_data[self._id_variable]):
if self._img_count and i + 1 > self._img_count:
break
img_name = '{}_batch_{}_{}.{}'.format(ex_id, self._batch_done[stream_name],
self.image_suffix, self._out_format)
img_name = img_name.replace(os.sep, '___')
image_path = os.path.join(stream_out_dir, img_name)
image = self.process_img(i, batch_data)
success = cv2.imwrite(image_path, image)
if not success:
logging.error('Cannot save image `%s`', image_path)
[docs] def after_epoch(self, epoch_id: int, **_):
"""
Set ``_current_epoch_id`` which is used for distinguish between epochs directories.
Call the ``_reset`` function.
"""
self._current_epoch_id = epoch_id + 1
self._reset()
[docs] def _reset(self) -> None:
"""Reset ``_batch_count`` to initial value."""
self._batch_done = collections.defaultdict(lambda: 0)
[docs] def to_color(self, image):
"""If the ``image`` is in grayscale, it will be converted to RGB."""
if not (len(image.shape) == 3 and image.shape[2] == 3):
return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
return image