What does backward() do in PyTorch?


The backward() method is used to compute the gradient during the backward pass in a neural network.

  • The gradients are computed when this method is executed.

  • These gradients are stored in the respective variables.

  • The gradients are computed with respect to these variables, and the gradients are accessed using .grad.

  • If we do not call the backward() method for computing the gradient, the gradients are not computed.

  • And, if we access the gradients using .grad, the result is None.

Let's have a couple of examples to demonstrate how it works.

Example 1

In this example, we attempt to access the gradients without calling the backward() method. We notice that all the gradients are None.

# import torch library
import torch

# define three tensor
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)
print("x:", x)
print("w:", w)
print("b:", b)

# define a function of the above defined tensors
y = w * x + b
print("y:", y)

# print the gradient w.r.t above tensors
print("x.grad:", x.grad)
print("w.grad:", w.grad)
print("b.grad:", b.grad)

Output

x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
x.grad: None
w.grad: None
b.grad: None

Example 2

In this second example, the backward() method is called for the function y. Then, the gradients are accessed. The gradient with respect to a tensor that doesn't require grad is again None. The gradient with respect to the tensor with gradient is not None.

# import torch library
import torch

# define three tensors
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)
print("x:", x)
print("w:", w)
print("b:", b)

# define a function y
y = w * x + b
print("y:", y)

# take the backward() for y
y.backward()
# print the gradients w.r.t. above x, w, and b
print("x.grad:", x.grad)
print("w.grad:", w.grad)
print("b.grad:", b.grad)

Output

x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
x.grad: None
w.grad: tensor(2.)
b.grad: tensor(1.)

Updated on: 06-Dec-2021

3K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements