Program to remove all nodes from BST which are not in range in Python

In this tutorial, we'll learn how to remove all nodes from a Binary Search Tree (BST) that are not within a given range [low, high]. This is a common tree pruning problem that uses the BST property to efficiently eliminate nodes.

Problem Statement

Given a BST and two values low and high, we need to delete all nodes that are not between [low, high] (inclusive). The resulting tree should maintain the BST property.

Original BST 5 1 9 7 10 Range [7, 10] - Keep yellow nodes

Algorithm Approach

The solution uses a recursive approach that takes advantage of the BST property ?

  • If current node's value is less than low, all nodes in left subtree are also less than low, so we only need to check the right subtree
  • If current node's value is greater than high, all nodes in right subtree are also greater than high, so we only need to check the left subtree
  • If current node's value is within range, we recursively process both subtrees

Implementation

class TreeNode:
    def __init__(self, data, left=None, right=None):
        self.data = data
        self.left = left
        self.right = right

def print_tree(root):
    """Helper function to print tree in inorder traversal"""
    if root is not None:
        print_tree(root.left)
        print(root.data, end=', ')
        print_tree(root.right)

class Solution:
    def solve(self, root, low, high):
        if not root:
            return None
        
        # If current node is less than low, 
        # then all left subtree nodes are also less than low
        if low > root.data:
            return self.solve(root.right, low, high)
        
        # If current node is greater than high,
        # then all right subtree nodes are also greater than high
        if high < root.data:
            return self.solve(root.left, low, high)
        
        # Current node is in range, so recursively trim both subtrees
        root.right = self.solve(root.right, low, high)
        root.left = self.solve(root.left, low, high)
        
        return root

# Create the BST
root = TreeNode(5)
root.left = TreeNode(1)
root.right = TreeNode(9)
root.right.left = TreeNode(7)
root.right.right = TreeNode(10)
root.right.left.left = TreeNode(6)
root.right.left.right = TreeNode(8)

# Set range
low = 7
high = 10

# Remove nodes outside range
ob = Solution()
result = ob.solve(root, low, high)

print("Nodes within range [7, 10]:")
print_tree(result)
Nodes within range [7, 10]:
7, 8, 9, 10, 

How It Works

The algorithm performs the following steps ?

  1. Base case: If root is None, return None
  2. Node too small: If root.data < low, discard entire left subtree and recurse on right subtree
  3. Node too large: If root.data > high, discard entire right subtree and recurse on left subtree
  4. Node in range: Keep current node and recursively trim both left and right subtrees

Time and Space Complexity

  • Time Complexity: O(n) in worst case, where n is the number of nodes
  • Space Complexity: O(h) where h is the height of the tree (due to recursion stack)

Conclusion

This BST trimming algorithm efficiently removes nodes outside a given range by leveraging the BST property. The recursive approach ensures that entire subtrees are pruned when possible, making it more efficient than checking each node individually.

Updated on: 2026-03-25T10:49:30+05:30

255 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements