Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How can Keras be used to load weights from checkpoint and re-evaluate the model using Python?
TensorFlow is a machine learning framework provided by Google for implementing algorithms, deep learning applications, and neural networks. When training complex models, saving and loading weights from checkpoints is essential for resuming training or deploying models.
Keras, the high-level API built into TensorFlow, provides simple methods to save model weights during training and load them later for evaluation or inference.
Installing TensorFlow
Install TensorFlow using pip ?
pip install tensorflow
Loading Weights from Checkpoint
Use the load_weights() method to restore saved model weights from a checkpoint file ?
import tensorflow as tf
from tensorflow import keras
import numpy as np
# Assume we have a pre-trained model and test data
# This example shows the weight loading and evaluation process
print("Loading weights from checkpoint...")
model.load_weights(checkpoint_path)
print("Re-evaluating the model...")
loss, accuracy = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model accuracy: {:.3f}%".format(100 * accuracy))
Complete Example with Model Creation
Here's a complete example showing how to create a model, save weights, and reload them ?
import tensorflow as tf
from tensorflow import keras
import numpy as np
# Load sample data (MNIST)
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
# Create a simple model
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Train briefly and save weights
checkpoint_path = "model_checkpoint.ckpt"
model.fit(train_images[:1000], train_labels[:1000], epochs=1)
model.save_weights(checkpoint_path)
# Create a new model with same architecture
new_model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10)
])
new_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Load weights from checkpoint
print("Loading weights from checkpoint...")
new_model.load_weights(checkpoint_path)
# Re-evaluate the model
print("Re-evaluating the model...")
loss, accuracy = new_model.evaluate(test_images, test_labels, verbose=2)
print("Restored model accuracy: {:.3f}%".format(100 * accuracy))
Key Points
Checkpoint Path − Specify the file path where weights are saved (usually with .ckpt extension)
Model Architecture − The model must have the same architecture when loading weights
Evaluation − Use
model.evaluate()to test performance on validation/test dataCompilation − Recompile the model after loading weights if you plan to continue training
Automatic Checkpoint Saving
Use ModelCheckpoint callback to automatically save weights during training ?
# Save weights automatically during training
checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
save_best_only=True,
monitor='val_accuracy',
verbose=1
)
model.fit(train_images, train_labels,
epochs=5,
validation_data=(test_images, test_labels),
callbacks=[checkpoint_callback])
Conclusion
Loading weights from checkpoints allows you to restore trained models and continue evaluation or training. Use load_weights() to restore saved parameters and evaluate() to test model performance on new data.
