Source code for blocks.hooks.compute_ctc_stats

import numpy as np
from emloop.hooks import AccumulateVariables

from ..utils.ctc import decode_ctc_prediction, decode_ctc_prediction_with_eos


[docs]class ComputeCTCStats(AccumulateVariables): """ Accumulate CTC labels, labels_mask and predictions in order to compute sentence error rate after each epoch. - **labels**: 2d array of zero-padded target labels - **labels_mask**: 2d array of labels masks wherein ones mark the valid values - **predictions**: 3d array of predictions of shape (batch x position x probs) wherein probs have dim ``num_classes + 1`` .. code-block:: yaml :caption: compute ctc stats for each epoch hooks: - blocks.hooks.ComputeCTCStats """
[docs] def __init__(self, labels: str='labels', predictions: str='predictions', labels_mask: str='labels_mask', decode_fn=(lambda x: decode_ctc_prediction(x)[0]), **kwargs): """ Create new ComputeCTCStats hook. :param labels: ``labels`` variable name :param predictions: ``predictions`` variable name :param labels_mask: ``labels_mask`` variable name :param decode_fn: function for decoding raw CTC output, :py:func:`decode_ctc_prediction` by default """ self._decode_fn = decode_fn variables = self._labels, self._predictions, self._labels_mask = (labels, predictions, labels_mask) super().__init__(variables, **kwargs)
[docs] @staticmethod def _edit_distance(s1: str, s2: str): """Compute Levenshtein edit distance between two strings.""" m = len(s1) + 1 n = len(s2) + 1 tbl = {} for i in range(m): tbl[i, 0] = i for j in range(n): tbl[0, j] = j for i in range(1, m): for j in range(1, n): cost = 0 if s1[i - 1] == s2[j - 1] else 1 tbl[i, j] = min(tbl[i, j - 1] + 1, tbl[i - 1, j] + 1, tbl[i - 1, j - 1] + cost) return tbl[i, j]
[docs] def after_epoch(self, epoch_data, **kwargs) -> None: """ Compute sentence error rate and mean edit distance for each stream and save it to the ``epoch_data`` as ``sentence_error`` and ``edit_distance``, respectively. :param epoch_data: epoch data """ for stream in epoch_data: stream_data = self._accumulator[stream] labels = [label[mask > 0] for label, mask in zip(stream_data[self._labels], stream_data[self._labels_mask])] predictions = [self._decode_fn(prediction) for prediction in stream_data[self._predictions]] epoch_data[stream]['sentence_error'] = 1 - np.mean([np.array_equal(label, prediction) for label, prediction in zip(labels, predictions)]) epoch_data[stream]['edit_distance'] = np.mean([ComputeCTCStats._edit_distance(label, prediction) for label, prediction in zip(labels, predictions)]) super().after_epoch(epoch_data=epoch_data, **kwargs)
[docs]class ComputeCTCStatsWithEOS(ComputeCTCStats): """Same as :py:class:`ComputeCTCStats` but use :py:func:`decode_ctc_prediction_with_eos` for decoding."""
[docs] def __init__(self, **kwargs): """Create new ComputeCTCStatsWithEOS hook.""" super().__init__(decode_fn=(lambda x: decode_ctc_prediction_with_eos(x)[0]), **kwargs)