GD’s Implicit Bias on Separable Data

TLDR: This post goes through the intuition behind why using gradient descent to fit a linear model to linearly separable data will learn a maximum-margin decision boundary. It’s a clean example of an implicit bias of gradient descent, and also extends (in some form) to more complicated settings (like homogeneous deep networks).

Often, we use neural networks to solve optimization problems where there are many different solutions which minimize the training objective. In these cases, the particular minima we learn (or approach) is a consequence of how we go about finding a minimum; this is known as an implicit bias of our optimization algorithm. Since different global minima behave very differently outside of the training set, these implicit biases can have major effects on how our models generalize. Lots about implicit biases of different optimization algorithms still aren’t well understood; indeed, understanding these biases, and their corollary of consistently finding surprisingly well-generalizing solutions is a fundamental open problem in deep learning theory.

This post will talk about a particular setting where we do understand how this implicit bias is working. The setting is interesting not just as a toy case to develop intuition, but also because it seems like similar versions of these results may hold in much more complicated settings. It’s also useful because it forces us to confront that inductive biases are playing a role in learning; thinking only in terms of loss-minimization doesn’t explain the observed behavior.

Lots of the time I find discussion of implicit bias to be unsatisfying or confused; understanding this example well will, I hope, clear things up. Most of this math is based on “The Implicit Bias of Gradient Descent on Separable Data” by Soudry et al.

The Setting

Consider a classification data set {xn,yn}n=1N\{\vec{x}_n, y_n\}^N_{n=1}, with d-dimensional real number inputs xnRd\vec{x}_n \in \mathbb{R}^d and binary labels yn{1,1}.y_n \in \{-1, 1\}. We use gradient descent to find a vector w\vec{w} which minimizes L(w)\mathcal{L}(\vec{w}), defined by:

L(w)=n=1N(yn(wxn))(u)=eu\begin{align*} \mathcal{L}(\vec{w}) &= \sum_{n=1}^N \ell(y_n (\vec{w} \cdot\vec{x_n})) \\ \ell(u) &= e^{-u} \end{align*}

Our loss encourages high values of wxn\vec{w}\cdot\vec{x}_n which match the sign of yny_n. We classify points according to the sign of wxn\vec{w}\cdot \vec{x}_n. We assume our data is separable (that is, w\exists \vec{w}_* such that xn,\forall \vec{x}_n,  yn(wxn)>0y_n(\vec{w}_* \cdot \vec{x}_n) > 0).

Exponential loss, (u)=eu\ell(u) = e^{-u}.

With these conditions, we know that the infimum of the loss is zero, but that this can’t be achieved by any finite w\vec{w}. We use gradient descent to minimize L(w)\mathcal{L}(\vec{w}), with updates of the form:

wt+1=wtηL(wt)\vec{w}_{t+1} = \vec{w}_t - \eta \cdot \nabla \mathcal{L}(\vec{w}_t)

With a sufficiently small η\eta, we can prove that gradient descent will converge to a global minimum as tt \to\infty. We make no assumptions about the initialization of our weight vector w\vec{w}.

The Task

In this setting, approaching zero loss requires the norm of w\vec{w} to diverge to \infty (i.e., w||\vec{w}|| \to \infty). However, since the sign of wxn\vec{w} \cdot \vec{x}_n determines its classification, blowing up the norm of w\vec{w} has no effect on the functional behavior of our model (and thus can’t effect things like the generalization behavior we care about).

On the other hand, the direction of our weight vector does determine classification, and thus is relevant to generalization behavior. Accordingly, we want to characterize the behavior of wtwt\frac{\vec{w}_t}{||\vec{w}_t||} as tt \to \infty. This distinction is important, and has interesting consequences like the potential for test loss to increase even as our classifier gets more accurate (we’ll talk about this in ‘Extensions’).

It’s worth focusing on the fact that a w\vec{w} which separates the data doesn’t need to change its direction to approach a global minimum. Scaling its magnitude is enough; accordingly, we could imagine it being the case that ww\frac{\vec{w} }{||\vec{w}||} doesn’t converge to anything in particular (and just sticks with the first separating solution it comes to while scaling up its norm). This isn’t what happens; instead, ww\frac{\vec{w}}{||w||} does indeed converge. I focus on this to reinforce that this result isn’t obvious, and isn’t what we come to by only reasoning about somehow getting to a global minimum.

2D Intuitive Model

There’s a really natural way to think about our setup with two dimensional points. Say we have a set of 2d points {xn=(an,bn)}n=1N\{\vec{x}_n = (a_n, b_n)\}_{n=1}^N with corresponding labels yn{1,1}y_n \in \{-1, 1\}.

Now say w0=[01]\vec{w}_0 = \begin{bmatrix} 0 & 1 \end{bmatrix}. In this case, our decision boundary is the x-axis, since wxn=0    bn=0\vec{w} \cdot \vec{x_n} = 0 \implies b_n = 0. Here, yn(w0xn)y_n(\vec{w}_0 \cdot \vec{x_n}) measures the distance between a point and the x-axis, made negative for points on the wrong side of the decision boundary. This value is then scaled according to our exponential loss, and summed to get our L(w0)\mathcal{L}(\vec{w}_0).

What’s cool is that we can use this same intuition for all weight vectors w\vec{w}, since we can always decompose w\vec{w} into a rotation (R\mathcal{R}) and scaling (s\mathcal{s}) applied to w0\vec{w}_0. And since wxn=(sRw0)xn=w0(sRTxn)\vec{w} \cdot \vec{x}_n = (\mathcal{sR}\vec{w}_0) \cdot \vec{x}_n = \vec{w}_0 \cdot (\mathcal{sR}^T\vec{x}_n), we can equivalently think about new values w\vec{w} as performing a rotation and scaling to our data points, and then use our w0=[01]\vec{w}_0 = \begin{bmatrix} 0 & 1 \end{bmatrix} value to score them as a function of their distance from the x-axis (visualized below).

Values of w\vec{w} which separate the data correspond to rotations which bring all points with a positive label above the x-axis, and all points with a negative label below the x axis. Scaling doesn’t change the classification, but it does change what’s plugged into (u)\ell(u), and thus our loss L(w)\mathcal{L}(\vec{w}).

Here, any separating value of w\vec{w} would give you zero loss if its norm was scaled to infinity; that is, every separating rotation corresponds to an (unreachable) global minimum achieved by diverging our scale s\mathcal{s} to infinity. The question of implicit bias, here, is the question of which separating rotation we get among the uncountably infinite number of separating rotations (which all can tend towards global minima via scaling).

The Solution

So, what happens when we use gradient descent? How do we select between separating solutions?


The key insight is that since L(w)=n=1Nynexp(yn(wxn))xn-\nabla \mathcal{L}(\vec{w})=\sum_{n=1}^N y_n\exp \left(-y_n(\vec{w}\cdot\vec{x}_n) \right)\vec{x}_n, as the magnitude of w\vec{w} diverges to infinity (as tt \to \infty), only the terms with the largest (least negative) exponents will meaningfully contribute to the gradient. These are the terms with the smallest margin (wxn\vec{w} \cdot \vec{x}_n), which are the points closest to the decision boundary. The gradient, then, will be dominated by the closest points, scaled according to ynexp(ynwxn)y_n \exp(-y_n\vec{w} \cdot \vec{x}_n). This means that (as the magnitude of w\vec{w} diverges to infinity), our gradient L(w)-\nabla \mathcal{L}(\vec{w}) will become a linear combination of the closest points (support vectors) only.

At this point, I think it becomes intuitive that we would converge to the decision boundary preferred by these closest points, since these closest points have total control over our updates. The decision boundary preferred by these closest points, arg maxw^(minnyn(w^xn))\argmax_{\hat{w}}(\min_n y_n(\hat{w} \cdot \vec{x}_n )), is the max-margin solution.

In our 2D framing, this means selecting the rotation (without scaling) which maximizes the minimum height of the points (and still correctly classifies). This is different from other plausible answers, like minimizing the sum of distances from the x-axis across all points.

One thing that’s cool is that this turns out to be the same as the direction of the w^\hat{w} with the smallest 2\ell_2 norm which maps all points to at least distance one from the decision boundary (on the correct side):

w=arg minww2 s.t. yn(wxn)1\begin{equation} \vec{w}^* = \argmin_{\vec{w}} ||\vec{w}||^2 \text{ s.t. } y_n(\vec{w}\cdot \vec{x}_n) \geq 1 \end{equation}

Again, this is intuitive in our 2D model; as we rotate and then attempt to scale down as much as possible, we are stopped by the points closest to the x-axis. The rotation that allows us the most scaling is the rotation which initiates these closest points furthest from the x-axis, which is the max margin solution.

Our down-scaling is blocked by the closest points intersecting with y=±1y=\pm 1.

It’s important to note that we only get to this solution because of the exponential tail of our loss function, since this is what made all but the support vectors’ gradients not matter.


To be (a bit) more formal, though, we can say that if ww\frac{\vec{w}}{||\vec{w}||} converges to some value w\vec{w}_\infty, then this w\vec{w}_{\infty} must itself be dominated by a linear combination of its support vectors; the part of w\vec{w}_\infty which does not come from these support vector gradient updates (i.e., the initial conditions) is negligible (since its norm tends to infinity). This w\vec{w}_{\infty} is proportional to w^=wminn(wxn)\hat{w} = \frac{\vec{w}_{\infty} }{\text{min}_n(\vec{w} \cdot \vec{x}_n)}, which has the properties:

w^=n=1Nαnynxn,n(αn0 and yn(w^xn)=1) OR (αn=0 and yn(w^xn)>1)\hat{w}=\sum_{n=1}^{N} \alpha_{n} y_n\vec{x}_{n}, \\ \forall n\left(\alpha_{n} \geq 0 \text { and } y_n(\hat{w} \cdot \vec{x}_n)=1\right) \text{ OR } \left(\alpha_{n}=0 \text { and } y_n(\hat{w} \cdot \vec{x}_{n})>1\right)

That is, w^\hat{w} is a combination of points distance one from the decision boundary (with the scalar’s sign corresponding to the point’s class), and all other points are further from the decision boundary. These turn out to be the KKT conditions for eq. 1, meaning w^\hat{w} is its solution. Since w\vec{w}_{\infty} is proportional to w^\hat{w}, we have that ww\frac{\vec{w}}{||\vec{w}||} does indeed converge in direction to the max-margin solution.

Extensions

This result has been extended to a broader class of loss functions (including multi-class classification with cross-entropy loss). It’s also been extended to using stochastic gradient descent, and to positive homogeneous deep networks like ReLU MLPs without bias terms. Soudry et al. expected this to be true; the below figure is taken from their paper, and shows a convolutional neural network trained on CIFAR10 using SGD.

This figure initially looks weird, since the validation loss goes up as we train past separation, even as the validation accuracy improves. In light of these results, though, this is intuitive; though we are converging towards the max margin direction (which should help accuracy), we also blow up our weight norm (which amplifies any mistakes we make). Sometimes, though, we mistakenly think this increase in validation loss suggests our model is generalizing worse and worse, and decide to stop training accordingly (whereas using validation accuracy would correctly lead us to keep training).

What about adding a bias term to our original setup, learning according to (yn(wxn+b))\ell(y_n(\vec{w} \cdot \vec{x}_n + b))? One way to do this is to appending a one to each input, forming xn=[xn,1...xn,d1]\vec{x}_n' = \begin{bmatrix} x_{n, 1} & ... &x_{n, d} & 1\end{bmatrix}. Of course, our same results hold for xn\vec{x}_n', meaning we get the max margin solution (but in the new input space). In our 2D case, we can think of this as converting all of our points into 3D points with z=1z=1, and then allowing 3D rotations instead of just 2D ones. We then get the max margin solution in this 3D space.

Other Optimizers

Momentum, acceleration, and stochasticity all don’t change this implicit bias. Interestingly, though, it turns out that adaptive optimizers like Adam and AdaGrad don’t converge to the 2\ell_2 max margin predictor, and instead the limit direction can depend on the initial point and step size (though they do converge to zero loss). This might be part of the reason that these optimizers are thought to find solutions which don’t generalize as well as SGD.

Conclusion

If we understood what sorts of solutions gradient descent converged to, and what generalization properties those solutions had, we could feel more confident in predicting traits about our learned model. In particular, if we knew our model would converge to the max-margin solution of the training data, and we knew that max-margin solutions tend to generalize in particularly nice ways, we would expect our model to also generalize in nice ways (in the limit). I think this has ramifications relevant to AI safety (more info in a subsequent post), and is part of the reason I think studying the science of deep learning is useful.

Thanks to Sam Marks, Davis Brown, Hannah Erlebach, Max Nadeau, Dmitrii Krasheninnikov, and Lauro Langosco for thoughtful comments.