Source code for blocks.utils.ctc

from typing import Tuple

import numpy as np

__all__ = ['decode_ctc_prediction', 'decode_ctc_prediction_with_eos']


[docs]def decode_ctc_prediction(ctc_prediction: np.ndarray) -> Tuple[list, list]: """ Decode raw CTC prediction to the actual output assuming that the last prediction bit is reserved for the *empty* character. :param ctc_prediction: 2d array of time x probabilities :return: a tuple of (decoded predictions, indices of emitted characters) """ empty = ctc_prediction.shape[1] - 1 output = [] indices = [] prediction = np.argmax(ctc_prediction, axis=-1) last_char = empty for i, char in enumerate(prediction): if char != last_char and char != empty: output.append(char) indices.append(i) last_char = char return output, indices
[docs]def decode_ctc_prediction_with_eos(ctc_prediction): """ Same as :py:func:`decode_ctc_prediction` but cut the prediction (and indices) at <EOS> character assuming that <EOS> = dim-2. :param ctc_prediction: 2d array of time x probabilities :return: a tuple of (decoded predictions, indices of emitted characters) """ eos = ctc_prediction.shape[1] - 2 # dim-2 characters, <EOS> and <blank> prediction, indices = decode_ctc_prediction(ctc_prediction) if eos in prediction: eos_ix = prediction.index(eos) + 1 prediction = prediction[:eos_ix] indices = indices[:eos_ix] return prediction, indices