PyTorch – How to normalize an image with mean and standard deviation?


The Normalize() transform normalizes an image with mean and standard deviation. The torchvision.transforms module provides many important transforms that can be used to perform different types of manipulations on the image data.

Normalize() accepts only tensor images of any size. A tensor image is a torch tensor. A tensor image may have n number of channels. The Normalize() transform normalizes the tensor image for each channel.

As this transform supports only tensor image, the PIL images should be first converted to a torch tensor. And after applying Normalize() transform, we convert the normalized torch tensor to a PIL image.

Steps

We could use the following steps to normalize an image with mean and standard deviation −

  • Import the required libraries. In all the following examples, the required Python libraries are torch, Pillow, and torchvision. Make sure you have already installed them.

import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
  • Read the input image. The input image is a PIL image or a torch tensor. If the input image is PIL image, convert it to a torch tensor.

img = Image.open('sunset.jpg')
# convert image to torch tensor
imgTensor = T.ToTensor()(img)
  • Define a transform to normalize the image with mean and standard deviation. Here, we use mean and std of the ImageNet dataset.

transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
  • Apply the above-defined transform on the input image to normalize the image.

normalized_imgTensor = transform(imgTensor)
  • Convert the normalized tensor image to PIL image.

normalized_img = T.ToPILImage()(normalized_imgTensor)
  • Show the normalized image.

normalized _img.show()

Input Image

This image is used as the input file in all the following examples.

Example 1

The following Python program normalizes the input image to mean and standard deviation. We use the mean and standard deviation of ImageNet dataset.

# import required libraries
import torch
import torchvision.transforms as T
from PIL import Image

# Read the input image
img = Image.open('sunset.jpg')

# convert image to torch tensor
imgTensor = T.ToTensor()(img)

# define a transform to normalize the tensor
transform = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

# normalize the converted tensor using above defined transform
normalized_imgTensor = transform(imgTensor)

# convert the normalized tensor to PIL image
normalized_img = T.ToPILImage()(normalized_imgTensor)

# display the normalized PIL image
normalized_img.show()

Output

It will produce the following output −

Example 2

In this example, we define a Compose transform to perform three transformations.

  • Convert the PIL image to tensor image.

  • Normalize the tensor image.

  • Convert the normalized image tensor to PIL image.

# import required libraries
import torch
import torchvision.transforms as T
from PIL import Image

# read the input image
img = Image.open('sunset.jpg')

# define a transform to:
# convert the PIL image to tensor
# normalize the tensor
# convert the tensor to PIL image
transform = T.Compose([ T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), T.ToPILImage()])

# apply the above tensor on input image
img = transform(img)
img.show()

Output

It will produce the following output −

Updated on: 06-Jan-2022

6K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements