Introduction
In this previous post, we discussed the intuition behind sharpness-aware minimization (SAM). In this post, we describe how SAM is implemented.
Consider the general supervised learning set-up, and assume that we have a class of models parameterized by . If we have a dataset
, a common approach is to select
by minimizing the empirical loss
where is the loss incurred for a single data point. Instead of minimizing the empirical loss, sharpness-aware minimization (SAM) (Foret et al. 2020, Reference 1) solves the problem
where are hyperparameters and
is the Lp-norm, with
typically chosen to be 2.
We could try to solve this problem using first-order methods, e.g. gradient descent:
for some step size . However, we need some tricks to approximate
efficiently.
Simplification 1: Taylor expansion
The main difficulty in computing is that we need to solve the maximization problem before taking a gradient w.r.t.
. Let’s approximate the objective function in the maximization problem with a first-order Taylor expansion. If we define
, then
Defining the last quantity as , we recognize it as the solution to the classical dual norm problem:
where . With this expression,
Simplification 2: Drop second-order terms
Since is a function of
, computing
implicitly depends on the Hessian of
. While Reference 1 notes that such a computation can be done, we can just drop this term to accelerate computation. This gives us our final gradient approximation which we use in practice:
In other words, the gradient of the SAM loss at is approximately the gradient of the empirical loss at
, a point that is distance
away from
in a specially chosen direction.
References:
- Foret, P., et. al. (2020). Sharpness-Aware Minimization for Efficiently Improving Generalization.