Source code for blocks.models.segmentation

"""
Module with binary segmentation auto-encoder model.
"""
import logging
from typing import Mapping, Tuple, Sequence, Union

import tensorflow as tf
import tensorflow.contrib.slim as slim
import emloop_tensorflow as eltf

from ..utils import get_balancing_weights


[docs]class BinarySegmentation(eltf.BaseModel): """ Configurable binary segmentation auto-encoder with skip-connections and stuff. The segmentation works in parallel with multiple masks, thus, the following outputs are named accordingly. **Inputs** - ``images`` (4-dim tensor NHWC) scaled to 0-255 - ``<name>`` (3-dim tensor NHW) scaled to 0/255 for each <name> in ``mask_names`` **Outputs** - ``<name>_probabilities`` and ``<name>_predictions`` (3-dim tensor NHW) scaled to 0-1 and 0/1 respectively for each <name> in ``mask_names`` - ``loss`` and ``<name>_pixel_loss`` optimization targets for each <name> in ``mask_names`` - ``<name>_f1``, ``<name>_recall`` and ``<name>_precision`` performance measures for each <name> in ``mask_names`` **Requirements** - The dataset has to provide ``img_shape()`` method returning a 2- or 3- tuple or list with the image shape. Only the channel dimension needs to be specified, other values are ignored. .. code-block:: yaml :caption: example usage in config model: name: SegmentationNet class: blocks.models.BinarySegmentation input_name: images mask_names: [masks, masks_eroded] architecture: encoder_config: [16c3, 16c3, 16c3, 16c3, mp2, 32c3, 32c3, 32c3, 32c3, mp2, 64c3, 64c3, 64c3] use_bn: true use_ln: false skip_connections: true l2: 0.00001 balance_loss: false optimizer: class: AdamOptimizer learning_rate: 0.0001 inputs: [images, masks, masks_eroded] outputs: [loss, masks_predictions, masks_probabilities, masks_f1, masks_eroded_predictions, masks_eroded_probabilities, masks_eroded_f1] """
[docs] def _create_model(self, architecture: Mapping, loss_type: str='mse', balance_loss: Union[bool, str, Mapping[str, str]]=False, l2: float=0.0, input_name: str='images', mask_names: Sequence[str]=('masks',), final_kernel: Tuple[int, int]=(5, 5)) -> None: """ Create new binary segmentation auto-encoder. :param architecture: architecture configuration as accepted by ``emloop.models.conv.cnn_autoencoder`` :param loss_type: loss type (either ``mse``, ``l1``, or ``xtropy``) :param balance_loss: 0/1 pixel loss balancing. If false, all pixel losses will remain untouched. If true, each pixel loss will be balanced according to the corresponding mask. If string, all pixel losses will be balanced according to the mask identified by the string. If mapping, each pixel loss ``l`` will balaned by ``balance_loss[l]``; pixel losses not present in the mapping will not be balanced at all. :param l2: l2 weights regularization rate :param input_name: stream source name providing the input images :param mask_names: sequence of stream source names providing the target segmentations :param final_kernel: kernel size of the final convolution """ # 1. inputs img_shape = list(self._dataset.img_shape()) images = tf.placeholder(dtype=tf.float32, shape=[None,None,None]+img_shape[2:], name=input_name) / 255 masks = {mask_name: tf.cast( tf.round(tf.placeholder(dtype=tf.int32, shape=(None,None,None), name=mask_name) / 255), tf.int32) for mask_name in mask_names} # If input shape is without channel, create last channel of size 1 if len(img_shape) == 2: images = tf.expand_dims(images, -1) # 2. auto-encoder, output probabilities and pixel losses logging.info('\tBuilding auto-encoder') if l2 > 0: logging.info('\tUsing weights l2 regularization: %s', l2) with slim.arg_scope([slim.conv2d], weights_regularizer=slim.l2_regularizer(l2), weights_initializer=slim.variance_scaling_initializer(1/2., 'FAN_AVG')): with tf.variable_scope('CNN_autoencoder'): _, net = eltf.models.cnn_autoencoder(x=images, is_training=self.is_training, **architecture) with tf.variable_scope('CNN_final'): if loss_type == 'mse': logging.info('\tUsing MSE loss') net = slim.conv2d(net, len(mask_names), final_kernel, activation_fn=tf.nn.sigmoid, scope='cnn_final_inner') probabilities = {mask_name: tf.identity(net[:, :, :, i], name='{}_probabilities'.format(mask_name)) for i, mask_name in enumerate(mask_names)} pixel_losses = {mask_name: tf.identity( tf.losses.mean_squared_error(labels=masks[mask_name], predictions=probabilities[mask_name], reduction=tf.losses.Reduction.NONE), name='{}_pixel_loss'.format(mask_name)) for mask_name in mask_names} elif loss_type == 'l1': logging.info('\tUsing L1 loss') net = slim.conv2d(net, len(mask_names), final_kernel, activation_fn=tf.nn.sigmoid, scope='cnn_final_inner') probabilities = {mask_name: tf.identity(net[:, :, :, i], name='{}_probabilities'.format(mask_name)) for i, mask_name in enumerate(mask_names)} pixel_losses = {mask_name: tf.identity( tf.losses.absolute_difference(labels=masks[mask_name], predictions=probabilities[mask_name], reduction=tf.losses.Reduction.NONE), name='{}_pixel_loss'.format(mask_name)) for mask_name in mask_names} elif loss_type == 'xtropy': logging.info('\tUsing cross-entropy loss') net = slim.conv2d(net, 2 * len(mask_names), final_kernel, activation_fn=tf.identity, scope='cnn_final_inner') probabilities = {mask_name: tf.identity(tf.nn.softmax(net[:, :, :, 2*i:2*(i+1)])[:, :, :, -1], name='{}_probabilities'.format(mask_name)) for i, mask_name in enumerate(mask_names)} pixel_losses = {mask_name: tf.identity( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=masks[mask_name], logits=net[:, :, :, 2*i:2*(i+1)]), name='{}_pixel_loss'.format(mask_name)) for i, mask_name in enumerate(mask_names)} else: raise ValueError('Unrecognized loss_type `{}`. Recognized types are `mse`, `l1`, and `xtropy`' .format(loss_type)) probabilities = {mask_name: tf.identity(probabilities[mask_name], name='{}_probabilities'.format(mask_name)) for mask_name in mask_names} # 3. pixel loss balancing if isinstance(balance_loss, bool): # make sure `balance_loss=1` is an invalid option if balance_loss: logging.info('Balancing all pixel losses by their corresponding masks.') for mask_name in mask_names: pixel_losses[mask_name] *= get_balancing_weights(masks[mask_name]) else: logging.info('The balancing of all pixel losses is turned off.') elif isinstance(balance_loss, str): logging.info('Balancing all pixel losses by a single mask: `{}`.'.format(balance_loss)) assert balance_loss in mask_names, 'Mask `{}` not found in mask_names (`{}`). ' \ 'Balancing aborted.'.format(balance_loss, mask_names) for mask_name in mask_names: pixel_losses[mask_name] *= get_balancing_weights(masks[balance_loss]) elif isinstance(balance_loss, Mapping): logging.info('Balancing the pixel losses by mapping: `{}`.'.format(balance_loss)) for mask_name in mask_names: if mask_name in balance_loss: balance_mask_name = balance_loss[mask_name] logging.debug('\tBalancing mask `{}` by `{}`.'.format(mask_name, balance_mask_name)) assert balance_mask_name in mask_names, 'Mask `{}` not found in mask_names (`{}`). ' \ 'Balancing aborted.'.format(balance_loss[mask_name], mask_names) pixel_losses[mask_name] *= get_balancing_weights(masks[balance_mask_name]) else: logging.warning('\tCannot balance mask `{}`; missing entry in the mapping.'.format(mask_name)) else: raise ValueError('Unsupported balance_loss type `{}`. Supported types are {{bool, str, Mapping[str, str]}}' .format(type(balance_loss))) # 4. final loss aggregation pixel_losses_concated = tf.stack([pixel_losses[mask_name] for mask_name in mask_names], axis=0) pixel_loss = tf.reduce_mean(pixel_losses_concated, axis=0, name='pixel_loss') tf.reduce_mean(pixel_loss, axis=(1, 2), name='loss') # 5. predictions, statistics predictions = {mask_name: tf.cast(tf.round(probabilities[mask_name], name='{}_predictions'.format(mask_name)), tf.int32) for mask_name in mask_names} for mask_name in mask_names: eltf.bin_stats(predictions=tf.layers.flatten(predictions[mask_name]), labels=tf.layers.flatten(masks[mask_name]), prefix=mask_name)