Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How can Keras be used to train a model with new callback in Python?
TensorFlow is a machine learning framework provided by Google. It is an open−source framework used with Python to implement algorithms, deep learning applications and much more. Keras is a high−level deep learning API that runs on top of TensorFlow, providing essential abstractions for building machine learning models with callbacks for training customization.
The 'tensorflow' package can be installed on Windows using the below line of code −
pip install tensorflow
Keras is already present within the TensorFlow package and can be accessed using −
import tensorflow as tf from tensorflow import keras
Understanding Callbacks in Keras
Callbacks are functions that are called at specific points during model training. They allow you to customize training behavior, save model checkpoints, monitor metrics, and implement early stopping.
Creating a ModelCheckpoint Callback
The most common callback is ModelCheckpoint, which saves model weights during training ?
import tensorflow as tf
from tensorflow import keras
import os
# Create checkpoint directory
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Create ModelCheckpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True,
verbose=1
)
Training Model with Callback
Here's how to train a model using the checkpoint callback ?
# Assuming you have a model, train_images, train_labels, test_images, test_labels
print("The model is trained with new callback")
model.fit(train_images,
train_labels,
epochs=50,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
# List checkpoint directory contents
import os
print("\nCheckpoint files:")
for file in os.listdir(checkpoint_dir):
print(file)
# Get the latest checkpoint
print("The latest checkpoint being updated")
latest = tf.train.latest_checkpoint(checkpoint_dir)
print(latest)
Complete Working Example
Here's a complete example with a simple neural network ?
import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
# Create sample data
train_images = np.random.random((1000, 28, 28))
train_labels = np.random.randint(0, 10, (1000,))
test_images = np.random.random((200, 28, 28))
test_labels = np.random.randint(0, 10, (200,))
# Create a simple model
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Set up checkpoint callback
checkpoint_dir = './training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True,
verbose=1
)
# Train with callback
model.fit(train_images,
train_labels,
epochs=5,
callbacks=[cp_callback],
validation_data=(test_images, test_labels))
Other Useful Callbacks
Keras provides several built−in callbacks for different purposes ?
# Early Stopping callback
early_stop = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
restore_best_weights=True
)
# Learning Rate Scheduler
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(
lambda epoch: 0.01 * (0.1 ** (epoch // 10))
)
# CSV Logger
csv_logger = tf.keras.callbacks.CSVLogger('training.log')
# Use multiple callbacks
callbacks_list = [cp_callback, early_stop, lr_scheduler, csv_logger]
model.fit(train_images, train_labels,
epochs=50,
callbacks=callbacks_list,
validation_data=(test_images, test_labels))
Key Benefits of Using Callbacks
| Callback Type | Purpose | Benefit |
|---|---|---|
| ModelCheckpoint | Save model weights | Prevent loss of training progress |
| EarlyStopping | Stop training when no improvement | Prevent overfitting |
| LearningRateScheduler | Adjust learning rate | Better convergence |
| CSVLogger | Log training metrics | Track training history |
Conclusion
Callbacks in Keras provide powerful customization for model training. Use ModelCheckpoint to save progress, EarlyStopping to prevent overfitting, and other callbacks to monitor and control the training process effectively.
