# Strassen’s Matrix Multiplication

Strassen’s Matrix Multiplication is the divide and conquer approach to solve the matrix multiplication problems. The usual matrix multiplication method multiplies each row with each column to achieve the product matrix. The time complexity taken by this approach is O(n3), since it takes two loops to multiply. Strassen’s method was introduced to reduce the time complexity from O(n3) to O(nlog 7).

## Naïve Method

First, we will discuss naïve method and its complexity. Here, we are calculating Z=𝑿X × Y. Using Naïve method, two matrices (X and Y) can be multiplied if the order of these matrices are p × q and q × r and the resultant matrix will be of order p × r. The following pseudocode describes the naïve multiplication −

Algorithm: Matrix-Multiplication (X, Y, Z)
for i = 1 to p do
for j = 1 to r do
Z[i,j] := 0
for k = 1 to q do
Z[i,j] := Z[i,j] + X[i,k] × Y[k,j]


### Complexity

Here, we assume that integer operations take O(1) time. There are three for loops in this algorithm and one is nested in other. Hence, the algorithm takes O(n3) time to execute.

## Strassen’s Matrix Multiplication Algorithm

In this context, using Strassen’s Matrix multiplication algorithm, the time consumption can be improved a little bit.

Strassen’s Matrix multiplication can be performed only on square matrices where n is a power of 2. Order of both of the matrices are n × n.

Divide X, Y and Z into four (n/2)×(n/2) matrices as represented below −

$Z = \begin{bmatrix}I & J \\K & L \end{bmatrix}$ $X = \begin{bmatrix}A & B \\C & D \end{bmatrix}$ and $Y = \begin{bmatrix}E & F \\G & H \end{bmatrix}$

Using Strassen’s Algorithm compute the following −

$$M_{1} \: \colon= (A+C) \times (E+F)$$

$$M_{2} \: \colon= (B+D) \times (G+H)$$

$$M_{3} \: \colon= (A-D) \times (E+H)$$

$$M_{4} \: \colon= A \times (F-H)$$

$$M_{5} \: \colon= (C+D) \times (E)$$

$$M_{6} \: \colon= (A+B) \times (H)$$

$$M_{7} \: \colon= D \times (G-E)$$

Then,

$$I \: \colon= M_{2} + M_{3} - M_{6} - M_{7}$$

$$J \: \colon= M_{4} + M_{6}$$

$$K \: \colon= M_{5} + M_{7}$$

$$L \: \colon= M_{1} - M_{3} - M_{4} - M_{5}$$

### Analysis

$$T(n)=\begin{cases}c & if\:n= 1\\7\:x\:T(\frac{n}{2})+d\:x\:n^2 & otherwise\end{cases} \:where\: c\: and \:d\:are\: constants$$

Using this recurrence relation, we get $T(n) = O(n^{log7})$

Hence, the complexity of Strassen’s matrix multiplication algorithm is $O(n^{log7})$.

### Example

Let us look at the implementation of Strassen's Matrix Multiplication in various programming languages: C, C++, Java, Python.

#include<stdio.h>
int main(){
int z;
int i, j;
int m1, m2, m3, m4 , m5, m6, m7;
int x = {
{12, 34},
{22, 10}
};
int y = {
{3, 4},
{2, 1}
};
printf("\nThe first matrix is\n");
for(i = 0; i < 2; i++) {
printf("\n");
for(j = 0; j < 2; j++)
printf("%d\t", x[i][j]);
}
printf("\nThe second matrix is\n");
for(i = 0; i < 2; i++) {
printf("\n");
for(j = 0; j < 2; j++)
printf("%d\t", y[i][j]);
}
m1= (x + x) * (y + y);
m2= (x + x) * y;
m3= x * (y - y);
m4= x * (y - y);
m5= (x + x) * y;
m6= (x - x) * (y+y);
m7= (x - x) * (y+y);
z = m1 + m4- m5 + m7;
z = m3 + m5;
z = m2 + m4;
z = m1 - m2 + m3 + m6;
printf("\nProduct achieved using Strassen's algorithm \n");
for(i = 0; i < 2 ; i++) {
printf("\n");
for(j = 0; j < 2; j++)
printf("%d\t", z[i][j]);
}
return 0;
}



### Output

The first matrix is

12	34
22	10
The second matrix is

3	4
2	1
Product achieved using Strassen's algorithm

104	82
86	98

