Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to squeeze and unsqueeze a tensor in PyTorch?
In PyTorch, you can modify tensor dimensions using torch.squeeze() and torch.unsqueeze() methods. The squeeze operation removes dimensions of size 1, while unsqueeze adds new dimensions of size 1 at specified positions.
Understanding Squeeze Operation
The torch.squeeze() method removes all dimensions of size 1 from a tensor. For example, if a tensor has shape (2 × 1 × 3 × 1), squeezing will result in shape (2 × 3).
Example
import torch
# Create a tensor with dimensions of size 1
tensor = torch.ones(2, 1, 2, 1)
print("Original tensor shape:", tensor.shape)
print("Original tensor:\n", tensor)
# Squeeze the tensor (removes all size-1 dimensions)
squeezed = torch.squeeze(tensor)
print("\nSqueezed tensor shape:", squeezed.shape)
print("Squeezed tensor:\n", squeezed)
Original tensor shape: torch.Size([2, 1, 2, 1])
Original tensor:
tensor([[[[1.],
[1.]]],
[[[1.],
[1.]]]])
Squeezed tensor shape: torch.Size([2, 2])
Squeezed tensor:
tensor([[1., 1.],
[1., 1.]])
Understanding Unsqueeze Operation
The torch.unsqueeze() method adds a new dimension of size 1 at the specified position. The dim parameter determines where the new dimension is inserted.
Example
import torch
# Create a 1D tensor
tensor = torch.tensor([1, 2, 3, 4])
print("Original tensor shape:", tensor.shape)
print("Original tensor:", tensor)
# Unsqueeze at dimension 0 (adds dimension at the beginning)
unsqueezed_0 = torch.unsqueeze(tensor, dim=0)
print("\nUnsqueezed at dim=0 shape:", unsqueezed_0.shape)
print("Unsqueezed at dim=0:\n", unsqueezed_0)
# Unsqueeze at dimension 1 (adds dimension at the end)
unsqueezed_1 = torch.unsqueeze(tensor, dim=1)
print("\nUnsqueezed at dim=1 shape:", unsqueezed_1.shape)
print("Unsqueezed at dim=1:\n", unsqueezed_1)
Original tensor shape: torch.Size([4])
Original tensor: tensor([1, 2, 3, 4])
Unsqueezed at dim=0 shape: torch.Size([1, 4])
Unsqueezed at dim=0:
tensor([[1, 2, 3, 4]])
Unsqueezed at dim=1 shape: torch.Size([4, 1])
Unsqueezed at dim=1:
tensor([[1],
[2],
[3],
[4]])
Practical Use Cases
These operations are commonly used in deep learning for reshaping tensors to match expected input dimensions for neural network layers or mathematical operations ?
import torch
# Example: Preparing data for batch processing
data = torch.tensor([1.0, 2.0, 3.0])
print("Original data:", data.shape)
# Add batch dimension (common in neural networks)
batched = torch.unsqueeze(data, dim=0)
print("With batch dimension:", batched.shape)
# Add feature dimension
features = torch.unsqueeze(batched, dim=2)
print("With feature dimension:", features.shape)
# Remove unnecessary dimensions
cleaned = torch.squeeze(features)
print("After squeezing:", cleaned.shape)
Original data: torch.Size([3]) With batch dimension: torch.Size([1, 3]) With feature dimension: torch.Size([1, 3, 1]) After squeezing: torch.Size([3])
Key Parameters
- torch.squeeze(input, dim=None): If dim is specified, only that dimension is squeezed (if it has size 1)
- torch.unsqueeze(input, dim): The dim parameter is required and specifies where to insert the new dimension
Conclusion
Use torch.squeeze() to remove size-1 dimensions and torch.unsqueeze() to add new dimensions. These operations are essential for tensor manipulation in PyTorch, especially when preparing data for neural networks or mathematical operations.
