"""
Module with binary segmentation auto-encoder model.
"""
import logging
from typing import Mapping, Tuple, Sequence, Union
import tensorflow as tf
import tensorflow.contrib.slim as slim
import emloop_tensorflow as eltf
from ..utils import get_balancing_weights
[docs]class BinarySegmentation(eltf.BaseModel):
"""
Configurable binary segmentation auto-encoder with skip-connections and stuff.
The segmentation works in parallel with multiple masks, thus, the following outputs are named accordingly.
**Inputs**
- ``images`` (4-dim tensor NHWC) scaled to 0-255
- ``<name>`` (3-dim tensor NHW) scaled to 0/255 for each <name> in ``mask_names``
**Outputs**
- ``<name>_probabilities`` and ``<name>_predictions`` (3-dim tensor NHW) scaled to 0-1 and 0/1 respectively
for each <name> in ``mask_names``
- ``loss`` and ``<name>_pixel_loss`` optimization targets for each <name> in ``mask_names``
- ``<name>_f1``, ``<name>_recall`` and ``<name>_precision`` performance measures for each <name>
in ``mask_names``
**Requirements**
- The dataset has to provide ``img_shape()`` method returning a 2- or 3- tuple or list with the image shape.
Only the channel dimension needs to be specified, other values are ignored.
.. code-block:: yaml
:caption: example usage in config
model:
name: SegmentationNet
class: blocks.models.BinarySegmentation
input_name: images
mask_names: [masks, masks_eroded]
architecture:
encoder_config: [16c3, 16c3, 16c3, 16c3, mp2,
32c3, 32c3, 32c3, 32c3, mp2,
64c3, 64c3, 64c3]
use_bn: true
use_ln: false
skip_connections: true
l2: 0.00001
balance_loss: false
optimizer:
class: AdamOptimizer
learning_rate: 0.0001
inputs: [images, masks, masks_eroded]
outputs: [loss,
masks_predictions, masks_probabilities, masks_f1,
masks_eroded_predictions, masks_eroded_probabilities, masks_eroded_f1]
"""
[docs] def _create_model(self,
architecture: Mapping,
loss_type: str='mse',
balance_loss: Union[bool, str, Mapping[str, str]]=False,
l2: float=0.0,
input_name: str='images',
mask_names: Sequence[str]=('masks',),
final_kernel: Tuple[int, int]=(5, 5)) -> None:
"""
Create new binary segmentation auto-encoder.
:param architecture: architecture configuration as accepted by ``emloop.models.conv.cnn_autoencoder``
:param loss_type: loss type (either ``mse``, ``l1``, or ``xtropy``)
:param balance_loss: 0/1 pixel loss balancing. If false, all pixel losses will remain untouched. If true,
each pixel loss will be balanced according to the corresponding mask. If string, all
pixel losses will be balanced according to the mask identified by the string. If mapping,
each pixel loss ``l`` will balaned by ``balance_loss[l]``; pixel losses not present in
the mapping will not be balanced at all.
:param l2: l2 weights regularization rate
:param input_name: stream source name providing the input images
:param mask_names: sequence of stream source names providing the target segmentations
:param final_kernel: kernel size of the final convolution
"""
# 1. inputs
img_shape = list(self._dataset.img_shape())
images = tf.placeholder(dtype=tf.float32, shape=[None,None,None]+img_shape[2:], name=input_name) / 255
masks = {mask_name: tf.cast(
tf.round(tf.placeholder(dtype=tf.int32, shape=(None,None,None), name=mask_name) / 255),
tf.int32)
for mask_name in mask_names}
# If input shape is without channel, create last channel of size 1
if len(img_shape) == 2:
images = tf.expand_dims(images, -1)
# 2. auto-encoder, output probabilities and pixel losses
logging.info('\tBuilding auto-encoder')
if l2 > 0:
logging.info('\tUsing weights l2 regularization: %s', l2)
with slim.arg_scope([slim.conv2d], weights_regularizer=slim.l2_regularizer(l2),
weights_initializer=slim.variance_scaling_initializer(1/2., 'FAN_AVG')):
with tf.variable_scope('CNN_autoencoder'):
_, net = eltf.models.cnn_autoencoder(x=images, is_training=self.is_training, **architecture)
with tf.variable_scope('CNN_final'):
if loss_type == 'mse':
logging.info('\tUsing MSE loss')
net = slim.conv2d(net, len(mask_names), final_kernel, activation_fn=tf.nn.sigmoid,
scope='cnn_final_inner')
probabilities = {mask_name: tf.identity(net[:, :, :, i],
name='{}_probabilities'.format(mask_name))
for i, mask_name in enumerate(mask_names)}
pixel_losses = {mask_name: tf.identity(
tf.losses.mean_squared_error(labels=masks[mask_name],
predictions=probabilities[mask_name],
reduction=tf.losses.Reduction.NONE),
name='{}_pixel_loss'.format(mask_name))
for mask_name in mask_names}
elif loss_type == 'l1':
logging.info('\tUsing L1 loss')
net = slim.conv2d(net, len(mask_names), final_kernel, activation_fn=tf.nn.sigmoid,
scope='cnn_final_inner')
probabilities = {mask_name: tf.identity(net[:, :, :, i],
name='{}_probabilities'.format(mask_name))
for i, mask_name in enumerate(mask_names)}
pixel_losses = {mask_name: tf.identity(
tf.losses.absolute_difference(labels=masks[mask_name],
predictions=probabilities[mask_name],
reduction=tf.losses.Reduction.NONE),
name='{}_pixel_loss'.format(mask_name))
for mask_name in mask_names}
elif loss_type == 'xtropy':
logging.info('\tUsing cross-entropy loss')
net = slim.conv2d(net, 2 * len(mask_names), final_kernel, activation_fn=tf.identity,
scope='cnn_final_inner')
probabilities = {mask_name: tf.identity(tf.nn.softmax(net[:, :, :, 2*i:2*(i+1)])[:, :, :, -1],
name='{}_probabilities'.format(mask_name))
for i, mask_name in enumerate(mask_names)}
pixel_losses = {mask_name: tf.identity(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=masks[mask_name],
logits=net[:, :, :, 2*i:2*(i+1)]),
name='{}_pixel_loss'.format(mask_name))
for i, mask_name in enumerate(mask_names)}
else:
raise ValueError('Unrecognized loss_type `{}`. Recognized types are `mse`, `l1`, and `xtropy`'
.format(loss_type))
probabilities = {mask_name: tf.identity(probabilities[mask_name], name='{}_probabilities'.format(mask_name))
for mask_name in mask_names}
# 3. pixel loss balancing
if isinstance(balance_loss, bool): # make sure `balance_loss=1` is an invalid option
if balance_loss:
logging.info('Balancing all pixel losses by their corresponding masks.')
for mask_name in mask_names:
pixel_losses[mask_name] *= get_balancing_weights(masks[mask_name])
else:
logging.info('The balancing of all pixel losses is turned off.')
elif isinstance(balance_loss, str):
logging.info('Balancing all pixel losses by a single mask: `{}`.'.format(balance_loss))
assert balance_loss in mask_names, 'Mask `{}` not found in mask_names (`{}`). ' \
'Balancing aborted.'.format(balance_loss, mask_names)
for mask_name in mask_names:
pixel_losses[mask_name] *= get_balancing_weights(masks[balance_loss])
elif isinstance(balance_loss, Mapping):
logging.info('Balancing the pixel losses by mapping: `{}`.'.format(balance_loss))
for mask_name in mask_names:
if mask_name in balance_loss:
balance_mask_name = balance_loss[mask_name]
logging.debug('\tBalancing mask `{}` by `{}`.'.format(mask_name, balance_mask_name))
assert balance_mask_name in mask_names, 'Mask `{}` not found in mask_names (`{}`). ' \
'Balancing aborted.'.format(balance_loss[mask_name],
mask_names)
pixel_losses[mask_name] *= get_balancing_weights(masks[balance_mask_name])
else:
logging.warning('\tCannot balance mask `{}`; missing entry in the mapping.'.format(mask_name))
else:
raise ValueError('Unsupported balance_loss type `{}`. Supported types are {{bool, str, Mapping[str, str]}}'
.format(type(balance_loss)))
# 4. final loss aggregation
pixel_losses_concated = tf.stack([pixel_losses[mask_name] for mask_name in mask_names], axis=0)
pixel_loss = tf.reduce_mean(pixel_losses_concated, axis=0, name='pixel_loss')
tf.reduce_mean(pixel_loss, axis=(1, 2), name='loss')
# 5. predictions, statistics
predictions = {mask_name: tf.cast(tf.round(probabilities[mask_name], name='{}_predictions'.format(mask_name)), tf.int32)
for mask_name in mask_names}
for mask_name in mask_names:
eltf.bin_stats(predictions=tf.layers.flatten(predictions[mask_name]),
labels=tf.layers.flatten(masks[mask_name]),
prefix=mask_name)