Multiclass image classification using Transfer learning


Introduction

One of the most common tasks involved in Deep Learning based on Image data is Image Classification. Image classification has become more interesting in the research field due to the development of new and high-performing machine learning frameworks. Such classification can either be binary where two classes of images are present or multiclass classification which deals with more than two image classes. Here, in this article, we are going to explore transfer learning with multiclass image classification.

Multiclass Image Classification

With the advancement in artificial Neural networks and the development of Convolutional Neural Networks complex operations on images have become easy and have contributed to the growth and development of tasks like Multiclass image classification, Image Segmentation, and Image Detection.

Multiclass Image Classification is one of the very primary yet powerful computer vision tasks that can be performed using CNN networks. In this method, we have more than two classes of images that are labeled according to their categories ( eg. CIFAR, Fashion MNIST).

For Classification, we can either prepare our own labeled dataset or download already available image datasets like CIFAR 10. Preprocessing techniques may include task like image augmentation to bring variation in the image dataset if the number of images per class is low.

For training the model we may either build a model architecture from scratch using any of the Deep Learning frameworks like Tensorflow or Pytorch, etc., or use a readily available backbone architecture like VGG16, Resnet, etc. The advantage of the later one is that we do have to build the architecture from scratch and only focus on finetuning the model or changing the last 1 or 2 layers as per our use case. This is where transfer learning comes into play which is a very intuitive technique to train image models.

What is Transfer Learning and why it's important?

Transfer learning is a research problem in the field of machine learning. It stores the knowledge gained while solving one problem and applies it to a different but related problem. For example, the knowledge gained while learning to recognize cats could apply when trying to recognize cheetahs. In deep learning, transfer learning is a technique whereby a neural network model is first trained on a problem similar to the problem that is being solved. Transfer learning has the advantage of decreasing the training time for a learning model and can result in lower generalization errors. We can use a pretrained model which is trained on other datasets like ImageNet and modify the last layers to meet the purpose of our task. In such a scenario, we may save time, effort, and resources in training a model from scratch. Such trained models have tons of image patterns and information on already trained images through rigorous training.

Code Implementation

In this example, we are going to use the CIFAR-10 dataset to do multiclass classification. We would also be using the VGG19 network and modifyin it to do transfer learning

Dataset used

The dataset is CIFRAR 10.it is available from Canadian Institute for Advanced Research (CIFAR). It consists of 60000 32×32 color images in 10 classes, with 6000 images per class. The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 50000 training images and 10000 test images in this dataset. The dataset can be imported from keras.

Implementation using Keras API

Example

import numpy as np import pandas as pd from sklearn.utils.multiclass import unique_labels import os import matplotlib.pyplot as plt import matplotlib.image as mpimg import seaborn as sns import itertools from keras.datasets import cifar10 from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix from keras import Sequential from keras.applications import VGG19 #For Transfer Learning from keras.preprocessing.image import ImageDataGenerator from keras.optimizers import SGD,Adam from keras.callbacks import ReduceLROnPlateau from keras.layers import Flatten,Dense,BatchNormalization,Activation,Dropout from keras.utils import to_categorical # Download the CIFAR dataset (x_train,y_train),(x_test,y_test) = cifar10.load_data() #defining training and test sets x_train,x_val,y_train,y_val=train_test_split(x_train,y_train,test_size=.3) #Dimension of the dataset print((x_train.shape,y_train.shape)) print((x_val.shape,y_val.shape)) print((x_test.shape,y_test.shape)) #One Hot Encoding y_train=to_categorical(y_train) y_val=to_categorical(y_val) y_test=to_categorical(y_test) #Verifying the dimension after one hot encoding print((x_train.shape,y_train.shape)) print((x_val.shape,y_val.shape)) print((x_test.shape,y_test.shape)) #Image Data Augmentation train_generator = ImageDataGenerator(rotation_range=2, horizontal_flip=True, zoom_range=.1) val_generator = ImageDataGenerator(rotation_range=2, horizontal_flip=True, zoom_range=.1) test_generator = ImageDataGenerator(rotation_range=2, horizontal_flip= True, zoom_range=.1) #Fitting the augmentation defined above to the data train_generator.fit(x_train) val_generator.fit(x_val) test_generator.fit(x_test) #Learning Rate Annealer lrr= ReduceLROnPlateau(monitor='val_acc', factor=.01, patience=3, min_lr=1e-5) #Defining the VGG Convolutional Neural Net base_model = VGG19(include_top = False, weights = 'imagenet', input_shape = (32,32,3), classes = y_train.shape[1]) #Adding the final layers to the above base models where the actual classification is done in the dense layers model= Sequential() model.add(base_model) model.add(Flatten()) #Model summary model.summary() #Adding the Dense layers along with activation and batch normalization model.add(Dense(1024,activation=('relu'),input_dim=512)) model.add(Dense(512,activation=('relu'))) model.add(Dense(256,activation=('relu'))) model.add(Dropout(.3)) model.add(Dense(128,activation=('relu'))) #model.add(Dropout(.2)) model.add(Dense(10,activation=('softmax'))) #Checking the final model summary model.summary() #Making prediction predict_y = model.predict(x_test) y_pred=np.argmax(predict_y,axis=1) y_true=np.argmax(y_test,axis=1)

Output

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 2s 0us/step
((35000, 32, 32, 3), (35000, 1))
((15000, 32, 32, 3), (15000, 1))
((10000, 32, 32, 3), (10000, 1))
((35000, 32, 32, 3), (35000, 10))
((15000, 32, 32, 3), (15000, 10))
((10000, 32, 32, 3), (10000, 10))
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
80134624/80134624 [==============================] - 1s 0us/step
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
vgg19 (Functional) (None, 1, 1, 512) 20024384
flatten (Flatten) (None, 512) 0
=================================================================
Total params: 20,024,384
Trainable params: 20,024,384
Non-trainable params: 0
_________________________________________________________________
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
vgg19 (Functional) (None, 1, 1, 512) 20024384
flatten (Flatten) (None, 512) 0
dense (Dense) (None, 1024) 525312
dense_1 (Dense) (None, 512) 524800
dense_2 (Dense) (None, 256) 131328
dropout (Dropout) (None, 256) 0
dense_3 (Dense) (None, 128) 32896
dense_4 (Dense) (None, 10) 1290
=================================================================
Total params: 21,240,010
Trainable params: 21,240,010
Non-trainable params: 0
_________________________________________________________________
313/313 [==============================] - 158s 503ms/step

Conclusion

Multiclass image classification has proved highly beneficial to the Deep Learning Fraternity. Being one of the most important basic tasks in computer vision, it is widely adopted in the AI industry as a base task even for complex computer vision applications like Image Segmentation, Detection, and Visual recognition tasks.

Updated on: 01-Dec-2022

2K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements