Your regression may only need one gradient step. Really.
I’ve been rethinking gradient descent over the weekend. It struck me that calculating the gradient is typically way more expensive than taking the step that follows it. I ran the numbers and found that about 80% of the training loop is spent calculating a gradient.
This led me to some fun hacking and I want to demonstrate the findings in this document. In particular I would like to highlight some ideas that had insightful results;
These results are done on artifical data, so the results deserve to be taken with a grain of salt, but the ideas are intertraining nonetheless.
Suppose that I want to optimise a function, say \(f(x)\). You could calculate the gradient and take a step but after taking this step you’d need to calculate the gradient again.
It would work, but it may take a while. Especially if calculating the gradient is expensive (it usually is). So how about we do some extra work to do a calculated step instead.
A calculated step might be more expensive to calculate, but this is offset by all the small steps we would otherwise do.
This is where calculus can help us. In particular, taylor series! Suppose that we have some value \(x\) and we’d like to estimate what \(f(x + t)\) is then we can approximate this by;
\[f\left(x+t\right) \approx f\left(x\right)+f^{\prime}\left(x\right) t+\frac{1}{2} f^{\prime \prime}\left(x\right) t^{2}\]
If I know the derivatives of the function \(f\) then I can approximate what \(f(x+t)\) might be. But we can go a step further. Suppose now that I am interested in finding a minimum for this function. Then we can rewrite the expression to represent iteration.
\[f\left(x_{k}+t_k\right) \approx f\left(x_{k}\right)+f^{\prime}\left(x_{k}\right) t_k+\frac{1}{2} f^{\prime \prime}\left(x_{k}\right) t_k^{2}\]
Here \(x_k\) represents \(x\) at iteration time \(k\) and the next value \(x_{k+1}\) will be \(x_{k} + t_k\). The question now becomes; how can I choose \(t_k\) such that travel to the minimum as fast as possible? It turns out to be;
\[ x_{k+1}=x_{k}+t_k=x_{k}-\frac{f^{\prime}\left(x_{k}\right)}{f^{\prime \prime}\left(x_{k}\right)} \]
This formula can be used for functions with a single parameter, but with some linear algebra tricks we can also extend it to functions with many inputs using a hessian matrix.
If the second derivative is positive, the quadratic approximation is a convex function of \(t\), and its minimum can be found by setting the derivative to zero. Note that; \[ 0=\frac{\mathrm{d}}{\mathrm{d} t_k}\left(f\left(x_{k}\right)+f^{\prime}\left(x_{k}\right) t_k+\frac{1}{2} f^{\prime \prime}\left(x_{k}\right) t_k^{2}\right)=f^{\prime}\left(x_{k}\right)+f^{\prime \prime}\left(x_{k}\right)t_k\]
Thus the minimum is achieved for \(t_k=-\frac{f^{\prime}\left(x_{k}\right)}{f^{\prime \prime}\left(x_{k}\right)}\)
Putting everything together, Newton’s method performs the iteration;
\[ x_{k+1}=x_{k}+t=x_{k}-\frac{f^{\prime}\left(x_{k}\right)}{f^{\prime \prime}\left(x_{k}\right)} \] Now, this formula works for single parameter functions, but we can also express this result in linear algebra terms.
\[ x_{k+1}=x_{k}+t=x_{k}- [f^{\prime \prime}\left(x_{k}\right)]^{-1} f^{\prime}\left(x_{k}\right) \]
Here \(x_k\) is a vector, \([f^{\prime \prime}\left(x_{k}\right)]^{-1}\) is the Hessian matrix and \(f^{\prime}\left(x_{k}\right)\) is the gradient vector.This, to me, was a great excuse to play with jax. It has a couple of like-able features but a main one is that it is an autograd library that also features a hessian. You can just-in-time compile derivate functions and it will also run on GPU’s and TPU’s.
This is how you might implement a linear regression;
import jax.numpy as np
from jax import grad, jit
def predict(params, inputs):
return inputs @ params
def mse(params, inputs, targets):
= predict(params, inputs)
preds return np.mean((preds - targets)**2)
= jit(grad(mse)) # compiled gradient evaluation function grad_fun
The grad_fun
is now a compiled gradient function that has two input parameters left; inputs
and targets
and it will return the gradiet of the mse
function. That means that I can use it in a learning loop. So here’s an implementation of linear regression;
import tqdm
import numpy as np
import matplotlib.pylab as plt
# generate random regression data
= 1_000_000, 10
n, k = [np.ones((n, 1)), np.random.normal(0, 1, (n, k))]
both = np.concatenate(both, axis=1)
X = np.random.normal(0, 5, (k + 1,))
true_w = X @ true_w
y 42)
np.random.seed(= np.random.normal(0, 1, (k + 1,))
W
= 0.02
stepsize = 100
n_step = np.zeros((n_step,))
hist_gd for i in tqdm.tqdm(range(n_step)):
= mse(W, inputs=X, targets=y)
hist_gd[i] = grad_fun(W, inputs=X, targets=y)
dW -= dW*stepsize W
This is what the mean squared error looks like over the epochs.
Let’s now do the same thing, but use the calculus trick.
from jax import hessian
# use same data, but reset the found weights
42)
np.random.seed(= np.random.normal(0, 1, (k + 1,))
W
= 100
n_step = np.zeros((n_step,))
hist_hess for i in tqdm.tqdm(range(n_step)):
= mse(W, inputs=X, targets=y)
hist_hess[i] = np.linalg.inv(hessian(mse)(W, X, y))
inv_hessian = inv_hessian @ grad_fun(W, inputs=X, targets=y)
dW -= dW W
Want to see something cool? This is the new result.
The shocking thing is that this graph always has the same shape, no matter the rows or columns. When I first ran this I could barely believe it. By using the hessian trick we predict how big of a step we need to make and it hits bullseye.
There’s reason for this bullseye but it is a bit mathematical.
Let’s rewrite the loss for linear regression in matrix terms.
\[ L(\beta) = (y - X \beta)^T (y - X \beta) \] If we simply differentiate then the gradient vector is;
\[ \nabla L (\beta) = - X^T (y - X \beta) \] And the Hessian matrix is;
\[ \nabla^2 L (\beta) = X^T X \] Let’s remind ourselves of newtons method.
\[ x_{k+1}=x_{k}+t=x_{k}-\frac{f^{\prime}\left(x_{k}\right)}{f^{\prime \prime}\left(x_{k}\right)} \]
That means that our stepsize (see earlier derivation) needs to be;
\[ \begin{equation} \begin{split} t & = -[\nabla^2 L (\beta)]^{-1} \nabla L (\beta) \\ & = -(X^TX)^{-1}X^T(y - X \beta) \\ & = -(X^TX)^{-1}X^Ty + (X^TX)^{-1}X^TX \beta \\ & = \beta -(X^TX)^{-1}X^Ty \end{split} \end{equation} \] When we start our regression we start with \(\beta_k\) and then the update rule becomes;
\[ \begin{equation} \begin{split} \beta_{k+1} & = \beta_k - t\\ & = \beta_k - \beta_k + (X^TX)^{-1}X^Ty \\ & = (X^TX)^{-1}X^Ty \end{split} \end{equation} \] And this is a bit of a coincidence, but \((X^TX)^{-1}X^Ty\) is the closed form solution for linear regression. This means that using newtons method for a single iteration on standard linear regression is equivalent to using the close form method.
This does not mean that this is the fastest way to perform linear regression. You can benchmark it yourself, scikit-learn is faster.We should not expect something similar to happen with neural networks though.
This made me wonder, can we do something similar in spirit for neural networks? Well, maybe we should go for the other extreme. Instead of doing few steps, let’s do many!
Consider what we usually do.
If we briefly ignore the details of adam/momentum then the gradient descent idea does the two things calculating a gradient (thick dot) and moving a step (line). But what if we don’t stop stepping?
A dash here represents moving forward without re-evaluating the gradient. Once we notice that a step makes the score worse, we stop and check the gradient again.
It is well possible that the general direction that you’re moving in is a good one, do you really need to stop moving? Do we really need to calculate a gradient? Or can we just keep on stepping? Checking if the next step is making it worse is a forward pass, not a backward one. If a gradient calculation is about 80% of the compute power for training then this might be a neat idea.
There are two hyperparameters to this idea;
n
steps we check for another gradientTo check the merits of this idea I figured it’d be fun to write my own optimiser for pytorch.
KeepStepping
Optimizer
import torch
from torch.optim.optimizer import Optimizer, required
class KeepStepping(Optimizer):
"""
KeepStepping - PyTorch Optimizer
Inputs:
lr = learning rate, ie. the minimum stepsize
max_steps = the maximum number of steps that will be performed
before calculating another gradient
scale_i = to what degree do we scale our impatience
"""
def __init__(self, params, lr=required, max_steps=20, scale_i=0):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
= dict(lr=lr)
defaults super().__init__(params, defaults)
self.max_steps = max_steps
self.lr_orig = lr
self.scale_i = scale_i
def mini_step(self, i):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
= p.grad.data
d_p = - group['lr'] * self.scale_i * np.sqrt(i)
scale -group['lr'] - scale, d_p)
p.data.add_(
def step(self, closure):
"""Performs a single optimization step."""
= closure()
old_loss = 0
i self.mini_step(i)
= closure()
new_loss while (new_loss < old_loss) & (i < self.max_steps):
self.mini_step(i)
= new_loss
old_loss = closure()
new_loss += 1
i return new_loss
Before testing this, I consdering taking the previous idea and combining it with the idea before. I am doing less gradients, sure, but I am still taking lots of steps. Can I instead perhaps calculate how big the stepsize should be? There might be something adaptive that we can do here?
Given a direction that we’re supposed to move in, you could consider that we’re back in a one dimensional domain again and that we merely need to find the right stepsize.
So I made an implementation that takes this direction, numerical estimates of \(f'(x_{\text{direction}})\) and \(f''(x_{\text{direction}})\) and tries to adaptively estimate an appropriate stepsize.
KeepVaulting
Optimizer
import torch
from torch.optim.optimizer import Optimizer, required
class KeepVaulting(Optimizer):
"""
KeepVaulting - PyTorch Optimizer
Inputs:
lr = learning rate, ie. the minimum stepsize
max_steps = the maximum number of steps that will be
performed before calculating another gradient
"""
def __init__(self, params, lr=required, max_steps=20):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
= dict(lr=lr)
defaults super().__init__(params, defaults)
self.max_steps = max_steps
self.lr_orig = lr
def mini_step(self, jumpsize=1):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
= p.grad.data
d_p -(group['lr'] * float(jumpsize)), d_p)
p.data.add_(
def step(self, closure):
"""Performs a single optimization step."""
= closure()
old_loss = 0
i self.mini_step()
= closure()
new_loss = [old_loss.item(), new_loss.item()]
losses while (new_loss < old_loss) & (i < self.max_steps):
# we're using the secant method here to
# approximate the second order derivative
# http://acme.byu.edu/wp-content/uploads/2019/08/1dOptimization19.pdf
= (losses[-1] - losses[-2])/self.lr_orig
first_order_grad1 = (losses[-1] - first_order_grad1)/self.lr_orig
second_order_grad = -second_order_grad/first_order_grad1 * self.lr_orig
stepsize self.mini_step(stepsize)
= new_loss
old_loss = closure()
new_loss
losses.append(new_loss.item())+= 1
i return new_loss
A meaningful benchmark was hard to come up with so I just generated an artificial regression task with some deep layers. I don’t want to suggest that the following counts as “general performance” but they are interesting to think about. I’ll list a few results below.
generate_new_dataset
def generate_new_dataset(n_row, dim_in, dim_hidden, n_layers):
0)
torch.manual_seed(= torch.randn(n_row, dim_in)
x = x.sum(axis=1).reshape(-1, 1)
y
= torch.nn.Sequential(
model
torch.nn.Linear(dim_in, dim_hidden),*[torch.nn.Linear(dim_hidden, dim_hidden) for _ in range(n_layers)],
1),
torch.nn.Linear(dim_hidden,
)
= torch.nn.MSELoss(reduction='mean')
loss_fn
def loss_closure():
= model(x)
y_pred return loss_fn(y_pred, y)
return model, loss_fn, loss_closure, x, y
= {}
results = 1e-3
learning_rate = {
optimisers 'KS_50_0': lambda p: KeepStepping(p, lr=learning_rate, max_steps=50, scale_i=0),
'KS_50_2': lambda p: KeepStepping(p, lr=learning_rate, max_steps=50, scale_i=2),
'KS_10_0': lambda p: KeepStepping(p, lr=learning_rate, max_steps=10, scale_i=0),
'KS_10_2': lambda p: KeepStepping(p, lr=learning_rate, max_steps=10, scale_i=2),
'KV_10': lambda p: KeepVaulting(p, lr=learning_rate, max_steps=10),
'KV_50': lambda p: KeepVaulting(p, lr=learning_rate, max_steps=50),
'SGD': lambda p: torch.optim.SGD(p, lr=learning_rate),
'ADAM': lambda p: torch.optim.Adam(p, lr=learning_rate),
}
for name, alg in optimisers.items():
= generate_new_dataset()
model, loss_fn, loss_closure, x, y = 1000 if not 'K' in name else 100
n_steps = np.zeros((n_steps, 2))
results[name]
= alg(model.parameters())
optimizer
= time()
tic for t in tqdm.tqdm(range(n_steps)):
= model(x)
y_pred = loss_fn(y_pred, y)
loss
optimizer.zero_grad()
loss.backward()
optimizer.step(loss_closure)= [loss.item(), time() - tic]
results[name][t, :]
=(16, 4))
plt.figure(figsizefor name, hist in results.items():
= hist[:, 0], hist[:, 1]
score, times =name)
plt.plot(times, score, label"time (s)")
plt.xlabel("mean squared error")
plt.ylabel(; plt.legend()
It seems the impatient approach with a max step of 50 is beating Adam when it comes to convergence speed. The idea of vaulting is not performing as well as I had hoped but this may be due to numerical sensitivity. The idea seems to have merit to it but this work should not be seen as a proper benchmark.
Also … that last run is just a linear regression. The fastest way to optimise that is to use the hessian trick that we started with.
So what does all of this mean?
Well … it suggests that there may be a valid trade off between doing more work such that you need to do less gradient evaluations. Either you spend more time preparing a step such that you reduce the number of steps needed (like the hessian approach for the linear regression) or you just do a whole lot more of them without checking the gradient all the time.
I’d be interested in hearing stories from folks who benchmark this, so feel free to try it out and let me know if it does (or does not) work on your dataset.
For attribution, please cite this work as
Warmerdam (2020, April 10). koaning.io: More Descent, Less Gradient. Retrieved from https://koaning.io/posts/more-descent-less-gradient/
BibTeX citation
@misc{warmerdam2020more, author = {Warmerdam, Vincent}, title = {koaning.io: More Descent, Less Gradient}, url = {https://koaning.io/posts/more-descent-less-gradient/}, year = {2020} }