Recurrent Batch NormalizationRecurrent Batch NormalizationCooijmans, Tim and Ballas, Nicolas and Laurent, César and Courville, Aaron2016
Paper summaryabhshkdzThis paper presents a re-parameterization of the LSTM to successfully apply batch normalization, which results in faster convergence and improved generalization on a several sequential tasks. Main contributions:
- Batch normalization is applied to the input to hidden and hidden to hidden projections.
- Separate statistics are maintained for each timestep, estimated over each minibatch during training and over the whole dataset during test.
- For generalization to longer sequences during test time, population statistics of time T\_max are used for all time steps beyond it.
- The cell state is left untouched so as not to hinder the gradient flow.
- Proper initialization of batch normalization parameters to avoid vanishing gradients.
- They plot norm of gradient of loss wrt hidden state at different time steps for different BN variance initializations. High variance ($\gamma = 1$) causes gradients to die quickly by driving activations to the saturation region.
- Initializing BN variance to 0.1 works well.
- Simple idea, the authors finally got it to work. Proper initialization of BN parameters and maintaining separate estimates for each time step play a key role.
## Weaknesses / Notes
- It would be useful in practice to put down a proper formulation for using batch normalization with variable-length training sequences.
This paper describes how to apply the idea of batch normalization (BN) successfully to recurrent neural networks, specifically to LSTM networks. The technique involves the 3 following ideas:
**1) Careful initialization of the BN scaling parameter.** While standard practice is to initialize it to 1 (to have unit variance), they show that this situation creates problems with the gradient flow through time, which vanishes quickly. A value around 0.1 (used in the experiments) preserves gradient flow much better.
**2) Separate BN for the "hiddens to hiddens pre-activation and for the "inputs to hiddens" pre-activation.** In other words, 2 separate BN operators are applied on each contributions to the pre-activation, before summing and passing through the tanh and sigmoid non-linearities.
**3) Use of largest time-step BN statistics for longer test-time sequences.** Indeed, one issue with applying BN to RNNs is that if the input sequences have varying length, and if one uses per-time-step mean/variance statistics in the BN transformation (which is the natural thing to do), it hasn't been clear how do deal with the last time steps of longer sequences seen at test time, for which BN has no statistics from the training set. The paper shows evidence that the pre-activation statistics tend to gradually converge to stationary values over time steps, which supports the idea of simply using the training set's last time step statistics.
Among these ideas, I believe the most impactful idea is 1). The papers mentions towards the end that improper initialization of the BN scaling parameter probably explains previous failed attempts to apply BN to recurrent networks.
Experiments on 4 datasets confirms the method's success.
**My two cents**
This is an excellent development for LSTMs. BN has had an important impact on our success in training deep neural networks, and this approach might very well have a similar impact on the success of LSTMs in practice.