PyTorch – How to compute the error function of a tensor?

PyTorchServer Side ProgrammingProgramming

To compute the error function of a tensor, we use the torch.special.erf() method. It returns a new tensor with computed error function. It accepts torch tensor of any dimension. It is also known as Gauss error function

Steps

We could use the following steps to compute the error function of a tensor element-wise −

  • Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
  • Define a torch tensor. Here we define a 2D tensor of random numbers.

tensor = torch.randn(2,3,3)
  • Compute the error function of the above-defined tensor using torch.special.erf(tensor). Optionally assign this value to a new variable.

err = torch.special.erf(tensor)
  • Print the computed error function.

print("Entropy:", err)

Example 1

In this example, we compute the error function of a 1D tensor.

# import necessary libraries
import torch

# define a 1D tensor
tensor1 = torch.tensor([-1,2,3,4,5])

# print above created tensor
print("Tensor:", tensor1)

# compute the error function of the tensor
err = torch.special.erf(tensor1)

# Display the computed error function
print("Error :", err)

Output

Tensor: tensor([-1.0000, 1.0000, 3.0000, 0.0000, 0.5000])
Error : tensor([-0.8427, 0.8427, 1.0000, 0.0000, 0.5205])

Example 2

In this example, we compute the error function of a 2D tensor

# import necessary libraries
import torch

# define a tensor of random numbers
tensor1 = torch.randn(2,3,3)

# print above created tensor
print("Tensor:\n", tensor1)

# compute the error function of the tensor
err = torch.special.erf(tensor1)

# Display the computed error function
print("Error:\n", err)

Output

Tensor:
   tensor([[[-1.0724, 0.3955, -0.3472],
      [-0.7336, -0.8110, 1.2624],
      [ 0.2334, -0.9200, -0.9879]],

      [[ 0.8636, 0.3452, -0.4742],
      [-0.6868, 0.8436, -0.4195],
      [ 1.0410, -0.4681, 1.6284]]])
Error:
   tensor([[[-0.8706, 0.4241, -0.3766],
      [-0.7005, -0.7486, 0.9258],
      [ 0.2586, -0.8068, -0.8376]],

      [[ 0.7780, 0.3746, -0.4975],
      [-0.6686, 0.7671, -0.4470],
      [ 0.8590, -0.4921, 0.9787]]])
raja
Updated on 07-Jan-2022 06:10:14

Advertisements