Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How can Tensorflow be used to save and load weights for MNIST dataset?
TensorFlow is a machine learning framework provided by Google. It is an open-source framework used in conjunction with Python to implement algorithms, deep learning applications and much more. It is used in research and for production purposes.
The 'tensorflow' package can be installed on Windows using the below line of code:
pip install tensorflow
Keras is a deep learning API written in Python that runs on top of TensorFlow. It provides a high-level interface for building and training neural network models. When training deep learning models, it's crucial to save model weights periodically to avoid losing progress and enable model deployment.
Loading and Preprocessing MNIST Data
First, let's load the MNIST dataset and preprocess it for training:
import tensorflow as tf
from tensorflow import keras
import os
print("TensorFlow version:", tf.__version__)
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# Use subset for faster execution
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
# Normalize pixel values to 0-1 range
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
print("Training data shape:", train_images.shape)
print("Test data shape:", test_images.shape)
TensorFlow version: 2.x.x Training data shape: (1000, 784) Test data shape: (1000, 784)
Creating and Training the Model
Now let's create a simple neural network model and train it:
# Create a simple sequential model
def create_model():
model = keras.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
# Create and summarize the model
model = create_model()
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 512) 401920 dropout (Dropout) (None, 512) 0 dense_1 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0
Saving Model Weights During Training
Use ModelCheckpoint callback to automatically save weights during training:
# Create checkpoint callback
checkpoint_path = "training_checkpoints/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with checkpoint callback
model.fit(train_images, train_labels,
epochs=3,
validation_data=(test_images, test_labels),
callbacks=[cp_callback])
Epoch 1/3 32/32 [==============================] - 2s 45ms/step - loss: 0.7123 - accuracy: 0.7930 - val_loss: 0.4652 - val_accuracy: 0.8580 Epoch 00001: saving model to training_checkpoints/cp.ckpt Epoch 2/3 32/32 [==============================] - 1s 32ms/step - loss: 0.4441 - accuracy: 0.8750 - val_loss: 0.4096 - val_accuracy: 0.8850 Epoch 00002: saving model to training_checkpoints/cp.ckpt
Loading Saved Weights
Create a new model and load the previously saved weights:
# Create a fresh model
new_model = create_model()
# Evaluate untrained model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=0)
print("Untrained model accuracy: {:5.2f}%".format(100 * acc))
# Load weights from checkpoint
new_model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=0)
print("Restored model accuracy: {:5.2f}%".format(100 * acc))
Untrained model accuracy: 11.40% Restored model accuracy: 88.50%
Manual Weight Saving Options
You can also save weights manually at specific points:
# Save weights manually
model.save_weights('./manual_checkpoint')
# Save in HDF5 format
model.save_weights('./weights.h5')
# Load weights from HDF5 file
new_model_h5 = create_model()
new_model_h5.load_weights('./weights.h5')
# Verify loaded weights work correctly
loss, acc = new_model_h5.evaluate(test_images, test_labels, verbose=0)
print("Model with loaded HDF5 weights accuracy: {:5.2f}%".format(100 * acc))
Model with loaded HDF5 weights accuracy: 88.50%
Conclusion
TensorFlow provides multiple ways to save and load model weights: using ModelCheckpoint callbacks during training for automatic saving, or manual saving with save_weights(). This allows you to preserve trained models and resume training or deploy models for inference without losing training progress.
