What does "with torch no_grad" do in PyTorch?


The use of "with torch.no_grad()" is like a loop where every tensor inside the loop will have requires_grad set to False. It means any tensor with gradient currently attached with the current computational graph is now detached from the current graph. We no longer be able to compute the gradients with respect to this tensor.

A tensor is detached from the current graph until it is within the loop. As soon as it is out of the loop, it is again attached to the current graph if the tensor was defined with gradient.

Let's take a couple of examples for a better understanding of how it works.

Example 1

In this example, we created a tensor x with requires_grad = true. Next, we define a function y of this tensor x and put the function within the with torch.no_grad() loop. Now x is within the loop, so its requires_grad is set to False.

Within the loop the gradients of y could not be computed with respect to x. So, y.requires_grad returns False.

# import torch library
import torch

# define a torch tensor
x = torch.tensor(2., requires_grad = True)
print("x:", x)

# define a function y
with torch.no_grad():
   y = x ** 2
print("y:", y)

# check gradient for Y
print("y.requires_grad:", y.requires_grad)

Output

x: tensor(2., requires_grad=True)
y: tensor(4.)
y.requires_grad: False

Example 2

In this example, we have defined the function z out of the loop. So, z.requires_grad returns True.

# import torch library
import torch

# define three tensors
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)

print("x:", x)
print("w:", w)
print("b:", b)

# define a function y
y = w * x + b
print("y:", y)

# define a function z
with torch.no_grad():
   z = w * x + b

print("z:", z)

# check if requires grad is true or not
print("y.requires_grad:", y.requires_grad)
print("z.requires_grad:", z.requires_grad)

Output

x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
z: tensor(7.)
y.requires_grad: True
z.requires_grad: False

Updated on: 06-Dec-2021

5K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements