How to find mean across the image channels in PyTorch?


RGB images have three channels, Red, Green, and Blue. We need to compute the mean of the image pixel values across these image channels. For this purpose, we use the method torch.mean(). But the input parameter to this method is a PyTorch tensor. So, we first convert the image to the PyTorch tensor and then apply this method. It returns the mean value of all the elements in the tensor. To find the mean across the image channels, we set the parameter dim = [1,2].

Steps

  • Import the required library. In all the following Python examples, the required Python libraries are torch, torchvision, Pillow and OpenCV. Make sure you have already installed them.

  • Read the input image using image.open() and assign it to a variable "img".

  • Define a transform to convert the PIL image to PyTorch Tensor

  • Convert the image "img" to a PyTorch tensor using the above-defined transform and assign this tensor to "imgTensor".

  • Compute torch.mean(imgTensor, dim = [1,2]). It returns a tensor of three values. These three values are the mean values for the three channels RGB. You can assign these three mean values separately to three new variables "R_mean", "G_mean", and "B_mean".

  • Print the three mean values "R_mean", "G_mean", and "B_mean" of the image pixel.

Input Image

We will use the following image as the input in both the examples.

Example 1

# Python program to find mean across the image channels
# import necessary libraries
import torch
from PIL import Image
import torchvision.transforms as transforms

# Read the input image
img = Image.open('opera.jpg')

# Define transform to convert the image to PyTorch Tensor
transform = transforms.ToTensor()

# Convert image to PyTorch Tensor (Image Tensor)
imgTensor = transform(img)
print("Shape of Image Tensor:\n", imgTensor.shape)

# Compute mean of the Image Tensor across image channels RGB
R_mean, G_mean ,B_mean = torch.mean(imgTensor, dim = [1,2])

# print mean across image channel RGB
print("Mean across Read channel:", R_mean)
print("Mean across Green channel:", G_mean)
print("Mean across Blue channel:", B_mean)

Output

Shape of Image Tensor:
   torch.Size([3, 447, 640])
Mean across Read channel: tensor(0.1487)
Mean across Green channel: tensor(0.1607)
Mean across Blue channel: tensor(0.2521)

Example 2

We can also read the image using OpenCV. The images read using OpenCV are of type numpy.ndarray. Here, in this example, we use a different way to calculate the mean. We use imgTensor.mean(), the basic operation on tensor. Have a look at the following example.

# Python program to find mean across the image channels
# import necessary libraries
import torch
import cv2
import torchvision.transforms as transforms

# Read the input image either using cv2 or PIL
img = cv2.imread('opera.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Define transform to convert the image to PyTorch Tensor
transform = transforms.ToTensor()

# Convert image to PyTorch Tensor (Image Tensor)
imgTensor = transform(img)
print("Shape of Image Tensor:\n", imgTensor.shape)

# compute mean of the Image Tensor across image channels RGB
# The other way to compute the mean
R_mean, G_mean ,B_mean = imgTensor.mean(dim = [1,2])

# print mean across image channel RGB
print("Mean across Read channel:", R_mean)
print("Mean across Green channel:", G_mean)
print("Mean across Blue channel:", B_mean)

Output

Shape of Image Tensor:
   torch.Size([3, 447, 640])
Mean across Read channel: tensor(0.1487)
Mean across Green channel: tensor(0.1607)
Mean across Blue channel: tensor(0.2521)

Updated on: 06-Nov-2021

2K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements