- Trending Categories
Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
Physics
Chemistry
Biology
Mathematics
English
Economics
Psychology
Social Studies
Fashion Studies
Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
torch.argmax() Method in Python PyTorch
To find the indices of the maximum value of the elements in an input tensor, we can apply the torch.argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element. We can apply the torch.argmax() function to compute the indices of the maximum values of a tensor across a dimension..
Syntax
torch.argmax(input)
Steps
We could use the following steps to find the indices of the maximum values of all elements in the input tensor −
Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.
import torch
Define an input tensor input.
input = torch.randn(3,4)
Compute the indices of the maximum values of all the elements in the tensor input.
indices = torch.argmax(input)
Print the above computed tensor with indices.
print("Indices:
", indices)
Example 1
# Import the required library import torch # define an input tensor input = torch.tensor([0., -1., 2., 8.]) # print above defined tensor print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices)
Output
Input Tensor: tensor([ 0., -1., 2., 8.]) Indices: tensor(3)
In the above Python example, we find the index of the maximum value of the element of an input 1D tensor. The maximum value in the input tensor is 8 and the index of this element is 3.
Example 2
In this program, we compute the condition number with respect to the different matrix norms.
# Import the required library import torch # define an input tensor input = torch.randn(4,4) # print above defined tensor print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices) # Compute indices of the maximum value in dim 0 indices = torch.argmax(input, dim=0) # print the indices print("Indices in dim 0:
", indices) # Compute indices of the maximum value in dim 1 indices = torch.argmax(input, dim=1) # print the indices print("Indices in dim 1:
", indices)
Output
Input Tensor: tensor([[-1.6729, 1.2613, -1.2882, -0.8133], [ 0.9192, 0.9301, -0.2372, 0.0162], [-0.4669, 0.6604, -0.7982, 0.2621], [ 0.6436, 1.0328, 2.4573, 0.0606]]) Indices: tensor(14) Indices in dim 0: tensor([1, 0, 3, 2]) Indices in dim 1: tensor([1, 1, 1, 2])
In the above Python example, we find the indices of the maximum value of the element of an input 2D tensor in different dimensions. We generated the elements of the input tensor using the torch.randn() method, so you may notice getting different input tensor and indices.
- Related Articles
- torch.nn.Dropout() Method in Python PyTorch
- torch.polar() Method in Python PyTorch
- torch.normal() Method in Python PyTorch
- torch.rsqrt() Method in Python PyTorch
- Python – PyTorch clamp() method
- PyTorch – torch.linalg.solve() Method
- PyTorch – torch.log2() Method
- Python – PyTorch ceil() and floor() methods
- Python PyTorch – How to create an Identity Matrix?
- Activation Functions in Pytorch
- PyTorch – torch.linalg.cond()
- Creating a Tensor in Pytorch
- Class method vs static method in Python
- eval method in Python?
- dir() Method in Python
