Functional Transforms for Computer Vision using PyTorch


Computer vision tasks often require preprocessing and augmentation of image data to improve model performance and generalization. PyTorch, a popular deep learning framework, provides a powerful library for image transformations called torchvision.transforms. This library offers a wide range of predefined transforms for data augmentation and preprocessing. However, in some cases, predefined transforms may not be sufficient, and we need to apply custom transformations to our image data. In this blog post, we will explore the concept of functional transforms in PyTorch and demonstrate how to create and apply custom transforms for computer vision tasks.

Understanding Transforms in PyTorch

Transforms in PyTorch are operations that can be applied to input data, such as images, to modify their appearance or properties. These transformations can be categorized into two types: class transforms and functional transforms. Class transforms are implemented as classes that define both the transformation operation and the parameters associated with it. On the other hand, functional transforms are implemented as functions that perform the transformation operation on the input data.

Functional transforms offer more flexibility compared to class transforms as they allow us to define custom operations by leveraging the power of PyTorch tensors and functions. This makes functional transforms particularly useful when we need to apply complex or parameterized transformations to our image data.

Creating Custom Functional Transforms

To create a custom functional transform, we need to define a function that accepts an input tensor and performs the desired transformation operation. Let's say we want to create a custom transform called grayscale, which converts an RGB image to grayscale. Here's an example implementation 

import torch

def grayscale(img):
   """Converts an RGB image to grayscale.
    
   Args:
       img (Tensor): Input RGB image tensor of shape (C, H, W).
        
   Returns:
       Tensor: Grayscale image tensor of shape (1, H, W).
   """
   if img.size(0) != 3:
       raise ValueError("Input image must have 3 channels (RGB).")
        
   # Apply grayscale transformation
   grayscale_img = torch.mean(img, dim=0, keepdim=True)
    
   return grayscale_img

In this example, we define the grayscale function that takes an input RGB image tensor img of shape (C, H, W), where C represents the number of channels (3 for RGB images), and H and W represent the height and width of the image, respectively. The function first checks if the input image has the correct number of channels (3 in this case), and then applies the grayscale transformation by calculating the mean value across the channel dimension. The resulting grayscale image tensor is returned with shape (1, H, W), where the grayscale image has a single channel.

Applying Functional Transforms

Once we have defined our custom functional transform, we can apply it to our image data using the torchvision.transforms.functional module. This module provides utility functions for working with functional transforms. To apply a functional transform, we simply call the transform function and pass in the input data. Here's an example of applying the grayscale transform to an image 

from torchvision.transforms import functional as F
from PIL import Image

# Load the image using PIL
image = Image.open("image.jpg")

# Convert PIL image to PyTorch tensor
tensor_image = F.to_tensor(image)

# Apply the custom grayscale transform
grayscale_image = grayscale(tensor_image)

# Convert the grayscale tensor back to PIL image
grayscale_pil_image = F.to_pil_image(grayscale_image)

# Save the grayscale image
grayscale_pil_image.save("grayscale_image.jpg")

In this example, we first load an image using the PIL library and convert it to a PyTorch tensor using the F.to_tensor function. We then apply our grayscale transform to the tensor image, which returns a grayscale image tensor. Finally, we convert the grayscale tensor back to a PIL image using the F.to_pil_image function and save it as a JPEG file.

Integrating Custom Functional Transforms into the Data Pipeline

To effectively use custom functional transforms in computer vision tasks, it is crucial to integrate them seamlessly into the data pipeline. PyTorch provides the torchvision.transforms.Compose class, which allows us to chain multiple transforms together and apply them sequentially to our image data. We can easily integrate our custom functional transforms into the data pipeline by combining them with other predefined transforms. Here's an example 

from torchvision.transforms import Compose, RandomCrop, ToTensor

# Create a custom transform pipeline
custom_transforms = Compose([
   RandomCrop(224),     # Predefined transform
   grayscale,           # Custom transform
   ToTensor()            # Predefined transform
])

# Apply the transform pipeline to the image data
transformed_image = custom_transforms(image)

In this example, we create a transform pipeline using Compose and include both predefined and custom transforms. The RandomCrop transform randomly crops the image to a size of 224x224, the grayscale transform converts the image to grayscale using our custom transform, and ToTensor converts the transformed image to a PyTorch tensor. By integrating our custom functional transform into the data pipeline, we can easily apply it along with other transforms and ensure a consistent and efficient preprocessing workflow.

Random Transforms

Random transforms are useful for introducing variations and randomness into the data augmentation process. PyTorch provides several random transforms in the torchvision.transforms.functional module, such as random_crop, random_rotation, and random_horizontal_flip. These transforms can be combined with custom functional transforms to create diverse training data. For example, we can create a custom transform called random_resize_crop that randomly resizes and crops an image 

import random
from torchvision.transforms import functional as F

def random_resize_crop(img):
   """Randomly resizes and crops the image.
    
   Args:
       img (Tensor): Input image tensor.
        
   Returns:
       Tensor: Randomly resized and cropped image tensor.
   """
   # Randomly resize the image
   size = random.randint(256, 512)
   img = F.resize(img, size)
    
   # Randomly crop the image
   i, j, h, w = F.random_crop(img, (224, 224))
   img = F.crop(img, i, j, h, w)
    
   return img

By combining random transforms with custom functional transforms, we can introduce variations in image size, rotation, and flipping, making our models more robust and capable of handling diverse inputs.

Custom Parameterized Transforms

In certain scenarios, we may want to create parameterized transforms that can be adjusted based on specific requirements. For example, we might need a custom transform called contrast_adjustment that allows us to control the contrast of an image −

def contrast_adjustment(img, factor):
   """Adjusts the contrast of the image.
    
   Args:
       img (Tensor): Input image tensor.
       factor (float): Contrast adjustment factor.
        
   Returns:
       Tensor: Image tensor with adjusted contrast.
   """
   return F.adjust_contrast(img, factor)

Here, the contrast_adjustment transform accepts an additional parameter, factor, which determines the extent of contrast adjustment applied to the image. This parameterized approach enables fine-grained control over the transformation process.

Conclusion

In this article, we learned how to create custom transforms, and applied them to image data. Functional transforms offer more flexibility compared to class transforms, allowing us to define custom operations using PyTorch tensors and functions. With functional transforms, we have the freedom to design and apply tailored transformations to our computer vision tasks, empowering us to achieve better model performance and generalization.

Updated on: 14-Aug-2023

96 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements