Program to Find K-Largest Sum Pairs in Python


Suppose, we have been provided with two lists of numbers that are nums0 and nums1, and an integer k. Our goal is to find the k largest sum pairs where each pair contains one integer in nums0 and another in nums1. The sum of all of the pairs has to be returned.

So, if the input is like nums1 = [8, 6, 12], nums2 = [4, 6, 8], k = 2, then the output will be 38. We have these largest pairs [12, 8] and [12, 6].

To solve this, we will follow these steps −

  • if k > len(nums0) * len(nums1) is non-zero, then

    • return 0

  • pq := a new min heap

  • ans := 0

  • sort the list nums0 and nums1

  • i, j := size of nums0 − 1, size of nums1 − 1

  • visited := a new set

  • push into heap pq(−(nums0[i] + nums1[j]) , i, j)

  • for i in range 0 to k, do

    • sum, i, j := pop from heap pq

    • x := nums0[i − 1] + nums1[j]

    • if not (i − 1, j) in visited is non−zero, then

      • add(i − 1, j) to visited

      • push into heap pq(−x, i − 1, j)

    • y := nums0[i] + nums1[j − 1]

    • if not (i, j − 1) in visited is non−zero, then

      • add(i, j − 1) to visited

      • push into heap pq( −y, i, j − 1)

    • ans := ans + −sum

  • return ans

Let us see the following implementation to get better understanding −

Python

 Live Demo

from heapq import heappush, heappop
class Solution:
   def solve(self, nums0, nums1, k):
      if k > len(nums0) * len(nums1):
         return 0
      pq = []
      ans = 0
      nums0.sort(), nums1.sort()
      i, j = len(nums0) − 1, len(nums1) − 1
      visited = set()
      heappush(pq, (−(nums0[i] + nums1[j]), i, j))
      for _ in range(k):
         sum, i, j = heappop(pq)
         x = nums0[i − 1] + nums1[j]
         if not (i − 1, j) in visited:
            visited.add((i − 1, j))
            heappush(pq, (−x, i − 1, j))
         y = nums0[i] + nums1[j − 1]
         if not (i, j − 1) in visited:
            visited.add((i, j − 1))
            heappush(pq, (−y, i, j − 1))
         ans += −sum
      return ans
ob = Solution()
print(ob.solve([8, 6, 12], [4, 6, 8], 2))

Input

[8, 6, 12],[4, 6, 8],2

Output

38

Updated on: 15-Dec-2020

142 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements