Author: Arnaud Autef
Contents
<aside> 💡 In this reading group session: we review the LassoNet method, from its motivation to its theoretical justification.
In the next session: we review experimental results with the LassoNet and discuss potential applications and extensions of this algorithm
</aside>
The LassoNet paper by Ismael Lemhadri, Feng Ruan, Louis Abraham and Rob Tibshirani is to be published in JMLR, the paper website is available at https://lassonet.ml, most images throughout this presentation are taken from the paper.
Supervised learning problem
Dataset $D = (x_i,~y_i)_{i \le 1 \le n}$
Loss function $l(y',~y)$ to evaluate models (differentiable)
<aside> 💡 Task: Find a minimal set of features $k \subset [d]$ to model $y$
Problem: Relationship between the response $y$ and input variables $x_i \in \mathbb{R}^d$ is non linear, but well-modelled by a neural network function approximator.
Goal: Can we efficiently select variables $x_i$ when the mapping from $x_i$ to $y_i$ is a neural network?
</aside>
The LassoNet procedures that the authors propose can be broken down into 3 steps:
Augment the neural network model space with a skip connection
Let $\{g_{W} ~|~ W \in \mathbb{R}^p \}$ the space of neural network function approximators selected, with parameters $W \in \mathbb{R}^p$.
We assume that there exists some $W \in \mathbb{R}^p$ such that
$$ \forall i \in D, \quad y_i \approx g_W(x_i) $$
Consider the larger set space of models
$$ \{f_{\theta,W} : x \mapsto g_W(x) + x^T\theta ~|~ \theta \in \mathbb{R}^d,~W \in \mathbb{R}^p \} $$
On a picture:
Adapted from the LassoNet paper Figure 3.
Define a sparsity inducing loss function to select variables in those models
$$ \min_{\theta,W} \quad L(\theta,~W) + \lambda ||\theta||1\\ \text{ }\\ \text{s.t}~\quad \forall 1 \le j \le d,\quad ||W_j^{(1)}||{\infty} \le M |\theta_j| $$
With,
$$ L(\theta,~W) = \frac{1}{n} \sum_{1 \le i \le n} l(f_{\theta,W}(x_i), y_i) $$
Discussion
Penalize skip connection weights $\theta$ with a Lasso-like L1 penalty
→ Only a subset $k \subset [d]$ of input variables will have a non-zero weight $\theta_j$ in the skip connection, for high enough $\lambda$.
Enforce the following constraint
<aside> 💡 If input feature $j \in [d]$ is not "Lasso-selected" in the skip connection, zero out $j$ from the network inputs.
In mathematical terms: $\theta_j =0 \Rightarrow W_j^{(1)} = 0$
</aside>
→ This is enforced by the constraint
$$ \forall 1 \le j \le d,\quad ||W_j^{(1)}||_{\infty} \le M |\theta_j| $$
→ More generally, the parameter $M \ge 0$ controls "How large" $W_j^{(1)}$ can be compared to $\theta_j$, for any input feature $j$
$M = 0$ reduces to Lasso linear regression
$M \rightarrow + \infty$ reduces to an unpenalized fit of the underlying neural network
Design an efficient learning algorithm to minimize this loss
Starting Idea: look for something close to Stochastic Gradient Descent (SGD), which works well for neural networks
Problem 1 - non-differentiability
$$ L(\theta,~W) + \lambda ||\theta||_1 $$
Solution 1 - use proximal gradient descent
For problems of the form $f(x) = g(x) + h(x)$ with $g$ differentiable and $h$ non-differentiable. Here,
Replace Gradient Descent
$$ x_{t+1} \leftarrow x_t - \alpha \nabla f(x_t) = \argmin_z \frac{1}{2\alpha}||z - (x_t - \alpha \nabla f(x_t))||_2^2 $$
By Proximal Gradient Descent
$$ x_{t+1} \leftarrow \argmin_z \frac{1}{2\alpha}||z - (x_t - \alpha \nabla f(x_t))||_2^2 + h(z) $$
Quick reference for curious readers
Where do we get with ideas 1., 2. ?
"Stochastic Proximal Gradient Descent"
Iterate for $E$ epochs
Iterate over batches $b = (x_i,y_i)_i$ of size $B$
Solve the batch-level proximal optimization problem
$$ \begin{aligned} \argmin_{\theta',~W'} &\frac{1}{2\alpha}||\theta' - (\theta - \alpha \nabla_\theta L_B(\theta, W))||_2^2\\ +~&\frac{1}{2\alpha}||W' - (W - \alpha \nabla_W L_B(\theta, W))||_2^2\\ +~&\lambda ||\theta'||1\\ \text{ }\\ \text{s.t}~\quad &\forall 1 \le j \le d,\quad ||W_j^{(1)}||{\infty} \le M |\theta_j| \end{aligned} $$
With the batch-level loss
$$ L_B(\theta,~W) = \frac{1}{B} \sum_{i \in b} l(f_{\theta,W}(x_i), y_i) $$
Problem 2 - Problem constraints are non-convex, tough problem to solve at first glance
Solution 2 - a new "Hier-Prox" algorithm solves the above problem efficiently and exactly
Taken from the LassoNet paper
→ Hier-Prox's complexity is in $\mathcal{O}(N \log N)$
Taken from the LassoNet paper
<aside> 💡 Outcome: end-to-end procedure which fits trains neural networks on the penalized objective, with decreasing penalties $\lambda$, producing a "path" of selected variables $k \subset [d]$
Start by fitting a neural network parameterized by $W$ on the dataset with no $\lambda$ penalty
For decreasing $\lambda$ penalties
a. Apply the "LassoNet" learning algorithm → Proximal Stochastic Gradient Descent with constraints to get optimal $W, \theta$ for the current $\lambda$
b. observe selected variables $j$, the ones with $\theta_j \neq 0$
All along the path, "warm-start" the neural network trainings with the previous optimal parameters!
</aside>
Proof walkthrough with scribbled notes as support