Introduction
Consider the general machine learning set-up. We have a class of models parameterized by (e.g. for linear regression,
would be the coefficients of the model). Each of these models takes in some input
and outputs a result
. We want to select the parameter
which minimizes some population loss:
where is the loss incurred for a single data point, and
is the population distribution for our data. Unfortunately we don’t know
and hence can’t evaluate
exactly. Instead, we often have a dataset
with which we can define the empirical loss
A viable approach is to select by minimizing the empirical loss: the hope here is that the empirical loss is a good approximation to the population loss.
For many modern ML models (especially overparameterized models), the loss functions are non-convex with multiple local and global minima. In this setup, it’s known that the parameter obtained by minimizing the empirical loss does not necessarily translate into small population loss. That is, good performance on the training set does not “generalize” well. In fact, many different
can give similar values of
but very different values of
.
“Flatness” and generalization performance
One thing that is emerging from the literature is a connection between the “flatness” of minima and generalization performance. In particular, models corresponding to a minimum point whose loss function neighborhood is relatively “flat” tends to have better generalization performance. The intuition is as follows: think of the empirical loss as a random function, with the randomness coming from which data points are chosen from the sample. As we draw several of these empirical loss functions, minima associated with flat areas of the function tend to stay in the same area (and hence are “robust”), while minima associated with sharp areas move around a lot.
Here is a stylized example to illustrate the intuition. Imagine that the line in black is the population loss. There are 10 blue lines in the figure, each representing a possible empirical loss function. You can see that there is a lot of variation in the part of the loss associated with the sharper minimum on the left as compared to the flatter minimum on the right.
Imagine for each of the 10 empirical loss functions, we locate the
value of the two minima, but record their value on the true population loss. The points in red correspond to the population loss for the sharper minimum while the points in blue correspond to that for the flatter minimum. We can see that the loss values for the blue points don’t fluctuate as much as that for the red points.
Sharpness-aware minimization (SAM)
They are many ways to define “flatness” or “sharpness”. Sharpness-aware minimization (SAM), introduced by Foret et. al. (2020) (Reference 1), is one way to formalize the notion of sharpness and use it in model training. Instead of finding parameter values which minimize the empirical loss at a point:
find parameter values whose entire neighborhoods have uniformly small empirical loss:
where is a hyperparameter and
is the Lp-norm, with
typically chosen to be 2.
The figure below shows what looks like for different values of
in our simple example. As
increases, the value of
increases a lot for the sharp minimum on the left but not very much for the flat minimum on the right.
In practice we don’t actually minimize , but minimize the SAM loss with an L2 regularization term. Also,
can’t be computed analytically: the paper has details on how to get around it (via approximations and such). One final note is on what value
should take. In the paper,
is treated as a hyperparameter which needs to be tuned (e.g. with grid search).
References:
- Foret, P., et. al. (2020). Sharpness-Aware Minimization for Efficiently Improving Generalization.