Source code for blocks.utils.tf
import tensorflow as tf
from . import raise_moved
__all__ = ['get_balancing_weights']
def cnn_encoder(*_, **__) -> None:
raise_moved('cnn_encoder', 'emloop_tensorflow.models.conv')
def cnn_autoencoder(*_, **__) -> None:
raise_moved('cnn_encoder', 'emloop_tensorflow.models.conv')
[docs]def get_balancing_weights(masks: tf.Tensor, correction: float=1e-3) -> tf.Tensor:
"""
Compute weights balancing zeros and ones in the given masks tensor.
.. warning::
The masks tensor must be binary - i.e., contain only ones and zeros!
:param masks: zero/one masks to be balanced
:param correction: correction parameter to avoid divergence of weights
:return: weights balancing the given mask (having the same shape)
"""
positives = tf.cast(masks, tf.float32)
negatives = tf.ones_like(positives) - positives
# Make the impact of the positive and negative pixels equal.
# Example: if every 10th pixel is positive, multiply the weight of positive pixels by 9.
positives_weight = positives * tf.reduce_mean(negatives) / (tf.reduce_mean(positives) + correction)
negatives_weight = negatives
# Normalize the weights, so that the weight for a pixel is 1 on average (to avoid changing the loss).
norm_coef = tf.reduce_mean(positives_weight + negatives_weight)
positives_weight /= norm_coef
negatives_weight /= norm_coef
return positives_weight * positives + negatives_weight * negatives