PyTorch – How to convert an image to grayscale?


To convert an image to grayscale, we apply Grayscale() transformation. It's one of the transforms provided by the torchvision.transforms module. This module contains many important transformations that can be used to perform different types manipulations on the image data.

Grayscale() transformation accepts both PIL and tensor images or a batch of tensor images. A tensor image is a PyTorch Tensor with shape [3, H, W], where H is the image height and W is the image width. A batch of tensor images is also a torch tensor with [B, 3, H, W]. B is the number of images in the batch.

Syntax

torchvision.transforms.Grayscale()(img)

It returns a grayscale image.

Steps

We could use the following steps to convert an image to grayscale −

  • Import the required libraries. In all the following examples, the required Python libraries are torch, Pillow, and torchvision. Make sure you have already installed them.

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
  • Read the input image. The input image is a PIL image or a torch tensor.

img = Image.open('laptop.jpg')
  • Define a transform to convert the original input image to grayscale.

transform = transforms.Grayscale()
  • Apply the above defined transform on the input image convert it to grayscale.

img = transform(img)
  • Visualize the grayscaled image.

img.show()

Input Image

The following image is used as the input in both the examples.

Example 1

The following Python3 program converts the input PIL image to grayscale.

# import required libraries
import torch
import torchvision.transforms as transforms
from PIL import Image

# Read the image
img = Image.open('laptop.jpg')

# define a transform to convert the image to grayscale
transform = transforms.Grayscale()

# apply the above transform on the image
img = transform(img)

# dispaly the image
img.show()

# num of output channels = 1
print(img.mode)

Output

It will produce the following output −

Note that the mode of the grayscale image is L. The grayscale image has a single channel.

Example 2

The following Python3 program shows how to convert an input image to grayscale.

# Python program to convert an image to grayscale
# import required libraries
import torch
import torchvision.transforms as transforms
from PIL import Image

# Read the image
img = Image.open('laptop.jpg')

# define a transform to convert the image to grayscale
transform = transforms.Grayscale(3)

# apply the above transform on the image
img = transform(img)

# display the image
img.show()

# the num of output channels =3, R=G=B, but Gray image
print(img.mode)

Output

It will produce the following output −

Note that the mode of output gray image is RGB. It has three channels, Red, Green, and Blue, but it's a grayscale image.

Updated on: 06-Jan-2022

6K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements