keras.fit() and keras.fit_generator()

Keras provides two powerful methods for training neural networks: fit() and fit_generator(). The fit() method is ideal for smaller datasets that can fit in memory, while fit_generator() handles large datasets by processing data in batches dynamically.

Understanding Keras Training Methods

Keras is a high-level neural networks API that simplifies deep learning model development. When training models, you need efficient methods to handle different dataset sizes and memory constraints. These two methods provide flexibility for various training scenarios.

The fit() Method

The fit() method is Keras' standard approach for model training. It loads the entire dataset into memory and processes it in specified batch sizes across multiple epochs.

Syntax

model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_val, y_val))

Key Parameters

  • x_train Training input data

  • y_train Training target labels

  • batch_size Number of samples processed before updating weights

  • epochs Number of complete passes through the dataset

  • validation_data Data for evaluating model performance

Example Using fit()

import tensorflow as tf
from tensorflow import keras
import numpy as np

# Create a simple neural network
model = keras.Sequential([
    keras.layers.Dense(units=64, activation='relu', input_dim=20),
    keras.layers.Dense(units=1, activation='sigmoid')
])

# Compile the model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Generate sample data
x_train = np.random.random((1000, 20))
y_train = np.random.randint(2, size=(1000, 1))

# Train using fit()
history = model.fit(x_train, y_train, epochs=5, batch_size=32, verbose=1)
Epoch 1/5
32/32 [==============================] - 1s 2ms/step - loss: 0.6954 - accuracy: 0.4980
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.6936 - accuracy: 0.5010
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.6924 - accuracy: 0.5030
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.6915 - accuracy: 0.5050
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.6908 - accuracy: 0.5070

The fit_generator() Method

The fit_generator() method processes data in batches using a generator function. This approach is memory-efficient for large datasets and enables real-time data augmentation.

Syntax

model.fit_generator(generator=train_generator, steps_per_epoch=steps_per_epoch, epochs=epochs, validation_data=val_generator, validation_steps=val_steps)

Key Parameters

  • generator Generator yielding batches of training data

  • steps_per_epoch Number of batches per epoch

  • validation_data Generator for validation data

  • validation_steps Number of validation batches per epoch

Example Using fit_generator()

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# Create sample data
x_train = np.random.random((1000, 20))
y_train = np.random.randint(2, size=(1000, 1))

# Define a simple data generator
def data_generator(x, y, batch_size):
    while True:
        indices = np.random.choice(len(x), batch_size)
        yield x[indices], y[indices]

# Create the model
model = Sequential([
    Dense(64, activation='relu', input_dim=20),
    Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Create generator
train_gen = data_generator(x_train, y_train, batch_size=32)

# Train using fit_generator()
history = model.fit_generator(
    generator=train_gen,
    steps_per_epoch=30,
    epochs=3,
    verbose=1
)
Epoch 1/3
30/30 [==============================] - 1s 3ms/step - loss: 0.6942 - accuracy: 0.5021
Epoch 2/3
30/30 [==============================] - 0s 2ms/step - loss: 0.6928 - accuracy: 0.5052
Epoch 3/3
30/30 [==============================] - 0s 2ms/step - loss: 0.6919 - accuracy: 0.5073

Comparison

Aspect fit() fit_generator()
Memory Usage Loads entire dataset Processes batches dynamically
Best For Small to medium datasets Large datasets, data augmentation
Data Augmentation Limited Real-time augmentation
Setup Complexity Simple Requires generator function

Important Note

fit_generator() has been deprecated in TensorFlow 2.1+. Use model.fit() with generators directly:

# Modern approach (TensorFlow 2.1+)
model.fit(train_gen, steps_per_epoch=30, epochs=3)

Conclusion

Choose fit() for smaller datasets that fit in memory. Use generators with fit() for large datasets or real-time data augmentation. The generator approach provides memory efficiency and flexibility for complex training scenarios.

Updated on: 2026-03-27T15:04:45+05:30

2K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements