How to find the transpose of a tensor in PyTorch?

PyTorchServer Side ProgrammingProgramming

To transpose a tensor, we need two dimensions to be transposed. If a tensor is 0-D or 1-D tensor, the transpose of the tensor is same as is. For a 2-D tensor, the transpose is computed using the two dimensions 0 and 1 as transpose(input, 0, 1).

Syntax

To find the transpose of a scalar, a vector or a matrix, we can apply the first syntax defined below.

And for any dimensional tensor, we can apply the second syntax.

  • For <= 2D tensors,

Tensor.t()
torch.t(input)
  • For any dimensional tensor,

Tensor.transpose(dim0, dim1) or
torch.transpose(input, dim0, dim1)

Parameters

  • input – It's a PyTorch tensor to be transposed.

  • dim0 – It's the first dimension to be transposed.

  • dim1 – It's the second dimension to be transposed.

Steps

  • Import the torch library. Make sure you have it already installed.

import torch
  • Create a PyTorch tensor and print the tensor. Here, we have created a 3×3 tensor.

t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print("Tensor:\n", t)
  • Find the transpose of the defined tensor using any of the above defined syntax and optionally assign the value to a new variable.

transposedTensor = torch.transpose(t, 0, 1)
  • Print the transposed tensor.

print("Transposed Tensor:\n", transposedTensor)

Example 1

# Python program to find transpose of a 2D tensor
# import torch library
import torch

# define a 2D tensor
A = torch.rand(2,3)
print(A)

# compute the transpose of the above tensor
print(A.t())
# or print(torch.t(A))

print(A.transpose(0, 1))
# or print(torch.transpose(A, 0, 1))

Output

tensor([[0.0676, 0.2984, 0.6766],
   [0.6200, 0.5874, 0.4150]])
tensor([[0.0676, 0.6200],
   [0.2984, 0.5874],
   [0.6766, 0.4150]])
tensor([[0.0676, 0.6200],
   [0.2984, 0.5874],
   [0.6766, 0.4150]])

Example 2

# Python program to find transpose of a 3D tensor
# import torch library
import torch

# create a 3D tensor
A = torch.tensor([[[1,2,3],[3,4,5]],
   [[5,6,7],[1,2,2]],
   [[1,2,4],[1,2,5]]])
print("Original Tensor A:\n",A)
print("Size of tensor:",A.size())

# print(A.t()) --> Error
# compute the transpose of the tensor
transposeA = torch.transpose(A, 0,1)
# other way to compute the transpose
# transposeA = A.transpose(0,1)

print("Transposed Tensor:\n",transposeA)
print("Size after transpose:",transposeA.size())

Output

Original Tensor A:
tensor([[[1, 2, 3],
   [3, 4, 5]],

   [[5, 6, 7],
   [1, 2, 2]],

   [[1, 2, 4],
   [1, 2, 5]]])
Size of tensor: torch.Size([3, 2, 3])
Transposed Tensor:
tensor([[[1, 2, 3],
   [5, 6, 7],
   [1, 2, 4]],

   [[3, 4, 5],
   [1, 2, 2],
   [1, 2, 5]]])
Size after transpose: torch.Size([2, 3, 3])
raja
Updated on 06-Dec-2021 12:20:28

Advertisements