How to perform an expand operation in PyTorch?

PyTorchServer Side ProgrammingProgramming

Tensor.expand() attribute is used to perform expand operation. It expands the Tensor to new dimensions along the singleton dimension.

  • Expanding a tensor only creates a new view of the original tensor; it doesn't make a copy of the original tensor.

  • If you set a particular dimension as -1, the tensor will not be expanded along this dimension.

  • For example, if we have a tensor of size (3,1), we can expand this tensor along the dimension of size 1.

Steps

To expand a tensor, one could follow the steps given below −

  • Import the torch library. Make sure you have already installed it.

import torch
  • Define a tensor having at least one dimension as singleton.

t = torch.tensor([[1],[2],[3]])
  • Expand the tensor along the singleton dimension. Expanding along a non-singleton dimension will throw a Runtime Error (see Example 3).

t_exp = t.expand(3,2)
  • Display the expanded tensor.

print("Tensor after expand:\n", t_exp)

Example 1

The following Python program shows how to expand a tensor of size (3,1) to a tensor of size (3,2). It expands the tensor along the dimension size of 1. The other dimension of size 3 remains unchanged.

# import required libraries
import torch

# create a tensor
t = torch.tensor([[1],[2],[3]])

# display the tensor
print("Tensor:\n", t)
print("Size of Tensor:\n", t.size())

# expand the tensor
exp = t.expand(3,2)
print("Tensor after expansion:\n", exp)

Output

Tensor:
 tensor([[1],
    [2],
    [3]])
Size of Tensor:
 torch.Size([3, 1])
Tensor after expansion:
 tensor([[1, 1],
    [2, 2],
    [3, 3]])

Example 2

The following Python program expands a tensor of size (1,3) to a tensor of size (3,3). It expands the tensor along the dimension size of 1.

# import required libraries
import torch

# create a tensor
t = torch.tensor([[1,2,3]])

# display the tensor
print("Tensor:\n", t)

# size of tensor is [1,3]
print("Size of Tensor:\n", t.size())

# expand the tensor
expandedTensor = t.expand(3,-1)

print("Expanded Tensor:\n", expandedTensor)
print("Size of expanded tensor:\n", expandedTensor.size())

Output

Tensor:
 tensor([[1, 2, 3]])
Size of Tensor:
 torch.Size([1, 3])
Expanded Tensor:
 tensor([[1, 2, 3],
    [1, 2, 3],
    [1, 2, 3]])
Size of expanded tensor:
 torch.Size([3, 3])

Example 3

In the following Python program, we tried to expand the tensor along a nonsingleton dimension, hence it throws a Runtime Error.

# import required libraries
import torch

# create a tensor
t = torch.tensor([[1,2,3]])

# display the tensor
print("Tensor:\n", t)

# size of tensor is [1,3]
print("Size of Tensor:\n", t.size())
t.expand(3,4)

Output

Tensor:
 tensor([[1, 2, 3]])
Size of Tensor:
 torch.Size([1, 3])


RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 1. Target sizes: [3, 4]. Tensor sizes: [1, 3]
raja
Published on 06-Dec-2021 10:59:35
Advertisements