What is Sharpness-Aware Minimization (SAM)?

Introduction

Consider the general machine learning set-up. We have a class of models parameterized by w \in \mathcal{W} \subseteq \mathbb{R}^d (e.g. for linear regression, w would be the coefficients of the model). Each of these models takes in some input x \in \mathcal{X} and outputs a result y \in \mathcal{Y} \in \mathcal{Y}. We want to select the parameter w which minimizes some population loss:

\begin{aligned} L_P(w) = \mathbb{E}_{(x,y) \sim D}[l(w, x, y)], \end{aligned}

where l is the loss incurred for a single data point, and D is the population distribution for our data. Unfortunately we don’t know D and hence can’t evaluate L_P(w) exactly. Instead, we often have a dataset (x_1, y_1), \dots, (x_n, y_n) with which we can define the empirical loss

\begin{aligned} L_S(w) = \frac{1}{n}\sum_{i=1}^n l(w, x_i, y_i). \end{aligned}

A viable approach is to select w by minimizing the empirical loss: the hope here is that the empirical loss is a good approximation to the population loss.

For many modern ML models (especially overparameterized models), the loss functions are non-convex with multiple local and global minima. In this setup, it’s known that the parameter \hat{w} obtained by minimizing the empirical loss does not necessarily translate into small population loss. That is, good performance on the training set does not “generalize” well. In fact, many different w can give similar values of L_S(w) but very different values of L_P(w).

“Flatness” and generalization performance

One thing that is emerging from the literature is a connection between the “flatness” of minima and generalization performance. In particular, models corresponding to a minimum point whose loss function neighborhood is relatively “flat” tends to have better generalization performance. The intuition is as follows: think of the empirical loss L_S as a random function, with the randomness coming from which data points are chosen from the sample. As we draw several of these empirical loss functions, minima associated with flat areas of the function tend to stay in the same area (and hence are “robust”), while minima associated with sharp areas move around a lot.

Here is a stylized example to illustrate the intuition. Imagine that the line in black is the population loss. There are 10 blue lines in the figure, each representing a possible empirical loss function. You can see that there is a lot of variation in the part of the loss associated with the sharper minimum on the left as compared to the flatter minimum on the right.

Imagine for each of the 10 empirical loss functions, we locate the x value of the two minima, but record their value on the true population loss. The points in red correspond to the population loss for the sharper minimum while the points in blue correspond to that for the flatter minimum. We can see that the loss values for the blue points don’t fluctuate as much as that for the red points.

Sharpness-aware minimization (SAM)

They are many ways to define “flatness” or “sharpness”. Sharpness-aware minimization (SAM), introduced by Foret et. al. (2020) (Reference 1), is one way to formalize the notion of sharpness and use it in model training. Instead of finding parameter values which minimize the empirical loss at a point:

\begin{aligned} \hat{w} = \text{argmin}_{w \in \mathcal{W}} L_S(w), \end{aligned}

find parameter values whose entire neighborhoods have uniformly small empirical loss:

\begin{aligned} \hat{w} = \text{argmin}_{w \in \mathcal{W}} L_S^{SAM}(w) = \text{argmin}_{w \in \mathcal{W}} \left\{ \max_{\| \epsilon \|_p \leq \rho} L_S(w + \epsilon) \right\}, \end{aligned}

where \rho \geq 0 is a hyperparameter and \| \cdot \|_p is the Lp-norm, with p typically chosen to be 2.

The figure below shows what L_S^{SAM} looks like for different values of \rho in our simple example. As \rho increases, the value of L_S^{SAM} increases a lot for the sharp minimum on the left but not very much for the flat minimum on the right.

In practice we don’t actually minimize L_S^{SAM}, but minimize the SAM loss with an L2 regularization term. Also, L_S^{SAM} can’t be computed analytically: the paper has details on how to get around it (via approximations and such). One final note is on what value \rho should take. In the paper, \rho is treated as a hyperparameter which needs to be tuned (e.g. with grid search).

References:

  1. Foret, P., et. al. (2020). Sharpness-Aware Minimization for Efficiently Improving Generalization.
Advertisement

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s