- 31 July 2019 / By Abdulkadir Pir
Deep Down in CNN: Does it Sound too Complicated? Let’s Observe Weights with a Heatmap!
We all use some deep learning tools to create applications for our projects. Since in our era to build a computer vision application is very easy unlike the past times, sometimes these tools prevent us from learning the working mechanism of the algorithms.
We believe that the best way to understand how it works, to observe it while it works!
So we will experiment how CNN works and how it makes a classification using heatmap. This will give us a better understanding of convolution used in neural networks and also the importance of weights in our algorithms.
We will follow the instructions in the Fastai DL-1 Course in this example. A big thanks to them for making neural nets uncool again!
The Context of the Post:
- Choosing a Dataset
- Exploratory Data Analysis and Data Augmentation
- Creating a Model and Learner
- Analyzing The Result Using Heatmap
1. Choosing a Dataset
When experimenting with new things, the best way to follow is using a precleaned dataset for us. Because we don’t want to lose our motivation while cleaning and preparing our data for the model. Best place to find the good quality dataset in order to try new things on is Kaggle!
I used an image classification dataset which named Plant Seedling Classification in this example. I want to explain briefly why I chose this dataset; because it is a well-balanced classification problem and also the aim of the competition is appropriate for my algorithm.
In this competition, we are asked to create a model which successfully accomplish the task of classification of the images belonging 12 different species. The dataset contains 960 images.
Another reason why I chose this dataset is, it seemed a hard task for me to categorize plant types from these images. They look very familiar and with their same background(soil) I thought that the algorithm would not be really successful to identify plant categories. And I really wondered how the convolution works in this specific problem.
2. Exploratory Data Analysis and Data Augmentation
Before deciding upon the library, architecture, and other project requirements the most important step is to explore the dataset. Although Kaggle competition datasets are really well prepared, as a data scientist you always need to understand the problem first before creating the solution.
tfms = get_transforms(flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.)
data = ImageDataBunch.from_folder(
path = path,
valid_pct = 0.2,
bs = 16,
size = 224,
ds_tfms = tfms,
So after observing the images from our dataset, I prepared a data model for training. Here I want to explain briefly what I did in this code blog:
- Firstly I decided the data augmentation specs. I applied a vertical flip and some warping to augment the images. Also since some images are darker and have tiny plants in them, I used some lighting effect and zooming.
- I spared 20% of the dataset as validation in order to observe my training.
- I used transfer learning in this task so normalized the images with imagenet stats. Also resized them like the original imagenet dataset.
3. Creating a Model and Learner
Fastai is a great deep learning library with its easy to use and understand characteristics. In this experiment, I used a CNN Learner to train my model. So we have a great function in the library called cnn_learner(). But we are not just a tool user who just copy and paste the codes; we want to deeply understand what is going on under the bridge. We have another great function named doc() in our library. So let’s use it to understand cnn_learner() function:
doc(cnn_learner)Result:cnn_learner(data:DataBunch, base_arch:Callable, cut:Union[int, Callable]=None, pretrained:bool=True, lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5, custom_head:Optional[Module]=None, split_on:Union[Callable, Collection[ModuleList], NoneType]=None, bn_final:bool=False, init='kaiming_normal_', concat_pool:bool=True, **kwargs:Any)
So now we can investigate every variable in this function!
After that, we can create our learner. In this problem, I am not focused on accuracy. I want a good result but won’t spend too much time for fine tuning and other steps. I chose a pre-trained model which is resnet34. This model has 34 layers and trained on imagenet dataset. Since I have plants in my task, I believed the results will be enough for me.
Let’s create the learner:
learn = cnn_learner(data, models.resnet34, metrics=accuracy)
We prepared our model and learner. Now it is time to decide the learning rate. We will use another great function for that:
So let’s train this model:
lr = 0.01
We got 93% accuracy and this is enough for me. If I would be participating in this contest, I would use a more complex architecture with fine-tuning and also work on some data preprocessing but since we are focused on understanding what is going on this result is more than enough.
5. Analyzing The Result Using Heatmap
To prepare the heatmap function we used another great Fastai specification. In the library, we have callbacks and hooks libraries, with them we can investigate the training such as backpropagation and evaluate them.
_,ax = plt.subplots()
ax.imshow(hm, alpha=0.6, extent=(0,224,224,0),
I created the heatmap function for our 224–224 pixel-sized images using 0.6 alpha value and magma method. So let’s try it on an image:
The result is better than what I supposed to find out. The soil on the upper left of the image almost has no effect on the result. And this is great because there is not any part of the plant in that area. Also, we can clearly see that the center of the plant has the most effect on the result. After that, the outer parts of the plant which are located at the upper right and left down of the image has a big effect on the result.
Heatmaps are great visualization tools to understand what is going on inside the convolution function. And we can apply them to other research areas.
This post aimed to explain the weights in the convolutions with a joyful method. I will go on to write other examples. Please share your thoughts with me!