What is elastic weight consolidation?

Elastic weight consolidation (EWC) is a technique proposed by Kirkpatrick et al. (2016) (Reference 1) as a way for avoiding catastrophic forgetting in neural networks.

Set-up and notation

Imagine that we want to train a neural network to be able to perform well on several different tasks. In real-world settings, it’s not always possible to have all the data from all tasks available at the beginning of model training, so we want our model to be able to learn continually: “that is, [have] the ability to learn consecutive tasks without forgetting how to perform previously trained tasks.”

This turns out to be difficult for neural networks because of a phenomenon known as catastrophic forgetting. Assume that I have already trained a model to be good at one task, task A. Now, when I further train this model to be good at a second task, task B, performance on task A tends to degrade, with the performance drop often happening abruptly (hence “catastrophic”). This happens because “the weights in the network that are important for task A are changed to meet the objectives of task B”.

Let’s rewrite the previous paragraph with some mathematical notation. Assume that we have training data \mathcal{D}_A and \mathcal{D}_B for tasks A and B respectively. Let \mathcal{D} = \mathcal{D}_A \cup \mathcal{D}_B. Our neural network is parameterized by \theta: model training corresponds to finding an “optimal” value of \theta.

Let’s say that training on task A gives us the network parameters \theta_A^\star. Catastrophic forgetting says that when further training on task B (to reach new optimal parameters \theta_B^\star), \theta_B^\star is too far away from \theta_A^\star and hence has poor performance on task A.

Elastic weight consolidation (EWC)

The intuition behind elastic weight consolidation (EWC) is the following: since large neural networks tend to be over-parameterized, it is likely that for any one task, there are many values of \theta that will give similar performance. Hence, when further training on task B, we should constrain the model parameters to be “close” to \theta_A^\star in some sense. The figure below (Figure 1 from Reference 1) illustrates this intuition:How do we make this intuition concrete? Assume that we are in a Bayesian setting. What we would like to do is maximize the (log) posterior probability of \theta given all the data \mathcal{D}:

\begin{aligned} \log p(\theta | \mathcal{D}) = \log p(\mathcal{D} | \theta) + \log p(\theta) - \log p(\mathcal{D}). \end{aligned}

Note that the equation above applies if we replace all instances of \mathcal{D} with \mathcal{D}_A or \mathcal{D}_B. Assuming the data for tasks A and B are independent, let’s do some rewriting of the equation above:

\begin{aligned} \log p(\theta | \mathcal{D}) &= \log p(\mathcal{D}_A | \theta) + \log p(\mathcal{D}_B | \theta) + \log p(\theta) - \log p(\mathcal{D}_A) - \log p(\mathcal{D}_B) \\  &= \log p(\mathcal{D}_B | \theta) + \log p(\theta) - \log p(\mathcal{D}_A) - \log p(\mathcal{D}_B) \\  &\qquad + [\log p(\theta | \mathcal{D}_A) - \log p(\theta) + \log p (\mathcal{D}_A) ] \\  &= \log p(\mathcal{D}_B | \theta) + \log p(\theta | \mathcal{D}_A) - \log p(\mathcal{D}_B). \end{aligned}

Thus, maximizing the log posterior probability of \theta is equivalent to maximizing

\log p(\mathcal{D}_B | \theta) + \log p(\theta | \mathcal{D}_A). \qquad -(1)

Assume that we have already trained our model in some way to get to \theta_A^\star which does well on task A. We recognize the first term above as the log-likelihood function for task B’s training data: maximizing just the first term alone corresponds to the usual maximum likelihood estimation (MLE). For the second term, we can approximate the posterior distribution of \theta based on task A’s data as a Gaussian distribution centered around \theta_A^\star with some covariance. (This is known as the Laplace approximation). Simplifying even further: instead of the distribution having precision matrix equal to the full Fisher information matrix F, we assume that the precision matrix is diagonal having the same values as the diagonal of F.

Putting it altogether, maximizing (1) is equivalent to minimizing

\mathcal{L}(\theta) = \mathcal{L}_B (\theta) + \sum_i \frac{\lambda}{2} F_i (\theta_i - \theta_{A,i}^\star)^2,

where \mathcal{L}_B(\theta) is the loss for task B only, i indexes the weights and \lambda \geq 0 is an overall hyperparameter that balances the importance of the two tasks. EWC minimizes \mathcal{L}(\theta).

Why the name “elastic weight consolidation”? From the authors:

“This constraint is implemented as a quadratic penalty, and can therefore be imagined as a spring anchoring the parameters to the previous solution, hence the name elastic. Importantly, the stiffness of this spring should not be the same for all parameters; rather, it should be greater for those parameters that matter most to the performance during task A.”

References:

  1. Kirkpartick, J., et al. (2016). Overcoming catastrophic forgetting in neural networks.

Leave a comment