Using the Laplace method to approximate the posterior distribution of neural network parameters and outputs

In this previous post, we described how the Laplace method (or Laplace approximation) allows us to approximate a probability distribution with a Gaussian distribution with suitably chosen parameters. In this post, we show how it can be used to estimate the posterior distribution of the parameters and output of a neural network (NN), or really any prediction function. This post generally follows the exposition in Section 5.7.1 of Bishop (2006) (Reference 1).

Set-up

Assume that we are in the supervised learning setting with i.i.d. data \mathcal{D} = \{ x_i, y_i \}_{i=1}^n, where x_i \in \mathbb{R}^p and y_i \in \mathbb{R}. We have a neural network model f(\cdot ; w) which takes in an input vector x and outputs a prediction y. We would like to use the Bayesian framework to learn w. To do that, we need to specify a prior distribution for the weights w, and likelihood function for the data, and then we can turn the proverbial Bayesian crank to get posterior distributions.

Assume the prior distribution of the network weights is normal with some mean \mu_w and precision matrix \Lambda (inverse of covariance matrix):

\begin{aligned} p(w) = \mathcal{N}(w | \mu_w, \Lambda^{-1}) \end{aligned}

Assume that the conditional distribution of the target variable y given x is normal as well, with mean at the neural network output and some precision parameter \beta \in \mathbb{R}:

\begin{aligned} p(y \mid x, w) = \mathcal{N}(y | f(x; w), \beta^{-1}). \end{aligned}

Since our data is i.i.d., if we assume the x_i‘s are fixed, our likelihood function is

\begin{aligned} p(\mathcal{D} | w) = \prod_{i=1}^n \mathcal{N}(y_i | f(x_i ; w), \beta^{-1}). \end{aligned}

Posterior distribution of NN parameters

We now have all the ingredients we need for Bayesian inference. The posterior distribution is proportional to the product of the prior and the likelihood:

\begin{aligned} p(w \mid \mathcal{D}) \propto p(\mathcal{D}|w) p(w). \end{aligned}

Taking logarithms on both sides:

\begin{aligned} \log p(w \mid \mathcal{D}) &= -\frac{1}{2}(w-\mu_w)^\top \Lambda (w - \mu_w) - \frac{\beta}{2}\sum_{i=1}^n \left[ f(x_i; w) - y_i \right]^2 + C \end{aligned}

for some constant C. If f(x_i; w) was linear in w, the RHS is a quadratic expression in w, which implies that p(w \mid \mathcal{D} has Gaussian distribution with parameters which we can derive. Unfortunately for NNs and many other prediction models, f(x_i; w) is not linear in w.

Instead, we approximate the posterior with Laplace’s approximation. The details of Laplace’s approximation are in this previous post. we omit them here. If we let w_{MAP} be the mode of the posterior p(w | \mathcal{D}), then we can approximate the posterior as

\begin{aligned} p(w \mid \mathcal{D}) \approx \mathcal{N}(w | w_{MAP}, \Lambda^{-1}), \end{aligned}

where

\begin{aligned} \Lambda &= \left[ - \nabla^2 [\log p(w \mid \mathcal{D})] \right]_{w = w_{MAP}} \\  &= \Lambda + \frac{\beta}{2} \left[ \sum_{i=1}^n \nabla^2 \left[ f(x_i; w) - y_i \right]^2 \right]_{w = w_{MAP}}. \end{aligned}

Posterior distribution of target variable

Once we have the posterior distribution of the parameters, we can get the posterior predictive distribution of the target by marginalizing out the parameters:

\begin{aligned} p(y \mid x, \mathcal{D}) &= \int p(y \mid x, w) p(w \mid \mathcal{D}) dw. \end{aligned}

If the mean of y | x was  linear in w, then this would be a linear Gaussian model, and we could use standard results for that model to obtain the posterior predictive distribution (see this previous post). Unfortunately, for NNs this is not the case.

Again we turn to a Taylor approximation to move forward. Assuming we can ignore quadratic and higher order terms,

\begin{aligned} f(x; w) &\approx f(x; w_{MAP}) + g^\top (w -w_{MAP}), \\  g &= \left[ \nabla_w f(x; w) \right]_{w = w_{MAP}}. \end{aligned}

With this approximation, we have a linear Gaussian model:

\begin{aligned} p(w \mid \mathcal{D}) &= \mathcal{N}(w | w_{MAP}, \Lambda^{-1}), \\  p(y \mid x, w) &= \mathcal{N} \left( y \mid  f(x; w_{MAP}) + g^\top (w -w_{MAP}), \beta^{-1} \right). \end{aligned}

Applying the result for the linear Gaussian model,

\begin{aligned} p(y \mid x, \mathcal{D}) &= \mathcal{N}\left( y \mid f(x; w_{MAP}) + g^\top (w_{MAP} -w_{MAP}), \beta^{-1} + g^\top \Lambda^{-1} g \right) \\  &= \mathcal{N}\left( y \mid f(x; w_{MAP}), \beta^{-1} + g^\top \Lambda^{-1} g \right). \end{aligned}

References:

  1. Bishop, C. M. (2006). Pattern Recognition and Machine Learning (Section 4.4).

Conditional distributions for the linear Gaussian model

A linear Gaussian model is a model where (i) the input variables are jointly Gaussian, and (ii) the output variables are also jointly Gaussian, having means which are linear combinations of the input variables (and possibly a bias term). Linear Gaussian models are popular because in this model, several key quantities end up having Gaussian distribution with parameters which we can compute. Here is the main theorem:

Theorem. Assume that

\begin{aligned} p(x) &= \mathcal{N}\left( x \mid \mu, \Lambda^{-1} \right) , \\  p(y \mid x) &= \mathcal{N} \left( y \mid Ax + b, L^{-1} \right). \end{aligned}

Then

\begin{aligned} p(y) &= \mathcal{N}\left( y \mid A\mu + b, L^{-1} + A \Lambda^{-1}A^\top \right) , \\  p(x \mid y) &= \mathcal{N} \left( x \mid \Sigma \left[ A^\top L (y - b) + \Lambda \mu \right], \Sigma \right), \end{aligned}

where

\begin{aligned} \Sigma = \left( \Lambda + A^\top L A \right)^{-1} \end{aligned}

The details of the proof are a bit involved but the overall idea is quite simple. If you compute the log of the joint distribution p(x, y) = p(x) p(y\mid x), you will find that \log p(y) is a quadratic expression in y, which implies that (x, y) has Gaussian distribution. Some algebra on the expression, along with the matrix inversion formula, gives the mean and covariance matrix of the joint distribution. We can pick out the relevant terms for the marginal distribution of y.

The result for the conditional distribution x\mid y follows directly from the result in this previous post.

See Section 2.3.3 of Bishop (2006) (Reference 1) for the gory details.

References:

  1. Bishop, C. M. (2006). Pattern Recognition and Machine Learning (Section 2.3.3).

Using the Laplace approximation to approximate a probability distribution

The Laplace approximation is a method for approximating a probability distribution p over \mathbb{R}^d with a Gaussian distribution \hat{p}. The approximation consists of two steps: estimating the mean \mu and variance \Sigma of the approximate Gaussian distribution.

Step 1: Estimate the mean

The Laplace approximation places the mean of the Gaussian distribution \mu at the mode of p, i.e. the value z_0 which maximizes p. The mode is usually determined using standard optimization algorithms like stochastic gradient descent (SGD).

Step 2: Estimate the variance

Having located the mean of the Gaussian distribution at z_0, we now estimate the variance \Sigma in a way that allows the curvature of \hat{p} to match that of the true distribution p. Recall that the log probability density function (PDF) of a Gaussian distribution can be written as

\begin{aligned} \log \hat{p}(z) &= C - \frac{1}{2}(z-z_0)^\top \Sigma^{-1}(z-z_0) \end{aligned}

for some constant C. (Here we have assumed that the mean of the distribution is at z_0.) If we let \Lambda be the precision matrix, \Lambda = \Sigma^{-1}, then the above becomes

\begin{aligned} \log \hat{p}(z) &= C - \frac{1}{2}(z-z_0)^\top \Lambda(z-z_0). \end{aligned}

Taking a Taylor expansion of the log PDF of the true distribution around z = z_0 and looking at up to the quadratic term:

\begin{aligned} \log p(z) &\approx \log p(z_0) - \frac{1}{2}(z-z_0)^\top A (z-z_0), \\  A &= \left[ - \nabla^2 [\log p(z)] \right]_{z = z_0}. \end{aligned}

There is no linear term because [\nabla \log p(z)]_{z = z_0} = 0. Looking at the two equations above, we see that setting \Lambda = A makes the two expressions the same. This is exactly what the Laplace approximation does: it sets the precision matrix of the approximate Gaussian distribution to the Hessian of the negative log PDF of the true distribution, evaluated at the mode of the distribution.

(Note: I previously wrote a post on using the Laplace approximation as a way to approximate the denominator of a posterior distribution in Bayesian inference. The method is basically the same as that in this post, but I find the use case in this post more prevalent.)

References:

  1. Bishop, C. M. (2006). Pattern Recognition and Machine Learning (Section 4.4).

KL divergence between two normal distributions

In a recent blog post, John Cook gave an explicit formula for the KL divergence between two 1-dimensional normal distributions. It turns out that there is a similar formula for the KL divergence between two multivariate normal distributions: this post states and proves the formula.

Let P be the normal distribution \mathcal{N}(\mu_1, \Sigma_1) on \mathbb{R}^d and let Q be the normal distribution \mathcal{N}(\mu_2, \Sigma_2) on \mathbb{R}^d. Then the KL divergence of Q from P is

\begin{aligned} D_{KL}(P || Q) &= \frac{1}{2}\left[ \log \frac{|\Sigma_2|}{|\Sigma_1|} - d + \text{tr}\left( \Sigma_2^{-1} \Sigma_1 \right) + (\mu_2 - \mu_1)^\top \Sigma_2^{-1} (\mu_2 - \mu_1) \right]. \end{aligned}

Here’s the proof. If p and q denote the PDF of P and Q respectively, then

\begin{aligned} D_{KL}(P || Q) &= \int_{\mathbb{R}^d} p(x) \log \dfrac{p(x)}{q(x)} dx \\  &= \int_{\mathbb{R}^d}  \log \dfrac{|\Sigma_1|^{-1/2} \exp \left\{ -\frac{1}{2}(x-\mu_1)^\top \Sigma_1^{-1} (x - \mu_1) \right\}}{|\Sigma_2|^{-1/2} \exp \left\{ -\frac{1}{2}(x-\mu_2)^\top \Sigma_2^{-1} (x - \mu_2) \right\}} p(x) dx \\  &= \int_{\mathbb{R}^d} \left[ \frac{1}{2}\log \frac{|\Sigma_2|}{|\Sigma_1|} + \frac{1}{2}(x-\mu_2)^\top \Sigma_2^{-1} (x - \mu_2) - \frac{1}{2}(x-\mu_1)^\top \Sigma_1^{-1} (x - \mu_1) \right] p(x) dx \\  &= \frac{1}{2} \left[ \log \frac{|\Sigma_2|}{|\Sigma_1|} + \mathbb{E}_P \left[ (X-\mu_2)^\top \Sigma_2^{-1} (X - \mu_2) \right] - \mathbb{E}_P \left[ (X-\mu_1)^\top \Sigma_1^{-1} (X - \mu_1) \right] \right], \end{aligned}

where \mathbb{E}_P[f(X)] denotes the expectation of f(X) assuming that X \sim P. Now, we employ the trace trick to complete the proof. Using the trace trick, we know that if Y \in \mathbb{R}^n is a random vector such that \mathbb{E}[Y] = \mu and \text{Var } Y = \Sigma, then for any fixed matrix A \in \mathbb{R}^{n \times n},

\begin{aligned} \mathbb{E} \left[ Y^T AY \right] = \mu^T A \mu + \text{tr}(A\Sigma). \end{aligned}

(See this previous post for the statement and proof.) Applying this result on X - \mu_1 with X \sim P:

\begin{aligned} \mathbb{E}_P \left[ (X-\mu_1)^\top \Sigma_1^{-1} (X - \mu_1) \right] &= 0^\top \Sigma_1^{-1} 0 + \text{tr}(\Sigma_1^{-1} \Sigma_1) \\  &= \text{tr}(I_d) \\  &= d. \end{aligned}

Applying this result on X - \mu_2 with X \sim P:

\begin{aligned} \mathbb{E}_P \left[ (X-\mu_2)^\top \Sigma_2^{-1} (X - \mu_2) \right] &= (\mu_1 - \mu_2)^\top \Sigma_2^{-1} (\mu_1 - \mu_2) + \text{tr}(\Sigma_2^{-1} \Sigma_1) \\  &= (\mu_2 - \mu_1)^\top \Sigma_2^{-1} (\mu_2 - \mu_1) + \text{tr}(\Sigma_2^{-1} \Sigma_1). \end{aligned}

Plugging these expressions into the equation above gives us the desired result.

Note 1: When d = 1, i.e. P and Q are 1-dimensional normal distributions, the formula simplifies to

\begin{aligned} D_{KL}(P || Q) &= \frac{1}{2}\left[ \log \frac{\sigma_2^2}{\sigma_1^2} - 1 + \frac{\sigma_1^2}{\sigma_2^2} + \frac{(\mu_2 - \mu_1)^2}{\sigma_2^2} \right] \\  &= \log \frac{\sigma_2}{\sigma_1} - \frac{1}{2} + \frac{\sigma_1^2 + (\mu_2 - \mu_1)^2}{2\sigma_2^2}, \end{aligned}

which is the same formula as that in John Cook’s post.

Note 2: When \mu_1 = \mu_2 = \mu, the formula simplifies to

\begin{aligned} D_{KL}(\mathcal{N}(\mu, \Sigma_1) || \mathcal{N}(\mu, \Sigma_2)) &= \frac{1}{2}\left[ \log \frac{|\Sigma_2|}{|\Sigma_1|} - d + \text{tr}\left( \Sigma_2^{-1} \Sigma_1 \right) \right]. \end{aligned}

Note 3: When \Sigma_1 = \Sigma_2 = \Sigma, the formula simplifies to

\begin{aligned} D_{KL}(\mathcal{N}(\mu_1, \Sigma) || \mathcal{N}(\mu_2, \Sigma)) &= \frac{1}{2}\left[ \log 1 - d + \text{tr}\left( I_d \right) + (\mu_2 - \mu_1)^\top \Sigma^{-1} (\mu_2 - \mu_1) \right] \\  &= \frac{1}{2}(\mu_2 - \mu_1)^\top \Sigma^{-1} (\mu_2 - \mu_1). \end{aligned}

References:

  1. StackExchange. KL divergence between two multivariate Gaussians.

Conditional distribution for the multivariate normal distribution

Assume that Y_1, Y_2 are random vectors such that \begin{pmatrix} Y_1 & Y_2 \end{pmatrix}^\top has multivariate normal distribution:

\begin{aligned} \begin{pmatrix} Y_1 \\ Y_2 \end{pmatrix} &\sim \mathcal{N}(\mu, \Sigma), \\  \mu = \begin{pmatrix} \mu_1 \\ \mu_2 \end{pmatrix}, &\quad \Sigma = \begin{pmatrix} \Sigma_{11} & \Sigma_{12} \\ \Sigma_{21} & \Sigma_{22} \end{pmatrix}. \end{aligned}

Assume further that \Sigma_{22} is invertible. Then the conditional distribution of Y_1 given Y_2 = y_2 is also multivariate normal, and we have explicit formulas for the conditional mean and variance:

\begin{aligned} Y_1 &| Y_2 = y_2 \sim \mathcal{N}(\mu', \Sigma'), \\  \mu' &= \mu_1 + \Sigma_{12}\Sigma_{22}^{-1}(y_2 - \mu_2), \\  \Sigma' &= \Sigma_{11} - \Sigma_{12} \Sigma_{22}^{-1} \Sigma_{21}. \end{aligned}

The conditional variance \Sigma' is the Schur complement of the block \Sigma_{22} of the matrix \Sigma.

In the rest of this post, I sketch two ways to prove this formula.

Sketch of proof 1

According to the law of conditional probability, we know that the probability density function (PDF) of the conditional distribution satisfies

\begin{aligned} p(y_1 \mid y_2) &= \dfrac{p(y_1, y_2)}{p(y_2)}. \end{aligned}

Since we have formulas for the PDF of the joint distribution (Y_1, Y_2) and the marginal distribution Y_2, we can plug these expressions into the formula above. After some messy algebra, we obtain the PDF of the multivariate normal with the desired parameters \mu' and \Sigma'. The only other technical result we need in this approach is the matrix inversion formula for a block matrix. See Reference 3 for details of this proof.

Sketch of proof 2

Consider the random vector Z = Y_1 - AY_2, where A is some constant matrix with the correct dimensions that make the equation make sense. We know that \begin{pmatrix} Z & Y_2 \end{pmatrix}^\top still has multivariate normal distribution, and by standard transformation formulas we can show that

\begin{aligned} \text{Cov}(Z, Y_2) &= \Sigma_{12} - A \Sigma_{22}. \end{aligned}

If we set A = \Sigma_{12}\Sigma_{22}^{-1}, then the covariance between Z and Y_2 is zero and since they are jointly normal, it means that Z and Y_2 are independent. In other words, the conditional distribution of Y_1 - AY_2 given Y_2 = y_2 is the same as the unconditional distribution of Y_1 - AY_2. Since we know it’s normal, it remains to compute the mean and variance of the distribution, which we can use standard formulas for. See Reference 4 for details of this proof.

References:

  1. Wikipedia. Multivariate normal distribution: Conditional distributions.
  2. StackExchange. Deriving the conditional distributions of a multivariate normal distribution.
  3. The Book of Statistical Proofs. Proof: Conditional distributions of the multivariate normal distribution.
  4. Owen, A. Appendix B: Probability Review (Appendix B.9).