Neural Networks: Overfitting and Regularization
Congratulations, you made a neural network! Now you can train it and use it to classify stuff. If you used a popular course you likely made one that classifies handwritten digits as numbers, maybe you made one that determines if a picture is a cat or dog, real useful stuff. You probably even added some extra stuff to help it train better, maybe data normalization or shuffling, maybe improvements to gradient descent, but that’s not what we’re going to be talking about here, because it may not be helping you towards the goals of your network. But that’s the goal right, to train faster or with less error? those things all help that, yes, but they reduce training error. Reducing training error is good, but only as long as it’s also reducing error when you run the network on new data.
Before we get into the details we should go over some stuff from a broad perspective. A neural network takes in data (i.e. a handwritten digit or a picture that may or may not be a cat) and produces some prediction about that data (i.e. what number the digit is or if the picture is indeed a cat). In order to make accurate prediction you must train the network. Training is done by taking in already classified data, called training data, (i.e. a handwritten digit and what digit it is, or a picture labelled as either cat or not cat) and then with the network and the training data run an iterative process of forward and backward propagation. Forward propagation takes the training data and makes a prediction based on the network so far, just like the network should be doing when fully trained. Then back propagation updates the network’s variables so that the the prediction is more accurate. The first iteration of this process with the training data might have a lot of error (lots of dogs getting called cats and cats getting called dogs). That is the training error. As we continue training, and iterate more, the training error should shrink and hopefully get very low. The network “learns” the training data and can correctly identify the data with high accuracy.
At that point your network is trained and you can use it on real data. But how well is it doing with real data? Well if you have more data that comes pre labelled you can use it as validation data, it will be similar to the training data set, just new to the network. Feed the network your validation data and it will produce predictions and because it comes labelled you can calculate the error. That’s the validation error. This is what we really want to minimize: the error the network has when making predictions about data that it has never seen before. Often, even when the training error is really low, the validation error in practice can be much higher. That discrepancy can be caused by overfitting a big problem in neural networks that should be addressed.
Overfitting in statistics is “the production of an analysis that corresponds too closely or exactly to a particular set of data, and may therefore fail to fit additional data”. Another way to look at it is that it gives too much importance to the noise and outliers specific to the training data that skews the results when used in practice. It helps to have a visual representation of overfitting. Even a relatively “simple” network like to predict handwritten numbers the data will have hundreds of variables, but, it’s easier to visualize with two.
The red and green dots are data points in the training set, for instance pictures of animals, red are dogs and green are cats. when you train the neural network places the blue line to make a prediction. If the data point is below or left of the blue line it’s identified as a cat, it it’s above or right of the line it’s predicted as a dog. There are two data points that are one the “wrong” side of the line. But they may be outliers, for instance a cat that looks like a dog. Or it might just be noise, like a blurry picture. If that picture were the training data we would have pretty low error with only two points misidentified. Regardless of the reason if we train more we can reduce that training error even farther.
Here’s is a classic example of overfitting. We captured those last two data points but at the cost of much more complication in the blue line (compare it to the one above). Having a more complicated prediction line isn’t necessarily a bad thing, when dealing with complicated data like pictures you may expect it to be, but it often means that too much value is being placed on irregular data such as outliers. This model, Though it has no training error, would have more validation error when you try to actually use the predictions. It fits the training data well but won’t generalize as well.
So your network is overfitting, you’re getting more error than you want, how can you reduce that? We’ll look at several different methods each have pros and cons, but each would help your model predict better. The most basic way is not usually an option, it’s to make your training set bigger. Get more pictures of labelled data. It can definitely help eliminate the problems that arise from overfitting, but it’s not usually feasible. If you don’t have more pictures ready to go it can be hard or expensive to get more, especially if that doesn’t already exist. So we really want to focus on things we can do with just changes to our model.
Well if we can’t get more data we can pretend to. A simple method to simulate more data is Data Augmentation. Data augmentation is done by supplementing your data with more points, but instead of getting whole new data points, just manipulate the training data you’re already using. For instance you can multiply the size of your set by four if you include copies of each picture rotated 90 degrees. This obviously isn’t as good if you had new data, but it helps reduce the impact of noise in the data which in practice reduces over fitting. Other data augmentation options could be to scale the picture differently or move the picture left or right within a frame. There are many other options depending on the data, but it depends on what that data is so it doesn’t always generalize well. For instance a picture of a cat you could mirror the image as augmentation, but if you’re looking at number if you rotate a 6 it becomes a 9 which could give even worse results than without augmentation.
We would prefer to use methods that can be used without worrying as much about the specifics of the data. For that we use a set of techniques called Regularization. “Regularization is a technique which makes slight modifications to the learning algorithm such that the model generalizes better.” There are many types of regularization, we’ll look at four that are common and useful for neural networks.
The first two methods are often grouped together because they both fall under the category of weight decay: L1 and L2 Regularization. These methods adjust the cost function that’s used when you propagate to adjust the variables in your network. The main variables that get adjusted in propagation are the weights and biases that are applied to the incoming data (and then applied to each layer if you use a multi layer network). These adjustments are made by determining how far away they are from where you would want them to reduce error. This difference between the current value and where you want them is usually called the cost (be aware sometimes other variable name conventions are used). The usual way cost is calculated is just a function of the current weights and based on the activation function (for activating the neurons, we won’t go over this now because you need to know what a neuron is to have built a neural network). Instead of that, for L1 and L2 regularization, to find cost we decay that value towards zero so that in a sense the weights are “penalized” and any single one will have a smaller effect on the final predictions.
for L1 regularization we use the equation
cost = (cost without regularization) + λ/ (2m) * normalized(weights)
m is the number of data points used. lambda is a parameter that we choose to decide how fast or slow we want to decay the weights, often around 0.1. The weights is a vector and so must be normalized.
This is close to L2 regularization, but L2 regularization uses the square of the normalized weights.
cost = (cost without regularization) + λ/ (2m) * normalized(weights)²
Squaring the normalized weights has a couple purposes, for instance it makes it so that the decay affects some weights more than others based on how large that weight is, larger weights decay less. Another is because of back propagation in our model. During back propagation we end up using the derivative of the cost with respect to the weight. it’s not pivotal that you understand that now in detail, but it is important to know the derivative is used and L2 regularization’s derivative makes that easier.
Both of L1 and L2 are weight decay regularizations, The weights act on the neurons, but we don’t have to use weight decay, instead if we want we can accomplish a similar result by repeatedly only using parts of the training set in what’s called Dropout. During each step instead of decaying all weights instead we can randomly choose some weights disregard for that iteration. Disregarding, or dropping that weights is the same as dropping the neuron that weight affects since they get multiplied so if you set it to 0 it gets unused for that step. The reason this works so well is because during ordinary training without dropout some variables change based on each other, and that’s part of what can cause overfitting. The outlier or noise spreads to other variables and changes them too, that’s what causes the model to work less well on data that’s not the training set. But by dropping out different random variables during each iteration the data we want to eliminate can’t get incorporated into other variables as easily because it is not reliably present, whereas the data we do want would come from multiple examples and so be present even with dropout. In practice dropout is useful and even easier to implement than weight decay, but since we drop variables each iteration the dropping effect can be large if we don’t have a sufficiently large network to begin with. In many cases dropout is used simultaneously with weight decay (usually L2) to great effect.
When you begin training a neural network the error starts quite high (usually we initialize randomly so the predictions/training error start out no better than chance) but as you begin to train the training error drops. Training error drops rapidly as first as the networks updates the weights of the most important features, then it slows and approaches some ideal level of training error for your chosen model size.
Validation error follows a similar pattern, but as training progresses past an ideal level the validation error can rise again as the model passes ideal fit levels and moves into overfitting. In this sense overtraining is a direct cause of overfitting and therefore bad generalization of the model. That’s the rationale for our last regularization method, aptly referred to as Early Stopping. This is exactly what it sounds like, and at least in theory the simplest of the regularization methods. Stop training when you’re no longer fitting the network to useful data, but when the network starts fitting it to the outlier, noise, overfit data. That doesn’t mean bad data like a bad example, for instance a handwritten digit with bad handwriting, but rather a deep feature that the model is giving weight to, often late in the model and therefore very hard to pinpoint to directly account for. But because we know that it tends to happen more late in training we can account for it by ignoring it entirely and not training that much. In practice the difficult task is deciding when to stop training. Much has been researched into it and basic methods involving defining some new parameters for a training rate and if the model is no longer training at that rate (the error is still decreasing but too slowly) you can stop training right then and there. Then use the model as it is and get better results than if you’d kept training! My example blue and red graph make’s it look like a simple case, unfortunately often the graph is nonlinear and can be quite ugly so it can be quite challenging deciding when to stop, because too early or late and the error remains high still.
Data augmentation, L1 and L2 regularization, dropout and early stopping are all powerful tools that work towards the same goal of combatting overfitting and getting more accurate data from a neural network. However none are the sole answer. Often multiple are used together and it is useful to be familiar with each of them to get the best results out of any model. Overfitting is a real problem with neural networks, but a problem that can be solved.