How to rescale a tensor in the range [0, 1] and sum to 1 in PyTorch?


We can rescale an n-dimensional input Tensor such that the elements lie within the range [0,1] and sum to 1. To do this, we can apply the Softmax() function. We can rescale the n-dimensional input tensor along a particular dimension. The size of the output tensor is the same as the input tensor.

Syntax

torch.nn.Softmax(dim)

Parameters

  • dim – The dimension along which the Softmax is computed.

Steps

We could use the following steps to crop an image at random location with given size −

  • Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
  • Define a n-dimensional input tensor input.

input = torch.randn(5,2)
  • Define the Softmax function passing the dimension dim as an optional parameter.

softmax = torch.nn.Softmax(dim = 1)
  • Apply the above defined Softmax function on the input tensor input.

output = softmax(input)
  • Print the tensor containing Softmax values.

print(output)

Example 1

The following Python program rescales a tensor in the range [0, 1] and sum to 1.

import torch
input = torch.randn(5)
print(input)

softmax = torch.nn.Softmax(dim = 0)
output = softmax(input)
print(output)

print(output.sum())

Output

tensor([-0.5654, -0.9031, -0.3060, -0.6847, -1.4268])
tensor([0.2315, 0.1651, 0.3001, 0.2055, 0.0978])
tensor(1.0000)

Notice that after rescaling, the elements of the tensor are in the range [0,1] and also the sum of elements of the rescaled tensor is 1.

Example 2

The following Python program rescales a tensor in the range [0, 1] and sum to 1.

# Import the required library
import torch
input = torch.randn(5,2)
print(input)

softmax = torch.nn.Softmax(dim = 1)
output = softmax(input)
print(output)
print(output[0])
print(output[1].sum())

Output

tensor([[-0.5788, 0.9244],
   [-0.5172, 1.6231],
   [ 1.3032, -2.1107],
   [-0.4802, 0.1321],
   [-1.3219, -0.3570]])
tensor([[0.1819, 0.8181],
   [0.1052, 0.8948],
   [0.9681, 0.0319],
   [0.3515, 0.6485],
   [0.2759, 0.7241]])
tensor([0.1819, 0.8181])
tensor(1.)

Notice that after rescaling, the elements of the tensor are in the range [0,1] and also the sum of elements of rescaled tensor is 1.

Updated on: 25-Jan-2022

2K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements