How to find the k-th and the top "k" elements of a tensor in PyTorch?


PyTorch provides a method torch.kthvalue() to find the k-th element of a tensor. It returns the value of the k-th element of tensor sorted in ascending order, and the index of the element in the original tensor.

torch.topk() method is used to find the top "k" elements. It returns the top "k" or largest "k" elements in the tensor.

Steps

  • Import the required library. In all the following Python examples, the required Python library is torch. Make sure you have already installed it.

  • Create a PyTorch tensor and print it.

  • Compute torch.kthvalue(input, k). It returns two tensors. Assign these two tensors to two new variables "value" and "index". Here, input is a tensor and k is an integer number.

  • Compute torch.topk(input, k). It returns two tensors. First tensor has the values of top "k" elements and the second tensor has the indices of these elements in the original tensor. Assign these two tensors to the new variables "values" and "indices".

  • Print the value and index of the k-th element of the tensor, and the values and indices of the top "k" elements of the tensor.

Example 1

This python program shows how to find the k-th element of a tensor.

# Python program to find k-th element of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the 3rd element in sorted tensor. First it sorts the
# tensor in ascending order then returns the kth element value
# from sorted tensor and the index of element in original tensor
value, index = torch.kthvalue(T, 3)

# print 3rd element with value and index
print("3rd element value:", value)
print("3rd element index:", index)

Output

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
3rd element value: tensor(2.3340)
3rd element index: tensor(0)

Example 2

The following Python program shows how to find the top "k" or largest "k" elements of a tensor.

# Python program to find to top k elements of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the top k=2 or 2 largest elements of the tensor
# returns the 2 largest values and their indices in original
# tensor
values, indices = torch.topk(T, 2)

# print top 2 elements with value and index
print("Top 2 element values:", values)
print("Top 2 element indices:", indices)

Output

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
Top 2 element values: tensor([5.0000, 4.4430])
Top 2 element indices: tensor([4, 5])

Updated on: 06-Nov-2021

708 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements