Program to find valid matrix given row and column sums in Python


Suppose we have two arrays rowSum and colSum with non-negative values where rowSum[i] has the sum of the elements in the ith row and colSum[j] has the sum of the elements in the jth column of a 2D matrix. We have to find any matrix with non-negative values of size (rowSum size x colSum size) that satisfies the given rowSum and colSum values.

So, if the input is like rowSum = [13,14,12] colSum = [9,13,17], then the output will be

940
095
0012

To solve this, we will follow these steps −

  • matrix := create empty matrix
  • visited := a new set
  • Define a function minimum() . This will take r,c
  • min_total := infinity
  • type := blank string
  • for i in range 0 to size of r - 1, do
    • if r[i] < min_total, then
      • index := i
      • type := 'row'
      • min_total := r[i]
  • for i in range 0 to size of c - 1, do
    • if c[i] < min_total, then
      • min_total := c[i]
      • type := 'col'
      • index := i
  • if type is same as 'row', then
    • r[index] := infinity
    • for i in range 0 to size of c - 1, do
      • if c[i] is not same as infinity and c[i] >= min_total, then
        • c[i] := c[i] - min_total
        • matrix[index, i] := min_total
        • come out from loop
  • if type is same as 'col', then
    • c[index] := infinity
    • for i in range 0 to size of r - 1, do
      • if r[i] is not same as infinity and r[i] >= min_total, then
        • r[i] := r[i] - min_total
        • matrix[i, index] := min_total
        • come out from loop
  • insert pair (index,type) into visited
  • From the main method do the following −
  • while size of visited is not same as size of r +len(c) , do
  • minimum(r, c)
  • return matrix

Example

Let us see the following implementation to get better understanding −

def solve(r, c):
   matrix = [[0]*len(c) for _ in range(len(r))]
   visited = set()

   def minimum(r,c):
      min_total = float('inf')
   
      type = ''
      for i in range(len(r)):
         if(r[i] < min_total):
            index = i
            type = 'row'
            min_total = r[i]

      for i in range(len(c)):
         if(c[i] < min_total):
            min_total = c[i]
            type = 'col'
            index = i

      if(type == 'row'):
         r[index] = float('inf')

         for i in range(len(c)):
            if(c[i] != float('inf') and c[i] >= min_total):
               c[i] -= min_total
               matrix[index][i] = min_total
               break

      if(type == 'col'):
         c[index] = float('inf')
         for i in range(len(r)):
            if(r[i] != float('inf') and r[i] >= min_total):
               r[i] -= min_total
               matrix[i][index] = min_total
               break

      visited.add((index,type))

   while len(visited) != len(r)+len(c):
      minimum(r,c)

   return matrix

rowSum = [13,14,12]
colSum = [9,13,17]
print(solve(rowSum, colSum))

Input

[13,14,12], [9,13,17]

Output

[[9, 4, 0], [0, 9, 5], [0, 0, 12]]

Updated on: 04-Oct-2021

207 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements