Source code for blocks.hooks.threshold
"""
Threshold image.
"""
import numpy as np
import emloop as el
[docs]class Threshold(el.AbstractHook):
"""
Threshold image.
"""
[docs] def __init__(self, input_variable: str, output_variable: str, threshold: float, **kwargs):
"""
:param input_variable: name of the variable representing the probability image
:param output_variable: name of the variable to which threshold image will be saved
:param threshold: the values above this threshold will be set to 1 and the rest to 0
"""
self._input_variable = input_variable
self._output_variable = output_variable
self._threshold = threshold
super().__init__(**kwargs)
[docs] def after_batch(self, stream_name: str, batch_data: el.Batch):
assert self._input_variable in batch_data
assert self._output_variable not in batch_data
batch_data[self._output_variable] = [(matrix > self._threshold).astype(np.uint8)
for matrix in batch_data[self._input_variable]]
super().after_batch(stream_name=stream_name, batch_data=batch_data)