Program to find number of good leaf nodes pairs using Python

A binary tree contains good leaf node pairs when the shortest path between two different leaf nodes is less than or equal to a given distance d.

Consider this binary tree with distance d = 4 ?

1 2 3 4 5 6 8 7 Green circles = Leaf nodes Good pairs: (8,7) path=2, (5,6) path=2

The pairs (8,7) and (5,6) have path distances of 2, which is ? 4. Other pairs like (7,5) have distance 5 > 4, so they're not good.

Algorithm

We use a recursive approach that tracks distances from each leaf to its ancestors ?

  • For each node, collect distance information from left and right subtrees
  • Increment distances by 1 as we move up the tree
  • Count pairs where sum of distances ? d
  • Return combined distance information to parent

Implementation

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

class Solution:
    def __init__(self):
        self.count = 0
    
    def solve(self, root, d):
        def get_distances(node):
            if not node:
                return []
            
            # If leaf node, return distance 0 from itself
            if not node.left and not node.right:
                return [0]
            
            # Get distances from left and right subtrees
            left_distances = get_distances(node.left)
            right_distances = get_distances(node.right)
            
            # Count good pairs between left and right subtrees
            for left_dist in left_distances:
                for right_dist in right_distances:
                    if left_dist + right_dist + 2 <= d:
                        self.count += 1
            
            # Return distances incremented by 1 (moving up one level)
            all_distances = []
            for dist in left_distances + right_distances:
                if dist + 1 < d:  # Only keep distances that could form valid pairs
                    all_distances.append(dist + 1)
            
            return all_distances
        
        get_distances(root)
        return self.count

# Create the binary tree
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(3)
root.left.right = TreeNode(4)
root.left.right.left = TreeNode(8)
root.left.right.right = TreeNode(7)
root.right.left = TreeNode(5)
root.right.right = TreeNode(6)

d = 4
solution = Solution()
result = solution.solve(root, d)
print(f"Number of good leaf node pairs: {result}")
Number of good leaf node pairs: 2

How It Works

The algorithm works by ?

  1. Base Case: Leaf nodes return distance 0
  2. Recursive Case: For each internal node, get distances from both subtrees
  3. Count Pairs: Check all combinations of left and right distances, adding 2 for the path through current node
  4. Propagate: Return incremented distances to parent, filtering out those that exceed useful range

Time Complexity

The time complexity is O(n × h²) where n is the number of nodes and h is the height of the tree. In the worst case of a balanced tree, this becomes O(n × log²n).

Conclusion

This solution efficiently counts good leaf node pairs by tracking distances from leaves to ancestors and counting valid combinations. The key insight is that the distance between two leaves equals the sum of their distances to their lowest common ancestor plus 2.

Updated on: 2026-03-25T21:00:50+05:30

368 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements