Delete Nodes And Return Forest in Python

When working with binary trees, we sometimes need to delete specific nodes and return the remaining forest of trees. This problem requires us to remove nodes with values in a given list and return the roots of all remaining subtrees.

Problem Understanding

Given a binary tree and a list of values to delete, we need to ?

  • Remove all nodes with values in the to_delete list
  • Return the roots of all remaining subtrees (the forest)
  • When a node is deleted, its children become new roots if they exist
Original Tree: 1 2 3 4 5 6 7 Delete: [3, 5] (red nodes) ? Resulting Forest: 1 2 4 6 7 Forest: [1, 6, 7]

Algorithm

We use a recursive approach with the following steps ?

  1. Convert to_delete to a set for O(1) lookup
  2. Use a helper function that tracks if the current node should be a root
  3. If a node is deleted, its children become potential new roots
  4. Return None for deleted nodes, otherwise return the node

Implementation

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

class Solution:
    def delNodes(self, root, to_delete):
        """
        Delete nodes and return forest
        :param root: TreeNode - root of binary tree
        :param to_delete: List[int] - values to delete
        :return: List[TreeNode] - roots of remaining trees
        """
        to_delete = set(to_delete)
        self.result = []
        self.solve(root, to_delete, True)
        return self.result
    
    def solve(self, node, to_delete, is_root):
        if not node:
            return None
        
        # Check if current node should be deleted
        should_delete = node.val in to_delete
        
        # If node is not deleted and is a root, add to result
        if not should_delete and is_root:
            self.result.append(node)
        
        # Recursively process children
        # If current node is deleted, children become new roots
        node.left = self.solve(node.left, to_delete, should_delete)
        node.right = self.solve(node.right, to_delete, should_delete)
        
        # Return None if node is deleted, otherwise return node
        return None if should_delete else node

# Example usage
def build_tree():
    """Build the example tree: [1,2,3,4,5,6,7]"""
    root = TreeNode(1)
    root.left = TreeNode(2)
    root.right = TreeNode(3)
    root.left.left = TreeNode(4)
    root.left.right = TreeNode(5)
    root.right.left = TreeNode(6)
    root.right.right = TreeNode(7)
    return root

def print_forest(forest):
    """Helper function to print the forest"""
    def inorder(node):
        if not node:
            return []
        return inorder(node.left) + [node.val] + inorder(node.right)
    
    for i, tree in enumerate(forest):
        print(f"Tree {i+1}: {inorder(tree)}")

# Test the solution
solution = Solution()
root = build_tree()
to_delete = [3, 5]
forest = solution.delNodes(root, to_delete)
print_forest(forest)
Tree 1: [4, 2, 1]
Tree 2: [6]
Tree 3: [7]

How It Works

The algorithm works by ?

Step Action Result
1 Visit node 1 (root) Not deleted, add to result
2 Visit node 3 (right child) Deleted, children 6,7 become roots
3 Visit node 5 (left child of 2) Deleted, no children
4 Visit nodes 6,7 Not deleted, added as new roots

Time and Space Complexity

  • Time Complexity: O(n), where n is the number of nodes
  • Space Complexity: O(h + d), where h is tree height and d is size of to_delete set

Conclusion

This solution efficiently deletes specified nodes and returns the forest using a recursive approach. The key insight is tracking when nodes become new roots after their parents are deleted.

Updated on: 2026-03-25T08:17:16+05:30

326 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements