- Data Structure
- Networking
- RDBMS
- Operating System
- Java
- MS Excel
- iOS
- HTML
- CSS
- Android
- Python
- C Programming
- C++
- C#
- MongoDB
- MySQL
- Javascript
- PHP
- Physics
- Chemistry
- Biology
- Mathematics
- English
- Economics
- Psychology
- Social Studies
- Fashion Studies
- Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
What does backward() do in PyTorch?
The backward() method is used to compute the gradient during the backward pass in a neural network.
The gradients are computed when this method is executed.
These gradients are stored in the respective variables.
The gradients are computed with respect to these variables, and the gradients are accessed using .grad.
If we do not call the backward() method for computing the gradient, the gradients are not computed.
And, if we access the gradients using .grad, the result is None.
Let's have a couple of examples to demonstrate how it works.
Example 1
In this example, we attempt to access the gradients without calling the backward() method. We notice that all the gradients are None.
# import torch library import torch # define three tensor x = torch.tensor(2., requires_grad = False) w = torch.tensor(3., requires_grad = True) b = torch.tensor(1., requires_grad = True) print("x:", x) print("w:", w) print("b:", b) # define a function of the above defined tensors y = w * x + b print("y:", y) # print the gradient w.r.t above tensors print("x.grad:", x.grad) print("w.grad:", w.grad) print("b.grad:", b.grad)
Output
x: tensor(2.) w: tensor(3., requires_grad=True) b: tensor(1., requires_grad=True) y: tensor(7., grad_fn=<AddBackward0>) x.grad: None w.grad: None b.grad: None
Example 2
In this second example, the backward() method is called for the function y. Then, the gradients are accessed. The gradient with respect to a tensor that doesn't require grad is again None. The gradient with respect to the tensor with gradient is not None.
# import torch library import torch # define three tensors x = torch.tensor(2., requires_grad = False) w = torch.tensor(3., requires_grad = True) b = torch.tensor(1., requires_grad = True) print("x:", x) print("w:", w) print("b:", b) # define a function y y = w * x + b print("y:", y) # take the backward() for y y.backward() # print the gradients w.r.t. above x, w, and b print("x.grad:", x.grad) print("w.grad:", w.grad) print("b.grad:", b.grad)
Output
x: tensor(2.) w: tensor(3., requires_grad=True) b: tensor(1., requires_grad=True) y: tensor(7., grad_fn=<AddBackward0>) x.grad: None w.grad: tensor(2.) b.grad: tensor(1.)