PyTorch – torch.linalg.solve() Method


To solve a square system of linear equations with unique solution, we could apply the torch.linalg.solve() method. This method takes two parameters −

  • first, the coefficient matrix A, and

  • second, the right-hand tensor b.

Where A is a square matrix and b is a vector. The solution is unique if A invertible. We can solve a number of systems of linear equations. In this case, A is a batch of square matrices and b is a batch of vectors.

Syntax

torch.linalg.solve(A, b)

Parameters

  • A – Square matrix or batch of square matrices. It is the coefficient matrix of system of linear equations.

  • b – Vector or a batch of vectors. It's the right-hand tensor of the linear system.

It returns a tensor of the solution of the system of linear equations.

Note − This method assumes that the coefficient matrix A is invertible. If it is not invertible, a Runtime Error will be raised.

Steps

We could use the following steps to solve a square system of linear equations.

  • Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
  • Define a Coefficient matrix and the right-hand side tensor for the given square system of linear equations.

A = torch.tensor([[2., 3.],[1., -2.]])
b = torch.tensor([3., 0.])
  • Compute the unique solution using torch.linalg.solve(A,b). Coefficient matrix A must be invertible.

X = torch.linalg.solve(A, b)
  • Display the solution.

print("Solution:
", X)
  • Check if the calculated solution is correct or not.

print(torch.allclose(A @ X, b))
# True for correct solution

Example 1

Take a look at the following example −

# import required library
import torch

'''
Let's suppose our square system of linear equations is:
2x + 3y = 3
x - 2y = 0
'''

print("Linear equation:")
print("2x + 3y = 3")
print("x - 2y = 0")

# define the coefficient matrix A
A = torch.tensor([[2., 3.],[1., -2.]])
# define right hand side tensor b
b = torch.tensor([3., 0.])

# Solve the linear equation
X = torch.linalg.solve(A, b)

# print the solution of above linear equation
print("Solution:
", X) # check above solution to be true print(torch.allclose(A @ X, b))

Output

It will produce the following output −

Linear equation:
2x + 3y = 3
x - 2y = 0
Solution:
   tensor([0.8571, 0.4286])
True

Example 2

Let's take another example −

# import required library
import torch

# define the coefficient matrix A for a 3x3
# square system of linear equations
A = torch.randn(3,3)

# define right hand side tensor b
b = torch.randn(3)

# Solve the linear equation
X = torch.linalg.solve(A, b)

# print the solution of above linear equation
print("Solution:
", X) # check above solution to be true print(torch.allclose(A @ X, b))

Output

It will produce the following output −

Solution:
   tensor([-0.2867, -0.9850, 0.9938])
True

Updated on: 07-Jan-2022

584 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements