Summary by CodyWild 1 month ago
This paper was a real delight to read, and even though I’m summarizing it here, I’d really encourage you, if you’re reading this, to read the paper itself, since I found it to be unusually clearly written. It tackles the problem of understanding how features of loss functions - these integral, yet arcane, objects defined in millions of parameter-dimensions - impact model performance. Loss function analysis is generally a difficult area, since the number of dimensions and number of points needed to evaluate to calculate loss are both so high. The latter presents computational challenges, the former ones of understanding: human brains and many-dimensional spaces are not a good fit. Overall, this paper contributes by 1) arguing for a new way of visualizing loss functions, 2) demonstrating how and in what cases “flatness” of loss function contributes to performance and trainability, and 3))
The authors review a few historically common ways of visualizing loss functions, before introducing their variant. The simplest, one-dimensional visualization technique, 1-D Linear Interpolation, works by taking two parameter settings (say, a random initialization, and the final network minimum), and smoothly interpolating between the two, by taking a convex combination mediated by alpha. Then, you can plot the value of the loss at all of these parameter configurations as a function of alpha. If you want to plot in 2D, with a contour plot, you can do so in a pretty similar manner, by picking two random “direction vectors” of the same size as the parameter vector, and then adding amounts of those directions, weighted by alpha and beta, to your starting point. These random directions become your axes, and you get a snapshot of the change in your loss function as you move along them.
The authors then make the observation that these techniques can’t natively be used to compare two different models, if the parameters of those models are on different scales. If you take a neural net, multiply one layer by 10, and then divide the next layer by 10, you’ve essentially done a no-op that won’t impact the outcome. However, if you’re moving by a fixed amount along your random direction in parameter space, you’ll have to move much farther to go the commensurate amount of distance in the network that’s been multiplied by 10. To address this problem, they suggest a simple fix: after you’ve selected each of your random directions, scale the value in each direction vector by the norm of the filter that corresponds to that value. This gets rid of the sensitivity of your plots to the scale of weights. (One thing I admit I’m a little confused by here is the fact that each value in the direction vector corresponds to a filter, rather than to a weight; I would have natively thought theta, and all the direction vectors, are of length number-of-model-parameters, and each value is a single weight. I think I still broadly grasp the intuition, but I’d value having a better sense of this).
To demonstrate the value of their normalization system, they compare the interpolation plots for a model with small and large batch size, with and without weight decay. Small batches are known to increase flatness of the loss function around the eventual minimum, which seems co-occurrent with good generalization results. And, that bears out in the original model’s linear interpolation (figs a, b, c), where the small model has the wider solution basin, and also better performance. However, once weight decay is applied (figs d, e, f), the small-batch basin appears to shrink to be very narrow, although small-batch still has dominant performance. At first glance, this would seem to be a contradiction of the “flatter solutions mean more generalization” rule.
https://i.imgur.com/V0H13kK.png
But this is just because weight decay hits smaller models more strongly, because they have more distinct updates during which they apply the weight decay penalty. This means that when weight decay is applied, the overall scale of weights in the small-batch network is lower, and so it’s solution looked “sharp” when plotted on the same weight scale as the large-batch network. When normalization was used, this effect by and large went away, and you once again saw higher performance with flatter loss functions. (batch size and performance with and without weight decay, shown normalized below)
https://i.imgur.com/vEUIgo0.png
A few other, more scattered observations from the paper:
- I’ve heard explanations of skip connections in terms of “giving the model shorter gradient paths between parameters and output,” but haven’t really seen an argument for why skip connections lead to smoother loss functions, even they empirically seem to
https://i.imgur.com/g3QqRzh.png
- The authors also devise a technique for visualizing the change in loss function along the trajectory taken by the optimization algorithm, so that different ones can be compared. The main problem in previous methods for this has been that optimization trajectories happen in a low-dimensional manifold within parameter space, so if you just randomly select directions, you won’t see any interesting movement along the trajectory. To fix this, they choose as their axes the principal components you get from making a matrix out of the parameter values at each epoch: this prioritizes the parameters that had the most variance throughout training.

more
less