Author: Arnaud Autef

Contents

Session 1/2

<aside> 💡 In this reading group session: we review the LassoNet method, from its motivation to its theoretical justification.

In the next session: we review experimental results with the LassoNet and discuss potential applications and extensions of this algorithm

</aside>

The LassoNet paper by Ismael Lemhadri, Feng Ruan, Louis Abraham and Rob Tibshirani is to be published in JMLR, the paper website is available at https://lassonet.ml, most images throughout this presentation are taken from the paper.

Setup

<aside> 💡 Task: Find a minimal set of features $k \subset [d]$ to model $y$

Problem: Relationship between the response $y$ and input variables $x_i \in \mathbb{R}^d$ is non linear, but well-modelled by a neural network function approximator.

Goal: Can we efficiently select variables $x_i$ when the mapping from $x_i$ to $y_i$ is a neural network?

</aside>

Approach

The LassoNet procedures that the authors propose can be broken down into 3 steps:

  1. Augment the neural network model space with a skip connection

  2. Define a sparsity inducing loss function to select variables in those models

    $$ \min_{\theta,W} \quad L(\theta,~W) + \lambda ||\theta||1\\ \text{ }\\ \text{s.t}~\quad \forall 1 \le j \le d,\quad ||W_j^{(1)}||{\infty} \le M |\theta_j| $$

    With,

    $$ L(\theta,~W) = \frac{1}{n} \sum_{1 \le i \le n} l(f_{\theta,W}(x_i), y_i) $$

    Discussion

    1. Penalize skip connection weights $\theta$ with a Lasso-like L1 penalty

      → Only a subset $k \subset [d]$ of input variables will have a non-zero weight $\theta_j$ in the skip connection, for high enough $\lambda$.

    2. Enforce the following constraint

      <aside> 💡 If input feature $j \in [d]$ is not "Lasso-selected" in the skip connection, zero out $j$ from the network inputs.

      In mathematical terms: $\theta_j =0 \Rightarrow W_j^{(1)} = 0$

      </aside>

      → This is enforced by the constraint

      $$ \forall 1 \le j \le d,\quad ||W_j^{(1)}||_{\infty} \le M |\theta_j| $$

    → More generally, the parameter $M \ge 0$ controls "How large" $W_j^{(1)}$ can be compared to $\theta_j$, for any input feature $j$

  3. Design an efficient learning algorithm to minimize this loss

    1. Starting Idea: look for something close to Stochastic Gradient Descent (SGD), which works well for neural networks

      • SGD in a nutshell
    2. Problem 1 - non-differentiability

      $$ L(\theta,~W) + \lambda ||\theta||_1 $$

      • $L(\theta, W)$ is differentiable but $|| \theta ||_1$ is only sub-differentiable

      Solution 1 - use proximal gradient descent

      • For problems of the form $f(x) = g(x) + h(x)$ with $g$ differentiable and $h$ non-differentiable. Here,

        • $g$ = $L(\theta,~W)$
        • $h = \lambda ||\theta||_1$
      • Replace Gradient Descent

        $$ x_{t+1} \leftarrow x_t - \alpha \nabla f(x_t) = \argmin_z \frac{1}{2\alpha}||z - (x_t - \alpha \nabla f(x_t))||_2^2 $$

      • By Proximal Gradient Descent

        $$ x_{t+1} \leftarrow \argmin_z \frac{1}{2\alpha}||z - (x_t - \alpha \nabla f(x_t))||_2^2 + h(z) $$

      • Quick reference for curious readers

    3. Where do we get with ideas 1., 2. ?

      "Stochastic Proximal Gradient Descent"

      • Iterate for $E$ epochs

        • Iterate over batches $b = (x_i,y_i)_i$ of size $B$

          Solve the batch-level proximal optimization problem

          $$ \begin{aligned} \argmin_{\theta',~W'} &\frac{1}{2\alpha}||\theta' - (\theta - \alpha \nabla_\theta L_B(\theta, W))||_2^2\\ +~&\frac{1}{2\alpha}||W' - (W - \alpha \nabla_W L_B(\theta, W))||_2^2\\ +~&\lambda ||\theta'||1\\ \text{ }\\ \text{s.t}~\quad &\forall 1 \le j \le d,\quad ||W_j^{(1)}||{\infty} \le M |\theta_j| \end{aligned} $$

          With the batch-level loss

          $$ L_B(\theta,~W) = \frac{1}{B} \sum_{i \in b} l(f_{\theta,W}(x_i), y_i) $$

      • Problem 2 - Problem constraints are non-convex, tough problem to solve at first glance

    4. Solution 2 - a new "Hier-Prox" algorithm solves the above problem efficiently and exactly

      Taken from the LassoNet paper

      Taken from the LassoNet paper

      • If we write $N := Kd + d$ where,

      → Hier-Prox's complexity is in $\mathcal{O}(N \log N)$

Putting it all together

Taken from the LassoNet paper

Taken from the LassoNet paper

<aside> 💡 Outcome: end-to-end procedure which fits trains neural networks on the penalized objective, with decreasing penalties $\lambda$, producing a "path" of selected variables $k \subset [d]$

  1. Start by fitting a neural network parameterized by $W$ on the dataset with no $\lambda$ penalty

  2. For decreasing $\lambda$ penalties

    a. Apply the "LassoNet" learning algorithm → Proximal Stochastic Gradient Descent with constraints to get optimal $W, \theta$ for the current $\lambda$

    b. observe selected variables $j$, the ones with $\theta_j \neq 0$

All along the path, "warm-start" the neural network trainings with the previous optimal parameters!

</aside>

Understanding Hier-Prox

Proof walkthrough with scribbled notes as support

lassonetproofwalkthrough.pdf