Multiclass image classification using Transfer learning

Introduction

One of the most common tasks involved in Deep Learning based on Image data is Image Classification. Image classification has become more interesting in the research field due to the development of new and high-performing machine learning frameworks. Such classification can either be binary where two classes of images are present or multiclass classification which deals with more than two image classes. Here, in this article, we are going to explore transfer learning with multiclass image classification.

Multiclass Image Classification

With the advancement in artificial Neural networks and the development of Convolutional Neural Networks complex operations on images have become easy and have contributed to the growth and development of tasks like Multiclass image classification, Image Segmentation, and Image Detection.

Multiclass Image Classification is one of the very primary yet powerful computer vision tasks that can be performed using CNN networks. In this method, we have more than two classes of images that are labeled according to their categories ( eg. CIFAR, Fashion MNIST).

For Classification, we can either prepare our own labeled dataset or download already available image datasets like CIFAR 10. Preprocessing techniques may include task like image augmentation to bring variation in the image dataset if the number of images per class is low.

For training the model we may either build a model architecture from scratch using any of the Deep Learning frameworks like Tensorflow or Pytorch, etc., or use a readily available backbone architecture like VGG16, Resnet, etc. The advantage of the later one is that we do have to build the architecture from scratch and only focus on finetuning the model or changing the last 1 or 2 layers as per our use case. This is where transfer learning comes into play which is a very intuitive technique to train image models.

What is Transfer Learning and why it's important?

Transfer learning is a research problem in the field of machine learning. It stores the knowledge gained while solving one problem and applies it to a different but related problem. For example, the knowledge gained while learning to recognize cats could apply when trying to recognize cheetahs. In deep learning, transfer learning is a technique whereby a neural network model is first trained on a problem similar to the problem that is being solved. Transfer learning has the advantage of decreasing the training time for a learning model and can result in lower generalization errors. We can use a pretrained model which is trained on other datasets like ImageNet and modify the last layers to meet the purpose of our task. In such a scenario, we may save time, effort, and resources in training a model from scratch. Such trained models have tons of image patterns and information on already trained images through rigorous training.

Code Implementation

In this example, we are going to use the CIFAR-10 dataset to do multiclass classification. We would also be using the VGG19 network and modifyin it to do transfer learning

Dataset used

The dataset is CIFRAR 10.it is available from Canadian Institute for Advanced Research (CIFAR). It consists of 60000 32×32 color images in 10 classes, with 6000 images per class. The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 50000 training images and 10000 test images in this dataset. The dataset can be imported from keras.

Implementation using Keras API

Example

<div class="code-mirror  language-python" contenteditable="plaintext-only" spellcheck="false" style="outline: none; overflow-wrap: break-word; overflow-y: auto; white-space: pre-wrap;"><span class="token keyword">import</span> numpy <span class="token keyword">as</span> np
<span class="token keyword">import</span> pandas <span class="token keyword">as</span> pd
<span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>multiclass <span class="token keyword">import</span> unique_labels
<span class="token keyword">import</span> os
<span class="token keyword">import</span> matplotlib<span class="token punctuation">.</span>pyplot <span class="token keyword">as</span> plt
<span class="token keyword">import</span> matplotlib<span class="token punctuation">.</span>image <span class="token keyword">as</span> mpimg
<span class="token keyword">import</span> seaborn <span class="token keyword">as</span> sns
<span class="token keyword">import</span> itertools
<span class="token keyword">from</span> keras<span class="token punctuation">.</span>datasets <span class="token keyword">import</span> cifar10
<span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>model_selection <span class="token keyword">import</span> train_test_split
<span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>metrics <span class="token keyword">import</span> confusion_matrix
<span class="token keyword">from</span> keras <span class="token keyword">import</span> Sequential
<span class="token keyword">from</span> keras<span class="token punctuation">.</span>applications <span class="token keyword">import</span> VGG19 <span class="token comment">#For Transfer Learning</span>
<span class="token keyword">from</span> keras<span class="token punctuation">.</span>preprocessing<span class="token punctuation">.</span>image <span class="token keyword">import</span> ImageDataGenerator
<span class="token keyword">from</span> keras<span class="token punctuation">.</span>optimizers <span class="token keyword">import</span> SGD<span class="token punctuation">,</span>Adam
<span class="token keyword">from</span> keras<span class="token punctuation">.</span>callbacks <span class="token keyword">import</span> ReduceLROnPlateau
<span class="token keyword">from</span> keras<span class="token punctuation">.</span>layers <span class="token keyword">import</span> Flatten<span class="token punctuation">,</span>Dense<span class="token punctuation">,</span>BatchNormalization<span class="token punctuation">,</span>Activation<span class="token punctuation">,</span>Dropout
<span class="token keyword">from</span> keras<span class="token punctuation">.</span>utils <span class="token keyword">import</span> to_categorical

<span class="token comment"># Download the CIFAR dataset</span>
<span class="token punctuation">(</span>x_train<span class="token punctuation">,</span>y_train<span class="token punctuation">)</span><span class="token punctuation">,</span><span class="token punctuation">(</span>x_test<span class="token punctuation">,</span>y_test<span class="token punctuation">)</span> <span class="token operator">=</span> cifar10<span class="token punctuation">.</span>load_data<span class="token punctuation">(</span><span class="token punctuation">)</span>

<span class="token comment">#defining training and test sets</span>
x_train<span class="token punctuation">,</span>x_val<span class="token punctuation">,</span>y_train<span class="token punctuation">,</span>y_val<span class="token operator">=</span>train_test_split<span class="token punctuation">(</span>x_train<span class="token punctuation">,</span>y_train<span class="token punctuation">,</span>test_size<span class="token operator">=</span><span class="token number">.3</span><span class="token punctuation">)</span>

<span class="token comment">#Dimension of the dataset</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_train<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_train<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_val<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_val<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_test<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_test<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#One Hot Encoding</span>
y_train<span class="token operator">=</span>to_categorical<span class="token punctuation">(</span>y_train<span class="token punctuation">)</span>
y_val<span class="token operator">=</span>to_categorical<span class="token punctuation">(</span>y_val<span class="token punctuation">)</span>
y_test<span class="token operator">=</span>to_categorical<span class="token punctuation">(</span>y_test<span class="token punctuation">)</span>

<span class="token comment">#Verifying the dimension after one hot encoding</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_train<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_train<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_val<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_val<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_test<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_test<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#Image Data Augmentation</span>
train_generator <span class="token operator">=</span> ImageDataGenerator<span class="token punctuation">(</span>rotation_range<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> horizontal_flip<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> zoom_range<span class="token operator">=</span><span class="token number">.1</span><span class="token punctuation">)</span>
val_generator <span class="token operator">=</span> ImageDataGenerator<span class="token punctuation">(</span>rotation_range<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> horizontal_flip<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> zoom_range<span class="token operator">=</span><span class="token number">.1</span><span class="token punctuation">)</span>
test_generator <span class="token operator">=</span> ImageDataGenerator<span class="token punctuation">(</span>rotation_range<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> horizontal_flip<span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">,</span> zoom_range<span class="token operator">=</span><span class="token number">.1</span><span class="token punctuation">)</span>

<span class="token comment">#Fitting the augmentation defined above to the data</span>
train_generator<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>x_train<span class="token punctuation">)</span>
val_generator<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>x_val<span class="token punctuation">)</span>
test_generator<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>x_test<span class="token punctuation">)</span>

<span class="token comment">#Learning Rate Annealer</span>
lrr<span class="token operator">=</span> ReduceLROnPlateau<span class="token punctuation">(</span>monitor<span class="token operator">=</span><span class="token string">'val_acc'</span><span class="token punctuation">,</span> factor<span class="token operator">=</span><span class="token number">.01</span><span class="token punctuation">,</span> patience<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> min_lr<span class="token operator">=</span><span class="token number">1e-5</span><span class="token punctuation">)</span>

<span class="token comment">#Defining the VGG Convolutional Neural Net</span>
base_model <span class="token operator">=</span> VGG19<span class="token punctuation">(</span>include_top <span class="token operator">=</span> <span class="token boolean">False</span><span class="token punctuation">,</span> weights <span class="token operator">=</span> <span class="token string">'imagenet'</span><span class="token punctuation">,</span> input_shape <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">,</span><span class="token number">32</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> classes <span class="token operator">=</span> y_train<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment">#Adding the final layers to the above base models where the actual classification is done in the dense layers</span>
model<span class="token operator">=</span> Sequential<span class="token punctuation">(</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>base_model<span class="token punctuation">)</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Flatten<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#Model summary</span>
model<span class="token punctuation">.</span>summary<span class="token punctuation">(</span><span class="token punctuation">)</span>

<span class="token comment">#Adding the Dense layers along with activation and batch normalization</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">,</span>input_dim<span class="token operator">=</span><span class="token number">512</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dropout<span class="token punctuation">(</span><span class="token number">.3</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#model.add(Dropout(.2))</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">10</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'softmax'</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#Checking the final model summary</span>
model<span class="token punctuation">.</span>summary<span class="token punctuation">(</span><span class="token punctuation">)</span>

<span class="token comment">#Making prediction</span>
predict_y <span class="token operator">=</span> model<span class="token punctuation">.</span>predict<span class="token punctuation">(</span>x_test<span class="token punctuation">)</span>
y_pred<span class="token operator">=</span>np<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>predict_y<span class="token punctuation">,</span>axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
y_true<span class="token operator">=</span>np<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>y_test<span class="token punctuation">,</span>axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
</div>

Output

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 2s 0us/step
((35000, 32, 32, 3), (35000, 1))
((15000, 32, 32, 3), (15000, 1))
((10000, 32, 32, 3), (10000, 1))
((35000, 32, 32, 3), (35000, 10))
((15000, 32, 32, 3), (15000, 10))
((10000, 32, 32, 3), (10000, 10))
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
80134624/80134624 [==============================] - 1s 0us/step
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
vgg19 (Functional) (None, 1, 1, 512) 20024384
flatten (Flatten) (None, 512) 0
=================================================================
Total params: 20,024,384
Trainable params: 20,024,384
Non-trainable params: 0
_________________________________________________________________
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
vgg19 (Functional) (None, 1, 1, 512) 20024384
flatten (Flatten) (None, 512) 0
dense (Dense) (None, 1024) 525312
dense_1 (Dense) (None, 512) 524800
dense_2 (Dense) (None, 256) 131328
dropout (Dropout) (None, 256) 0
dense_3 (Dense) (None, 128) 32896
dense_4 (Dense) (None, 10) 1290
=================================================================
Total params: 21,240,010
Trainable params: 21,240,010
Non-trainable params: 0
_________________________________________________________________
313/313 [==============================] - 158s 503ms/step

Conclusion

Multiclass image classification has proved highly beneficial to the Deep Learning Fraternity. Being one of the most important basic tasks in computer vision, it is widely adopted in the AI industry as a base task even for complex computer vision applications like Image Segmentation, Detection, and Visual recognition tasks.

Updated on: 2022-12-01T06:56:42+05:30

2K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements