Program to find kpr sum for all queries for a given list of numbers in Python


Suppose we have a list of numbers nums. We also have a list of queries where queries[i] contains three elements [k, p, r], for each query we shall have to find kpr_sum. The formula for kpr_sum is like below.

$$\mathrm{{𝑘𝑝𝑟}\_{𝑠𝑢𝑚} =\sum_{\substack{𝑖=𝑃}}^{𝑅−1}\sum_{\substack{𝑗=𝑖+1}}^{𝑅}(𝐾 ⊕(𝐴[𝑖]⊕𝐴[𝑗]))}$$

If the sum is too large, then return sum modulo 10^9+7.

So, if the input is like nums = [1,2,3] queries = [[1,1,3],[2,1,3]], then the output will be [5, 4] because for the first element it is (1 XOR (1 XOR 2)) + (1 XOR (1 XOR 3)) + (1 XOR (2 XOR 3)) = 5, similarly for second query, it is 4.

To solve this, we will follow these steps −

  • m := 10^9 + 7
  • N := size of nums
  • q_cnt := size of queries
  • C := a new list
  • res := a new list
  • for i in range 0 to 19, do
    • R := an array with single element 0
    • t := 0
    • for each x in nums, do
      • t := t + (x after shifting i times to the right) AND 1
      • insert t at the end of R
    • insert R at the end of C
  • for j in range 0 to q_cnt, do
    • (K, P, R) := queries[j]
    • d := R - P + 1
    • t := 0
    • for i in range 0 to 19, do
      • n1 := C[i, R] - C[i, P-1]
      • n0 := d - n1
      • if (K after shifting i times to the right) AND 1 is non-zero, then
        • x := quotient of (n1 *(n1 - 1) + n0 *(n0 - 1))/2
      • otherwise,
        • x := n1 * n0
      • t :=(t +(x after shifting i times to the left)) mod m
    • insert t at the end of res
  • return res

Example

Let us see the following implementation to get better understanding −

def solve(nums, queries):
    m = 10**9 + 7
    N = len(nums)
    q_cnt = len(queries)
    C = []
    res = []
    for i in range(20):
        R = [0]
        t = 0
        for x in nums:
            t += (x >> i) & 1
            R.append(t)
        C.append(R)
    for j in range(q_cnt):
        K, P, R = queries[j]
        d = R - P + 1
        t = 0
        for i in range(20):
            n1 = C[i][R] - C[i][P-1]
            n0 = d - n1
            if (K >> i) & 1:
                x = (n1 * (n1 - 1) + n0 * (n0 - 1)) >> 1
            else:
                x = n1 * n0
            t = (t + (x << i)) % m
        res.append(t)
   
    return res

nums = [1,2,3]
queries = [[1,1,3],[2,1,3]]
print(solve(nums, queries))

Input

[1,2,3], [[1,1,3],[2,1,3]]

Output

[5, 4]

Updated on: 06-Oct-2021

206 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements