Gradient Descent
Getting there by little steps
What if wanted to convert inches to centimeters, but didn’t know that the formula is inches * 2.54? But what we did have was the following table of belt sizes from the Gap!
Waist Size | Belt Length (in) | Belt Length (cm) |
---|---|---|
28 | 30.5 | 77 |
30 | 32.5 | 83 |
32 | 34.5 | 88 |
34 | 36.5 | 93 |
36 | 38.5 | 98 |
38 | 40.5 | 103 |
40 | 42.5 | 108 |
42 | 44.5 | 113 |
44 | 46.5 | 118 |
46 | 48.5 | 123 |
What we could do is guess the multiplier, and see how wrong it is.
import numpy as np
= np.array([30.5, 32.5, 34.5, 36.5, 38.5, 40.5, 42.5, 44.5, 46.5, 48.5])
belt_in = np.array([77, 83, 88, 93, 98, 103, 108, 113, 118, 123]) belt_cm
= 1.5
multiplier_guess = belt_in * multiplier_guess cm_guess
# If our guess was right, this should all be 0
- belt_cm cm_guess
array([-31.25, -34.25, -36.25, -38.25, -40.25, -42.25, -44.25, -46.25,
-48.25, -50.25])
Our guess wasn’t a great guess. With this multiplier, our guesses are all too small. Let’s describe how bad our guess was with one number, and call it the “loss.” The usual loss function for data like this is the Mean Squared Error.
def mse(actual, guess):
"""
Given the actual target outcomes and the outcomes we guessed,
calculate the mean squared error.
"""
= actual-guess
error = np.power(error, 2)
squared_error = np.mean(squared_error)
mean_squared_error return(mean_squared_error)
mse(belt_cm, cm_guess)
1728.2125
If we made our multiplier guess a little closer to what it ought to be, though, our mean squared error, or loss, should get smaller.
+= 0.2
multiplier_guess = belt_in * multiplier_guess
cm_guess mse(belt_cm, cm_guess)
1128.2125
One thing we could try doing is make a long list of possible multipliers, and try them all to see which one has the smallest loss. This is also known as a “grid search”. I’ll have to re-write the loss function to calculate the loss for specific multipliers
# This gives us 50 evenly spaced numbers between 0 and 50
= np.linspace(start = 0., stop = 5., num = 50)
possible_mults
def mse_loss(multiplier, inches, cm):
"""
given a multiplier, and a set of traning data,
(inches and their equivalent centimeters), return the
mean squared error obtained by using the given multiplier
"""
= inches * multiplier
cm_guess = mse(cm_guess, cm)
loss return(loss)
= np.array([mse_loss(m, belt_in, belt_cm) for m in possible_mults]) losses
It’s probably best to visualize the relationship between the multiplier and the loss in a graph.
If we get the index of the smallest loss and get the associated multiplier, we can see that we’re not too far off!
possible_mults[losses.argmin()]
2.5510204081632653
Why not always just do grid search?
One thing that is going to remain the same no matter how complicated the models get is the measure of how well they’ve done, or the loss, is going to get boiled down to one number. But in real modelling situations, or neural networks, the number of parameters is going to get huge. Here we have only one parameter, but if we had even just 5 parameters, and tried doing a grid search over 50 evenly spaced values of each parameter, the number of possible combinations of parameter values will get intractable.
f"{(5 ** 50):,}"
'88,817,841,970,012,523,233,890,533,447,265,625'
Without seeing the whole map, we can tell which way is the right direction.
Let’s look at the plot of our parameter vs the loss again:
There are a few really important features of this loss function:
- As the estimate gets further away from the ideal value in either direction, the loss increases.
- The increase is “monotonic”, meaning it’s not bumpy or sometime going up, sometimes going down.
- The further away the guess gets from the optimal value, the steeper the “walls” of the curve get.
Let’s say we were just these two point here, and we couldn’t “see” the whole curve, but we knew features 1 through 3 were true. With that in hand, and information about how the loss function is calculated, we can get the slope of the function at each point (indicated by the arrows).
If we were able to to update our parameter in a way that is proportional to the slope of the loss, then we would gradually get closer and closer to the optimal value. The updates would be very large at first, while the parameter values are far away from the optimal value, and then would start updating by smaller and smaller amounts as we home in on the optimal value because the slopes get shallower and shallower the closer we get.
The slope of the loss function at any given point is the gradient, and this process of gradually descending downwards is called gradient descent.
Gradient Descent
“But Joe!” you exclaim, “How do you calculate the slope of the loss for a single point without seeing the whole distribution?”
The answer to that question used to be “with calculus.” But nowadays, people do it with “autograd” or “autodiff”, which basically means “we let the computer figure it out.” There isn’t autograd functionality in numpy, but there is in a closely related library called Jax, which is being developed by Google. Jax has a module called numpy
which is designed to operate exactly the same way as numpy
.
import jax.numpy as jnp
from jax import grad
I’m going to rewrite the inches to centimeter functions over again, this time making sure to use jax functions to ensure everything runs smoothly.
def inch_to_cm_jax(multiplier, inches):
"""
a function that converts inches to cm
"""
= jnp.dot(inches, multiplier)
cm return(cm)
def cm_loss_jax(multiplier, inches, cm):
"""
estimate the mismatch between the
"""
= inch_to_cm_jax(multiplier, inches)
est = est - cm
diff = jnp.power(diff, 2)
sq_err = jnp.mean(sq_err)
mean_sq_err return(mean_sq_err)
Then we pass the new loss function to a jax function called grad()
to create a new gradient function.
= grad(cm_loss_jax, argnums=0) cm_loss_grad_jax
Where cm_loss_jax()
will give use the mean-squared error for a specific multiplier, cm_loss_grad_jax()
will give us the slope for that multiplier, automatically.
print(multiplier_guess)
1.7
# This is the mean-squared-error
print(cm_loss_jax(multiplier_guess, belt_in, belt_cm))
1128.2124
# This is the slope
print(cm_loss_grad_jax(multiplier_guess, belt_in, belt_cm))
-2681.3499
Learning Rates and “Epochs”
Now we can write a for-loop to iteratively update out multiplier guess, changing it just a little bit proportional to the gradient. There are two “hyper parameters” we need to choose here.
- The “learning rate”. We can’t go adding the gradient itself to the multiplier. The gradient right now is in the thousands, and we’re trying to nudge 1.7 to 2.54. So, we pick a “learning rate”, which is just a very small decimal to multiply the gradient by before we add it to the parameter. I’ll say let’s start at 1/100,000
- The number of “epochs.” We need to decide how many for loops we’re going to go through before we decide to call it and check on how the learning has gone. I’ll say let’s go for 1000.
= 1/100_000
learning_rate = 1000 epochs
# I want to be able to plot everything after, so I'm going to create collectors.
= []
epoch_list = []
param_list = []
loss_list = [] gradient_list
= 0.
multiplier_guess for i in range(epochs):
# append the current epoch
epoch_list.append(i)# append the current guess
param_list.append(multiplier_guess)
= cm_loss_jax(multiplier_guess, belt_in, belt_cm)
loss
loss_list.append(loss)= cm_loss_grad_jax(multiplier_guess, belt_in, belt_cm)
gradient
gradient_list.append(gradient)
+= -(gradient * learning_rate)
multiplier_guess
print(f"The final guess was {multiplier_guess:.3f}")
The final guess was 2.541
This will all work with more parameters
Let’s try estimating the body mass of penguins from their bill length again.
import pandas as pd
from palmerpenguins import load_penguins
= load_penguins() penguins
Here, I grab columns for the bill length and body mass as numpy arrays.
= np.array(penguins.dropna()["bill_length_mm"])
bill_length = np.array(penguins.dropna()["body_mass_g"]) body_mass
And then I “normalize” the data, by subtracting the mean and dividing by the standard deviation. Understanding this part isn’t crucial. It’ll just make the parameter estimation go more smoothly.
= (bill_length - bill_length.mean())/bill_length.std()
bill_length_z = (body_mass - body_mass.mean())/body_mass.std() body_mass_z
I’ll set up a prediction matrix with one row for each penguin and colum full of just 1s, and a column with the bill length data.
= np.stack([np.ones(bill_length_z.size), bill_length_z], axis = 1)
bill_length_X 0:10, ] bill_length_X[
array([[ 1. , -0.89604189],
[ 1. , -0.82278787],
[ 1. , -0.67627982],
[ 1. , -1.33556603],
[ 1. , -0.85941488],
[ 1. , -0.9326689 ],
[ 1. , -0.87772838],
[ 1. , -0.52977177],
[ 1. , -0.98760942],
[ 1. , -1.72014965]])
The return of Dot Products
I’ve added this column of 1s so that we can have just one vector of parameters, the first value being the slope and the second being the intercept.
= np.array([2, 4]) fake_param
Now we can do element-wise multiplication for the data for any given penguin…
1, ] * fake_param bill_length_X[
array([ 2. , -3.29115147])
…and then sum it up to get the estimated body mass for the penguin
1,] * fake_param).sum() (bill_length_X[
-1.2911514669828215
A.K.A a Dot Product
1,], fake_param) np.dot(bill_length_X[
-1.2911514669828215
Dot product with the whole matrix
In fact, we can get the estimated body mass for all penguins with just a sigle dot product.
= np.dot(bill_length_X, fake_param)
body_mass_est body_mass_est
array([-1.58416756, -1.29115147, -0.70511928, -3.34226411, -1.43765951,
-1.73067561, -1.51091354, -0.1190871 , -1.95043767, -4.8805986 ,
-3.41551813, -1.87718365, 0.90646922, -5.02710664, 3.47036003,
-2.53646986, -2.60972388, -3.9282963 , -2.24345377, -1.80392963,
-4.36782043, -0.48535721, -0.55861124, -2.46321584, -0.55861124,
-1.29115147, -2.975994 , -1.29115147, -0.26559514, -3.56202618,
-1.51091354, -1.80392963, 0.68670715, -2.6829779 , -1.0713894 ,
-3.48877216, -0.33884917, -3.85504227, 2.07853359, -3.12250204,
-1.21789744, -0.1190871 , -3.85504227, 0.75996118, -1.21789744,
-0.85162733, -4.5875825 , 0.54019911, -4.95385262, 0.10067497,
-1.65742158, -0.48535721, -3.48877216, -2.6829779 , -4.07480434,
0.02742095, -2.6829779 , -0.1190871 , -3.56202618, 0.24718302,
-4.22131239, -0.1190871 , -3.9282963 , 0.39369106, -5.68639285,
-1.14464342, -1.21789744, 3.32385198, -4.22131239, 1.12623129,
-0.26559514, -2.975994 , -3.70853423, 0.61345313, -4.8805986 ,
1.19948532, -3.34226411, -4.51432848, -2.90273997, 0.02742095,
-3.6352802 , -3.19575607, -2.17019974, -1.73067561, -4.07480434,
-0.1190871 , -5.32012273, -1.21789744, -3.70853423, -0.33884917,
-2.31670779, -0.70511928, -5.97940894, 1.41924739, -4.5875825 ,
-0.19234112, -2.60972388, -2.53646986, -2.46321584, -1.14464342,
-1.95043767, -2.24345377, -2.31670779, 1.41924739, -2.31670779,
3.17734394, -1.14464342, 0.68670715, -1.21789744, 1.05297727,
-1.95043767, -2.90273997, -4.07480434, -0.1190871 , -3.70853423,
-2.60972388, -0.77837331, 0.10067497, -4.44107446, -0.48535721,
-1.80392963, 0.17392899, -1.65742158, 2.07853359, -2.0236917 ,
1.34599336, -3.26901009, -2.75623193, -2.31670779, -0.1190871 ,
-4.14805837, -0.77837331, -3.12250204, -1.14464342, -0.77837331,
-0.48535721, -6.71194917, -0.41210319, -2.90273997, -1.65742158,
-1.51091354, -3.41551813, -3.85504227, -2.53646986, -3.85504227,
0.17392899, 3.54361405, 6.40052095, 5.44821865, 6.40052095,
4.6424244 , 3.83663014, 3.03083589, 3.98313819, 1.49250141,
4.05639221, -0.26559514, 5.66798072, 3.10408991, 5.22845658,
3.32385198, 5.88774279, 0.54019911, 5.81448877, 3.61686808,
5.44821865, 6.547029 , 2.81107382, 3.83663014, 3.6901221 ,
1.19948532, 3.54361405, 4.78893244, 5.08194854, 6.40052095,
4.42266233, 1.12623129, 2.81107382, 13.43290716, 5.74123474,
5.22845658, 0.97972325, 2.29829566, 2.00527957, 5.44821865,
1.05297727, 6.10750486, 2.95758187, 6.10750486, 6.76679107,
1.71226348, 3.10408991, 6.76679107, 2.66456578, 2.88432785,
3.90988417, 5.30171061, 2.81107382, 6.47377497, 3.83663014,
2.7378198 , 1.85877152, 3.10408991, 1.41924739, 6.69353704,
2.95758187, 3.61686808, 3.25059796, 9.55044394, 3.32385198,
6.2540129 , 6.03425084, 1.63900945, 6.91329911, 4.71567842,
3.76337612, 5.08194854, 3.83663014, 3.76337612, 5.37496463,
4.56917038, 7.2063152 , 2.88432785, 2.88432785, 5.74123474,
8.23187153, 4.49591635, 6.40052095, 2.66456578, 6.98655314,
1.56575543, 7.35282325, 4.56917038, 7.93885543, 4.56917038,
8.01210946, 3.10408991, 6.03425084, 2.37154968, 6.98655314,
5.96099681, 4.12964624, 5.22845658, 7.2063152 , 5.30171061,
10.72250831, 4.34940831, 5.74123474, 4.05639221, 0.32043704,
8.89115773, 1.49250141, 5.00869451, 6.76679107, 6.2540129 ,
1.63900945, 7.4993313 , 3.61686808, 10.13647613, 5.52147267,
4.34940831, 4.05639221, 6.69353704, 2.88432785, 6.32726693,
3.83663014, 6.40052095, 7.35282325, 3.03083589, 8.37837957,
2.88432785, 3.54361405, 7.35282325, 3.47036003, 7.35282325,
3.90988417, 7.64583934, 4.20290026, 7.86560141, 3.39710601,
6.76679107, 6.62028302, 12.26084279, 3.76337612, 5.81448877,
0.8332152 , 5.30171061, 1.41924739, 6.84004509, 3.98313819,
7.86560141, 6.76679107, 6.03425084, 3.76337612, 8.4516336 ,
-0.26559514, 9.47718992, 0.90646922, 7.13306118, 6.18075888,
4.56917038, 4.6424244 , 7.86560141, 4.12964624, 8.96441176,
5.66798072, 3.61686808, 7.05980716, 3.10408991, 7.05980716,
6.98655314, 6.47377497, 5.66798072, 7.4993313 , 6.2540129 ,
5.00869451, 7.42607727, 3.25059796, 6.91329911, 0.90646922,
8.01210946, 2.88432785, 5.88774279, 6.547029 , 3.17734394,
7.79234739, 4.05639221, 3.25059796, 10.64925429, 1.63900945,
6.10750486, 6.98655314, 6.547029 ])
This is our first foray into matrix multiplication.
Doing linear regression with gradient descent
If we start off with some (bad) guesses for the slope and intercept, we can get the estimated body mass for every penguin:
= np.array([-2., 0.])
param_guess = np.dot(bill_length_X, param_guess) mass_guess
And then we can again get the mean squared error, or the loss, which is a single value describing how bad we’re doing at predicting body mass with this intercept and slope.
- mass_guess, 2)) np.mean(np.power(body_mass_z
5.0
Loss function, now in two dimensions
Just like we plotted the loss as it related to the single multiplier above, we can plot the loss as it relates to these two parameters.
And again, we’ve got a shape with a kind of curvature, and we can update both the intercept and the slope values incrementally with the negative of the slope of the curvature, to gradually arrive close to the best values.
These functions are basically the same as the single parameter case from above:
def fit_mass(params, X):
"""
Given some values and parameters
guess the outcome
"""
= jnp.dot(X, params)
est return(est)
def fit_loss(params, X, actual):
"""
Return the loss of the params
"""
= fit_mass(params, X)
est = est - actual
err = jnp.power(err, 2)
sq_err = jnp.mean(sq_err)
mse return(mse)
= grad(fit_loss, argnums=0) fit_grad
And this is the setup to run the for-loop that gradually updates our parameters.
= []
epoch_list = []
param_list = []
loss_list = []
gradient_list
= np.array([-2., 0])
param_guess = 0.01
learning_rate #vt = np.array([0., 0.])
for i in range(1000):
# append the current epoch
epoch_list.append(i)
param_list.append(param_guess)= fit_loss(param_guess, bill_length_X, body_mass_z)
loss
loss_list.append(loss)= fit_grad(param_guess, bill_length_X, body_mass_z)
gradient
gradient_list.append(gradient)
+= -(gradient * learning_rate)
param_guess
print(f"Final param guess was {param_guess}")
Final param guess was [4.4573799e-08 5.8944976e-01]
Here’s an animation of how the line we treat as the best estimate changes over training epochs
More advanced options still operate on the same principles
In all of the examples we’ve looked at here, we’ve done two things:
- Calculated the loss for all data points all at once.
- Updated the parameters by multiplying the negative gradient by some small “learning parameter” number.
There are more complicated and nuanced ways to go about this process, but they all operate on the same principles.
“Stochastic Gradient Descent”
Sometimes it’s not possible or is too computationally intensive to calculate the loss and its gradient for all data points in one go. There are a few ways of dealing with this, one of which is to chunk the data up unto randomized batches, and get the loss & gradient one batch at a time. This is called “Stochastic Gradient Descent”.
“Optimizers”
There are also a whole array of gradient descent “optimizers”. Some of them gradually change the learning rate parameter. Others introduce the concept of “momentum” into the process. One of the most popular one I see people use when I’m reading blogs from neural network people is called Adam.