Program to find out distance between two nodes in a binary tree in Python

Suppose, we are given a binary tree and are asked to find the distance between two nodes in the binary tree. We find out the edges between the two nodes like in a graph and return the number of edges or the distance between them. A node of a tree has the structure as below −

data : <integer value>
right : <pointer to another node of the tree>
left : <pointer to another node of the tree>

So, if the input is like

5 3 7 2 4 6 8 Distance between nodes 2 and 8 is 4

and the nodes between which we have to find the distance between are 2 and 8; then the output will be 4.

The edges between the two nodes 2 and 8 are: (2, 3), (3, 5), (5, 7), and (7, 8). There are 4 edges in the path between them, so the distance is 4.

Algorithm

To solve this, we will follow these steps −

  • Find the Lowest Common Ancestor (LCA) of the two nodes
  • Calculate the distance from LCA to first node
  • Calculate the distance from LCA to second node
  • Return the sum of both distances

Implementation

import collections

class TreeNode:
    def __init__(self, data, left=None, right=None):
        self.data = data
        self.left = left
        self.right = right

def insert(temp, data):
    queue = []
    queue.append(temp)
    while queue:
        temp = queue.pop(0)
        if not temp.left:
            if data is not None:
                temp.left = TreeNode(data)
            else:
                temp.left = TreeNode(0)
            break
        else:
            queue.append(temp.left)
        
        if not temp.right:
            if data is not None:
                temp.right = TreeNode(data)
            else:
                temp.right = TreeNode(0)
            break
        else:
            queue.append(temp.right)

def make_tree(elements):
    tree = TreeNode(elements[0])
    for element in elements[1:]:
        insert(tree, element)
    return tree

def findLca(root, p, q):
    if root is None:
        return None
    if root.data in (p, q):
        return root
    
    left = findLca(root.left, p, q)
    right = findLca(root.right, p, q)
    
    if left and right:
        return root
    return left or right

def findDist(root, data):
    queue = collections.deque()
    queue.append((root, 0))
    
    while queue:
        current, dist = queue.popleft()
        if current.data == data:
            return dist
        if current.left:
            queue.append((current.left, dist + 1))
        if current.right:
            queue.append((current.right, dist + 1))

def solve(root, p, q):
    lca_node = findLca(root, p, q)
    return findDist(lca_node, p) + findDist(lca_node, q)

# Create the tree and find distance
root = make_tree([5, 3, 7, 2, 4, 6, 8])
distance = solve(root, 2, 8)
print("Distance between nodes 2 and 8:", distance)
Distance between nodes 2 and 8: 4

How It Works

The algorithm works in two main steps:

  1. Find LCA: The findLca() function recursively searches for the lowest common ancestor of two nodes.
  2. Calculate Distance: The findDist() function uses BFS to find the distance from LCA to each target node.

For nodes 2 and 8 in our example:

  • LCA of 2 and 8 is node 5 (root)
  • Distance from 5 to 2: 5 ? 3 ? 2 = 2 edges
  • Distance from 5 to 8: 5 ? 7 ? 8 = 2 edges
  • Total distance: 2 + 2 = 4

Conclusion

To find the distance between two nodes in a binary tree, we first locate their lowest common ancestor and then sum the distances from the LCA to each node. This approach efficiently handles the tree structure and provides the shortest path between any two nodes.

Updated on: 2026-03-26T14:33:37+05:30

873 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements