Program to find minimum cost to merge stones in Python


Suppose we have N piles of stones arranged in a row. Here the i-th pile has stones[i] number of stones. A move consists of merging K consecutive piles into one pile, now the cost of this move is equal to the total number of stones in these K number of piles. We have to find the minimum cost to merge all piles of stones into one pile. If there is no such solution then, return -1.

So, if the input is like nums = [3,2,4,1], K = 2, then the output will be 20, because, initially have [3, 2, 4, 1]. Then merge [3, 2] with cost 5, and we have [5, 4, 1]. After that merge [4, 1] with cost 5, and we have [5, 5]. Then merge [5, 5] with cost 10, and we have [10]. So, the total cost was 20, and this is the minimum one.

To solve this, we will follow these steps −

  • n := size of nums

  • if (n-1) mod (K-1) is not 0, then

    • return -1

  • dp := one n x n array and fill with 0

  • sums := n array of size (n+1) and fill with 0

  • for i in range 1 to n, do

    • sums[i] := sums[i-1]+nums[i-1]

  • for length in range K to n, do

    • for i in range 0 to n-length, do

      • j := i+length-1

      • dp[i, j] := infinity

      • for t in range i to j-1, update in each step by K-1, do

        • dp[i][j] = minimum of dp[i, j] and (dp[i, t] + dp[t+1, j])

      • if (j-i) mod (K-1) is same as 0, then

        • dp[i, j] := dp[i, j] + sums[j+1]-sums[i]

  • return dp[0, n-1]

Example

Let us see the following implementation to get better understanding

import heapq
def solve(nums, K):
   n = len(nums)
   if (n-1)%(K-1) != 0:
      return -1
   dp = [[0]*n for _ in range(n)]
   sums = [0]*(n+1)
   for i in range(1,n+1):
      sums[i] = sums[i-1]+nums[i-1]
   for length in range(K,n+1):
      for i in range(n-length+1):
         j = i+length-1
         dp[i][j] = float('inf')
         for t in range(i,j,K-1):
            dp[i][j] = min(dp[i][j], dp[i][t]+dp[t+1][j])
         if (j-i)%(K-1)==0:
            dp[i][j] += sums[j+1]-sums[i]
   return dp[0][n-1]

nums = [3,2,4,1]
K = 2
print(solve(nums, K))

Input

[3,2,4,1], 2

Output

20

Updated on: 07-Oct-2021

186 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements