How can transfer learning be implemented in Python using Keras?


Tensorflow is a machine learning framework that is provided by Google. It is an open-source framework used in conjunction with Python to implement algorithms, deep learning applications, and much more. It is used in research and for production purposes.

Tensor is a data structure used in TensorFlow. It helps connect edges in a flow diagram. This flow diagram is known as the ‘Data flow graph’. Tensors are nothing but a multidimensional array or a list.

Keras means ‘horn’ in Greek. Keras was developed as a part of research for the project ONEIROS (Open ended Neuro-Electronic Intelligent Robot Operating System). Keras is a deep learning API, which is written in Python. It is a high-level API that has a productive interface that helps solve machine learning problems.

It runs on top of the Tensorflow framework. It was built to help experiment in a quick manner. It provides essential abstractions and building blocks that are essential in developing and encapsulating machine learning solutions.

It is highly scalable and comes with cross-platform abilities. This means Keras can be run on TPU or clusters of GPUs. Keras models can also be exported to run in a web browser or a mobile phone as well.

Keras is already present within the Tensorflow package. It can be accessed using the below line of code.

import tensorflow
from tensorflow import keras

We are using Google Colaboratory to run the below code. Google Colab or Colaboratory helps run Python code over the browser and requires zero configuration and free access to GPUs (Graphical Processing Units). Colaboratory has been built on top of Jupyter Notebook. Following is the code snippet −

Example

model = keras.Sequential([
   keras.Input(shape=(784))
   layers.Dense(32, activation='relu'),
   layers.Dense(32, activation='relu'),
   layers.Dense(32, activation='relu'),
   layers.Dense(10),
])
print("Load the pre-trained weights")
model.load_weights(...)
print("Freeze all the layers except the last layer")
for layer in model.layers[:-1]:
   layer.trainable = False
print("Recompile the model and train it")
print("The last layer weights will be updated")
model.compile(...)
model.fit(...)

Code credit − https://www.tensorflow.org/guide/keras/sequential_model

Output

Load the pre-trained weights
Freeze all the layers except the last layer
Recompile the model and train it
The last layer weights will be updated

Explanation

  • Transfer learning indicates freezing of the bottom layers in a model and training the top layers.

  • The sequential model is built.

  • The pre-trained weights of the old model are loaded and bound with this model.

  • The bottom layers are frozen except for the last layer.

  • The layers are iterated over and the ‘layer.trainable’ is set to ‘False’ for every layer except the last layer.

  • It is compiled and fit to the data.

Updated on: 18-Jan-2021

219 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements