Source code for blocks.hooks.save_images

"""
Module with SaveImages hook which saves images.
"""

import os
import collections
from typing import Iterable
import logging

import cv2
import numpy as np
import emloop as el


[docs]class SaveImages(el.AbstractHook): """ Save images of the provided streams. This hook can be used as a base class for the hooks, which need systematically save their results. For this purpose, ``process_img`` method and ``image_suffix`` property should be overridden. .. caution:: Saving all the images may require considerable amount of time. Use this hook only for debugging or limit the number of saved images with ``image_count`` or ``batch_count`` parameters. .. code-block:: yaml :caption: Save images from test and valid streams, stored in variable x. hooks: - blocks.hooks.SaveImages: variable: x streams: [test, valid] .. code-block:: yaml :caption: Save the first two images from the first ten batches (from the train stream). hooks: - blocks.hooks.SaveImages: img_count: 2 batch_count: 10 """ _ROOT_DIR = 'visual'
[docs] def __init__(self, output_dir: str, streams: Iterable[str]=['train'], variable: str='images', id_variable: str='ids', out_format: str='png', factor: float=1, img_count: int=None, batch_count: int=None, **kwargs): """ :param output_dir: output directory where images will be saved :param streams: list of stream names to be dumped :param variable: name of the variable representing the source image :param id_variable: name of the variable which represents a unique example id :param out_format: extension of the saved image :param factor: a constant by which the image is multiplied :param img_count: count of images which will be saved from each batch (first ``img_count`` images will be saved) :param batch_count: count of batches from which the images will be saved (first ``batch_count`` will be processed) """ super().__init__(**kwargs) self._output_dir = output_dir self._streams = streams self._variable = variable self._id_variable = id_variable self._out_format = out_format self._factor = factor self._img_count = img_count self._batch_count = batch_count self._current_epoch_id = '_' self._reset()
@property def image_suffix(self) -> str: """The suffix of the saved image, used to distinguish between images from different hooks.""" return self._variable
[docs] def process_img(self, img_i: int, batch_data: el.Batch) -> np.ndarray: """ This method is called in ``after_batch`` method and its purpose is to prepare image for save. If convenient, this method can be overridden in a subclass. """ return (batch_data[self._variable][img_i] * self._factor).astype(np.uint8)
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch): """ Save images in provided streams from selected variable. The amount of batches and images to be processed is possible to control by ``batch_count`` and ``img_count`` parameters. """ if stream_name in self._streams: # assert variables in batch data if self._id_variable not in batch_data: raise KeyError('Variable `{}` to be used as a unique id was not found in the batch data ' '(of stream `{}`). Available variables are `{}`.'.format(self._id_variable, stream_name, batch_data.keys())) if self._variable not in batch_data: raise KeyError('Variable `{}` to be saved was not found in the batch data (of stream `{}`). ' 'Available variables are `{}`.'.format(self._variable, stream_name, batch_data.keys())) self._batch_done[stream_name] += 1 if self._batch_count and self._batch_done[stream_name] > self._batch_count: return stream_out_dir = os.path.join(self._output_dir, self._ROOT_DIR, 'epoch_{}'.format(self._current_epoch_id), stream_name) os.makedirs(stream_out_dir, exist_ok=True) for i, ex_id in enumerate(batch_data[self._id_variable]): if self._img_count and i + 1 > self._img_count: break img_name = '{}_batch_{}_{}.{}'.format(ex_id, self._batch_done[stream_name], self.image_suffix, self._out_format) img_name = img_name.replace(os.sep, '___') image_path = os.path.join(stream_out_dir, img_name) image = self.process_img(i, batch_data) success = cv2.imwrite(image_path, image) if not success: logging.error('Cannot save image `%s`', image_path)
[docs] def after_epoch(self, epoch_id: int, **_): """ Set ``_current_epoch_id`` which is used for distinguish between epochs directories. Call the ``_reset`` function. """ self._current_epoch_id = epoch_id + 1 self._reset()
[docs] def _reset(self) -> None: """Reset ``_batch_count`` to initial value.""" self._batch_done = collections.defaultdict(lambda: 0)
[docs] def to_color(self, image): """If the ``image`` is in grayscale, it will be converted to RGB.""" if not (len(image.shape) == 3 and image.shape[2] == 3): return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) return image