Understanding Batch Normalization, Bjorck, Gomes, Selman, Weinberger; 2018 - Summary
author: timchen0618
score: 8 / 10

Background

Normalizing the input data to zero-mean and constant standard deviation has long been recognized as good practice in training neural networks.

Batch normalization (BN) extend this idea to the input of each layer and achieve immense success in various areas.

It enables faster training, larger learning rates, and higher testing accuracy (and better generalizability).

However, the reason for the improvement is to be debated.

(e.g. the paper that propose BN said it was solving “internal covariate shift”, but some suggested otherwise)

Core idea

The paper tried to answer two main questions:

Q: What Makes BN Beneficial?

A: It’s mainly because BN allows for larger learing rates, which encourages updates along flat regions and prevents the network from being trapped in a local minima.

Q: Why Need BN? Can’t We Just Use High Learning Rates Anyway?

A: No, we can’t. The paper showed empirically that we cannot use large learning rates in networks without BN, since they would lead to diverging loss.

The same cannot be observed for neural networks with batch normalization.

Experimental Results

Large Learning Rates Are the Key

As in the above figure, training with or without BN but with a small learning rate (lr=0.0001) yield similar performance. However, networks with BN has clear advantage when the learning rates are large. They conclude that enabling high learning rates is what makes BN beneficial.

Testing Accuracy

Why Large Learning Rates Are Good?

The upper-bound of estimated error of gradient step positively correlates to the learning rate \(\alpha\). Higher learning rate -> larger estimated error (or noise) -> better generalization. (Will not follow the gradient completely and will not overfit the training loss landscape)

Formula

(Here \(\nabla l(x)\) means the true gradient, and \(\nabla_{SGD}(x)\) means the step size taken by SGD.)

Divergence of NN Without BN

They basically showed that you cannot use large learning rate for networks without NN since it will lead to divergence.

Here they define

At first few updates, network without BN exhibit great relative loss (divergence) when the step size is greater than \(10^{-3}\). Same does not happen to network with BN.

Step Size

Other than training loss, they look at the output of activations without BN.

Activations of upper layers are extremely large, several orders of magnitude larger than lower ones. (Notice the scale) The authors claim that this implies divergence is caused by exploding activations. And BN fixes this problem by normalizing the input, preventing the large activations from propagating.

Activation Heat Map

As seen in the figure below, when BN is applied, the activations are indeed contained within a roughly constant value across layers. While the mean and standard deviation of channels grow exponentially with depth.

Channel Mean and Variance

Gradients With and Without BN

Below is the gradient of the final output layer, and larger gradients are in yellow. The authors showed with this plot that gradient updates in the same batch (the same row in the plot) are roughly towards the same direction, adding up to a greater absolute value. This result in same and mostly wrong predictions in the first few steps and larger gradients to correct them later. (see the yellow ones in the left figure)

Gradient Heat Map

What interesting variants are explored?

Random Initialization

They conducted experiments that showed the phenomenon of exploding gradients in networks without BN are caused by random initialization.

TL;DR