Source code for blocks.utils.ctc
from typing import Tuple
import numpy as np
__all__ = ['decode_ctc_prediction', 'decode_ctc_prediction_with_eos']
[docs]def decode_ctc_prediction(ctc_prediction: np.ndarray) -> Tuple[list, list]:
"""
Decode raw CTC prediction to the actual output assuming that the last prediction bit
is reserved for the *empty* character.
:param ctc_prediction: 2d array of time x probabilities
:return: a tuple of (decoded predictions, indices of emitted characters)
"""
empty = ctc_prediction.shape[1] - 1
output = []
indices = []
prediction = np.argmax(ctc_prediction, axis=-1)
last_char = empty
for i, char in enumerate(prediction):
if char != last_char and char != empty:
output.append(char)
indices.append(i)
last_char = char
return output, indices
[docs]def decode_ctc_prediction_with_eos(ctc_prediction):
"""
Same as :py:func:`decode_ctc_prediction` but cut the prediction (and indices) at <EOS> character assuming that
<EOS> = dim-2.
:param ctc_prediction: 2d array of time x probabilities
:return: a tuple of (decoded predictions, indices of emitted characters)
"""
eos = ctc_prediction.shape[1] - 2 # dim-2 characters, <EOS> and <blank>
prediction, indices = decode_ctc_prediction(ctc_prediction)
if eos in prediction:
eos_ix = prediction.index(eos) + 1
prediction = prediction[:eos_ix]
indices = indices[:eos_ix]
return prediction, indices