What does Tensor.detach() do in PyTorch?

PyTorchServer Side ProgrammingProgramming

Tensor.detach() is used to detach a tensor from the current computational graph. It returns a new tensor that doesn't require a gradient.

  • When we don't need a tensor to be traced for the gradient computation, we detach the tensor from the current computational graph.

  • We also need to detach a tensor when we need to move the tensor from GPU to CPU.

Syntax

Tensor.detach()

It returns a new tensor without requires_grad = True. The gradient with respect to this tensor will no longer be computed.

Steps

  • Import the torch library. Make sure you have it already installed.

import torch
  • Create a PyTorch tensor with requires_grad = True and print the tensor.

x = torch.tensor(2.0, requires_grad = True)
print("x:", x)
  • Compute Tensor.detach() and optionally assign this value to a new variable.

x_detach = x.detach()
  • Print the tensor after .detach() operation is performed.

print("Tensor with detach:", x_detach)

Example 1

# import torch library
import torch

# create a tensor with requires_gradient=true
x = torch.tensor(2.0, requires_grad = True)

# print the tensor
print("Tensor:", x)

# tensor.detach operation
x_detach = x.detach()
print("Tensor with detach:", x_detach)

Output

Tensor: tensor(2., requires_grad=True)
Tensor with detach: tensor(2.)

Notice that in the above output, the tensor after detach doesn't have requires_grad = True

Example 2

# import torch library
import torch

# define a tensor with requires_grad=true
x = torch.rand(3, requires_grad = True)
print("x:", x)

# apply above tensor to use detach()
y = 3 + x
z = 3 * x.detach()

print("y:", y)
print("z:", z)

Output

x: tensor([0.5656, 0.8402, 0.6661], requires_grad=True)
y: tensor([3.5656, 3.8402, 3.6661], grad_fn=<AddBackward0>)
z: tensor([1.6968, 2.5207, 1.9984])
raja
Published on 06-Dec-2021 11:24:29

Advertisements