Program to make almost BST to exact BST in python

A binary search tree (BST) has a specific property: for each node, all values in the left subtree are smaller, and all values in the right subtree are larger. When two nodes are swapped in a BST, we can identify and fix them by performing an inorder traversal and finding the nodes that violate the BST property.

Understanding the Problem

In an inorder traversal of a valid BST, the values should be in ascending order. When exactly two nodes are swapped, we'll find one or two violations where a node's value is less than the previous node's value.

Before (Incorrect BST) 3 6 8 2 9 Inorder: 6, 3, 2, 8, 9 Violations: 6?3, 3?2 After (Correct BST) 3 2 8 6 9 Inorder: 2, 3, 6, 8, 9 Perfect BST order!

Algorithm Approach

We perform an inorder traversal and track violations. When we find a node smaller than the previous node, we identify the swapped nodes ?

  • First violation: Mark the larger node (prev_node) as max_node
  • Second violation: Mark the smaller node as min_node
  • Single violation: The two adjacent nodes are swapped

Implementation

class TreeNode:
    def __init__(self, data, left=None, right=None):
        self.val = data
        self.left = left
        self.right = right

def print_tree(root):
    """Print inorder traversal of the tree"""
    if root is not None:
        print_tree(root.left)
        print(root.val, end=', ')
        print_tree(root.right)

def __iter__(self):
    """Inorder iterator for TreeNode"""
    if self.left:
        for node in self.left:
            yield node
    yield self
    if self.right:
        for node in self.right:
            yield node

# Add iterator method to TreeNode class
setattr(TreeNode, "__iter__", __iter__)

class Solution:
    def solve(self, root):
        prev_node = None
        min_node = None
        max_node = None
        found_one = False
        
        # Traverse tree in inorder
        for node in root:
            if prev_node:
                if node.val < prev_node.val:
                    # Found a violation
                    if min_node is None or node.val < min_node.val:
                        min_node = node
                    if max_node is None or max_node.val < prev_node.val:
                        max_node = prev_node
                    
                    if found_one:
                        break  # Found both violations
                    else:
                        found_one = True
            
            prev_node = node
        
        # Swap the values of the misplaced nodes
        min_node.val, max_node.val = max_node.val, min_node.val
        return root

# Create the incorrect BST
ob = Solution()
root = TreeNode(3)
root.left = TreeNode(6)
root.right = TreeNode(8)
root.right.left = TreeNode(2)
root.right.right = TreeNode(9)

print("Before correction:")
print_tree(root)
print("\n")

print("After correction:")
print_tree(ob.solve(root))
print()
Before correction:
6, 3, 2, 8, 9, 

After correction:
2, 3, 6, 8, 9, 

How It Works

The algorithm identifies violations during inorder traversal ?

  1. Track previous node: Compare each node with the previous one
  2. Detect violations: When current < previous, we found swapped nodes
  3. Record candidates: Keep track of the minimum and maximum violating nodes
  4. Swap values: Exchange the values of the two misplaced nodes

Time and Space Complexity

Complexity Value Explanation
Time O(n) Single inorder traversal
Space O(h) Recursion depth (h = height)

Conclusion

This algorithm efficiently fixes a BST with two swapped nodes by performing one inorder traversal to identify violations and then swapping the misplaced values. The approach works for both adjacent and non-adjacent swapped nodes.

Updated on: 2026-03-25T13:05:51+05:30

238 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements