Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How can Keras be used to plot the model as a graph and display input and output shapes using Python?
Keras provides powerful visualization tools to plot neural network models as graphs and display their architecture. The plot_model utility from keras.utils helps visualize model structure, layer connections, and tensor shapes, making it easier to understand complex architectures.
Prerequisites
Before plotting models, ensure you have the required dependencies installed ?
# Install required packages pip install tensorflow pydot graphviz
Creating a Sample Model
Let's create a simple sequential model to demonstrate plotting capabilities ?
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Create a simple sequential model
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),
layers.Dropout(0.3),
layers.Dense(32, activation='relu'),
layers.Dense(10, activation='softmax')
])
print("Model created successfully")
print(f"Total parameters: {model.count_params()}")
Model created successfully Total parameters: 4522
Basic Model Plotting
Use plot_model to generate a basic visualization of the model structure ?
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Create model
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),
layers.Dropout(0.3),
layers.Dense(32, activation='relu'),
layers.Dense(10, activation='softmax')
])
# Plot basic model structure
keras.utils.plot_model(model, "basic_model.png")
print("Basic model plot saved as 'basic_model.png'")
Basic model plot saved as 'basic_model.png'
Displaying Input and Output Shapes
Add the show_shapes=True parameter to display tensor dimensions at each layer ?
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Create model
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),
layers.Dropout(0.3),
layers.Dense(32, activation='relu'),
layers.Dense(10, activation='softmax')
])
# Plot model with shapes
keras.utils.plot_model(
model,
"model_with_shapes.png",
show_shapes=True,
show_layer_names=True
)
print("Model plot with shapes saved as 'model_with_shapes.png'")
Model plot with shapes saved as 'model_with_shapes.png'
Advanced Plotting Options
Customize the plot with additional parameters for better visualization ?
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Create a functional API model for demonstration
inputs = keras.Input(shape=(32,))
x = layers.Dense(64, activation='relu')(inputs)
x = layers.Dropout(0.3)(x)
x = layers.Dense(32, activation='relu')(x)
outputs = layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
# Advanced plotting with all options
keras.utils.plot_model(
model,
"detailed_model.png",
show_shapes=True,
show_dtype=True,
show_layer_names=True,
rankdir='TB', # Top to Bottom direction
expand_nested=True,
dpi=150
)
print("Detailed model plot saved")
print("Plot shows layer names, shapes, and data types")
Detailed model plot saved Plot shows layer names, shapes, and data types
Plotting Options Summary
| Parameter | Description | Default |
|---|---|---|
show_shapes |
Display input/output shapes | False |
show_dtype |
Display data types | False |
show_layer_names |
Display layer names | True |
rankdir |
Direction ('TB', 'BT', 'LR', 'RL') | 'TB' |
dpi |
Image resolution | 96 |
Conclusion
Keras plot_model utility provides an excellent way to visualize neural network architectures. Use show_shapes=True to display tensor dimensions and customize the plot with various parameters for better documentation and understanding of your models.
