Understanding Batch Normalization, Bjorck, Gomes, Selman, Weinberger; 2018 - Summary
author: dalmeraz
score: 5 / 10

Goal:

Clarify vague intuitions about batch normalization through experimentation.

Background

Normalizing input data is good and batch normalization extends this concept to intermediate layers. Doing this on batches allows for significant speed up and produces better results, however no one understands fully why this works.

Batch Norm and Learning Rate

Activation and gradients in deep networks without batch normalization are heavy tailed which causes exploding means and variences within the model, specifically in later layers. Through analysis it is found that this “exploding” causes diverging loss.

This heavy tail was visualized through the following plot which shows the magnitude of gradients when first starting to train a model:

The solution to this is using small learning rates, however using small learning rates causes the issues of:

Batch norm does not have a need to use small learning rates since it “corrects” values and thus avoids the issues of small learning rates.

This theory was further tested with several models that varied with batch norm and no batch norm.

Test details:

The model achieves the same accuracy with and without batch normalization in low learning rates, however since we can use high learning rates with batch normalization, we are able to get better results. Thus, the paper argues that batch normalization’s reason for success is due to the power of being able to use bigger learning rates. They further elaborate on this theoretically by speaking about a gradient’s upper bounds of noise when given by SGD:

Where:

The key principle here is that an increase in the learning rate also increases the noise generated by our model. This is important because it has previously been explored that noise within SGD is important in the role of regularizing networks. So in summary, higher learning rates:

As we look at gradients we also see that changes in step size causes divergences (indicated by the increase of loss) more aggressively without batch norm.

This divergence in result causes the neural networks without batch normalization to have means and variances that explode in later layers. Batch normalization fixes this by correcting the means and variances and not allowing them to explode. It was also empirically analyzed that this effect of batch normalization is so crucial to gradients that even just adding batch normalization to the final layer of the model captures 2/3rds of the overall batch normalization improvement.

Random intialization’s Effect on Gradient Explosions

The paper also explores that weaknesses in models that are fixed by batch normalization is initialization. If we multiply M, N*N gaussian matrices, we get a “blown up” singular value distribution as we reach the origin, which only becomes more aggressive as we increase M. This can be seen as emulating a neural network’s weight matrices, and can show how this causes gradient explosion as networks become deeper:

Since the proportion between large and singular values increases drastically with depth, this strongly encourages models to undergo gradient explosion if there is no batch normalization. batch normalization fixes this issues by resetting the scale of values as matrix multiplications are looped. This is supported by the correlation that smaller Resnet models can use larger learning rates.

TL;DR