Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
Program to find maximum sum of non-adjacent nodes of a tree in Python
Finding the maximum sum of non-adjacent nodes in a binary tree is a classic dynamic programming problem. We need to select nodes such that no parent-child pairs are included, maximizing the total sum.
Problem Understanding
Given a binary tree, we want to find the maximum sum where no two selected nodes are adjacent (parent-child relationship). For each node, we have two choices: include it or exclude it.
Algorithm Approach
For each node, we calculate two values ?
- Include current node: Add current value + sum of excluded children
- Exclude current node: Sum of maximum values from both children
Implementation
class TreeNode:
def __init__(self, data, left=None, right=None):
self.val = data
self.left = left
self.right = right
def f(node):
if not node:
return 0, 0
# Get results from left and right subtrees
left_include, left_exclude = f(node.left)
right_include, right_exclude = f(node.right)
# Include current node: add its value + excluded children
include_current = node.val + left_exclude + right_exclude
# Exclude current node: take maximum from both children
exclude_current = max(left_include, left_exclude) + max(right_include, right_exclude)
return include_current, exclude_current
class Solution:
def solve(self, root):
include_root, exclude_root = f(root)
return max(include_root, exclude_root)
# Create the tree
ob = Solution()
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(10)
root.left.left = TreeNode(4)
root.left.right = TreeNode(3)
print(ob.solve(root))
17
How It Works
The recursive function returns a tuple (include, exclude) for each node ?
- include: Maximum sum when current node is included
- exclude: Maximum sum when current node is excluded
The base case returns (0, 0) for null nodes. For each node, we decide whether including it gives a better sum than excluding it.
Example Walkthrough
# Let's trace through the example tree step by step
class TreeNode:
def __init__(self, data, left=None, right=None):
self.val = data
self.left = left
self.right = right
def f_with_trace(node, level=0):
indent = " " * level
if not node:
print(f"{indent}Null node: returning (0, 0)")
return 0, 0
print(f"{indent}Node {node.val}:")
left_include, left_exclude = f_with_trace(node.left, level + 1)
right_include, right_exclude = f_with_trace(node.right, level + 1)
include_current = node.val + left_exclude + right_exclude
exclude_current = max(left_include, left_exclude) + max(right_include, right_exclude)
print(f"{indent} Include {node.val}: {node.val} + {left_exclude} + {right_exclude} = {include_current}")
print(f"{indent} Exclude {node.val}: {max(left_include, left_exclude)} + {max(right_include, right_exclude)} = {exclude_current}")
return include_current, exclude_current
# Create and solve
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(10)
root.left.left = TreeNode(4)
root.left.right = TreeNode(3)
result = f_with_trace(root)
print(f"\nFinal result: max({result[0]}, {result[1]}) = {max(result)}")
Node 1:
Node 2:
Node 4:
Null node: returning (0, 0)
Null node: returning (0, 0)
Include 4: 4 + 0 + 0 = 4
Exclude 4: 0 + 0 = 0
Node 3:
Null node: returning (0, 0)
Null node: returning (0, 0)
Include 3: 3 + 0 + 0 = 3
Exclude 3: 0 + 0 = 0
Include 2: 2 + 0 + 0 = 2
Exclude 2: 4 + 3 = 7
Node 10:
Null node: returning (0, 0)
Null node: returning (0, 0)
Include 10: 10 + 0 + 0 = 10
Exclude 10: 0 + 0 = 0
Include 1: 1 + 7 + 0 = 8
Exclude 1: 7 + 10 = 17
Final result: max(8, 17) = 17
Conclusion
This dynamic programming solution efficiently finds the maximum sum of non-adjacent nodes with O(n) time complexity. The key insight is maintaining both "include" and "exclude" states for each node to make optimal decisions.
