Program to Find Out the Special Nodes in a Tree in Python

A tree has special nodes where every node in its subtree has a unique color. Given an n-ary tree as an adjacency list and node colors, we need to count how many nodes are special.

For each node i:

  • tree[i] contains its children and parent

  • color[i] represents its color value

A node is "special" if all nodes in its subtree (including itself) have distinct colors ?

Example

Consider this tree structure:

0 color: 1 1 color: 2 2 color: 1 3 color: 1

Tree structure: tree = [[1,2], [0], [0,3], [2]], colors = [1, 2, 1, 1]

Algorithm

We use DFS to traverse each subtree and check for color uniqueness:

  1. For each node, collect colors in its subtree

  2. If colors from different child subtrees intersect, the node isn't special

  3. If all colors are unique, increment the special node count

Implementation

class Solution:
    def solve(self, tree, color):
        self.result = 0
        
        def dfs(node, prev):
            colors = {color[node]}
            
            for child in tree[node]:
                if child != prev:
                    child_colors = dfs(child, node)
                    if colors and child_colors:
                        if self.check_intersection(colors, child_colors):
                            colors = None
                        else:
                            if len(colors) < len(child_colors):
                                child_colors |= colors
                                colors = child_colors
                            else:
                                colors |= child_colors
                    else:
                        colors = None
                        
            if colors:
                self.result += 1
            return colors
        
        def check_intersection(self, colors, child_colors):
            if len(colors) < len(child_colors):
                for c in colors:
                    if c in child_colors:
                        return True
            else:
                for c in child_colors:
                    if c in colors:
                        return True
            return False
        
        self.check_intersection = check_intersection
        dfs(0, -1)
        return self.result

# Test the solution
tree = [
    [1,2],
    [0],
    [0,3],
    [2]
]
colors = [1, 2, 1, 1]

solution = Solution()
result = solution.solve(tree, colors)
print(f"Number of special nodes: {result}")
Number of special nodes: 2

How It Works

The algorithm works as follows:

  • Node 1: Has only color 2 in its subtree ? Special

  • Node 3: Has only color 1 in its subtree ? Special

  • Node 2: Has colors {1, 1} in its subtree ? Not special (duplicate)

  • Node 0: Has colors {1, 2, 1, 1} in its subtree ? Not special (duplicates)

Conclusion

This solution uses DFS to traverse the tree and set operations to detect color duplicates efficiently. The time complexity is O(n) where n is the number of nodes in the tree.

Updated on: 2026-03-25T13:55:18+05:30

647 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements