Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How can Keras be used to save weights for model after specific number of epochs in Python?
Keras is a high-level deep learning API that runs on top of TensorFlow. When training neural networks, it's crucial to save model weights periodically to prevent loss of progress and enable model recovery. Keras provides the ModelCheckpoint callback to automatically save weights after specific intervals.
Setting Up Model Checkpointing
The ModelCheckpoint callback allows you to save weights at regular intervals during training. Here's how to configure it ?
import tensorflow as tf
import os
# Define checkpoint path with epoch number formatting
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
print("Checkpoint directory created:", checkpoint_dir)
Checkpoint directory created: training_2
Creating the ModelCheckpoint Callback
Configure the callback to save weights every few epochs ?
import tensorflow as tf
# Create a simple model for demonstration
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
# Set up checkpoint callback
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
batch_size = 32
print("Creating callback to save model weights every 5 epochs")
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5 * batch_size # Save every 5 epochs (5 * batch_size steps)
)
# Create model instance
model = create_model()
print("Model created successfully")
Creating callback to save model weights every 5 epochs Model created successfully
Alternative: Save Every N Epochs
You can also save weights at specific epoch intervals using the period parameter ?
import tensorflow as tf
# Alternative approach: Save every 3 epochs
cp_callback_epochs = tf.keras.callbacks.ModelCheckpoint(
filepath="training_2/cp-{epoch:04d}.ckpt",
verbose=1,
save_weights_only=True,
period=3 # Save every 3 epochs
)
print("Callback configured to save every 3 epochs")
Callback configured to save every 3 epochs
Training with Checkpointing
Here's how to use the callback during training ?
# Generate sample data for demonstration
import numpy as np
x_train = np.random.random((1000, 784))
y_train = np.random.randint(0, 10, (1000,))
# Train the model with checkpointing
history = model.fit(x_train, y_train,
epochs=10,
batch_size=32,
callbacks=[cp_callback],
validation_split=0.2,
verbose=1)
Key Parameters
| Parameter | Description | Example |
|---|---|---|
filepath |
Path to save weights | "cp-{epoch:04d}.ckpt" |
save_freq |
Save frequency in batches | 5 * batch_size |
period |
Save every N epochs | 3 |
save_weights_only |
Save only weights, not full model | True |
Loading Saved Weights
To restore weights from a checkpoint ?
# Create a new model instance
new_model = create_model()
# Load weights from a specific checkpoint
checkpoint_path = "training_2/cp-0005.ckpt"
new_model.load_weights(checkpoint_path)
print(f"Weights loaded from {checkpoint_path}")
Weights loaded from training_2/cp-0005.ckpt
Conclusion
The ModelCheckpoint callback in Keras provides flexible options for saving model weights during training. Use save_freq to control batch-level frequency or period for epoch-level intervals, ensuring your training progress is preserved automatically.
