Python – PyTorch clamp() method


torch.clamp() is used to clamp all the elements in an input into the range [min, max]. It takes three parameters: the input tensor, min, and max values. The values less than the min are replaced by the min and the values greater than the max are replaced by the max.

If min is not given, then there is no lower bound. If max is not given, then there is no upper bound. Suppose we set min=−0.5 and max=0.4, then the values less than −0.5 are replaced by −0.5 and values greater than 0.4 are replaced by 0.4. The values between these values are not changed. It only supports real-valued inputs

Syntax

torch.clamp(input, min=None, max=None)

Parameters

  • input - The input tensor.

  • min - Lower bound; it is a number or tensor.

  • max - upper bound; it is a number or tensor.

It returns a new tensor clamped all elements in input into the range [min, max].

Steps

  • Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
  • Create an input tensor and print it.

a = torch.tensor([0.73, 0.35, -0.39, -1.53])
print("input tensor:
", a)
  • Clamp the elements of the input tensor. Here we use min=-0.5, max=0.5.

t1 = torch.clamp(a, min=-0.5, max=0.5)
  • Print the tensor obtained after clamp.

print(t1)

Example 1

In the following Python program, we clamp the elements of a 1D input tensor. Note how the clamp() method works when min or max is None.

# Import the required library
import torch

# define a 1D tensor
a = torch.tensor([ 0.73, 0.35, -0.39, -1.53])
print("input tensor:
", a) print("clamp the tensor:") print("into range [-0.5, 0.5]:") t1 = torch.clamp(a, min=-0.5, max=0.5) print(t1) print("if min is None:") t2 = torch.clamp(a, max=0.5) print(t2) print("if max is None:") t3 = torch.clamp(a, min=0.5) print(t3) print("if min is greater than max:") t4 = torch.clamp(a, min=0.6, max=.5) print(t4)

Output

input tensor:
   tensor([ 0.7300, 0.3500, -0.3900, -1.5300])
clamp the tensor:
into range [-0.5, 0.5]:
   tensor([ 0.5000, 0.3500, -0.3900, -0.5000])
if min is None:
   tensor([ 0.5000, 0.3500, -0.3900, -1.5300])
if max is None:
   tensor([0.7300, 0.5000, 0.5000, 0.5000])
if min is greater than max:
   tensor([0.5000, 0.5000, 0.5000, 0.5000])

Example 2

In the following Python program, we clamp the elements of a 2D input tensor. Note how the clamp() method works when min or max is None.

# Import the required library
import torch

# define a 2D tensor of size [3, 4]
a = torch.randn(3,4)
print("input tensor:
", a) print("clamp the tensor:") print("into range [-0.6, 0.4]:") t1 = torch.clamp(a, min=-0.6, max=0.4) print(t1) print("if min is None (max=0.4):") t2 = torch.clamp(a, max=0.4) print(t2) print("if max is None (min=-0.6):") t3 = torch.clamp(a, min=-0.6) print(t3) print("if min is greater than max (min=0.6, max=0.4):") t4 = torch.clamp(a, min=0.6, max=0.4) print(t4)

Output

input tensor:
   tensor([[ 1.2133, 0.2199, -0.0864, -0.1143],
      [ 0.4205, 1.0258, 0.4022, -1.3172],
      [ 1.5405, 0.8545, 0.7009, 0.5874]])
clamp the tensor:
into range [-0.6, 0.4]:
   tensor([[ 0.4000, 0.2199, -0.0864, -0.1143],
      [ 0.4000, 0.4000, 0.4000, -0.6000],
      [ 0.4000, 0.4000, 0.4000, 0.4000]])
if min is None (max=0.4):
   tensor([[ 0.4000, 0.2199, -0.0864, -0.1143],
      [ 0.4000, 0.4000, 0.4000, -1.3172],
      [ 0.4000, 0.4000, 0.4000, 0.4000]])
if max is None (min=-0.6):
   tensor([[ 1.2133, 0.2199, -0.0864, -0.1143],
      [ 0.4205, 1.0258, 0.4022, -0.6000],
      [ 1.5405, 0.8545, 0.7009, 0.5874]])
if min is greater than max (min=0.6, max=0.4):
   tensor([[0.4000, 0.4000, 0.4000, 0.4000],
      [0.4000, 0.4000, 0.4000, 0.4000],
      [0.4000, 0.4000, 0.4000, 0.4000]])

Updated on: 20-Jan-2022

4K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements