How can functional API be used to work with residual connections in Python?

The Keras functional API provides powerful tools for building complex neural network architectures with residual connections. Unlike sequential models, the functional API allows you to create models with skip connections that bypass layers and add outputs at different points.

What are Residual Connections?

Residual connections (skip connections) allow the output of an earlier layer to be added directly to the output of a later layer. This helps solve the vanishing gradient problem in deep networks and enables training of very deep models like ResNet.

Input Conv2D Conv2D + Output Skip Connection

Setting Up Keras

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

Building a ResNet Model with Residual Connections

Here's how to create a toy ResNet model for CIFAR-10 using the functional API ?

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

print("Toy ResNet model for CIFAR10")
print("Layers generated for model")

# Input layer
inputs = keras.Input(shape=(32, 32, 3), name="img")

# Initial convolution layers
x = layers.Conv2D(32, 3, activation="relu")(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
block_1_output = layers.MaxPooling2D(3)(x)

# First residual block
x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_1_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
block_2_output = layers.add([x, block_1_output])  # Skip connection

# Second residual block
x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_2_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
block_3_output = layers.add([x, block_2_output])  # Skip connection

# Output layers
x = layers.Conv2D(64, 3, activation="relu")(block_3_output)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10)(x)

# Create model
model = keras.Model(inputs, outputs, name="toy_resnet")
print("More information about the model")
model.summary()
Toy ResNet model for CIFAR10
Layers generated for model
More information about the model
Model: "toy_resnet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
img (InputLayer)             [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 30, 30, 32)        896       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        18496     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 9, 9, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 9, 9, 64)          36928     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 9, 9, 64)          36928     
_________________________________________________________________
add (Add)                    (None, 9, 9, 64)          0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 9, 9, 64)          36928     
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 9, 9, 64)          36928     
_________________________________________________________________
add_1 (Add)                  (None, 9, 9, 64)          0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 7, 7, 64)          36928     
_________________________________________________________________
global_average_pooling2d (Gl (None, 64)                0         
_________________________________________________________________
dense (Dense)                (None, 256)               16640     
_________________________________________________________________
dropout (Dropout)            (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                2570      
=================================================================
Total params: 223,242
Trainable params: 223,242
Non-trainable params: 0
_________________________________________________________________

Key Components of Residual Connections

  • Skip Connections: The layers.add() function combines the output of the current layer with the output from an earlier layer

  • Identity Mapping: The skip connection preserves the original input, allowing gradients to flow more easily during backpropagation

  • Non-linear Topology: Unlike sequential models, residual networks have branches that merge at specific points

  • Same Padding: Ensures that feature maps have compatible dimensions for addition operations

Why Use Residual Connections?

Residual connections solve several problems in deep learning ?

  • Vanishing Gradient Problem: Skip connections provide direct paths for gradients to flow backward

  • Deeper Networks: Enable training of networks with hundreds of layers

  • Identity Function: If a layer learns nothing useful, the skip connection preserves the input unchanged

  • Faster Convergence: Networks typically train faster and achieve better performance

Conclusion

The Keras functional API makes it straightforward to implement residual connections using layers.add(). These skip connections enable deeper networks to train effectively and achieve better performance than traditional sequential architectures.

Updated on: 2026-03-25T14:50:17+05:30

332 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements