Source code for blocks.hooks.save_numpy

"""
Dump Numpy arrays.
"""

import os

import numpy as np
import emloop as el


[docs]class SaveNumpy(el.AbstractHook): """ Dump numpy arrays. """
[docs] def __init__(self, variable: str, id_variable: str, output_dir: str, **kwargs): """ :param variable: name of the variable representing the numpy array :param id_variable: name of the variable which represents a unique example id :param output_dir: output directory where masks will be saved """ super().__init__(**kwargs) self._variable = variable self._id_variable = id_variable self._output_dir = output_dir self._batch_count = 0 self._current_epoch = 0
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch): """Save the masks for each example.""" assert self._variable in batch_data assert self._id_variable in batch_data assert len(batch_data[self._variable]) == len(batch_data[self._id_variable]) self._batch_count += 1 out_dir = os.path.join(self._output_dir, stream_name, 'epoch_{}'.format(str(self._current_epoch))) os.makedirs(out_dir, exist_ok=True) var_name = self._variable.replace(os.sep, '_') for arr, arr_id in zip(batch_data[self._variable], batch_data[self._id_variable]): name = arr_id.replace(os.sep, '___') np.save(file=os.path.join(out_dir, f'{name}_batch_{self._batch_count}_{var_name}.npy'), arr=arr, allow_pickle=False) super().after_batch(stream_name=stream_name, batch_data=batch_data)
[docs] def after_epoch(self, epoch_id: int, **kwargs): """ Set ``_current_epoch_id`` which is used to distinguish between epoch directories. Reset ``_batch_count`` to initial value. """ self._current_epoch = epoch_id + 1 self._batch_count = 0 super().after_epoch(epoch_id=epoch_id, **kwargs)