- Trending Categories
- Data Structure
- Operating System
- C Programming
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
What does "with torch no_grad" do in PyTorch?
The use of "with torch.no_grad()" is like a loop where every tensor inside the loop will have requires_grad set to False. It means any tensor with gradient currently attached with the current computational graph is now detached from the current graph. We no longer be able to compute the gradients with respect to this tensor.
A tensor is detached from the current graph until it is within the loop. As soon as it is out of the loop, it is again attached to the current graph if the tensor was defined with gradient.
Let's take a couple of examples for a better understanding of how it works.
In this example, we created a tensor x with requires_grad = true. Next, we define a function y of this tensor x and put the function within the with torch.no_grad() loop. Now x is within the loop, so its requires_grad is set to False.
Within the loop the gradients of y could not be computed with respect to x. So, y.requires_grad returns False.
# import torch library import torch # define a torch tensor x = torch.tensor(2., requires_grad = True) print("x:", x) # define a function y with torch.no_grad(): y = x ** 2 print("y:", y) # check gradient for Y print("y.requires_grad:", y.requires_grad)
x: tensor(2., requires_grad=True) y: tensor(4.) y.requires_grad: False
In this example, we have defined the function z out of the loop. So, z.requires_grad returns True.
# import torch library import torch # define three tensors x = torch.tensor(2., requires_grad = False) w = torch.tensor(3., requires_grad = True) b = torch.tensor(1., requires_grad = True) print("x:", x) print("w:", w) print("b:", b) # define a function y y = w * x + b print("y:", y) # define a function z with torch.no_grad(): z = w * x + b print("z:", z) # check if requires grad is true or not print("y.requires_grad:", y.requires_grad) print("z.requires_grad:", z.requires_grad)
x: tensor(2.) w: tensor(3., requires_grad=True) b: tensor(1., requires_grad=True) y: tensor(7., grad_fn=<AddBackward0>) z: tensor(7.) y.requires_grad: True z.requires_grad: False
- What does Tensor.detach() do in PyTorch?
- What does backward() do in PyTorch?
- What does axes.flat in Matplotlib do?
- What do you know about the train with no engine India just launched?
- What does the method toArray() do?
- What does calling Tk() actually do?
- What does the pandas.series.array attribute do?
- What does the pandas.series.index attribute do?
- What does the pandas.series.values attribute do?
- What Does a VPN Tunnel Do?
- What does Integer.parseInt() method do in Java?
- What does synchronized modifier do in Java?
- What does abstract modifier in Java do?
- What does html() method do in jQuery?
- What does jQuery.andSelf( ) method do in jQuery?