Program to remove all nodes with only one child from a binary tree in Python?

Suppose we have a binary tree root; we have to remove all nodes with only one child. A node with only one child is a node that has either a left child or a right child, but not both.

So, if the input is like ?

1 2 3 4 5 6 7 8 ? Nodes with one child (to be removed)

then the output will be ?

1 6 5 7 8

Algorithm

To solve this, we will follow these steps ?

  • Define a method called solve(), this will take tree root

  • if root is null, then return root

  • if left of root is null and right of root is null, then return root (leaf node)

  • if left of root is null, then return solve(right of root) (only right child)

  • if right of root is null, then return solve(left of root) (only left child)

  • left of root := solve(left of root)

  • right of root := solve(right of root)

  • return root

Implementation

class TreeNode:
    def __init__(self, data, left=None, right=None):
        self.data = data
        self.left = left
        self.right = right

def print_tree(root):
    if root is not None:
        print_tree(root.left)
        print(root.data, end=', ')
        print_tree(root.right)

class Solution:
    def solve(self, root):
        if not root:
            return root
        
        # Leaf node - keep it
        if not root.left and not root.right:
            return root
        
        # Node with only right child - remove it, return right subtree
        if not root.left:
            return self.solve(root.right)
        
        # Node with only left child - remove it, return left subtree
        if not root.right:
            return self.solve(root.left)
        
        # Node with both children - keep it, process subtrees
        root.left = self.solve(root.left)
        root.right = self.solve(root.right)
        
        return root

# Create the binary tree
ob = Solution()
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(3)
root.left.left = TreeNode(4)
root.right.right = TreeNode(5)
root.left.left.right = TreeNode(6)
root.right.right.left = TreeNode(7)
root.right.right.right = TreeNode(8)

print("Original tree (inorder):", end=" ")
print_tree(root)
print()

res = ob.solve(root)
print("After removing nodes with one child:", end=" ")
print_tree(res)
Original tree (inorder): 4, 6, 2, 1, 3, 7, 5, 8, 
After removing nodes with one child: 6, 1, 7, 5, 8,

How It Works

The algorithm uses a post-order traversal approach. For each node, it first processes the left and right subtrees recursively, then decides what to do with the current node ?

  • Leaf nodes (no children) are kept as they don't violate the condition

  • Nodes with only one child are removed, and their single child takes their place

  • Nodes with both children are kept after processing their subtrees

Example Walkthrough

In the given example, nodes 2, 3, and 4 have only one child each ?

  • Node 4 has only right child (6) ? removed, 6 becomes child of 2

  • Node 2 has only left child (now 6) ? removed, 6 becomes left child of 1

  • Node 3 has only right child (5) ? removed, 5 becomes right child of 1

Conclusion

This recursive solution efficiently removes all nodes with only one child by replacing them with their single child. The time complexity is O(n) where n is the number of nodes, and space complexity is O(h) where h is the height of the tree.

Updated on: 2026-03-25T12:12:23+05:30

758 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements