Source code for blocks.hooks.compute_ctc_stats
import numpy as np
from emloop.hooks import AccumulateVariables
from ..utils.ctc import decode_ctc_prediction, decode_ctc_prediction_with_eos
[docs]class ComputeCTCStats(AccumulateVariables):
"""
Accumulate CTC labels, labels_mask and predictions in order to compute sentence error rate after each epoch.
- **labels**: 2d array of zero-padded target labels
- **labels_mask**: 2d array of labels masks wherein ones mark the valid values
- **predictions**: 3d array of predictions of shape (batch x position x probs) wherein probs have dim ``num_classes + 1``
.. code-block:: yaml
:caption: compute ctc stats for each epoch
hooks:
- blocks.hooks.ComputeCTCStats
"""
[docs] def __init__(self, labels: str='labels', predictions: str='predictions', labels_mask: str='labels_mask',
decode_fn=(lambda x: decode_ctc_prediction(x)[0]), **kwargs):
"""
Create new ComputeCTCStats hook.
:param labels: ``labels`` variable name
:param predictions: ``predictions`` variable name
:param labels_mask: ``labels_mask`` variable name
:param decode_fn: function for decoding raw CTC output, :py:func:`decode_ctc_prediction` by default
"""
self._decode_fn = decode_fn
variables = self._labels, self._predictions, self._labels_mask = (labels, predictions, labels_mask)
super().__init__(variables, **kwargs)
[docs] @staticmethod
def _edit_distance(s1: str, s2: str):
"""Compute Levenshtein edit distance between two strings."""
m = len(s1) + 1
n = len(s2) + 1
tbl = {}
for i in range(m): tbl[i, 0] = i
for j in range(n): tbl[0, j] = j
for i in range(1, m):
for j in range(1, n):
cost = 0 if s1[i - 1] == s2[j - 1] else 1
tbl[i, j] = min(tbl[i, j - 1] + 1, tbl[i - 1, j] + 1, tbl[i - 1, j - 1] + cost)
return tbl[i, j]
[docs] def after_epoch(self, epoch_data, **kwargs) -> None:
"""
Compute sentence error rate and mean edit distance for each stream and save it to the ``epoch_data`` as
``sentence_error`` and ``edit_distance``, respectively.
:param epoch_data: epoch data
"""
for stream in epoch_data:
stream_data = self._accumulator[stream]
labels = [label[mask > 0] for label, mask in zip(stream_data[self._labels], stream_data[self._labels_mask])]
predictions = [self._decode_fn(prediction) for prediction in stream_data[self._predictions]]
epoch_data[stream]['sentence_error'] = 1 - np.mean([np.array_equal(label, prediction)
for label, prediction in zip(labels, predictions)])
epoch_data[stream]['edit_distance'] = np.mean([ComputeCTCStats._edit_distance(label, prediction)
for label, prediction in zip(labels, predictions)])
super().after_epoch(epoch_data=epoch_data, **kwargs)
[docs]class ComputeCTCStatsWithEOS(ComputeCTCStats):
"""Same as :py:class:`ComputeCTCStats` but use :py:func:`decode_ctc_prediction_with_eos` for decoding."""
[docs] def __init__(self, **kwargs):
"""Create new ComputeCTCStatsWithEOS hook."""
super().__init__(decode_fn=(lambda x: decode_ctc_prediction_with_eos(x)[0]), **kwargs)