Source code for blocks.hooks.save_numpy
"""
Dump Numpy arrays.
"""
import os
import numpy as np
import emloop as el
[docs]class SaveNumpy(el.AbstractHook):
"""
Dump numpy arrays.
"""
[docs] def __init__(self, variable: str, id_variable: str, output_dir: str, **kwargs):
"""
:param variable: name of the variable representing the numpy array
:param id_variable: name of the variable which represents a unique example id
:param output_dir: output directory where masks will be saved
"""
super().__init__(**kwargs)
self._variable = variable
self._id_variable = id_variable
self._output_dir = output_dir
self._batch_count = 0
self._current_epoch = 0
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch):
"""Save the masks for each example."""
assert self._variable in batch_data
assert self._id_variable in batch_data
assert len(batch_data[self._variable]) == len(batch_data[self._id_variable])
self._batch_count += 1
out_dir = os.path.join(self._output_dir, stream_name, 'epoch_{}'.format(str(self._current_epoch)))
os.makedirs(out_dir, exist_ok=True)
var_name = self._variable.replace(os.sep, '_')
for arr, arr_id in zip(batch_data[self._variable], batch_data[self._id_variable]):
name = arr_id.replace(os.sep, '___')
np.save(file=os.path.join(out_dir, f'{name}_batch_{self._batch_count}_{var_name}.npy'),
arr=arr,
allow_pickle=False)
super().after_batch(stream_name=stream_name, batch_data=batch_data)
[docs] def after_epoch(self, epoch_id: int, **kwargs):
"""
Set ``_current_epoch_id`` which is used to distinguish between epoch directories.
Reset ``_batch_count`` to initial value.
"""
self._current_epoch = epoch_id + 1
self._batch_count = 0
super().after_epoch(epoch_id=epoch_id, **kwargs)