Regularization and the Bias-Variance Tradeoff in Machine Learning: Part 1
When we fit a supervised learning algorithm, we have to strike a balance between how well we learn the structure in our training data and how well our algorithm can generalize to data it has never seen before. This can be seen as a tradeoff between error due to model bias and error due to model variance which is also often discussed in terms of over- and under-fitting our data.
A highly biased model is one which has not learned the structure of the training data. This is generally referred to as underfitting, meaning that our model is not adequately capturing the relationship between our predictors and our target. Underfitting is usually diagnosable by high training error.
Let’s say we have a (massively over-)simplified regression problem: we want to predict a person’s salary based on their age. We collect some data and it looks something like this:
This is, of course, a situation where we can pretty well eyeball that the relationship between age and income can be described by a curve: it starts fairly low, peaks somewhere around middle age, and then starts to decline as people in our sample exit the workforce.
We might start out by training a very naive “model”: just guess the average salary for everyone! The plot below shows the same training data, with the results of our naive regression line in red.
Obviously this model has captured very little information about the true relationship; it’s a classic case of underfitting. It’s worth noting that it can sometimes be a nice idea to fit this kind of naive model as a baseline against which to compare your more sophisticated techniques: if they don’t beat the underfitting model, that’s often a red flag!
Clearly a biased model is doing us very little good. However, a model with high variance error can be equally problematic. Variance in this case refers to how sensitive our model is to slight permutations in the training data. It would be very easy to just connect all the dots in our plot above (we don’t even need a “learning” algorithm to do that), and reduce our training error to zero!
But what happens when we show our model some new data? Unless our training data is a perfect encapsulation of all the information we’ll ever need (in which case, why train a model at all?) our zero-bias model is going to choke as soon as it sees something it’s never seen before.
When an observation comes along over to the far right or far left of this plot, does it seem likely that our line will accurately extrapolate to the right target value? What about the next 79 year old we try to predict salary for? That one rough retiree could wreak havoc with our model.
A model with variance problems is usually diagnosable when assessing performance on a validation data set; that is, a piece of our data that is not presented to the model at training time but upon which model performance is assessed while tuning hyperparameters. If training error is substantially lower than validation error, we’ve got an over-fitting problem and need to reduce model error due to variance.
So What Do I Do??
We’ve seen some examples of the extremes ends of error due to bias and error due to variance. The former results in a model that cannot learn any patterns from our data, and the latter results in a model that cannot generalize to new data effectively. There is almost always a balance to be struck between these two error sources in order to build a model that will perform well when deployed. Regularization is a powerful tool for controlling this tradeoff.
Regularization is a general term for a number of methods that allow us to adjust for models that are displaying symptoms of over- or under- fitting. Regularization takes many flavors depending on the algorithm in question, but the common thread is that regularization introduces a small amount of bias in exchange for a reduction in variance.
Put another way, regularization is a way of telling our model to do its best to fit our training data under the constraint that it should learn the simplest possible model. Why is simplicity good? Looking at our second plot above, it’s clear that the equation for drawing that red line is pretty gnarly. If we let our model get too complex without reigning it in, we’ll get something like that. If, however, we put some guardrails on complexity while still encouraging the lowest bias possible we can (hopefully) get something like this:
Notice that this curve isn’t perfect: there are certain training cases for which its predictions are way off, and there’s a degree to which the modeling technique and tuning approach could even better improve our in-sample performance. But this model does a good job of capturing the overall signal in our data, which means when we show it new data it has a much better shot of hitting the mark.
Ultimately that’s our goal any time we train a supervised machine learning algorithm: capture enough of the structure in our training data to learn something meaningful, but not so much of the additional noise that we can’t generalize to new data. Regularization is a powerful tool to control this tradeoff and optimize model performance. Next time, we’ll get deeper into how regularization actually works in a variety of supervised learning algorithms.