A Gentle Intro to Transfer Learning

Slav Ivanov
Slav
Published in
6 min readNov 16, 2017

--

Nowadays most applications of Deep Learning rely on Transfer Learning. This is especially true in the domain of Computer Vision. We will explore what Transfer Learning is, how to do it and what the potential pitfalls are. To do this, we’ll go on a little startup quest.

The Startup

In a otherwise uneventful day, a friend approaches us with a revolutionary idea. He wants to build the next big thing in social media: Facebook for pets. There people would upload pictures of their pets and enjoy others’ pet pictures. Naturally, we jump onboard.

Who wouldn’t want to look at dog pictures all day long?

We quickly raise funding and recruit people to help us to build it. After a round of customer interviews, we realize that cat people don’t really enjoy looking at dog photos. Also, it turns out parrot owners plainly hate cats and folks who have a goldfish only like to look at other fish.

Since this is 2017, we don’t want people to tell us manually what is in the picture they are uploading. Instead, we will use the power of Machine Learning to recognize the animal in each image. So we’d be able to group them into separate sections of our app.

The Convolutional Neural Network

We talk about our project to a friend who recently got his PhD in Machine Learning (ML). He advises us to use Convolutional Neural Networks (CNNs), a particular ML architecture. They were invented by Yann LeCun, who is currently head of AI at Facebook. He explains that Alex Krizhevsky et al. popularized CNNs in 2012. The team crushed the competition in the high-profile Computer Vision competition “ImageNet Large Scale Visual Recognition Challenge”. Ever since then, CNNs has dominated the field of Computer Vision. They are also making forays into related Deep Learning domains.

For more info on Deep Learning in general, I highly recommend the Part 1 and Part 2 of the excellent “Practical Deep Learning for coders”. They are available for free from Fast.AI.

Convolutional Neural Networks have an input, which is the image we are trying to classify. Then there are the learnable weights, organized in groups called filters or kernels. These are filters layered on top of each other. At the end of the network, we have our output. It tells us whether the image we fed is a cat, dog, bird, etc.

We take the Fast.AI course and teach ourselves to build CNNs. We know that deep learning needs large data, so we download the ImageNet dataset. It contains about a million images of, among other things, cats, dogs, and goldfish.

We start training the model using this data, optimizing it with a Stochastic Gradient Descent algorithm.

Sure enough, soon our CNN can tell apart cats from dogs from parrots and more. We call it the Pet-cognizer®.

In reality, people have already done the hard work of training many CNN models on ImageNet. Such pretrained networks are distributed freely. See Pretrained Models on Keras and TorchVision for PyTorch.

The filters in the first layer of the CNN have learned to recognize edges, gradients, and solid color areas. The second layer group combines the data from the first layer to detect stripes, right angles, circles and more.

By the time we have reached the 3rd layer group, we are already recognizing people, birds, and car wheels. This process goes on for many layers.

Finally, the output layer of the network tells us the probability of each class for this image.

Our Pet-cognizer® works!

The New Task

Just as we are ready to launch our project, someone suggest that it would be great if we can further segment dog sections into breeds. You oppose the idea, but the team strong-arms you into it. Huffing and puffing, you go back to the model. We call it the Breed-congizer and get to work.

The problem is that we don’t have a million pictures of dogs to train our new model. Our dataset has 120 breeds with about 100 images for each. We try to train our CNN from scratch with this tiny dataset, but the results are not great.

Results from training from scratch. All of these are incorrect

The Transfer Learning

We remember about a discussion on Reddit about Transfer Learning.

The idea is to take the knowledge learned in a model and apply it to another task.

Transfer learning sounds like what we want to do. We decide to reuse the already trained Pet-cognizer®.

Scenario 1: New dataset is similar to initial dataset

Our new breed dataset is close to the ImageNet dataset we first trained on. It’s similar in the sense that they both contain pictures of the “real world” (as opposed to images of documents or medical scans). Thus the filters in the CNN can be reused, and we don’t have to learn them again.

For Breed-cognizer®, we swap the only the last layer. Instead of telling us what animal is in the picture, it will give us probabilities for each dog breed .

Because the weights on the last layer are initially random we have to train it. But training these random weights might also change the great filters in the earlier layers. To avoid this, we freeze all layers but the last. Freezing means that the layer weights won’t be updated during training.

Rather miraculously, this works to an extent. We are able to achieve 80% accuracy on over 120 classes, just by training the last layer.

Results from training the last layer only

Fine-tuning

But the information that we have learned in the later layers might not be beneficial to the current task. For example, surely in the model there are weights that help distinguish between a cat and a goldfish. This is probably irrelevant to dog breed classification. To further better our network, we unfreeze the last few layers of the CNN and retrain them.

How many layers to unfreeze? It depends, so we start with one and experiment with more until there is no significant improvement.

But because some of these layers are useful, we only use very small Learning Rate for the unfrozen filters. This will help preserve the useful knowledge while changing what is not needed.

The Learning Rate controls how much we update weights on each Gradient Descent iteration.

Fine-tuning brings our Breed-cognizer® accuracy to 90% (without data augmentation).

Fine-tuned network results

Scenario 2: New dataset is not similar to initial dataset

Word of our awesome Computer Vision abilities spreads and people start contacting us for all kinds of projects. One of them involves satellite imagery and another is for a medical startup.

Satellite imagery. Source: Planet: Understanding the Amazon from Space

Both of these are not similar to the dataset we used to train our Pet-congizer®. Especially medical images, for example obtained via CT scans.

CT scan of a lung. Source: Data Science Bowl 2017

In such cases it is still a good idea to start with pretrained weights, but unfreeze all layers. Then we train on the new dataset, with a normal Learning Rate. Also, Stanford’s CS231 has an good discussion on this.

Transfer Learning is used on almost all Computer Vision tasks nowadays. It’s rare to train from scratch unless you have a massive dataset. This primer should have given you some intuition on how and why it works.

Did I miss anything? Is anything wrong? Let me know by leaving a reply below.

If you liked this article, please help others find it: hold the clap icon for as long as you think this article is worth it. Thanks a lot!

--

--