How to squeeze and unsqueeze a tensor in PyTorch?


To squeeze a tensor, we use the torch.squeeze() method. It returns a new tensor with all the dimensions of the input tensor but removes size 1. For example, if the shape of the input tensor is (M ☓ 1 ☓ N ☓ 1 ☓ P), then the squeezed tensor will have the shape (M ☓ M ☓ P).

To unsqueeze a tensor, we use the torch.unsqueeze() method. It returns a new tensor dimension of size 1 inserted at specific position.

Steps

  • Import the required library. In all the following Python examples, the required Python library is torch. Make sure you have already installed it.

  • Create a tensor and print it.

  • Compute torch.squeeze(input). It squeezes (removes) the size 1 and returns a tensor with all other dimensions of the input tensor.

  • Compute torch.unsqueeze(input, dim). It inserts a new dimension of size 1 at the given dim and returns the tensor.

  • Print the squeezed and/or unsqueezed tensor.

Example 1

# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch

# Create a tensor of all one
T = torch.ones(2,1,2) # size 2x1x2
print("Original Tensor T:\n", T )
print("Size of T:", T.size())

# Squeeze the dimension of the tensor
squeezed_T = torch.squeeze(T) # now size 2x2
print("Squeezed_T\n:", squeezed_T )
print("Size of Squeezed_T:", squeezed_T.size())

Output

Original Tensor T:
tensor([[[1., 1.]],
         [[1., 1.]]])
Size of T: torch.Size([2, 1, 2])
Squeezed_T
: tensor([[1., 1.],
         [1., 1.]])
Size of Squeezed_T: torch.Size([2, 2])

Example 2

# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch

# create a tensor
T = torch.Tensor([1,2,3]) # size 3
print("Original Tensor T:\n", T )
print("Size of T:", T.size())

# Squeeze the tensor in dimension o or column dim
unsqueezed_T = torch.unsqueeze(T, dim = 0) # now size 1x3
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of UnSqueezed T:", unsqueezed_T.size())

# Squeeze the tensor in dimension 1 or row dim
unsqueezed_T = torch.unsqueeze(T, dim = 1) # now size 3x1
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of Unsqueezed T:", unsqueezed_T.size())

Output

Original Tensor T:
   tensor([1., 2., 3.])
Size of T: torch.Size([3])
Unsqueezed T
: tensor([[1., 2., 3.]])
Size of UnSqueezed T: torch.Size([1, 3])
Unsqueezed T
: tensor([[1.],
         [2.],
         [3.]])
Size of Unsqueezed T: torch.Size([3, 1])

Updated on: 06-Nov-2021

4K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements