How to choose cross-entropy loss in TensorFlow?

Each Answer to this Q is separated by one/two green lines.

Classification problems, such as logistic regression or multinomial
logistic regression, optimize a cross-entropy loss.
Normally, the cross-entropy layer follows the softmax layer,
which produces probability distribution.

In tensorflow, there are at least a dozen of different cross-entropy loss functions:

  • tf.losses.softmax_cross_entropy
  • tf.losses.sparse_softmax_cross_entropy
  • tf.losses.sigmoid_cross_entropy
  • tf.contrib.losses.softmax_cross_entropy
  • tf.contrib.losses.sigmoid_cross_entropy
  • tf.nn.softmax_cross_entropy_with_logits
  • tf.nn.sigmoid_cross_entropy_with_logits

Which one works only for binary classification and which are suitable for multi-class problems? When should you use sigmoid instead of softmax? How are sparse functions different from others and why is it only softmax?

Related (more math-oriented) discussion: What are the differences between all these cross-entropy losses in Keras and TensorFlow?.

Preliminary facts

  • In functional sense, the sigmoid is a partial case of the softmax function, when the number of classes equals 2. Both of them do the same operation: transform the logits (see below) to probabilities.

    In simple binary classification, there’s no big difference between the two,
    however in case of multinomial classification, sigmoid allows to deal
    with non-exclusive labels (a.k.a. multi-labels), while softmax deals
    with exclusive classes (see below).

  • A logit (also called a score) is a raw unscaled value associated with a class, before computing the probability. In terms of neural network architecture, this means that a logit is an output of a dense (fully-connected) layer.

    Tensorflow naming is a bit strange: all of the functions below accept logits, not probabilities, and apply the transformation themselves (which is simply more efficient).

Sigmoid functions family

As stated earlier, sigmoid loss function is for binary classification.
But tensorflow functions are more general and allow to do
multi-label classification, when the classes are independent.
In other words, tf.nn.sigmoid_cross_entropy_with_logits solves N
binary classifications at once.

The labels must be one-hot encoded or can contain soft class probabilities.

tf.losses.sigmoid_cross_entropy in addition allows to set the in-batch weights,
i.e. make some examples more important than others.
tf.nn.weighted_cross_entropy_with_logits allows to set class weights
(remember, the classification is binary), i.e. make positive errors larger than
negative errors. This is useful when the training data is unbalanced.

Softmax functions family

These loss functions should be used for multinomial mutually exclusive classification,
i.e. pick one out of N classes. Also applicable when N = 2.

The labels must be one-hot encoded or can contain soft class probabilities:
a particular example can belong to class A with 50% probability and class B
with 50% probability. Note that strictly speaking it doesn’t mean that
it belongs to both classes, but one can interpret the probabilities this way.

Just like in sigmoid family, tf.losses.softmax_cross_entropy allows
to set the in-batch weights, i.e. make some examples more important than others.
As far as I know, as of tensorflow 1.3, there’s no built-in way to set class weights.

[UPD] In tensorflow 1.5, v2 version was introduced and the original softmax_cross_entropy_with_logits loss got deprecated. The only difference between them is that in a newer version, backpropagation happens into both logits and labels (here’s a discussion why this may be useful).

Sparse functions family

Like ordinary softmax above, these loss functions should be used for
multinomial mutually exclusive classification, i.e. pick one out of N classes.
The difference is in labels encoding: the classes are specified as integers (class index),
not one-hot vectors. Obviously, this doesn’t allow soft classes, but it
can save some memory when there are thousands or millions of classes.
However, note that logits argument must still contain logits per each class,
thus it consumes at least [batch_size, classes] memory.

Like above, tf.losses version has a weights argument which allows
to set the in-batch weights.

Sampled softmax functions family

These functions provide another alternative for dealing with huge number of classes.
Instead of computing and comparing an exact probability distribution, they compute
a loss estimate from a random sample.

The arguments weights and biases specify a separate fully-connected layer that
is used to compute the logits for a chosen sample.

Like above, labels are not one-hot encoded, but have the shape [batch_size, num_true].

Sampled functions are only suitable for training. In test time, it’s recommended to
use a standard softmax loss (either sparse or one-hot) to get an actual distribution.

Another alternative loss is tf.nn.nce_loss, which performs noise-contrastive estimation (if you’re interested, see this very detailed discussion). I’ve included this function to the softmax family, because NCE guarantees approximation to softmax in the limit.

However, for version 1.5, softmax_cross_entropy_with_logits_v2 must be used instead, while using its argument with the argument key=..., for example

softmax_cross_entropy_with_logits_v2(_sentinel=None, labels=y,
                                    logits=my_prediction, dim=-1, name=None)

While it is great that the accepted answer contains lot more info than what is asked, I felt that sharing a few generic thumb rules will make the answer more compact and intuitive:

  • There is just one real loss function. This is cross-entropy (CE). For a special case of a binary classification, this loss is called binary CE (note that the formula does not change) and for non-binary or multi-class situations the same is called categorical CE (CCE). Sparse functions are a special case of categorical CE where the expected values are not one-hot encoded but is an integer
  • We have the softmax formula which is an activation for multi-class scenario. For binary scenario, same formula is given a special name – sigmoid activation
  • Because there are sometimes numerical instabilities (for extreme values) when dealing with logarithmic functions, TF recommends combining the activation layer and the loss layer into one single function. This combined function is numerically more stable. TF provides these combined functions and they are suffixed with _with_logits

With this, let us now approach some situations. Say there is a simple binary classification problem – Is a cat present or not in the image? What is the choice of activation and loss function? It will be a sigmoid activation and a (binary)CE. So one could use sigmoid_cross_entropy or more preferably sigmoid_cross_entropy_with_logits. The latter combines the activation and the loss function and is supposed to be numerically stable.

How about a multi-class classification. Say we want to know if a cat or a dog or a donkey is present in the image. What is the choice of activation and loss function? It will be a softmax activation and a (categorical)CE. So one could use softmax_cross_entropy or more preferably softmax_cross_entropy_with_logits. We assume that the expected value is one-hot encoded (100 or 010 or 001). If (for some weird reason), this is not the case and the expected value is an integer (either 1 or 2 or 3) you could use the ‘sparse’ counterparts of the above functions.

There could be a third case. We could have a multi-label classification. So there could be a dog and a cat in the same image. How do we handle this? The trick here is to treat this situation as a multiple binary classification problems – basically cat or no cat / dog or no dog and donkey or no donkey. Find out the loss for each of the 3 (binary classifications) and then add them up. So essentially this boils down to using the sigmoid_cross_entropy_with_logits loss.

This answers the 3 specific questions you have asked. The functions shared above are all that are needed. You can ignore the tf.contrib family which is deprecated and should not be used.


The answers/resolutions are collected from stackoverflow, are licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0 .