Most people effortlessly recognize those digits as 504192. That ease is deceptive. In each hemisphere of our brain, humans have a primary visual cortex, also known as V1, containing 140 million neurons, with tens of billions of connections between them. And yet human vision involves not just V1, but an entire series of visual cortices - V2, V3, V4, and V5 - doing progressively more complex image processing. We carry in our heads a supercomputer, tuned by evolution over hundreds of millions of years, and superbly adapted to understand the visual world. Recognizing handwritten digits isn't easy. Rather, we humans are stupendously, astoundingly good at making sense of what our eyes show us. But nearly all that work is done unconsciously. And so we don't usually appreciate how tough a problem our visual systems solve.
The difficulty of visual pattern recognition becomes apparent if you attempt to write a computer program to recognize digits like those above. What seems easy when we do it ourselves suddenly becomes extremely difficult. Simple intuitions about how we recognize shapes - "a 9 has a loop at the top, and a vertical stroke in the bottom right" - turn out to be not so simple to express algorithmically. When you try to make such rules precise, you quickly get lost in a morass of exceptions and caveats and special cases. It seems hopeless.
Neural networks approach the problem in a different way. The idea is to take a large number of handwritten digits, known as training examples,
and then develop a system which can learn from those training examples. In other words, the neural network uses the examples to automatically infer rules for recognizing handwritten digits. Furthermore, by increasing the number of training examples, the network can learn more about handwriting, and so improve its accuracy. So while I've shown just 100 training digits above, perhaps we could build a better handwriting recognizer by using thousands or even millions or billions of training examples.
In this chapter we'll write a computer program implementing a neural network that learns to recognize handwritten digits. The program is just 74 lines long, and uses no special neural network libraries. But this short program can recognize digits with an accuracy over 96 percent, without human intervention. Furthermore, in later chapters we'll develop ideas which can improve accuracy to over 99 percent. In fact, the best commercial neural networks are now so good that they are used by banks to process cheques, and by post offices to recognize addresses.
We're focusing on handwriting recognition because it's an excellent prototype problem for learning about neural networks in general. As a prototype it hits a sweet spot: it's challenging - it's no small feat to recognize handwritten digits - but it's not so difficult as to require an extremely complicated solution, or tremendous computational power. Furthermore, it's a great way to develop more advanced techniques, such as deep learning. And so throughout the book we'll return repeatedly to the problem of handwriting recognition. Later in the book, we'll discuss how these ideas may be applied to other problems in computer vision, and also in speech, natural language processing, and other domains.
Of course, if the point of the chapter was only to write a computer program to recognize handwritten digits, then the chapter would be much shorter! But along the way we'll develop many key ideas about neural networks, including two important types of artificial neuron (the perceptron and the sigmoid neuron), and the standard learning algorithm for neural networks, known as stochastic gradient descent. Throughout, I focus on explaining why things are done the way they are, and on building your neural networks intuition. That requires a lengthier discussion than if I just presented the basic mechanics of what's going on, but it's worth it for the deeper understanding you'll attain. Amongst the payoffs, by the end of the chapter we'll be in position to understand what deep learning is, and why it matters.
What is a neural network? To get started, I'll explain a type of artificial neuron called a perceptron. Perceptrons were developed in the 1950s and 1960s by the scientist Frank Rosenblatt, inspired by earlier work by Warren McCulloch and Walter Pitts. Today, it's more common to use other models of artificial neurons - in this book, and in much modern work on neural networks, the main neuron model used is one called the sigmoid neuron. We'll get to sigmoid neurons shortly. But to understand why sigmoid neurons are defined the way they are, it's worth taking the time to first understand perceptrons.
So how do perceptrons work? A perceptron takes several binary inputs, , and produces a single binary output:
That's the basic mathematical model. A way you can think about the perceptron is that it's a device that makes decisions by weighing up evidence. Let me give an example. It's not a very realistic example, but it's easy to understand, and we'll soon get to more realistic examples. Suppose the weekend is coming up, and you've heard that there's going to be a cheese festival in your city. You like cheese, and are trying to decide whether or not to go to the festival. You might make your decision by weighing up three factors:
Now, suppose you absolutely adore cheese, so much so that you're happy to go to the festival even if your boyfriend or girlfriend is uninterested and the festival is hard to get to. But perhaps you really loathe bad weather, and there's no way you'd go to the festival if the weather is bad. You can use perceptrons to model this kind of decision-making. One way to do this is to choose a weight for the weather, and and for the other conditions. The larger value of indicates that the weather matters a lot to you, much more than whether your boyfriend or girlfriend joins you, or the nearness of public transit. Finally, suppose you choose a threshold of for the perceptron. With these choices, the perceptron implements the desired decision-making model, outputting whenever the weather is good, and whenever the weather is bad. It makes no difference to the output whether your boyfriend or girlfriend wants to go, or whether public transit is nearby.
By varying the weights and the threshold, we can get different models of decision-making. For example, suppose we instead chose a threshold of . Then the perceptron would decide that you should go to the festival whenever the weather was good or when both the festival was near public transit and your boyfriend or girlfriend was willing to join you. In other words, it'd be a different model of decision-making. Dropping the threshold means you're more willing to go to the festival.
Obviously, the perceptron isn't a complete model of human decision-making! But what the example illustrates is how a perceptron can weigh up different kinds of evidence in order to make decisions. And it should seem plausible that a complex network of perceptrons could make quite subtle decisions:
Incidentally, when I defined perceptrons I said that a perceptron has just a single output. In the network above the perceptrons look like they have multiple outputs. In fact, they're still single output. The multiple output arrows are merely a useful way of indicating that the output from a perceptron is being used as the input to several other perceptrons. It's less unwieldy than drawing a single output line which then splits.
Let's simplify the way we describe perceptrons. The condition is cumbersome, and we can make two notational changes to simplify it. The first change is to write as a dot product, , where and are vectors whose components are the weights and inputs, respectively. The second change is to move the threshold to the other side of the inequality, and to replace it by what's known as the perceptron's bias, . Using the bias instead of the threshold, the perceptron rule can be rewritten:
I've described perceptrons as a method for weighing evidence to make
decisions. Another way perceptrons can be used is to compute the
elementary logical functions we usually think of as underlying
computation, functions such as
NAND. For example, suppose we have a perceptron with two
inputs, each with weight , and an overall bias of . Here's our
NAND example shows that we can use perceptrons to compute
simple logical functions.
In fact, we can use
networks of perceptrons to compute any logical function at all.
The reason is that the
NAND gate is universal for
computation, that is, we can build any computation up out of
NAND gates. For example, we can use
NAND gates to
build a circuit which adds two bits, and . This requires
computing the bitwise sum, , as well as a carry bit
which is set to when both and are , i.e., the carry
bit is just the bitwise product :
NANDgates by perceptrons with two inputs, each with weight , and an overall bias of . Here's the resulting network. Note that I've moved the perceptron corresponding to the bottom right
NANDgate a little, just to make it easier to draw the arrows on the diagram:
The adder example demonstrates how a network of perceptrons can be
used to simulate a circuit containing many
NAND gates. And
NAND gates are universal for computation, it follows
that perceptrons are also universal for computation.
The computational universality of perceptrons is simultaneously
reassuring and disappointing. It's reassuring because it tells us
that networks of perceptrons can be as powerful as any other computing
device. But it's also disappointing, because it makes it seem as
though perceptrons are merely a new type of
That's hardly big news!
However, the situation is better than this view suggests. It turns
out that we can devise learning
algorithms which can
automatically tune the weights and biases of a network of artificial
neurons. This tuning happens in response to external stimuli, without
direct intervention by a programmer. These learning algorithms enable
us to use artificial neurons in a way which is radically different to
conventional logic gates. Instead of explicitly laying out a circuit
NAND and other gates, our neural networks can simply learn
to solve problems, sometimes problems where it would be extremely
difficult to directly design a conventional circuit.
Learning algorithms sound terrific. But how can we devise such algorithms for a neural network? Suppose we have a network of perceptrons that we'd like to use to learn to solve some problem. For example, the inputs to the network might be the raw pixel data from a scanned, handwritten image of a digit. And we'd like the network to learn weights and biases so that the output from the network correctly classifies the digit. To see how learning might work, suppose we make a small change in some weight (or bias) in the network. What we'd like is for this small change in weight to cause only a small corresponding change in the output from the network. As we'll see in a moment, this property will make learning possible. Schematically, here's what we want (obviously this network is too simple to do handwriting recognition!):
If it were true that a small change in a weight (or bias) causes only a small change in output, then we could use this fact to modify the weights and biases to get our network to behave more in the manner we want. For example, suppose the network was mistakenly classifying an image as an "8" when it should be a "9". We could figure out how to make a small change in the weights and biases so the network gets a little closer to classifying the image as a "9". And then we'd repeat this, changing the weights and biases over and over to produce better and better output. The network would be learning.
The problem is that this isn't what happens when our network contains perceptrons. In fact, a small change in the weights or bias of any single perceptron in the network can sometimes cause the output of that perceptron to completely flip, say from to . That flip may then cause the behaviour of the rest of the network to completely change in some very complicated way. So while your "9" might now be classified correctly, the behaviour of the network on all the other images is likely to have completely changed in some hard-to-control way. That makes it difficult to see how to gradually modify the weights and biases so that the network gets closer to the desired behaviour. Perhaps there's some clever way of getting around this problem. But it's not immediately obvious how we can get a network of perceptrons to learn.
We can overcome this problem by introducing a new type of artificial neuron called a sigmoid neuron. Sigmoid neurons are similar to perceptrons, but modified so that small changes in their weights and bias cause only a small change in their output. That's the crucial fact which will allow a network of sigmoid neurons to learn.
Okay, let me describe the sigmoid neuron. We'll depict sigmoid neurons in the same way we depicted perceptrons:
At first sight, sigmoid neurons appear very different to perceptrons. The algebraic form of the sigmoid function may seem opaque and forbidding if you're not already familiar with it. In fact, there are many similarities between perceptrons and sigmoid neurons, and the algebraic form of the sigmoid function turns out to be more of a technical detail than a true barrier to understanding.
To understand the similarity to the perceptron model, suppose is a large positive number. Then and so . In other words, when is large and positive, the output from the sigmoid neuron is approximately , just as it would have been for a perceptron. Suppose on the other hand that is very negative. Then , and . So when is very negative, the behaviour of a sigmoid neuron also closely approximates a perceptron. It's only when is of modest size that there's much deviation from the perceptron model.
What about the algebraic form of ? How can we understand that? In fact, the exact form of isn't so important - what really matters is the shape of the function when plotted. Here's the shape:
This shape is a smoothed out version of a step function:
If had in fact been a step function, then the sigmoid neuron would be a perceptron, since the output would be or depending on whether was positive or negative* *Actually, when the perceptron outputs , while the step function outputs . So, strictly speaking, we'd need to modify the step function at that one point. But you get the idea.. By using the actual function we get, as already implied above, a smoothed out perceptron. Indeed, it's the smoothness of the function that is the crucial fact, not its detailed form. The smoothness of means that small changes in the weights and in the bias will produce a small change in the output from the neuron. In fact, calculus tells us that is well approximated by
If it's the shape of which really matters, and not its exact
form, then why use the particular form used for in
How should we interpret the output from a sigmoid neuron? Obviously, one big difference between perceptrons and sigmoid neurons is that sigmoid neurons don't just output or . They can have as output any real number between and , so values such as and are legitimate outputs. This can be useful, for example, if we want to use the output value to represent the average intensity of the pixels in an image input to a neural network. But sometimes it can be a nuisance. Suppose we want the output from the network to indicate either "the input image is a 9" or "the input image is not a 9". Obviously, it'd be easiest to do this if the output was a or a , as in a perceptron. But in practice we can set up a convention to deal with this, for example, by deciding to interpret any output of at least as indicating a "9", and any output less than as indicating "not a 9". I'll always explicitly state when we're using such a convention, so it shouldn't cause any confusion.
In the next section I'll introduce a neural network that can do a pretty good job classifying handwritten digits. In preparation for that, it helps to explain some terminology that lets us name different parts of a network. Suppose we have the network:
The design of the input and output layers in a network is often straightforward. For example, suppose we're trying to determine whether a handwritten image depicts a "9" or not. A natural way to design the network is to encode the intensities of the image pixels into the input neurons. If the image is a by greyscale image, then we'd have input neurons, with the intensities scaled appropriately between and . The output layer will contain just a single neuron, with output values of less than indicating "input image is not a 9", and values greater than indicating "input image is a 9 ".
While the design of the input and output layers of a neural network is often straightforward, there can be quite an art to the design of the hidden layers. In particular, it's not possible to sum up the design process for the hidden layers with a few simple rules of thumb. Instead, neural networks researchers have developed many design heuristics for the hidden layers, which help people get the behaviour they want out of their nets. For example, such heuristics can be used to help determine how to trade off the number of hidden layers against the time required to train the network. We'll meet several such design heuristics later in this book.
Up to now, we've been discussing neural networks where the output from one layer is used as input to the next layer. Such networks are called feedforward neural networks. This means there are no loops in the network - information is always fed forward, never fed back. If we did have loops, we'd end up with situations where the input to the function depended on the output. That'd be hard to make sense of, and so we don't allow such loops.
However, there are other models of artificial neural networks in which feedback loops are possible. These models are called recurrent neural networks. The idea in these models is to have neurons which fire for some limited duration of time, before becoming quiescent. That firing can stimulate other neurons, which may fire a little while later, also for a limited duration. That causes still more neurons to fire, and so over time we get a cascade of neurons firing. Loops don't cause problems in such a model, since a neuron's output only affects its input at some later time, not instantaneously.
Recurrent neural nets have been less influential than feedforward networks, in part because the learning algorithms for recurrent nets are (at least to date) less powerful. But recurrent networks are still extremely interesting. They're much closer in spirit to how our brains work than feedforward networks. And it's possible that recurrent networks can solve important problems which can only be solved with great difficulty by feedforward networks. However, to limit our scope, in this book we're going to concentrate on the more widely-used feedforward networks.
Having defined neural networks, let's return to handwriting recognition. We can split the problem of recognizing handwritten digits into two sub-problems. First, we'd like a way of breaking an image containing many digits into a sequence of separate images, each containing a single digit. For example, we'd like to break the image
into six separate images,
We humans solve this segmentation problem with ease, but it's challenging for a computer program to correctly break up the image. Once the image has been segmented, the program then needs to classify each individual digit. So, for instance, we'd like our program to recognize that the first digit above,
is a 5.
We'll focus on writing a program to solve the second problem, that is, classifying individual digits. We do this because it turns out that the segmentation problem is not so difficult to solve, once you have a good way of classifying individual digits. There are many approaches to solving the segmentation problem. One approach is to trial many different ways of segmenting the image, using the individual digit classifier to score each trial segmentation. A trial segmentation gets a high score if the individual digit classifier is confident of its classification in all segments, and a low score if the classifier is having a lot of trouble in one or more segments. The idea is that if the classifier is having trouble somewhere, then it's probably having trouble because the segmentation has been chosen incorrectly. This idea and other variations can be used to solve the segmentation problem quite well. So instead of worrying about segmentation we'll concentrate on developing a neural network which can solve the more interesting and difficult problem, namely, recognizing individual handwritten digits.
To recognize individual digits we will use a three-layer neural network:
The input layer of the network contains neurons encoding the values of the input pixels. As discussed in the next section, our training data for the network will consist of many by pixel images of scanned handwritten digits, and so the input layer contains neurons. For simplicity I've omitted most of the input neurons in the diagram above. The input pixels are greyscale, with a value of representing white, a value of representing black, and in between values representing gradually darkening shades of grey.
The second layer of the network is a hidden layer. We denote the number of neurons in this hidden layer by , and we'll experiment with different values for . The example shown illustrates a small hidden layer, containing just neurons.
The output layer of the network contains 10 neurons. If the first neuron fires, i.e., has an output , then that will indicate that the network thinks the digit is a . If the second neuron fires then that will indicate that the network thinks the digit is a . And so on. A little more precisely, we number the output neurons from through , and figure out which neuron has the highest activation value. If that neuron is, say, neuron number , then our network will guess that the input digit was a . And so on for the other output neurons.
You might wonder why we use output neurons. After all, the goal of the network is to tell us which digit () corresponds to the input image. A seemingly natural way of doing that is to use just output neurons, treating each neuron as taking on a binary value, depending on whether the neuron's output is closer to or to . Four neurons are enough to encode the answer, since is more than the 10 possible values for the input digit. Why should our network use neurons instead? Isn't that inefficient? The ultimate justification is empirical: we can try out both network designs, and it turns out that, for this particular problem, the network with output neurons learns to recognize digits better than the network with output neurons. But that leaves us wondering why using output neurons works better. Is there some heuristic that would tell us in advance that we should use the -output encoding instead of the -output encoding?
To understand why we do this, it helps to think about what the neural network is doing from first principles. Consider first the case where we use output neurons. Let's concentrate on the first output neuron, the one that's trying to decide whether or not the digit is a . It does this by weighing up evidence from the hidden layer of neurons. What are those hidden neurons doing? Well, just suppose for the sake of argument that the first neuron in the hidden layer detects whether or not an image like the following is present:
It can do this by heavily weighting input pixels which overlap with the image, and only lightly weighting the other inputs. In a similar way, let's suppose for the sake of argument that the second, third, and fourth neurons in the hidden layer detect whether or not the following images are present:
As you may have guessed, these four images together make up the image that we saw in the line of digits shown earlier:
So if all four of these hidden neurons are firing then we can conclude that the digit is a . Of course, that's not the only sort of evidence we can use to conclude that the image was a - we could legitimately get a in many other ways (say, through translations of the above images, or slight distortions). But it seems safe to say that at least in this case we'd conclude that the input was a .
Supposing the neural network functions in this way, we can give a plausible explanation for why it's better to have outputs from the network, rather than . If we had outputs, then the first output neuron would be trying to decide what the most significant bit of the digit was. And there's no easy way to relate that most significant bit to simple shapes like those shown above. It's hard to imagine that there's any good historical reason the component shapes of the digit will be closely related to (say) the most significant bit in the output.
Now, with all that said, this is all just a heuristic. Nothing says that the three-layer neural network has to operate in the way I described, with the hidden neurons detecting simple component shapes. Maybe a clever learning algorithm will find some assignment of weights that lets us use only output neurons. But as a heuristic the way of thinking I've described works pretty well, and can save you a lot of time in designing good neural network architectures.
Now that we have a design for our neural network, how can it learn to recognize digits? The first thing we'll need is a data set to learn from - a so-called training data set. We'll use the MNIST data set, which contains tens of thousands of scanned images of handwritten digits, together with their correct classifications. MNIST's name comes from the fact that it is a modified subset of two data sets collected by NIST, the United States' National Institute of Standards and Technology. Here's a few images from MNIST:
As you can see, these digits are, in fact, the same as those shown at the beginning of this chapter as a challenge to recognize. Of course, when testing our network we'll ask it to recognize images which aren't in the training set!
The MNIST data comes in two parts. The first part contains 60,000 images to be used as training data. These images are scanned handwriting samples from 250 people, half of whom were US Census Bureau employees, and half of whom were high school students. The images are greyscale and 28 by 28 pixels in size. The second part of the MNIST data set is 10,000 images to be used as test data. Again, these are 28 by 28 greyscale images. We'll use the test data to evaluate how well our neural network has learned to recognize digits. To make this a good test of performance, the test data was taken from a different set of 250 people than the original training data (albeit still a group split between Census Bureau employees and high school students). This helps give us confidence that our system can recognize digits from people whose writing it didn't see during training.
We'll use the notation to denote a training input. It'll be convenient to regard each training input as a -dimensional vector. Each entry in the vector represents the grey value for a single pixel in the image. We'll denote the corresponding desired output by , where is a -dimensional vector. For example, if a particular training image, , depicts a , then is the desired output from the network. Note that here is the transpose operation, turning a row vector into an ordinary (column) vector.
What we'd like is an algorithm which lets us find weights and biases so that the output from the network approximates for all training inputs . To quantify how well we're achieving this goal we define a cost function* *Sometimes referred to as a loss or objective function. We use the term cost function throughout this book, but you should note the other terminology, since it's often used in research papers and other discussions of neural networks. :
Why introduce the quadratic cost? After all, aren't we primarily interested in the number of images correctly classified by the network? Why not try to maximize that number directly, rather than minimizing a proxy measure like the quadratic cost? The problem with that is that the number of images correctly classified is not a smooth function of the weights and biases in the network. For the most part, making small changes to the weights and biases won't cause any change at all in the number of training images classified correctly. That makes it difficult to figure out how to change the weights and biases to get improved performance. If we instead use a smooth cost function like the quadratic cost it turns out to be easy to figure out how to make small changes in the weights and biases so as to get an improvement in the cost. That's why we focus first on minimizing the quadratic cost, and only after that will we examine the classification accuracy.
Even given that we want to use a smooth cost function, you may still
wonder why we choose the quadratic function used in
Recapping, our goal in training a neural network is to find weights and biases which minimize the quadratic cost function . This is a well-posed problem, but it's got a lot of distracting structure as currently posed - the interpretation of and as weights and biases, the function lurking in the background, the choice of network architecture, MNIST, and so on. It turns out that we can understand a tremendous amount by ignoring most of that structure, and just concentrating on the minimization aspect. So for now we're going to forget all about the specific form of the cost function, the connection to neural networks, and so on. Instead, we're going to imagine that we've simply been given a function of many variables and we want to minimize that function. We're going to develop a technique called gradient descent which can be used to solve such minimization problems. Then we'll come back to the specific function we want to minimize for neural networks.
Okay, let's suppose we're trying to minimize some function, . This could be any real-valued function of many variables, . Note that I've replaced the and notation by to emphasize that this could be any function - we're not specifically thinking in the neural networks context any more. To minimize it helps to imagine as a function of just two variables, which we'll call and :
What we'd like is to find where achieves its global minimum. Now, of course, for the function plotted above, we can eyeball the graph and find the minimum. In that sense, I've perhaps shown slightly too simple a function! A general function, , may be a complicated function of many variables, and it won't usually be possible to just eyeball the graph to find the minimum.
One way of attacking the problem is to use calculus to try to find the minimum analytically. We could compute derivatives and then try using them to find places where is an extremum. With some luck that might work when is a function of just one or a few variables. But it'll turn into a nightmare when we have many more variables. And for neural networks we'll often want far more variables - the biggest neural networks have cost functions which depend on billions of weights and biases in an extremely complicated way. Using calculus to minimize that just won't work!
(After asserting that we'll gain insight by imagining as a function of just two variables, I've turned around twice in two paragraphs and said, "hey, but what if it's a function of many more than two variables?" Sorry about that. Please believe me when I say that it really does help to imagine as a function of two variables. It just happens that sometimes that picture breaks down, and the last two paragraphs were dealing with such breakdowns. Good thinking about mathematics often involves juggling multiple intuitive pictures, learning when it's appropriate to use each picture, and when it's not.)
Okay, so calculus doesn't work. Fortunately, there is a beautiful analogy which suggests an algorithm which works pretty well. We start by thinking of our function as a kind of a valley. If you squint just a little at the plot above, that shouldn't be too hard. And we imagine a ball rolling down the slope of the valley. Our everyday experience tells us that the ball will eventually roll to the bottom of the valley. Perhaps we can use this idea as a way to find a minimum for the function? We'd randomly choose a starting point for an (imaginary) ball, and then simulate the motion of the ball as it rolled down to the bottom of the valley. We could do this simulation simply by computing derivatives (and perhaps some second derivatives) of - those derivatives would tell us everything we need to know about the local "shape" of the valley, and therefore how our ball should roll.
Based on what I've just written, you might suppose that we'll be trying to write down Newton's equations of motion for the ball, considering the effects of friction and gravity, and so on. Actually, we're not going to take the ball-rolling analogy quite that seriously - we're devising an algorithm to minimize , not developing an accurate simulation of the laws of physics! The ball's-eye view is meant to stimulate our imagination, not constrain our thinking. So rather than get into all the messy details of physics, let's simply ask ourselves: if we were declared God for a day, and could make up our own laws of physics, dictating to the ball how it should roll, what law or laws of motion could we pick that would make it so the ball always rolled to the bottom of the valley?
To make this question more precise, let's think about what happens when we move the ball a small amount in the direction, and a small amount in the direction. Calculus tells us that changes as follows:
Summing up, the way the gradient descent algorithm works is to repeatedly compute the gradient , and then to move in the opposite direction, "falling down" the slope of the valley. We can visualize it like this:
Notice that with this rule gradient descent doesn't reproduce real physical motion. In real life a ball has momentum, and that momentum may allow it to roll across the slope, or even (momentarily) roll uphill. It's only after the effects of friction set in that the ball is guaranteed to roll down into the valley. By contrast, our rule for choosing just says "go down, right now". That's still a pretty good rule for finding the minimum!
To make gradient descent work correctly, we need to choose the
learning rate to be small
enough that Equation
I've explained gradient descent when is a function of just two variables. But, in fact, everything works just as well even when is a function of many more variables. Suppose in particular that is a function of variables, . Then the change in produced by a small change is
Indeed, there's even a sense in which gradient descent is the optimal strategy for searching for a minimum. Let's suppose that we're trying to make a move in position so as to decrease as much as possible. This is equivalent to minimizing . We'll constrain the size of the move so that for some small fixed . In other words, we want a move that is a small step of a fixed size, and we're trying to find the movement direction which decreases as much as possible. It can be proved that the choice of which minimizes is , where is determined by the size constraint . So gradient descent can be viewed as a way of taking small steps in the direction which does the most to immediately decrease .
People have investigated many variations of gradient descent, including variations that more closely mimic a real physical ball. These ball-mimicking variations have some advantages, but also have a major disadvantage: it turns out to be necessary to compute second partial derivatives of , and this can be quite costly. To see why it's costly, suppose we want to compute all the second partial derivatives . If there are a million such variables then we'd need to compute something like a trillion (i.e., a million squared) second partial derivatives* *Actually, more like half a trillion, since . Still, you get the point.! That's going to be computationally costly. With that said, there are tricks for avoiding this kind of problem, and finding alternatives to gradient descent is an active area of investigation. But in this book we'll use gradient descent (and variations) as our main approach to learning in neural networks.
How can we apply gradient descent to learn in a neural network? The
idea is to use gradient descent to find the weights and biases
which minimize the cost in
There are a number of challenges in applying the gradient descent
rule. We'll look into those in depth in later chapters. But for now
I just want to mention one problem. To understand what the problem
is, let's look back at the quadratic cost in
An idea called stochastic gradient descent can be used to speed up learning. The idea is to estimate the gradient by computing for a small sample of randomly chosen training inputs. By averaging over this small sample it turns out that we can quickly get a good estimate of the true gradient , and this helps speed up gradient descent, and thus learning.
To make these ideas more precise, stochastic gradient descent works by randomly picking out a small number of randomly chosen training inputs. We'll label those random training inputs , and refer to them as a mini-batch. Provided the sample size is large enough we expect that the average value of the will be roughly equal to the average over all , that is,
To connect this explicitly to learning in neural networks, suppose and denote the weights and biases in our neural network. Then stochastic gradient descent works by picking out a randomly chosen mini-batch of training inputs, and training with those,
Incidentally, it's worth noting that conventions vary about scaling of
the cost function and of mini-batch updates to the weights and biases.
We can think of stochastic gradient descent as being like political polling: it's much easier to sample a small mini-batch than it is to apply gradient descent to the full batch, just as carrying out a poll is easier than running a full election. For example, if we have a training set of size , as in MNIST, and choose a mini-batch size of (say) , this means we'll get a factor of speedup in estimating the gradient! Of course, the estimate won't be perfect - there will be statistical fluctuations - but it doesn't need to be perfect: all we really care about is moving in a general direction that will help decrease , and that means we don't need an exact computation of the gradient. In practice, stochastic gradient descent is a commonly used and powerful technique for learning in neural networks, and it's the basis for most of the learning techniques we'll develop in this book.