How to train Your ResNet is a series of blog posts by David Page and colleagues at Myrtle.ai that I've really enjoyed.

Over eight blog posts they describe how they managed to train a custom ResNet to 94% test accuracy on CIFAR10 in 26 seconds on a single GPU! To this day, the winning approach to this challenge on DAWNBench is based on the collection of tricks described in the 8th post of the series.

In this blog post, I'll summarize the main findings from each of the eight articles.

Beyond the technical details, I really like how the authors describe their process, that they keep asking why methods work rather than just applying them and their style of writing in general – a great read.

The chapters of the series are:

  1. The baseline (297 seconds)
  2. Increasing the mini-batch size (256 seconds)
  3. Profiling and Regularisation (154 seconds)
  4. Simplifying the Architecture (79 seconds)
  5. How hard is hyper-parameter tuning really? (theoretical study, not directly focused on acceleration)
  6. Weight decay and learning rate dynamics (theoretical study, not directly focused on acceleration)
  7. Batch norm does reduce internal covariate shift (theoretical study, not directly focused on acceleration)
  8. The famous bag of tricks (26 seconds)

Here are my highlights from each of the posts:

1. The Baseline

In part 1, they begin with the then-leading baseline by Ben Johnson (356 seconds; main differentiating characteristics: ResNet18, 1Cycle learning rate policy, mixed-precision training, similar to this).

  • They remove an architectural redundancy (down to 323 seconds).
  • They do image preprocessing once and store the results, instead of repeating this in every epoch (down to 308 seconds).
  • They batch calls made to random number generators during data augmentation and revert to doing data augmentation on a single process to avoid the parallelization overhead, brining the training time to 94% down to 297 seconds.

2. Increasing the mini-batch size

In part 2, they increase the mini-batch size from 128 to 512 and it works: the training time is reduced to 256 seconds.

The most interesting part of this post, however, is the discussion of why this increase of the batch size (while keeping approximately the same learning rate schedule that was hand-picked for the 128 mini-batch setting) does not cause more problems. Ultimately the authors suggest, that in the 128 mini-batch setting, forgetfulness, dominates the training. That is, if the learning rate were increased, parameter updates early in an epoch might be essentially "cancelled" by updates later in the same epoch. The network would forget the earlier samples. In the 512 mini-batch setting, however, curvature effects dominate.
One conclusion from this is that if you want train with very large learning rates, you'll have to scale your mini-batch size with the number of samples in your data set to avoid forgetfulness.

3. Profiling and Regularisation

In part 3, the authors fix a mixed-precision issue that should no longer be an issue if you're working with PyTorch AMP (down to 186s). Profiling their setup they find the following distribution over operations:

Note that the convolutions make up the vast majority of the run time and that batch norm also takes a significant amount of time.


Additionally, they apply CutOut regularisation in which a random 8x8 square of the input is zeroed out. They also accelerate the increase of the learning rate during 1Cycle schedule. In combination, these modifications allows them to increase the mini-batch size to 768, allowing them to train for fewer epochs and reducing their training time to 154 seconds.

4. Simplifying the Architecture

In part 4, the authors adapt the original model architecture in a number of ways. I found the process the authors use to identify a minimal architecture the most interesting bit of this post – rather than the specific architectural changes:

  1. Find a small subset of the baseline architecture that performs reasonably well by itself.
  2. Modify the architecture to improve its performance as a stand-alone network.
  3. Once you hit diminishing returns, add back additional layers.

Steps 1) and 2) look something like this:

I'd definitely recommend that you read the whole post. These architectural modifications bring the training time down to 79 seconds.

5. How hard is hyper-parameter tuning really?

In part 5 of the series, the authors tackle hyper-parameter tuning and how it can be made less expensive. This is also the first of three posts which less directly aim to accelerate training and focus instead on understanding some of the dynamics of training.

The authors show empirically that, in their setting, there are a number of flat directions in hyper-parameters space: The test accuracy essentially stays constant if either of \(\frac{\lambda}{N}\), \(\frac{\lambda}{1- \rho}\), \(\lambda\alpha\) are held constant,  \(\lambda\) is the learning rate, \(\N\) the batch size and \(\rho\) the momentum. The first of these is known as the Linear Scaling Rule: when the mini-batch size is multiplied by k, multiply the learning rate by k.

Based on this they suggest that simple coordinate descent will quickly find good hyper-parameters in this setting. Care has to be taken, however, to align the flat directions with the axes for this to actually be fast.

They then reason about why this makes sense. Summarising their arguments in the next post of the series:

