- Trending Categories
Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
Physics
Chemistry
Biology
Mathematics
English
Economics
Psychology
Social Studies
Fashion Studies
Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
How to find the k-th and the top "k" elements of a tensor in PyTorch?
PyTorch provides a method torch.kthvalue() to find the k-th element of a tensor. It returns the value of the k-th element of tensor sorted in ascending order, and the index of the element in the original tensor.
torch.topk() method is used to find the top "k" elements. It returns the top "k" or largest "k" elements in the tensor.
Steps
Import the required library. In all the following Python examples, the required Python library is torch. Make sure you have already installed it.
Create a PyTorch tensor and print it.
Compute torch.kthvalue(input, k). It returns two tensors. Assign these two tensors to two new variables "value" and "index". Here, input is a tensor and k is an integer number.
Compute torch.topk(input, k). It returns two tensors. First tensor has the values of top "k" elements and the second tensor has the indices of these elements in the original tensor. Assign these two tensors to the new variables "values" and "indices".
Print the value and index of the k-th element of the tensor, and the values and indices of the top "k" elements of the tensor.
Example 1
This python program shows how to find the k-th element of a tensor.
# Python program to find k-th element of a tensor # import necessary library import torch # Create a 1D tensor T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443]) print("Original Tensor:\n", T) # Find the 3rd element in sorted tensor. First it sorts the # tensor in ascending order then returns the kth element value # from sorted tensor and the index of element in original tensor value, index = torch.kthvalue(T, 3) # print 3rd element with value and index print("3rd element value:", value) print("3rd element index:", index)
Output
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) 3rd element value: tensor(2.3340) 3rd element index: tensor(0)
Example 2
The following Python program shows how to find the top "k" or largest "k" elements of a tensor.
# Python program to find to top k elements of a tensor # import necessary library import torch # Create a 1D tensor T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443]) print("Original Tensor:\n", T) # Find the top k=2 or 2 largest elements of the tensor # returns the 2 largest values and their indices in original # tensor values, indices = torch.topk(T, 2) # print top 2 elements with value and index print("Top 2 element values:", values) print("Top 2 element indices:", indices)
Output
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) Top 2 element values: tensor([5.0000, 4.4430]) Top 2 element indices: tensor([4, 5])