Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
Program to find out distance between two nodes in a binary tree in Python
Suppose, we are given a binary tree and are asked to find the distance between two nodes in the binary tree. We find out the edges between the two nodes like in a graph and return the number of edges or the distance between them. A node of a tree has the structure as below −
data : <integer value> right : <pointer to another node of the tree> left : <pointer to another node of the tree>
So, if the input is like
and the nodes between which we have to find the distance between are 2 and 8; then the output will be 4.
The edges between the two nodes 2 and 8 are: (2, 3), (3, 5), (5, 7), and (7, 8). There are 4 edges in the path between them, so the distance is 4.
Algorithm
To solve this, we will follow these steps −
- Find the Lowest Common Ancestor (LCA) of the two nodes
- Calculate the distance from LCA to first node
- Calculate the distance from LCA to second node
- Return the sum of both distances
Implementation
import collections
class TreeNode:
def __init__(self, data, left=None, right=None):
self.data = data
self.left = left
self.right = right
def insert(temp, data):
queue = []
queue.append(temp)
while queue:
temp = queue.pop(0)
if not temp.left:
if data is not None:
temp.left = TreeNode(data)
else:
temp.left = TreeNode(0)
break
else:
queue.append(temp.left)
if not temp.right:
if data is not None:
temp.right = TreeNode(data)
else:
temp.right = TreeNode(0)
break
else:
queue.append(temp.right)
def make_tree(elements):
tree = TreeNode(elements[0])
for element in elements[1:]:
insert(tree, element)
return tree
def findLca(root, p, q):
if root is None:
return None
if root.data in (p, q):
return root
left = findLca(root.left, p, q)
right = findLca(root.right, p, q)
if left and right:
return root
return left or right
def findDist(root, data):
queue = collections.deque()
queue.append((root, 0))
while queue:
current, dist = queue.popleft()
if current.data == data:
return dist
if current.left:
queue.append((current.left, dist + 1))
if current.right:
queue.append((current.right, dist + 1))
def solve(root, p, q):
lca_node = findLca(root, p, q)
return findDist(lca_node, p) + findDist(lca_node, q)
# Create the tree and find distance
root = make_tree([5, 3, 7, 2, 4, 6, 8])
distance = solve(root, 2, 8)
print("Distance between nodes 2 and 8:", distance)
Distance between nodes 2 and 8: 4
How It Works
The algorithm works in two main steps:
-
Find LCA: The
findLca()function recursively searches for the lowest common ancestor of two nodes. -
Calculate Distance: The
findDist()function uses BFS to find the distance from LCA to each target node.
For nodes 2 and 8 in our example:
- LCA of 2 and 8 is node 5 (root)
- Distance from 5 to 2: 5 ? 3 ? 2 = 2 edges
- Distance from 5 to 8: 5 ? 7 ? 8 = 2 edges
- Total distance: 2 + 2 = 4
Conclusion
To find the distance between two nodes in a binary tree, we first locate their lowest common ancestor and then sum the distances from the LCA to each node. This approach efficiently handles the tree structure and provides the shortest path between any two nodes.
