How to perform a permute operation in PyTorch?


torch.permute() method is used to perform a permute operation on a PyTorch tensor. It returns a view of the input tensor with its dimension permuted. It doesn't make a copy of the original tensor.

For example, a tensor with dimension [2, 3] can be permuted to [3, 2]. We can also permute a tensor with new dimension using Tensor.permute().

Syntax

torch.permute(input,dims)

Parameters

  • input – PyTorch tensor.

  • dims – Tuple of desired dimensions.

Steps

  • Import the torch library. Make sure you have it already installed.

import torch
  • Create a PyTorch tensor and print the tensor and the size of the tensor.

t = torch.tensor([[1,2],[3,4],[5,6]])
print("Tensor:
", t) print("Size of tensor:", t.size()) # size 3x2
  • Compute torch.permute(input, dims) and assign the value to a variable. It does not change the original tensor, input.

t1 = torch.permute(t, (1,0))
  • Print the resultant tensor and its size after the permute operation.

print("Tensor after Permuting:
", t1) print("Size after permuting:", t1.size())

Example 1

In the following Python program, the input tensor is of dimension [3,2]. We use dims = (1, 0) to permute the tensor with the new dimension [2,3].

# import the torch library
import torch

# create a tensor
t = torch.tensor([[1,2],[3,4],[5,6]])

# print the created tensor
print("Tensor:
", t) print("Size of tensor:", t.size()) # perform permute operation t1 = torch.permute(t,(1,0)) # print the permuted tensor print("Tensor after Permuting:
", t1) print("Size after permuting:", t1.size())

Output

Tensor:
 tensor([[1, 2],
    [3, 4],
    [5, 6]])
Size of tensor: torch.Size([3, 2])
Tensor after Permuting:
 tensor([[1, 3, 5],
   [2, 4, 6]])
Size after permuting: torch.Size([2, 3])

Example 2

In the following Python code, the input tensor size is [2,3,1]. We use dims = (0,2,1). It gives a view of the input tensor with the dimension [2,1,3].

# import torch library
import torch

# create a tensor
t = torch.randn(2,3,1)

# print the created tensor
print("Tensor:
", t) print("Size of tensor:", t.size()) # perform permute t1 = torch.permute(t, (0,2,1)) # print the resultant tensor print("Tensor after Permuting:
", t1) print("Size after permuting:", t1.size())

Output

Tensor:
 tensor([[[ 1.5285],
    [-0.2401],
    [ 0.2378]],

    [[ 0.4733],
     [-1.7317],
     [ 0.7557]]])
Size of tensor: torch.Size([2, 3, 1])
Tensor after Permuting:
 tensor([[[ 1.5285, -0.2401, 0.2378]],

    [[ 0.4733, -1.7317, 0.7557]]])
Size after permuting: torch.Size([2, 1, 3])

Updated on: 06-Dec-2021

2K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements