I have two problems with understanding the result of decision tree from scikit-learn. For example, this is one of my decision trees:

enter image description here
My question is that how I can use the tree?

The first question is that: if a sample satisfied the condition, then it goes to the LEFT branch (if exists), otherwise it goes RIGHT. In my case, if a sample with X[7] > 63521.3984. Then the sample will go to the green box. Correct?

The second question is that: when a sample reaches the leaf node, how can I know which category it belongs? In this example, I have three categories to classify. In the red box, there are 91, 212, and 113 samples are satisfied the condition, respectively. But how can I decide the category?
I know there is a function clf.predict(sample) to tell the category. Can I do that from the graph???
Many thanks.

The value line in each box is telling you how many samples at that node fall into each category, in order. That’s why, in each box, the numbers in value add up to the number shown in sample. For instance, in your red box, 91+212+113=416. So this means if you reach this node, there were 91 data points in category 1, 212 in category 2, and 113 in category 3.

If you were going to predict the outcome for a new data point that reached that leaf in the decision tree, you would predict category 2, because that is the most common category for samples at that node.

First question:
Yes, your logic is correct. The left node is True and the right node is False. This can be counter-intuitive; true can equate to a smaller sample.

Second question:
This problem is best resolved by visualizing the tree as a graph with pydotplus.
The ‘class_names’ attribute of tree.export_graphviz() will add a class declaration to the majority class of each node. Code is executed in an iPython notebook.

from sklearn.datasets import load_iris  
from sklearn import tree  
iris = load_iris()  
clf2 = tree.DecisionTreeClassifier()  
clf2 = clf2.fit(iris.data, iris.target)  

with open("iris.dot", 'w') as f:  
    f = tree.export_graphviz(clf, out_file=f)  
import os  

import pydotplus  
dot_data = tree.export_graphviz(clf2, out_file=None)  
graph2 = pydotplus.graph_from_dot_data(dot_data)  

from IPython.display import Image  
dot_data = tree.export_graphviz(clf2, out_file=None,  
                     filled=True, rounded=True,  # leaves_parallel=True, 
graph2 = pydotplus.graph_from_dot_data(dot_data)

## Color of nodes
nodes = graph2.get_node_list()

for node in nodes:
    if node.get_label():
        values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
        color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
        values = color[values.index(max(values))]; # print(values)
        color="#{:02x}{:02x}{:02x}".format(values[0], values[1], values[2]); # print(color)
        node.set_fillcolor(color )

Image(graph2.create_png() ) 

enter image description here

As for determining the class at the leaf, your example doesn’t have leaves with a single class, as the iris data set does. This is common and may require over-fitting the model to attain such an outcome. A discrete distribution of classes is best result for many cross-validated models.

According to the book “Learning scikit-learn: Machine Learning in Python”, The decision tree represents a series of decisions based on the training data.


To classify an instance, we should answer the question at each node. For example, Is sex<=0.5? (are we talking about a woman?). If the answer is yes, you go to the left child node in the tree; otherwise you go to the right child node. You keep answering questions (was she in the third class?, was she in the first class?, and was she below 13 years old?), until you reach a leaf. When you are there, the prediction corresponds to the target class that has most instances.

Add feature_names=X.columns to tree.export_graphviz where X is the training data.

My code is as follows

with open("lectureGini.txt", "w") as f:
    f = tree.export_graphviz(lectureGini, out_file=f,feature_names=X.columns)
# copy contents of file LectureGini.txt into WebGraphviz - http://webgraphviz.com/

lectureGini is the output from my DecisionTreeClassifier

This is a simple method I discovered that could be added to all the web examples of the Gini Index I had researched. All the web examples explained the method really well but none showed how to find the categories.
I don’t have Graphviz installed yet so am exporting a text file from jupyter and copying the text into the Webgraphwiz