- Trending Categories
Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
Physics
Chemistry
Biology
Mathematics
English
Economics
Psychology
Social Studies
Fashion Studies
Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
How to squeeze and unsqueeze a tensor in PyTorch?
To squeeze a tensor, we use the torch.squeeze() method. It returns a new tensor with all the dimensions of the input tensor but removes size 1. For example, if the shape of the input tensor is (M ☓ 1 ☓ N ☓ 1 ☓ P), then the squeezed tensor will have the shape (M ☓ M ☓ P).
To unsqueeze a tensor, we use the torch.unsqueeze() method. It returns a new tensor dimension of size 1 inserted at specific position.
Steps
Import the required library. In all the following Python examples, the required Python library is torch. Make sure you have already installed it.
Create a tensor and print it.
Compute torch.squeeze(input). It squeezes (removes) the size 1 and returns a tensor with all other dimensions of the input tensor.
Compute torch.unsqueeze(input, dim). It inserts a new dimension of size 1 at the given dim and returns the tensor.
Print the squeezed and/or unsqueezed tensor.
Example 1
# Python program to squeeze and unsqueeze a tensor # import necessary library import torch # Create a tensor of all one T = torch.ones(2,1,2) # size 2x1x2 print("Original Tensor T:\n", T ) print("Size of T:", T.size()) # Squeeze the dimension of the tensor squeezed_T = torch.squeeze(T) # now size 2x2 print("Squeezed_T\n:", squeezed_T ) print("Size of Squeezed_T:", squeezed_T.size())
Output
Original Tensor T: tensor([[[1., 1.]], [[1., 1.]]]) Size of T: torch.Size([2, 1, 2]) Squeezed_T : tensor([[1., 1.], [1., 1.]]) Size of Squeezed_T: torch.Size([2, 2])
Example 2
# Python program to squeeze and unsqueeze a tensor # import necessary library import torch # create a tensor T = torch.Tensor([1,2,3]) # size 3 print("Original Tensor T:\n", T ) print("Size of T:", T.size()) # Squeeze the tensor in dimension o or column dim unsqueezed_T = torch.unsqueeze(T, dim = 0) # now size 1x3 print("Unsqueezed T\n:", unsqueezed_T ) print("Size of UnSqueezed T:", unsqueezed_T.size()) # Squeeze the tensor in dimension 1 or row dim unsqueezed_T = torch.unsqueeze(T, dim = 1) # now size 3x1 print("Unsqueezed T\n:", unsqueezed_T ) print("Size of Unsqueezed T:", unsqueezed_T.size())
Output
Original Tensor T: tensor([1., 2., 3.]) Size of T: torch.Size([3]) Unsqueezed T : tensor([[1., 2., 3.]]) Size of UnSqueezed T: torch.Size([1, 3]) Unsqueezed T : tensor([[1.], [2.], [3.]]) Size of Unsqueezed T: torch.Size([3, 1])