Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to perform an expand operation in PyTorch?
Tensor.expand() attribute is used to perform expand operation. It expands the Tensor to new dimensions along the singleton dimension.
Expanding a tensor only creates a new view of the original tensor; it doesn't make a copy of the original tensor.
If you set a particular dimension as -1, the tensor will not be expanded along this dimension.
For example, if we have a tensor of size (3,1), we can expand this tensor along the dimension of size 1.
Steps
To expand a tensor, one could follow the steps given below −
Import the torch library. Make sure you have already installed it.
import torch
Define a tensor having at least one dimension as singleton.
t = torch.tensor([[1],[2],[3]])
Expand the tensor along the singleton dimension. Expanding along a non-singleton dimension will throw a Runtime Error (see Example 3).
t_exp = t.expand(3,2)
Display the expanded tensor.
print("Tensor after expand:<br>", t_exp)
Example 1
The following Python program shows how to expand a tensor of size (3,1) to a tensor of size (3,2). It expands the tensor along the dimension size of 1. The other dimension of size 3 remains unchanged.
# import required libraries
import torch
# create a tensor
t = torch.tensor([[1],[2],[3]])
# display the tensor
print("Tensor:<br>", t)
print("Size of Tensor:<br>", t.size())
# expand the tensor
exp = t.expand(3,2)
print("Tensor after expansion:<br>", exp)
Output
Tensor:
tensor([[1],
[2],
[3]])
Size of Tensor:
torch.Size([3, 1])
Tensor after expansion:
tensor([[1, 1],
[2, 2],
[3, 3]])
Example 2
The following Python program expands a tensor of size (1,3) to a tensor of size (3,3). It expands the tensor along the dimension size of 1.
# import required libraries
import torch
# create a tensor
t = torch.tensor([[1,2,3]])
# display the tensor
print("Tensor:<br>", t)
# size of tensor is [1,3]
print("Size of Tensor:<br>", t.size())
# expand the tensor
expandedTensor = t.expand(3,-1)
print("Expanded Tensor:<br>", expandedTensor)
print("Size of expanded tensor:<br>", expandedTensor.size())
Output
Tensor:
tensor([[1, 2, 3]])
Size of Tensor:
torch.Size([1, 3])
Expanded Tensor:
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
Size of expanded tensor:
torch.Size([3, 3])
Example 3
In the following Python program, we tried to expand the tensor along a nonsingleton dimension, hence it throws a Runtime Error.
# import required libraries
import torch
# create a tensor
t = torch.tensor([[1,2,3]])
# display the tensor
print("Tensor:<br>", t)
# size of tensor is [1,3]
print("Size of Tensor:<br>", t.size())
t.expand(3,4)
Output
Tensor: tensor([[1, 2, 3]]) Size of Tensor: torch.Size([1, 3]) RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 1. Target sizes: [3, 4]. Tensor sizes: [1, 3]
