How to resize a tensor in PyTorch?


To resize a PyTorch tensor, we use the .view() method. We can increase or decrease the dimension of the tensor, but we have to make sure that the total number of elements in a tensor must match before and after the resize.

Steps

  • Import the required library. In all the following Python examples, the required Python library is torch. Make sure you have already installed it.

  • Create a PyTorch tensor and print it.

  • Resize the above-created tensor using .view() and assign the value to a variable. .view() does not resize the original tensor; it only gives a view with the new size, as its name suggests.

  • Finally, print the tensor after the resize.

Example 1

# Python program to resize a tensor in PyTorch
# Import the library
import torch

# Create a tensor
T = torch.Tensor([1, 2, 3, 4, 5, 6])
print(T)

# Resize T to 2x3
x = T.view(2,3)
print("Tensor after resize:\n",x)

# Other way to resize T to 2x3
x = T.view(-1,3)
print("Tensor after resize:\n",x)

# Other way resize T to 2x3
x = T.view(2,-1)
print("Tensor after resize:\n",x)

Output

When you run the above Python 3 code, it will produce the following output

tensor([1., 2., 3., 4., 5., 6.])
Tensor after resize:
tensor([[1., 2., 3.],
         [4., 5., 6.]])
Tensor after resize:
tensor([[1., 2., 3.],
         [4., 5., 6.]])
Tensor after resize:
tensor([[1., 2., 3.],
         [4., 5., 6.]])

Example 2

# Import the library
import torch

# Create a tensor shape 4x3
T = torch.Tensor([[1,2,3],[2,1,3],[2,3,5],[5,6,4]])
print(T)

# Resize T to 3x4
x = T.view(-1,4)
print("Tensor after resize:\n",x)

# Other way to esize T to 3x4
x = T.view(3,-1)
print("Tensor after resize:\n",x)

# Resize T to 2x6
x = T.view(2,-1)
print("Tensor after resize:\n",x)

Output

When you run the above Python 3 code, it will produce the following output

tensor([[1., 2., 3.],
         [2., 1., 3.],
         [2., 3., 5.],
         [5., 6., 4.]])
Tensor after resize:
tensor([[1., 2., 3., 2.],
         [1., 3., 2., 3.],
         [5., 5., 6., 4.]])
Tensor after resize:
tensor([[1., 2., 3., 2.],
         [1., 3., 2., 3.],
         [5., 5., 6., 4.]])
Tensor after resize:
tensor([[1., 2., 3., 2., 1., 3.],
         [2., 3., 5., 5., 6., 4.]])

Updated on: 06-Nov-2021

6K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements