Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to sort the elements of a tensor in PyTorch?
To sort the elements of a tensor in PyTorch, we can use the torch.sort() method. This method returns two tensors: the first tensor contains sorted values of the elements and the second tensor contains indices of elements in the original tensor. We can sort 2D tensors row-wise and column-wise by specifying the dimension.
Syntax
torch.sort(input, dim=None, descending=False)
Parameters
- input − The input tensor to be sorted
- dim − Dimension along which to sort (0 for column-wise, 1 for row-wise)
- descending − If True, sorts in descending order (default: False)
Example 1: Sorting a 1D Tensor
The following example shows how to sort the elements of a 1D tensor ?
import torch
# Create a tensor
T = torch.Tensor([2.334, 4.433, -4.33, -0.433, 5, 4.443])
print("Original Tensor:")
print(T)
# Sort the tensor in ascending order
v = torch.sort(T)
# Print tensor of sorted values
print("Tensor with sorted values:")
print(v[0])
# Print indices of sorted values
print("Indices of sorted values:")
print(v[1])
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) Tensor with sorted values: tensor([-4.3300, -0.4330, 2.3340, 4.4330, 4.4430, 5.0000]) Indices of sorted values: tensor([2, 3, 0, 1, 5, 4])
Example 2: Sorting a 2D Tensor
The following example demonstrates sorting a 2D tensor along different dimensions ?
import torch
# Create a 2D tensor
T = torch.Tensor([[2, 3, -32],
[43, 4, -53],
[4, 37, -4],
[3, -75, 34]])
print("Original Tensor:")
print(T)
# Default sort (row-wise, dim=1)
print("\nDefault Sort (Row-wise):")
v = torch.sort(T)
print("Sorted values:")
print(v[0])
print("Indices:")
print(v[1])
# Column-wise sort (dim=0)
print("\nColumn-wise Sort (dim=0):")
v = torch.sort(T, dim=0)
print("Sorted values:")
print(v[0])
print("Indices:")
print(v[1])
# Row-wise sort (dim=1)
print("\nRow-wise Sort (dim=1):")
v = torch.sort(T, dim=1)
print("Sorted values:")
print(v[0])
print("Indices:")
print(v[1])
Original Tensor:
tensor([[ 2., 3., -32.],
[ 43., 4., -53.],
[ 4., 37., -4.],
[ 3., -75., 34.]])
Default Sort (Row-wise):
Sorted values:
tensor([[-32., 2., 3.],
[-53., 4., 43.],
[ -4., 4., 37.],
[-75., 3., 34.]])
Indices:
tensor([[2, 0, 1],
[2, 1, 0],
[2, 0, 1],
[1, 0, 2]])
Column-wise Sort (dim=0):
Sorted values:
tensor([[ 2., -75., -53.],
[ 3., 3., -32.],
[ 4., 4., -4.],
[ 43., 37., 34.]])
Indices:
tensor([[0, 3, 1],
[3, 0, 0],
[2, 1, 2],
[1, 2, 3]])
Row-wise Sort (dim=1):
Sorted values:
tensor([[-32., 2., 3.],
[-53., 4., 43.],
[ -4., 4., 37.],
[-75., 3., 34.]])
Indices:
tensor([[2, 0, 1],
[2, 1, 0],
[2, 0, 1],
[1, 0, 2]])
Example 3: Descending Order Sort
You can also sort tensors in descending order by setting the descending parameter to True ?
import torch
T = torch.Tensor([2.334, 4.433, -4.33, -0.433, 5, 4.443])
print("Original Tensor:")
print(T)
# Sort in descending order
v = torch.sort(T, descending=True)
print("Descending Sort:")
print("Sorted values:")
print(v[0])
print("Indices:")
print(v[1])
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) Descending Sort: Sorted values: tensor([ 5.0000, 4.4430, 4.4330, 2.3340, -0.4330, -4.3300]) Indices: tensor([4, 5, 1, 0, 3, 2])
Key Points
- The
torch.sort()function returns a tuple of (sorted_values, indices) - Use
dim=0for column-wise sorting in 2D tensors - Use
dim=1for row-wise sorting in 2D tensors - Set
descending=Truefor descending order sorting - The indices tensor shows the original positions of the sorted elements
Conclusion
The torch.sort() function is essential for sorting PyTorch tensors. It returns both sorted values and their original indices, making it useful for tracking element positions. Use the dim parameter to control sorting direction in multi-dimensional tensors.
