# How to sort the elements of a tensor in PyTorch?

To sort the elements of a tensor in PyTorch, we can use the torch.sort() method. This method returns two tensors. The first tensor is a tensor with sorted values of the elements and the second tensor is a tensor of indices of elements in the original tensor. We can compute the 2D tensors, row-wise and column-wise.

## Steps

• Import the required library. In all the following Python examples, the required Python library is torch. Make sure you have already installed it.

• Create a PyTorch tensor and print it.

• To sort the elements of the above-created tensor, compute torch.sort(input, dim). Assign this value to a new variable "v".Here, input is the input tensor and dim is the dimension along which the elements are sorted. To sort the elements row-wise, dim is set as 1, and to sort the elements column-wise, dim is set as 0.

• The Tensor with the sorted values can be accessed as v[0] and the tensor of indices of the sorted elements as v[1].

• Print the Tensor with the sorted values and the tensor with the indices of the sorted values.

## Example 1

The following Python program shows how to sort the elements of a 1D tensor.

# Python program to sort elements of a tensor
# import necessary library
import torch

# Create a tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# sort the tensor T
# it sorts the tensor in ascending order
v = torch.sort(T)

# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])

# print indices of sorted value
print("Indices of sorted value:\n", v[1])

## Output

Original Tensor:
tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
Tensor with sorted value:
tensor([-4.3300, -0.4330, 2.3340, 4.4330, 4.4430, 5.0000])
Indices of sorted value:
tensor([2, 3, 0, 1, 5, 4])

## Example 2

The following Python program shows how to sort the elements of a 2D tensor.

# Python program to sort elements of a 2-D tensor
# import the library
import torch

# Create a 2-D tensor
T = torch.Tensor([[2,3,-32],
[43,4,-53],
[4,37,-4],
[3,-75,34]])
print("Original Tensor:\n", T)

# sort tensor T
# it sorts the tensor in ascending order
v = torch.sort(T)

# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])

# print indices of sorted value
print("Indices of sorted value:\n", v[1])
print("Sort tensor Column-wise")
v = torch.sort(T, 0)

# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])

# print indices of sorted value
print("Indices of sorted value:\n", v[1])
print("Sort tensor Row-wise")
v = torch.sort(T, 1)

# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])

# print indices of sorted value
print("Indices of sorted value:\n", v[1])

## Output

Original Tensor:
tensor([[ 2., 3., -32.],
[ 43., 4., -53.],
[ 4., 37., -4.],
[ 3., -75., 34.]])
Tensor with sorted value:
tensor([[-32., 2., 3.],
[-53., 4., 43.],
[ -4., 4., 37.],
[-75., 3., 34.]])
Indices of sorted value:
tensor([[2, 0, 1],
[2, 1, 0],
[2, 0, 1],
[1, 0, 2]])
Sort tensor Column-wise
Tensor with sorted value:
tensor([[ 2., -75., -53.],
[ 3., 3., -32.],
[ 4., 4., -4.],
[ 43., 37., 34.]])
Indices of sorted value:
tensor([[0, 3, 1],
[3, 0, 0],
[2, 1, 2],
[1, 2, 3]])
Sort tensor Row-wise
Tensor with sorted value:
tensor([[-32., 2., 3.],
[-53., 4., 43.],
[ -4., 4., 37.],
[-75., 3., 34.]])
Indices of sorted value:
tensor([[2, 0, 1],
[2, 1, 0],
[2, 0, 1],
[1, 0, 2]])

Updated on: 06-Nov-2021

1K+ Views