Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
Multiclass image classification using Transfer learning
Introduction
One of the most common tasks involved in Deep Learning based on Image data is Image Classification. Image classification has become more interesting in the research field due to the development of new and high-performing machine learning frameworks. Such classification can either be binary where two classes of images are present or multiclass classification which deals with more than two image classes. Here, in this article, we are going to explore transfer learning with multiclass image classification.
Multiclass Image Classification
With the advancement in artificial Neural networks and the development of Convolutional Neural Networks complex operations on images have become easy and have contributed to the growth and development of tasks like Multiclass image classification, Image Segmentation, and Image Detection.
Multiclass Image Classification is one of the very primary yet powerful computer vision tasks that can be performed using CNN networks. In this method, we have more than two classes of images that are labeled according to their categories ( eg. CIFAR, Fashion MNIST).
For Classification, we can either prepare our own labeled dataset or download already available image datasets like CIFAR 10. Preprocessing techniques may include task like image augmentation to bring variation in the image dataset if the number of images per class is low.
For training the model we may either build a model architecture from scratch using any of the Deep Learning frameworks like Tensorflow or Pytorch, etc., or use a readily available backbone architecture like VGG16, Resnet, etc. The advantage of the later one is that we do have to build the architecture from scratch and only focus on finetuning the model or changing the last 1 or 2 layers as per our use case. This is where transfer learning comes into play which is a very intuitive technique to train image models.
What is Transfer Learning and why it's important?
Transfer learning is a research problem in the field of machine learning. It stores the knowledge gained while solving one problem and applies it to a different but related problem. For example, the knowledge gained while learning to recognize cats could apply when trying to recognize cheetahs. In deep learning, transfer learning is a technique whereby a neural network model is first trained on a problem similar to the problem that is being solved. Transfer learning has the advantage of decreasing the training time for a learning model and can result in lower generalization errors. We can use a pretrained model which is trained on other datasets like ImageNet and modify the last layers to meet the purpose of our task. In such a scenario, we may save time, effort, and resources in training a model from scratch. Such trained models have tons of image patterns and information on already trained images through rigorous training.
Code Implementation
In this example, we are going to use the CIFAR-10 dataset to do multiclass classification. We would also be using the VGG19 network and modifyin it to do transfer learning
Dataset used
The dataset is CIFRAR 10.it is available from Canadian Institute for Advanced Research (CIFAR). It consists of 60000 32×32 color images in 10 classes, with 6000 images per class. The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 50000 training images and 10000 test images in this dataset. The dataset can be imported from keras.
Implementation using Keras API
Example
<div class="code-mirror language-python" contenteditable="plaintext-only" spellcheck="false" style="outline: none; overflow-wrap: break-word; overflow-y: auto; white-space: pre-wrap;"><span class="token keyword">import</span> numpy <span class="token keyword">as</span> np <span class="token keyword">import</span> pandas <span class="token keyword">as</span> pd <span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>multiclass <span class="token keyword">import</span> unique_labels <span class="token keyword">import</span> os <span class="token keyword">import</span> matplotlib<span class="token punctuation">.</span>pyplot <span class="token keyword">as</span> plt <span class="token keyword">import</span> matplotlib<span class="token punctuation">.</span>image <span class="token keyword">as</span> mpimg <span class="token keyword">import</span> seaborn <span class="token keyword">as</span> sns <span class="token keyword">import</span> itertools <span class="token keyword">from</span> keras<span class="token punctuation">.</span>datasets <span class="token keyword">import</span> cifar10 <span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>model_selection <span class="token keyword">import</span> train_test_split <span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>metrics <span class="token keyword">import</span> confusion_matrix <span class="token keyword">from</span> keras <span class="token keyword">import</span> Sequential <span class="token keyword">from</span> keras<span class="token punctuation">.</span>applications <span class="token keyword">import</span> VGG19 <span class="token comment">#For Transfer Learning</span> <span class="token keyword">from</span> keras<span class="token punctuation">.</span>preprocessing<span class="token punctuation">.</span>image <span class="token keyword">import</span> ImageDataGenerator <span class="token keyword">from</span> keras<span class="token punctuation">.</span>optimizers <span class="token keyword">import</span> SGD<span class="token punctuation">,</span>Adam <span class="token keyword">from</span> keras<span class="token punctuation">.</span>callbacks <span class="token keyword">import</span> ReduceLROnPlateau <span class="token keyword">from</span> keras<span class="token punctuation">.</span>layers <span class="token keyword">import</span> Flatten<span class="token punctuation">,</span>Dense<span class="token punctuation">,</span>BatchNormalization<span class="token punctuation">,</span>Activation<span class="token punctuation">,</span>Dropout <span class="token keyword">from</span> keras<span class="token punctuation">.</span>utils <span class="token keyword">import</span> to_categorical <span class="token comment"># Download the CIFAR dataset</span> <span class="token punctuation">(</span>x_train<span class="token punctuation">,</span>y_train<span class="token punctuation">)</span><span class="token punctuation">,</span><span class="token punctuation">(</span>x_test<span class="token punctuation">,</span>y_test<span class="token punctuation">)</span> <span class="token operator">=</span> cifar10<span class="token punctuation">.</span>load_data<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment">#defining training and test sets</span> x_train<span class="token punctuation">,</span>x_val<span class="token punctuation">,</span>y_train<span class="token punctuation">,</span>y_val<span class="token operator">=</span>train_test_split<span class="token punctuation">(</span>x_train<span class="token punctuation">,</span>y_train<span class="token punctuation">,</span>test_size<span class="token operator">=</span><span class="token number">.3</span><span class="token punctuation">)</span> <span class="token comment">#Dimension of the dataset</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_train<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_train<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_val<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_val<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_test<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_test<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment">#One Hot Encoding</span> y_train<span class="token operator">=</span>to_categorical<span class="token punctuation">(</span>y_train<span class="token punctuation">)</span> y_val<span class="token operator">=</span>to_categorical<span class="token punctuation">(</span>y_val<span class="token punctuation">)</span> y_test<span class="token operator">=</span>to_categorical<span class="token punctuation">(</span>y_test<span class="token punctuation">)</span> <span class="token comment">#Verifying the dimension after one hot encoding</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_train<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_train<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_val<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_val<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span>x_test<span class="token punctuation">.</span>shape<span class="token punctuation">,</span>y_test<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment">#Image Data Augmentation</span> train_generator <span class="token operator">=</span> ImageDataGenerator<span class="token punctuation">(</span>rotation_range<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> horizontal_flip<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> zoom_range<span class="token operator">=</span><span class="token number">.1</span><span class="token punctuation">)</span> val_generator <span class="token operator">=</span> ImageDataGenerator<span class="token punctuation">(</span>rotation_range<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> horizontal_flip<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> zoom_range<span class="token operator">=</span><span class="token number">.1</span><span class="token punctuation">)</span> test_generator <span class="token operator">=</span> ImageDataGenerator<span class="token punctuation">(</span>rotation_range<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> horizontal_flip<span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">,</span> zoom_range<span class="token operator">=</span><span class="token number">.1</span><span class="token punctuation">)</span> <span class="token comment">#Fitting the augmentation defined above to the data</span> train_generator<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>x_train<span class="token punctuation">)</span> val_generator<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>x_val<span class="token punctuation">)</span> test_generator<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>x_test<span class="token punctuation">)</span> <span class="token comment">#Learning Rate Annealer</span> lrr<span class="token operator">=</span> ReduceLROnPlateau<span class="token punctuation">(</span>monitor<span class="token operator">=</span><span class="token string">'val_acc'</span><span class="token punctuation">,</span> factor<span class="token operator">=</span><span class="token number">.01</span><span class="token punctuation">,</span> patience<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> min_lr<span class="token operator">=</span><span class="token number">1e-5</span><span class="token punctuation">)</span> <span class="token comment">#Defining the VGG Convolutional Neural Net</span> base_model <span class="token operator">=</span> VGG19<span class="token punctuation">(</span>include_top <span class="token operator">=</span> <span class="token boolean">False</span><span class="token punctuation">,</span> weights <span class="token operator">=</span> <span class="token string">'imagenet'</span><span class="token punctuation">,</span> input_shape <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">,</span><span class="token number">32</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> classes <span class="token operator">=</span> y_train<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment">#Adding the final layers to the above base models where the actual classification is done in the dense layers</span> model<span class="token operator">=</span> Sequential<span class="token punctuation">(</span><span class="token punctuation">)</span> model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>base_model<span class="token punctuation">)</span> model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Flatten<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment">#Model summary</span> model<span class="token punctuation">.</span>summary<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment">#Adding the Dense layers along with activation and batch normalization</span> model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">,</span>input_dim<span class="token operator">=</span><span class="token number">512</span><span class="token punctuation">)</span><span class="token punctuation">)</span> model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dropout<span class="token punctuation">(</span><span class="token number">.3</span><span class="token punctuation">)</span><span class="token punctuation">)</span> model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment">#model.add(Dropout(.2))</span> model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span><span class="token number">10</span><span class="token punctuation">,</span>activation<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'softmax'</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment">#Checking the final model summary</span> model<span class="token punctuation">.</span>summary<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment">#Making prediction</span> predict_y <span class="token operator">=</span> model<span class="token punctuation">.</span>predict<span class="token punctuation">(</span>x_test<span class="token punctuation">)</span> y_pred<span class="token operator">=</span>np<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>predict_y<span class="token punctuation">,</span>axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> y_true<span class="token operator">=</span>np<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>y_test<span class="token punctuation">,</span>axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> </div>
Output
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 [==============================] - 2s 0us/step ((35000, 32, 32, 3), (35000, 1)) ((15000, 32, 32, 3), (15000, 1)) ((10000, 32, 32, 3), (10000, 1)) ((35000, 32, 32, 3), (35000, 10)) ((15000, 32, 32, 3), (15000, 10)) ((10000, 32, 32, 3), (10000, 10)) Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5 80134624/80134624 [==============================] - 1s 0us/step Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= vgg19 (Functional) (None, 1, 1, 512) 20024384 flatten (Flatten) (None, 512) 0 ================================================================= Total params: 20,024,384 Trainable params: 20,024,384 Non-trainable params: 0 _________________________________________________________________ Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= vgg19 (Functional) (None, 1, 1, 512) 20024384 flatten (Flatten) (None, 512) 0 dense (Dense) (None, 1024) 525312 dense_1 (Dense) (None, 512) 524800 dense_2 (Dense) (None, 256) 131328 dropout (Dropout) (None, 256) 0 dense_3 (Dense) (None, 128) 32896 dense_4 (Dense) (None, 10) 1290 ================================================================= Total params: 21,240,010 Trainable params: 21,240,010 Non-trainable params: 0 _________________________________________________________________ 313/313 [==============================] - 158s 503ms/step
Conclusion
Multiclass image classification has proved highly beneficial to the Deep Learning Fraternity. Being one of the most important basic tasks in computer vision, it is widely adopted in the AI industry as a base task even for complex computer vision applications like Image Segmentation, Detection, and Visual recognition tasks.