#include<iostream>
using namespace std;
int main() {
int z;
int i, j;
int m1, m2, m3, m4 , m5, m6, m7;
int x = {
{12, 34},
{22, 10}
};
int y = {
{3, 4},
{2, 1}
};
cout<<"\nThe first matrix is\n";
for(i = 0; i < 2; i++) {
cout<<endl;
for(j = 0; j < 2; j++)
cout<<x[i][j]<<" ";
}
cout<<"\nThe second matrix is\n";
for(i = 0;i < 2; i++){
cout<<endl;
for(j = 0;j < 2; j++)
cout<<y[i][j]<<" ";
}

m1 = (x + x) * (y + y);
m2 = (x + x) * y;
m3 = x * (y - y);
m4 = x * (y - y);
m5 = (x + x) * y;
m6 = (x - x) * (y+y);
m7 = (x - x) * (y+y);

z = m1 + m4- m5 + m7;
z = m3 + m5;
z = m2 + m4;
z = m1 - m2 + m3 + m6;

cout<<"\nProduct achieved using Strassen's algorithm \n";
for(i = 0; i < 2 ; i++) {
cout<<endl;
for(j = 0; j < 2; j++)
cout<<z[i][j]<<" ";
}
return 0;
}


### Output

The first matrix is

12 34
22 10
The second matrix is

3 4
2 1
Product achieved using Strassen's algorithm

104 82
86 98

public class Strassens {
public static void main(String[] args) {
int[][] x = {{12, 34}, {22, 10}};
int[][] y = {{3, 4}, {2, 1}};
int z[][] = new int;
int m1, m2, m3, m4 , m5, m6, m7;

System.out.print("The first matrix is: ");
for(int i = 0; i<2; i++) {
System.out.println();//new line
for(int j = 0; j<2; j++) {
System.out.print(x[i][j] + " ");
}
}
System.out.print("\nThe second matrix is: ");
for(int i = 0; i<2; i++) {
System.out.println();//new line
for(int j = 0; j<2; j++) {
System.out.print(y[i][j] + " ");
}
}
m1 = (x + x) * (y + y);
m2 = (x + x) * y;
m3 = x * (y - y);
m4 = x * (y - y);
m5 = (x + x) * y;
m6 = (x - x) * (y+y);
m7 = (x - x) * (y+y);
z = m1 + m4- m5 + m7;
z = m3 + m5;
z = m2 + m4;
z = m1 - m2 + m3 + m6;
System.out.print("\nProduct achieved using Strassen's algorithm: ");
for(int i = 0; i<2; i++) {
System.out.println();//new line
for(int j = 0; j<2; j++) {
System.out.print(z[i][j] + " ");
}
}
}
}


### Output

The first matrix is: 12 34 22 10
The second matrix is: 3 4
2 1
Product achieved using Strassen's algorithm: 104 82
86 98

a = [[1,2,3,4],[2,3,4,5],[3,4,5,6],[4,5,6,7]]
b = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]
def new_m(p, q): # create a matrix filled with 0s
matrix = [[0 for row in range(p)] for col in range(q)]
return matrix
def split(matrix): # split matrix into quarters
a = matrix
b = matrix
c = matrix
d = matrix
while(len(a) > len(matrix)/2):
a = a[:len(a)//2]
b = b[:len(b)//2]
c = c[len(c)//2:]
d = d[len(d)//2:]
while(len(a) > len(matrix)/2):
for i in range(len(a)//2):
a[i] = a[i][:len(a[i])//2]
b[i] = b[i][len(b[i])//2:]
c[i] = c[i][:len(c[i])//2]
d[i] = d[i][len(d[i])//2:]
return a,b,c,d
if type(a) == int:
d = a + b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a)):
c.append(a[i][j] + b[i][j])
d.append(c)
return d
def sub_m(a, b):
if type(a) == int:
d = a - b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a)):
c.append(a[i][j] - b[i][j])
d.append(c)
return d
def strassen(a, b, q):

# base case: 1x1 matrix
if q == 1:
d = []
d = a * b
return d
else:

#split matrices into quarters
a11, a12, a21, a22 = split(a)
b11, b12, b21, b22 = split(b)

# p1 = (a11+a22) * (b11+b22)

# p2 = (a21+a22) * b11

# p3 = a11 * (b12-b22)
p3 = strassen(a11, sub_m(b12,b22), q/2)

# p4 = a22 * (b21-b11)
p4 = strassen(a22, sub_m(b21,b11), q/2)

# p5 = (a11+a12) * b22

# p6 = (a21-a11) * (b11+b12)

# p7 = (a12-a22) * (b21+b22)

# c11 = p1 + p4 - p5 + p7

# c12 = p3 + p5

# c21 = p2 + p4

# c22 = p1 + p3 - p2 + p6
c = new_m(len(c11)*2,len(c11)*2)
for i in range(len(c11)):
for j in range(len(c11)):
c[i][j] = c11[i][j]
c[i][j+len(c11)] = c12[i][j]
c[i+len(c11)][j] = c21[i][j]
c[i+len(c11)][j+len(c11)] = c22[i][j]
return c

print("Output Product:")
print(strassen(a, b, 4))


### Output

Output Product:
[[70, 70, 70, 70], [96, 96, 96, 96], [122, 122, 122, 122], [148, 148, 148, 148]] 