Program to Find Out the Special Nodes in a Tree in Python


Suppose we have a 2D list of values called 'tree' which represents an n-ary tree and another list of values called 'color'. The tree is represented as an adjacency list and its root is tree[0].

The characteristics of an i-th node −

  • tree[i] is its children and parent.

  • color[i] is its color.

We call a node N "special" if every node in the subtree whose root is at N has a unique color. So we have this tree, we have to find out the number of special nodes.

So, if the input is like tree = [
   [1,2],
   [0],
   [0,3],
   [2]
]

colors = [1, 2, 1, 1], then the output will be 2.

To solve this, we will follow these steps −

  • result := 0

  • dfs(0, -1)

  • return result

  • Define a function check_intersection() . This will take colors, child_colors

    • if length of (colors) < length of (child_colors) , then

      • for each c in colors, do

        • if c in child_colors is non-zero, then

          • return True

    • otherwise,

      • for each c in child_colors, do

        • if c is present in child_colors, then

          • return True

  • Define a function dfs() . This will take node, prev

    • colors := {color[node]}

    • for each child in tree[node], do

      • if child is not same as prev, then

        • child_colors := dfs(child, node)

        • if colors and child_colors are not empty, then

          • if check_intersection(colors, child_colors) is non-zero, then

            • colors := null

          • otherwise,

            • if length of (colors) < length of (child_colors),then,

              • child_colors := child_colors OR colors

              • colors := child_colors

            • otherwise,

              • colors := colors OR child_colors

        • otherwise,

          • colors := null

      • if colors is not empty, then

        • result := result + 1

      • return colors

Example 

Let us see the following implementation to get better understanding −

import collections
class Solution:
   def solve(self, tree, color):
      self.result = 0
      def dfs(node, prev):
         colors = {color[node]}
         for child in tree[node]:
            if child != prev:
               child_colors = dfs(child, node)
               if colors and child_colors:
                  if self.check_intersection(colors, child_colors):
                     colors = None
                  else:
                     if len(colors) < len(child_colors):
                        child_colors |= colors
                        colors = child_colors
                     else:
                        colors |= child_colors
                  else:
                     colors = None
            if colors:
               self.result += 1
            return colors
         dfs(0, -1)
         return self.result
      def check_intersection(self, colors, child_colors):
         if len(colors) < len(child_colors):
            for c in colors:
               if c in child_colors:
                  return True
         else:
            for c in child_colors:
               if c in colors:
                  return True
ob = Solution()
print(ob.solve( [
   [1,2],
   [0],
   [0,3],
   [2]
], [1, 2, 1, 1]))

Input

[
   [1,2],
   [0],
   [0,3],
   [2]
], [1, 2, 1, 1]

Output

2

Updated on: 23-Dec-2020

389 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements