glmnet v4.0: generalizing the family parameter

I’ve had the privilege of working with Trevor Hastie on an extension of the glmnet package which has just been released. In essence, the glmnet() function’s family parameter can now be any object of class family. This enables the user to fit any generalized linear model with the elastic net penalty.

glmnet v4.0 is now available on CRAN here. We have also written a vignette which explains and demonstrates this new functionality. This blog post presents what we’ve done from a different perspective, and goes into a bit more detail on how we made it work.

I would love for you to download the package and try out this new functionality! Comments/suggestions/bug reports welcome 🙂

Background on generalized linear models (GLM)

Let’s say we have n observations x_1, \dots, x_n \in \mathbb{R}^p with corresponding responses y_1, \dots, y_n. The simple linear model assumes that the response is a linear combination of p features plus some (mean-zero) noise:

y_i = \beta^T x_i + \epsilon_i.

Generalized linear models (GLM) are an extension of linear models which allow more flexible modeling. A GLM consists of 3 parts:

  1. A linear predictor \eta_i = \beta_0 + \beta^T x_i,
  2. A link function \eta_i = g(\mu_i), and
  3. A random component y_i \sim f(y \mid \mu_i).

(You can read more about them here.) f and g are specified by the GLM, while \beta_0 and \beta are parameters that we have to estimate. With these 3 components and the data at hand, we can define the log-likelihood function and estimate \beta_0 and \beta via maximum likelihood estimation. If we define \ell_i (y_i, \eta_i) to be the negative log-likelihood associated with observation i, this amounts to solving

\begin{aligned} \underset{\beta_0, \beta}{\text{min}} \quad \frac{1}{n}\sum_{i = 1}^n \ell_i (y_i, \beta_0 + \beta^T x_i).  \end{aligned}

This can be solved via a technique called iteratively reweighted least squares (IRLS). At a high level this is what it does:

Step 1: Compute new weights w_1, \dots, w_n and new working responses z_1, \dots, z_n.

Step 2: Solve the weighted least squares (WLS) problem

\begin{aligned} \underset{\beta_0, \beta}{\text{min}} \quad \sum_{i=1}^n w_i (z_i - \beta_0 - \beta^T x_i)^2. \end{aligned}

Step 3: Repeat Steps 1 and 2 until convergence.

Essentially in each cycle we are making a quadratic approximation of the negative log-likelihood, then minimizing that approximation. (For more details, see this webpage or this older post.) The key point here is that once we find an efficient way to do Step 2 (solve WLS), we have a way to solve the GLM optimization problem.

This is how the glm() function in R fits GLMs. We specify the 3 components of the GLM via an object of class family. Below are examples of how we might run logistic regression and probit regression in R:

glm(y ~ x, family = binomial())                 # logistic regression
glm(y ~ x, family = binomial(link = "probit"))  # probit regression

GLMs with the elastic net penalty

In many instances, instead of minimizing the log-likelihood we want to minimize the sum of the log-likelihood and a penalty term. The penalty we choose will influence the properties that our solution will have. A popular choice of penalty is the elastic net penalty

\begin{aligned} \text{penalty} = \lambda \left[\frac{1-\alpha}{2}\|\beta\|_2^2 + \alpha \| \beta\|_1 \right].\end{aligned}

Here, \lambda \geq 0 and \alpha \in [0, 1] are hyperparameters that the user chooses. The elastic net penalty reduces to the lasso penalty when \alpha = 1, and to the ridge penalty when \alpha = 0.

The glmnet package solves this minimization problem for a grid of \lambda values. The IRLS algorithm used to compute the GLM solution can be easily adapted to compute the solution to (1):

Step 1: Compute new weights w_1, \dots, w_n and new working responses z_1, \dots, z_n.

Step 2: Solve the penalized WLS problem

\begin{aligned} \underset{\beta_0, \beta}{\text{min}} \quad \sum_{i=1}^n w_i (z_i - \beta_0 - \beta^T x_i)^2 + \lambda \left[\frac{1-\alpha}{2}\|\beta\|_2^2 + \alpha \| \beta\|_1 \right]. \quad -(2) \end{aligned}

Step 3: Repeat Steps 1 and 2 until convergence.

Step 1 is exactly the same as in the GLM algorithm, so as long as we have an efficient routine for solving (2), we have a way to optimize the penalized likelihood.

glmnet v4.0

While we could do the above for any GLM family in theory, it was not implemented in practice. Before v4.0, glmnet() could only optimize the penalized likelihood for special GLM families (e.g. ordinary least squares, logistic regression, Poisson regression). For each family, which we specified via a character string for the family parameter, we had custom FORTRAN code that ran the modified IRLS algorithm above. While this was computationally efficient, it did not allow us to fit any penalized GLM of our choosing.

From v4.0 onwards, we can do the above for any GLM family in practice. We can do so by passing a class family object to the family parameter instead of a character string. For example, if we want to do probit regression with the elastic net penalty, we would do something like this:

glmnet(x, y, family = binomial(link = "probit"))

Underneath the hood, instead of having custom FORTRAN code for each family, we have a FORTRAN subroutine that solves (2) efficiently. Everything surrounding the FORTRAN subroutine is basically the same as glmnet()/glm(). (Only in theory of course: as with most engineering tasks the devil is in the details!)

For the special families that glmnet pre-v4.0 could fit, we still recommend passing a character string to the family parameter as it would run more quickly. (The vignette has more details on this.)

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s