The first explained the observed weak dependence of training trajectories on the choice of learning rate \(\lambda\) and batch size \(N\) when the ratio \(\frac{\lambda}{N}\) was held fixed. A similar argument applied to momentum \(\rho\)  when \(\frac{\lambda}{1- \rho}\)was fixed. Both relied on a simple matching of first order terms in the weight update equations of SGD + momentum.
A third argument, regarding the weight decay \(\alpha\), was rather different and distinctly not first order in nature. We argued that for weights with a scaling symmetry, the gradient with respect to the weight norm vanishes and a second order effect (Pythagorean growth from orthogonal gradient updates) appears at leading order in the dynamics. What is worse, we argued that, although weight norms are irrelevant to the forward computation of the network, they determine the effective learning rate for the other parameters.

Further on scaling symmetry:

For weights with a scaling symmetry – which includes all the convolutional layers of our network because of subsequent batch normalisation – gradients are orthogonal to weights. As a result, gradient updates lead to an increase in weight norm whilst weight decay leads to a decrease.

Here's their main takeaway from the post:

Weight decay in the presence of batch normalisation acts as a stable control mechanism on the effective step size. If gradient updates get too small, weight decay shrinks the weights and boosts gradient step sizes until equilibrium is restored. The reverse happens when gradient updates grow too large.

6. Weight decay and learning rate dynamics

This post is the second of three more theoretically oriented posts, expanding on some of the arguments of the previous post by studying the learning rate dynamics more closely.

Their main result is that SGD with momentum and weight decay actually displays dynamics that are quite similar to those of Layer-wise Adaptive Rate Scaling (LARS). They also suggest that SGD with momentum provides a scaling of update steps sizes that are scale invariant – step sizes that work across layers of different initializations, sizes and positions within the model.

7. Batch norm does reduce internal covariate shift

This might be my favourite post of the series! The motivation for the post was to better understand why the authors couldn't optimise away the expensive batch norm layers without making their model much worse.

The authors summarise the post as follows (my emphasis):

First we reviewed the result that, in the absence of batch norm, deep networks with standard initialisations tend to produce ‘bad’, almost constant output distributions in which the inputs are ignored. We discussed how batch norm prevents this and that this can also be fixed at initialisation by using ‘frozen’ batch norm.

Next we turned to training and showed that ‘bad’ configurations exist near to the parameter configurations traversed during a typical training run. We explicitly found such nearby configurations by computing gradients of the means of the per channel output distributions. Later we extended this by computing gradients of the variance and skew of the per channel output distributions, arguing that changing these higher order statistics would also lead to a large increase in loss. We explained how batch norm, by preventing the propagation of changes to the statistics of internal layer distributions, greatly reduces the gradients in these directions.

Finally, we investigated the leading eigenvalues and eigenvectors of the Hessian of the loss, which account for the instability of SGD, and showed that the leading eigenvectors lie primarily in the low dimensional subspaces of gradients of output statistics that we computed before. The interpretation of this fact is that the cause of instability is indeed the highly curved loss landscape that is produced by failing to enforce appropriate constraints on the moments of the output distribution.

I strongly recommend that you read the whole post.

The post also reminded my of the paper Gradient descent happens in a tiny subspace by Gur-Ari et al..

8. The famous bag of tricks

A set of optimizations that reduces the training time from 70 seconds to 26 seconds. Code for these can be found here.

  1. Doing preprocessing on the GPU rather than the CPU.
  2. Interesting observation: max-pooling commutes with monotonically increasing functions such as ReLU, that is  max_pool(f(x)) == f(max_pool(x)) where we would expect the RHS to be more efficient than the LHS. In their experiments, they do indeed get a speed up by moving max-pooling before the activation function. They chose to take a small drop in accuracy in exchange for a speed-up by also moving the max-pooling to before the batchnorm.
  3. Applying label smoothing.
  4. Using the CELU activation (a smoothed version of ReLU).
  5. Using ghost batch norm: The regularising effect of batch norm seems to be best-balanced at a batch size of 32, but the authors need to use larger batch sizes to accelerate training. Ghost batch norm, then, applies batch norm to subsets of a given batch to get the best of both worlds.
  6. Freezing batch norm scales: Observing that the batch norm scale parameters does not seem to learn much throughout training, the authors freeze the scale parameters and correspondingly adjust the CELU \(\alpha\) and batch norm learning rate and weight decay parameters.
  7. Using Input patch whitening: Basically PCA whitening of the input using a fixed 3x3 convolution as the first layer.
  8. Applying exponential moving averages. Citing directly from the post: "Parameter averaging methods allow training to continue at a higher rate whilst potentially approaching minima along noisy or oscillatory directions by averaging over multiple iterates." This increases their model's accuracy, making it possible to train for one epoch less than before.
  9. Using test-time augmentation:  "present both the input image and its horizontally flipped version and come to a consensus by averaging network outputs for the two versions, thus guaranteeing invariance." This gives them an accuracy boost, allowing them to train for 2 epochs less than before.

All of the steps are also implemented in this colab written by David Page.

Conclusions

So what can we learn from all of this?

  1. Try to find a minimal model that will solve your model.
  2. Profile your code.
  3. Increase your batch size.
  4. Look for redundant or duplicate computations.
  5. Don't forget to ask why things work!