- Trending Categories
Data Structure
Networking
RDBMS
Operating System
Java
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
Physics
Chemistry
Biology
Mathematics
English
Economics
Psychology
Social Studies
Fashion Studies
Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
What does backward() do in PyTorch?
The backward() method is used to compute the gradient during the backward pass in a neural network.
The gradients are computed when this method is executed.
These gradients are stored in the respective variables.
The gradients are computed with respect to these variables, and the gradients are accessed using .grad.
If we do not call the backward() method for computing the gradient, the gradients are not computed.
And, if we access the gradients using .grad, the result is None.
Let's have a couple of examples to demonstrate how it works.
Example 1
In this example, we attempt to access the gradients without calling the backward() method. We notice that all the gradients are None.
# import torch library import torch # define three tensor x = torch.tensor(2., requires_grad = False) w = torch.tensor(3., requires_grad = True) b = torch.tensor(1., requires_grad = True) print("x:", x) print("w:", w) print("b:", b) # define a function of the above defined tensors y = w * x + b print("y:", y) # print the gradient w.r.t above tensors print("x.grad:", x.grad) print("w.grad:", w.grad) print("b.grad:", b.grad)
Output
x: tensor(2.) w: tensor(3., requires_grad=True) b: tensor(1., requires_grad=True) y: tensor(7., grad_fn=<AddBackward0>) x.grad: None w.grad: None b.grad: None
Example 2
In this second example, the backward() method is called for the function y. Then, the gradients are accessed. The gradient with respect to a tensor that doesn't require grad is again None. The gradient with respect to the tensor with gradient is not None.
# import torch library import torch # define three tensors x = torch.tensor(2., requires_grad = False) w = torch.tensor(3., requires_grad = True) b = torch.tensor(1., requires_grad = True) print("x:", x) print("w:", w) print("b:", b) # define a function y y = w * x + b print("y:", y) # take the backward() for y y.backward() # print the gradients w.r.t. above x, w, and b print("x.grad:", x.grad) print("w.grad:", w.grad) print("b.grad:", b.grad)
Output
x: tensor(2.) w: tensor(3., requires_grad=True) b: tensor(1., requires_grad=True) y: tensor(7., grad_fn=<AddBackward0>) x.grad: None w.grad: tensor(2.) b.grad: tensor(1.)
- Related Articles
- What does Tensor.detach() do in PyTorch?
- What does "with torch no_grad" do in PyTorch?
- What does "print >>" do in python?
- What does axes.flat in Matplotlib do?
- What does [::-1] do in Python?
- What does init() do in Swift?
- What Does the // Operator Do?
- What does % do to strings in Python?
- What does reload() function do in Python?
- What does raw_input() function do in python?
- What does input() function do in python?
- What does print() function do in Python?
- What does open() function do in Python?
- What does close() function do in Python?
- What does import Java.util.* in Java do?
