# How to measure the mean squared error(squared L2 norm) in PyTorch?

PyTorchServer Side ProgrammingProgramming

Mean squared error is computed as the mean of the squared differences between the input and target (predicted and actual) values. To compute the mean squared error in PyTorch, we apply the MSELoss() function provided by the torch.nn module. It creates a criterion that measures the mean squared error. It is also known as the squared L2 norm.

Both the actual and predicted values are torch tensors having the same number of elements. Both tensors may have any number of dimensions. This function returns a tensor of a scalar value. It is a type of loss function provided by the torch.nn module. The loss functions are used to optimize a deep neural network by minimizing the loss.

## Syntax

torch.nn.MSELoss()

## Steps

To measure the mean squared error, one could follow the steps given below

• Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
• Create the input and target tensors and print them.

input = torch.tensor([0.10, 0.20, 0.40, 0.50])
target = torch.tensor([0.09, 0.2, 0.38, 0.52])
• Create a criterion to measure the mean squared error

mse = nn.MSELoss()
• Compute the mean squared error (loss) and print it.

output = mse(input, target)
print("MSE loss:", output)

## Example 1

In this program, we measure the mean squared error between the input and target tensors. Both the input and target tensors are 1D torch tensors.

# Import the required libraries
import torch
import torch.nn as nn

# define the input and target tensors
input = torch.tensor([0.10, 0.20, 0.40, 0.50])
target = torch.tensor([0.09, 0.2, 0.38, 0.52])

# print input and target tensors
print("Input Tensor:\n", input)
print("Target Tensor:\n", target)

# create a criterion to measure the mean squared error
mse = nn.MSELoss()

# compute the loss (mean squared error)
output = mse(input, target)

# output.backward()
print("MSE loss:", output)

## Output

Input Tensor:
tensor([0.1000, 0.2000, 0.4000, 0.5000])
Target Tensor:
tensor([0.0900, 0.2000, 0.3800, 0.5200])
MSE loss: tensor(0.0002)

Notice that the mean squared error is a scalar value.

## Example 2

In this program, we measure the mean squared error between the input and target tensors. Both the input and target tensors are 2D torch tensors.

# Import the required libraries
import torch
import torch.nn as nn

# define the input and target tensors
input = torch.randn(3, 4)
target = torch.randn(3, 4)

# print input and target tensors
print("Input Tensor:\n", input)
print("Target Tensor:\n", target)

# create a criterion to measure the mean squared error
mse = nn.MSELoss()

# compute the loss (mean squared error)
output = mse(input, target)

# output.backward()
print("MSE loss:", output)

## Output

Input Tensor:
tensor([[-1.6413, 0.8950, -1.0392, 0.2382],
[-0.3868, 0.2483, 0.9811, -0.9260],
[-0.0263, -0.0911, -0.6234, 0.6360]])
Target Tensor:
tensor([[-1.6068, 0.7233, -0.0925, -0.3140],
[-0.4978, 1.3121, -1.4910, -1.4643],
[-2.2589, 0.3073, 0.2038, -1.5656]])
MSE loss: tensor(1.6209)

## Example 3

In this program, we measure the mean squared error between the input and target tensors. Both the input and target tensors are 2D torch tensors. The input tensor takes the parameter requires_grad=true. So, we also compute the gradients of this function w.r.t. the input tensor.

# Import the required libraries
import torch
import torch.nn as nn

# define the input and target tensors
input = torch.randn(4, 5, requires_grad = True)
target = torch.randn(4, 5)

# print input and target tensors
print("Input Tensor:\n", input)
print("Target Tensor:\n", target)

# create a criterion to measure the mean squared error
loss = nn.MSELoss()

# compute the loss (mean squared error)
output = loss(input, target)
output.backward()
print("MSE loss:", output)
print("input.grad:\n", input.grad)

## Output

Input Tensor:
tensor([[ 0.1813, 0.4199, 1.1768, -0.7068, 0.2960],
[ 0.7950, 0.0945, -0.0954, -1.0170, -0.1471],
[ 1.2264, 1.7573, 0.9099, 1.3720, -0.9087],
[-1.0122, -0.8649, -0.7797, -0.7787, 0.9944]],
[-0.1218, -0.1341, -0.0849, -0.1993, -0.0158]])