How can Keras be used to train a model with new callback in Python?

TensorFlow is a machine learning framework provided by Google. It is an open−source framework used with Python to implement algorithms, deep learning applications and much more. Keras is a high−level deep learning API that runs on top of TensorFlow, providing essential abstractions for building machine learning models with callbacks for training customization.

The 'tensorflow' package can be installed on Windows using the below line of code −

pip install tensorflow

Keras is already present within the TensorFlow package and can be accessed using −

import tensorflow as tf
from tensorflow import keras

Understanding Callbacks in Keras

Callbacks are functions that are called at specific points during model training. They allow you to customize training behavior, save model checkpoints, monitor metrics, and implement early stopping.

Creating a ModelCheckpoint Callback

The most common callback is ModelCheckpoint, which saves model weights during training ?

import tensorflow as tf
from tensorflow import keras
import os

# Create checkpoint directory
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

# Create ModelCheckpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True,
    verbose=1
)

Training Model with Callback

Here's how to train a model using the checkpoint callback ?

# Assuming you have a model, train_images, train_labels, test_images, test_labels
print("The model is trained with new callback")

model.fit(train_images,
    train_labels,
    epochs=50,
    callbacks=[cp_callback],
    validation_data=(test_images, test_labels),
    verbose=0)

# List checkpoint directory contents
import os
print("\nCheckpoint files:")
for file in os.listdir(checkpoint_dir):
    print(file)

# Get the latest checkpoint
print("The latest checkpoint being updated")
latest = tf.train.latest_checkpoint(checkpoint_dir)
print(latest)

Complete Working Example

Here's a complete example with a simple neural network ?

import tensorflow as tf
from tensorflow import keras
import numpy as np
import os

# Create sample data
train_images = np.random.random((1000, 28, 28))
train_labels = np.random.randint(0, 10, (1000,))
test_images = np.random.random((200, 28, 28))
test_labels = np.random.randint(0, 10, (200,))

# Create a simple model
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Set up checkpoint callback
checkpoint_dir = './training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True,
    verbose=1
)

# Train with callback
model.fit(train_images,
    train_labels,
    epochs=5,
    callbacks=[cp_callback],
    validation_data=(test_images, test_labels))

Other Useful Callbacks

Keras provides several built−in callbacks for different purposes ?

# Early Stopping callback
early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True
)

# Learning Rate Scheduler
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(
    lambda epoch: 0.01 * (0.1 ** (epoch // 10))
)

# CSV Logger
csv_logger = tf.keras.callbacks.CSVLogger('training.log')

# Use multiple callbacks
callbacks_list = [cp_callback, early_stop, lr_scheduler, csv_logger]

model.fit(train_images, train_labels,
    epochs=50,
    callbacks=callbacks_list,
    validation_data=(test_images, test_labels))

Key Benefits of Using Callbacks

Callback Type Purpose Benefit
ModelCheckpoint Save model weights Prevent loss of training progress
EarlyStopping Stop training when no improvement Prevent overfitting
LearningRateScheduler Adjust learning rate Better convergence
CSVLogger Log training metrics Track training history

Conclusion

Callbacks in Keras provide powerful customization for model training. Use ModelCheckpoint to save progress, EarlyStopping to prevent overfitting, and other callbacks to monitor and control the training process effectively.

Updated on: 2026-03-25T15:41:53+05:30

195 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements