# Program to Find Out the Special Nodes in a Tree in Python

PythonServer Side ProgrammingProgramming

Suppose we have a 2D list of values called 'tree' which represents an n-ary tree and another list of values called 'color'. The tree is represented as an adjacency list and its root is tree.

The characteristics of an i-th node −

• tree[i] is its children and parent.

• color[i] is its color.

We call a node N "special" if every node in the subtree whose root is at N has a unique color. So we have this tree, we have to find out the number of special nodes.

So, if the input is like tree = [
[1,2],
,
[0,3],

]

colors = [1, 2, 1, 1], then the output will be 2.

To solve this, we will follow these steps −

• result := 0

• dfs(0, -1)

• return result

• Define a function check_intersection() . This will take colors, child_colors

• if length of (colors) < length of (child_colors) , then

• for each c in colors, do

• if c in child_colors is non-zero, then

• return True

• otherwise,

• for each c in child_colors, do

• if c is present in child_colors, then

• return True

• Define a function dfs() . This will take node, prev

• colors := {color[node]}

• for each child in tree[node], do

• if child is not same as prev, then

• child_colors := dfs(child, node)

• if colors and child_colors are not empty, then

• if check_intersection(colors, child_colors) is non-zero, then

• colors := null

• otherwise,

• if length of (colors) < length of (child_colors),then,

• child_colors := child_colors OR colors

• colors := child_colors

• otherwise,

• colors := colors OR child_colors

• otherwise,

• colors := null

• if colors is not empty, then

• result := result + 1

• return colors

## Example

Let us see the following implementation to get better understanding −

import collections
class Solution:
def solve(self, tree, color):
self.result = 0
def dfs(node, prev):
colors = {color[node]}
for child in tree[node]:
if child != prev:
child_colors = dfs(child, node)
if colors and child_colors:
if self.check_intersection(colors, child_colors):
colors = None
else:
if len(colors) < len(child_colors):
child_colors |= colors
colors = child_colors
else:
colors |= child_colors
else:
colors = None
if colors:
self.result += 1
return colors
dfs(0, -1)
return self.result
def check_intersection(self, colors, child_colors):
if len(colors) < len(child_colors):
for c in colors:
if c in child_colors:
return True
else:
for c in child_colors:
if c in colors:
return True
ob = Solution()
print(ob.solve( [
[1,2],
,
[0,3],

], [1, 2, 1, 1]))

## Input

[
[1,2],
,
[0,3],

], [1, 2, 1, 1]

## Output

2