Skip to yearly menu bar Skip to main content


Plenary speaker
in
Workshop: OPT 2023: Optimization for Machine Learning

Provable Feature Learning in Gradient Descent, Jason Lee

Jason Lee


Abstract: **Abstract:** We focus on the task of learning a single index model $\sigma(w* x)$ with respect to the isotropic Gaussian distribution in d dimensions, including the special case when $\sigma$ is a kth order hermite which corresponds to the Gaussian analog of parity learning. Prior work has shown that the sample complexity of learning w* is governed by the *information exponent* k* of the link function \sigma, which is defined as the index of the first nonzero Hermite coefficient of $\sigma$. Prior upper bounds have shown that n > d^{k*-1} samples suffice for learning w* and that this is tight for online SGD (Ben Arous et al., 2020). However, the CSQ lower bound for gradient based methods only shows that n > d^{k*/2}$ samples are necessary. In this work, we close the gap between the upper and lower bounds by showing that online SGD on a smoothed loss learns w* with n > d^{k*/2}$ samples. Next, we turn to the problem of learning multi index models f(x) = g(Ux), where U encodes a latent representation of low dimension. Significant prior work has established that neural networks trained by gradient descent behave like kernel methods, despite significantly worse empirical performance of kernel methods. However, in this work we demonstrate that for this large class of functions that there is a large gap between kernel methods and gradient descent on a two-layer neural network, by showing that gradient descent learns representations relevant to the target task. We also demonstrate that these representations allow for efficient transfer learning, which is impossible in the kernel regime. Specifically, we consider the problem of learning polynomials which depend on only a few relevant directions, i.e. of the form f*(x)=g(Ux) where U is d by r. When the degree of f* is p, it is known that n≍dp samples are necessary to learn f* in the kernel regime. Our primary result is that gradient descent learns a representation of the data which depends only on the directions relevant to f*. This results in an improved sample complexity of n≍d^2r+drp. Furthermore, in a transfer learning setup where the data distributions in the source and target domain share the same representation U but have different polynomial heads we show that a popular heuristic for transfer learning has a target sample complexity independent of d.

Chat is not available.