- Trending Categories
- Data Structure
- Operating System
- MS Excel
- C Programming
- Social Studies
- Fashion Studies
- Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
What does "with torch no_grad" do in PyTorch?
The use of "with torch.no_grad()" is like a loop where every tensor inside the loop will have requires_grad set to False. It means any tensor with gradient currently attached with the current computational graph is now detached from the current graph. We no longer be able to compute the gradients with respect to this tensor.
A tensor is detached from the current graph until it is within the loop. As soon as it is out of the loop, it is again attached to the current graph if the tensor was defined with gradient.
Let's take a couple of examples for a better understanding of how it works.
In this example, we created a tensor x with requires_grad = true. Next, we define a function y of this tensor x and put the function within the with torch.no_grad() loop. Now x is within the loop, so its requires_grad is set to False.
Within the loop the gradients of y could not be computed with respect to x. So, y.requires_grad returns False.
# import torch library import torch # define a torch tensor x = torch.tensor(2., requires_grad = True) print("x:", x) # define a function y with torch.no_grad(): y = x ** 2 print("y:", y) # check gradient for Y print("y.requires_grad:", y.requires_grad)
x: tensor(2., requires_grad=True) y: tensor(4.) y.requires_grad: False
In this example, we have defined the function z out of the loop. So, z.requires_grad returns True.
# import torch library import torch # define three tensors x = torch.tensor(2., requires_grad = False) w = torch.tensor(3., requires_grad = True) b = torch.tensor(1., requires_grad = True) print("x:", x) print("w:", w) print("b:", b) # define a function y y = w * x + b print("y:", y) # define a function z with torch.no_grad(): z = w * x + b print("z:", z) # check if requires grad is true or not print("y.requires_grad:", y.requires_grad) print("z.requires_grad:", z.requires_grad)
x: tensor(2.) w: tensor(3., requires_grad=True) b: tensor(1., requires_grad=True) y: tensor(7., grad_fn=<AddBackward0>) z: tensor(7.) y.requires_grad: True z.requires_grad: False
- Related Articles
- What does Tensor.detach() do in PyTorch?
- What does backward() do in PyTorch?
- What does "print >>" do in python?
- What does axes.flat in Matplotlib do?
- What does [::-1] do in Python?
- What does init() do in Swift?
- What do you know about the train with no engine India just launched?\n
- What Does the // Operator Do?
- What does % do to strings in Python?
- What does reload() function do in Python?
- What does raw_input() function do in python?
- What does input() function do in python?
- What does print() function do in Python?
- What does open() function do in Python?
- What does close() function do in Python?