How to apply a 2D Max Pooling in PyTorch?

We can apply a 2D Max Pooling over an input image composed of several input planes using the torch.nn.MaxPool2d() module. The input to a 2D Max Pool layer must be of size [N,C,H,W] where N is the batch size, C is the number of channels, H and W are the height and width of the input image, respectively.

The main feature of a Max Pool operation is the filter or kernel size and stride. This module supports TensorFloat32.




  • kernel_size – The size of the window to take a max over.

Along with this parameter, there are some optional parameters also such as stride, padding, dilation, etc. We will take examples of these parameters in detail in the following Python examples.


You could use the following steps to apply a 2D Max Pooling −

  • Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it. To apply 2D Max Pooling on images we need torchvision and Pillow as well.

import torch
import torchvision
from PIL import Image
  • Define input tensor or read the input image. If an input is an image, then we first convert it into a torch tensor.

  • Define kernel_size, stride and other parameters.

  • Next define a Max Pooling pooling by passing the above-defined parameters to torch.nn.MaxPool2d().

pooling = nn.MaxPool2d(kernel_size)
  • Apply the Max Pooling pooling on the input tensor or the image tensor

output = pooling(input)
  • Next print the tensor after Max Pooling. If the input was an image tensor, then to visualize the image, we first convert the tensor obtained after Max Pooling to PIL image. and then visualize the image.

Let's take a couple of examples to have a better understanding of how it works.

Input Image

We will use the following image as the input file in Example 2.

Example 1

In the following Python example, we perform 2D Max Pooling on input tensor. We apply different combinations of kernel_size, stride, padding, and dilation.

# Python 3 program to perform 2D Max Pooling
# Import the required libraries
import torch
import torch.nn as nn

'''input of size = [N,C,H, W] or [C,H, W]
N==>batch size,
C==> number of channels,
H==> height of input planes in pixels,
W==> width in pixels.
input = torch.empty(3, 4, 4).random_(256)
print("Input Tensor:
", input) print("Input Size:",input.size()) # pool of square window of size=3, stride=1 pooling1 = nn.MaxPool2d(3, stride=1) # Perform Max Pool output = pooling1(input) print("Output Tensor:
", output) print("Output Size:",output.size()) # pool of non-square window pooling2 = nn.MaxPool2d((2, 1), stride=(1, 2)) # Perform Max Pool output = pooling2(input) print("Output Tensor:
", output) print("Output Size:",output.size())


Input Tensor:
   tensor([[[129., 61., 166., 156.],
      [130., 5., 15., 73.],
      [ 73., 173., 146., 11.],
      [ 62., 103., 118., 50.]],

      [[ 35., 147., 95., 127.],
      [ 79., 15., 109., 27.],
      [105., 51., 157., 137.],
      [142., 187., 95., 240.]],

      [[ 60., 36., 195., 167.],
      [181., 207., 244., 71.],
      [172., 242., 13., 228.],
      [144., 238., 222., 174.]]])
Input Size: torch.Size([3, 4, 4])
Output Tensor:
   tensor([[[173., 173.],
      [173., 173.]],

      [[157., 157.],
      [187., 240.]],

      [[244., 244.],
      [244., 244.]]])
Output Size: torch.Size([3, 2, 2])
Output Tensor:
   tensor([[[130., 166.],
      [130., 146.],
      [ 73., 146.]],

      [[ 79., 109.],
      [105., 157.],
      [142., 157.]],

      [[181., 244.],
      [181., 244.],
      [172., 222.]]])
Output Size: torch.Size([3, 3, 2])

Example 2

In the following Python example, we perform 2D Max Pooling on an input image. To apply 2D Max Pooling, we first convert the image to a torch tensor and after Max Pooling again convert it to a PIL image for visualization

# Python 3 program to perform 2D Max Pooling on image
# Import the required libraries
import torch
import torchvision
from PIL import Image
import torchvision.transforms as T
import torch.nn.functional as F

# read the input image
img ='elephant.jpg')

# convert the image to torch tensor
img = T.ToTensor()(img)
print("Original size of Image:", img.size()) #Size([3, 466, 700])

# unsqueeze to make 4D
img = img.unsqueeze(0)

# define max pool with square window of size=4, stride=1
pool = torch.nn.MaxPool2d(4, 1)
img = pool(img)
img = img.squeeze(0)
print("Size after MaxPool:",img.size())
img = T.ToPILImage()(img)


Original size of Image: torch.Size([3, 466, 700])
Size after MaxPool: torch.Size([3, 463, 697])

Note that you may get different output images at different runs because of random initialization of the weights and biases.