Lec 20: Decision Trees

by Josh Hug (Fall 2019)

In [1]:
import seaborn as sns
import pandas as pd
sns.set(font_scale=1.5)
import matplotlib.pyplot as plt
import numpy as np
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

from sklearn.pipeline import Pipeline
In [2]:
# set numpy random seed so that this notebook is deterministic
np.random.seed(23)

Linear Classification

In [3]:
iris_data = pd.read_csv("iris.csv")
iris_data.sample(5)
Out[3]:
sepal_length sepal_width petal_length petal_width species
139 6.9 3.1 5.4 2.1 virginica
125 7.2 3.2 6.0 1.8 virginica
67 5.8 2.7 4.1 1.0 versicolor
3 4.6 3.1 1.5 0.2 setosa
113 5.7 2.5 5.0 2.0 virginica
In [4]:
sns.scatterplot(data = iris_data, x = "petal_length", y="petal_width", hue="species");
#fig = plt.gcf()
#fig.savefig("iris_scatter_plot_with_petal_data.png", dpi=300, bbox_inches = "tight")
In [5]:
from sklearn.linear_model import LogisticRegression
logistic_regression_model = LogisticRegression(multi_class = 'ovr')
logistic_regression_model = logistic_regression_model.fit(iris_data[["petal_length", "petal_width"]], iris_data["species"])
In [6]:
from matplotlib.colors import ListedColormap
sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

xx, yy = np.meshgrid(np.arange(0, 7, 0.02),
                     np.arange(0, 2.8, 0.02))

Z_string = logistic_regression_model.predict(np.c_[xx.ravel(), yy.ravel()])
categories, Z_int = np.unique(Z_string, return_inverse=True)
Z_int = Z_int 
Z_int = Z_int.reshape(xx.shape)
#cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
sns.scatterplot(data = iris_data, x = "petal_length", y="petal_width", hue="species")
plt.xlim(0, 7)
plt.ylim(0, 2.8);
# plt.savefig("iris_decision_boundaries_logistic_regression_no_boundaries.png", dpi=300, bbox_inches = "tight")

Decision Tree Classification

In [7]:
from sklearn import tree
decision_tree_model = tree.DecisionTreeClassifier()
decision_tree_model = decision_tree_model.fit(iris_data[["petal_length", "petal_width"]], iris_data["species"])
In [8]:
four_random_rows = iris_data.sample(4)
four_random_rows
Out[8]:
sepal_length sepal_width petal_length petal_width species
148 6.2 3.4 5.4 2.3 virginica
64 5.6 2.9 3.6 1.3 versicolor
137 6.4 3.1 5.5 1.8 virginica
14 5.8 4.0 1.2 0.2 setosa
In [9]:
decision_tree_model.predict(four_random_rows[["petal_length", "petal_width"]])
Out[9]:
array(['virginica', 'versicolor', 'virginica', 'setosa'], dtype=object)
In [10]:
 tree.plot_tree(decision_tree_model);
In [11]:
# !pip install graphviz
In [12]:
import graphviz
In [13]:
dot_data = tree.export_graphviz(decision_tree_model, out_file=None, 
                      feature_names=["petal_length", "petal_width"],  
                      class_names=["setosa", "versicolor", "virginica"],  
                      filled=True, rounded=True,  
                      special_characters=True)  
graph = graphviz.Source(dot_data)
#graph.render(format="png", filename="iris_tree")
graph
Out[13]:
Tree 0 petal_width ≤ 0.8 gini = 0.667 samples = 150 value = [50, 50, 50] class = setosa 1 gini = 0.0 samples = 50 value = [50, 0, 0] class = setosa 0->1 True 2 petal_width ≤ 1.75 gini = 0.5 samples = 100 value = [0, 50, 50] class = versicolor 0->2 False 3 petal_length ≤ 4.95 gini = 0.168 samples = 54 value = [0, 49, 5] class = versicolor 2->3 12 petal_length ≤ 4.85 gini = 0.043 samples = 46 value = [0, 1, 45] class = virginica 2->12 4 petal_width ≤ 1.65 gini = 0.041 samples = 48 value = [0, 47, 1] class = versicolor 3->4 7 petal_width ≤ 1.55 gini = 0.444 samples = 6 value = [0, 2, 4] class = virginica 3->7 5 gini = 0.0 samples = 47 value = [0, 47, 0] class = versicolor 4->5 6 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 4->6 8 gini = 0.0 samples = 3 value = [0, 0, 3] class = virginica 7->8 9 petal_length ≤ 5.45 gini = 0.444 samples = 3 value = [0, 2, 1] class = versicolor 7->9 10 gini = 0.0 samples = 2 value = [0, 2, 0] class = versicolor 9->10 11 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 9->11 13 gini = 0.444 samples = 3 value = [0, 1, 2] class = virginica 12->13 14 gini = 0.0 samples = 43 value = [0, 0, 43] class = virginica 12->14
In [14]:
from matplotlib.colors import ListedColormap
sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

xx, yy = np.meshgrid(np.arange(0, 7, 0.02),
                     np.arange(0, 2.8, 0.02))

Z_string = decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
categories, Z_int = np.unique(Z_string, return_inverse=True)
Z_int = Z_int 
Z_int = Z_int.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
sns.scatterplot(data = iris_data, x = "petal_length", y="petal_width", hue="species");
#fig = plt.gcf()
#fig.savefig("iris_decision_boundaries.png", dpi=300, bbox_inches = "tight")
In [15]:
from sklearn.metrics import accuracy_score
predictions = decision_tree_model.predict(iris_data[["petal_length", "petal_width"]])
accuracy_score(predictions, iris_data["species"])
Out[15]:
0.9933333333333333
In [16]:
iris_data.query("petal_length > 2.45 and petal_width > 1.75 and petal_length <= 4.85")
Out[16]:
sepal_length sepal_width petal_length petal_width species
70 5.9 3.2 4.8 1.8 versicolor
126 6.2 2.8 4.8 1.8 virginica
138 6.0 3.0 4.8 1.8 virginica

Overfitting

In [17]:
train_iris_data, test_iris_data = np.split(iris_data.sample(frac=1), [110])
In [18]:
#sort so that the color labels match what we had in the earlier part of lecture
train_iris_data = train_iris_data.sort_values(by="species")
test_iris_data = test_iris_data.sort_values(by="species")
In [19]:
len(train_iris_data)
Out[19]:
110
In [20]:
train_iris_data.head(5)
Out[20]:
sepal_length sepal_width petal_length petal_width species
12 4.8 3.0 1.4 0.1 setosa
37 4.9 3.1 1.5 0.1 setosa
5 5.4 3.9 1.7 0.4 setosa
34 4.9 3.1 1.5 0.1 setosa
24 4.8 3.4 1.9 0.2 setosa
In [21]:
from sklearn import tree
decision_tree_model = tree.DecisionTreeClassifier()
decision_tree_model = decision_tree_model.fit(train_iris_data[["petal_length", "petal_width"]], train_iris_data["species"])
In [22]:
from matplotlib.colors import ListedColormap
sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

xx, yy = np.meshgrid(np.arange(0, 7, 0.02),
                     np.arange(0, 2.8, 0.02))

Z_string = decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
categories, Z_int = np.unique(Z_string, return_inverse=True)
Z_int = Z_int 
Z_int = Z_int.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
sns.scatterplot(data = train_iris_data, x = "petal_length", y="petal_width", hue="species");
#fig = plt.gcf()
#fig.savefig("iris_decision_boundaries_model_train_test_split_training_only.png", dpi=300, bbox_inches = "tight")
In [23]:
from matplotlib.colors import ListedColormap
sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

xx, yy = np.meshgrid(np.arange(0, 7, 0.02),
                     np.arange(0, 2.8, 0.02))

Z_string = decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
categories, Z_int = np.unique(Z_string, return_inverse=True)
Z_int = Z_int 
Z_int = Z_int.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
sns.scatterplot(data = iris_data, x = "petal_length", y="petal_width", hue="species");
#fig = plt.gcf()
#fig.savefig("iris_decision_boundaries_model_train_test_split.png", dpi=300, bbox_inches = "tight")
In [24]:
from matplotlib.colors import ListedColormap
sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

xx, yy = np.meshgrid(np.arange(0, 7, 0.02),
                     np.arange(0, 2.8, 0.02))

Z_string = decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
categories, Z_int = np.unique(Z_string, return_inverse=True)
Z_int = Z_int 
Z_int = Z_int.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
sns.scatterplot(data = test_iris_data, x = "petal_length", y="petal_width", hue="species");
#fig = plt.gcf()
#fig.savefig("iris_decision_boundaries_model_train_test_split_test_only.png", dpi=300, bbox_inches = "tight")
In [25]:
accuracy_score(decision_tree_model.predict(train_iris_data[["petal_length", "petal_width"]]), train_iris_data["species"])
Out[25]:
0.990909090909091
In [26]:
predictions = decision_tree_model.predict(test_iris_data[["petal_length", "petal_width"]])
accuracy_score(predictions, test_iris_data["species"])
Out[26]:
0.95
In [27]:
from sklearn import tree
sepal_decision_tree_model = tree.DecisionTreeClassifier()
sepal_decision_tree_model = decision_tree_model.fit(train_iris_data[["sepal_length", "sepal_width"]], train_iris_data["species"])
In [28]:
sns.scatterplot(data = iris_data, x = "sepal_length", y="sepal_width", hue="species", legend=False);
# fig = plt.gcf()
# fig.savefig("iris_scatter_plot_with_petal_data_sepal_only.png", dpi=300, bbox_inches = "tight")
In [29]:
from matplotlib.colors import ListedColormap
sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

xx, yy = np.meshgrid(np.arange(4, 8, 0.02),
                     np.arange(1.9, 4.5, 0.02))

Z_string = sepal_decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
categories, Z_int = np.unique(Z_string, return_inverse=True)
Z_int = Z_int 
Z_int = Z_int.reshape(xx.shape)
plt.contourf(xx, yy, Z_int, cmap=sns_cmap);
# fig = plt.gcf()
# fig.savefig("iris_sepal_decision_boundaries_no_data.png", dpi=300, bbox_inches = "tight")
In [30]:
from matplotlib.colors import ListedColormap
sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

xx, yy = np.meshgrid(np.arange(4, 8, 0.02),
                     np.arange(1.9, 4.5, 0.02))

Z_string = sepal_decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
categories, Z_int = np.unique(Z_string, return_inverse=True)
Z_int = Z_int 
Z_int = Z_int.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
sns.scatterplot(data = train_iris_data, x = "sepal_length", y="sepal_width", hue="species", legend=False);
# fig = plt.gcf()
# fig.savefig("iris_sepal_decision_boundaries_model_training_only.png", dpi=300, bbox_inches = "tight")
In [31]:
from matplotlib.colors import ListedColormap
sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

xx, yy = np.meshgrid(np.arange(4, 8, 0.02),
                     np.arange(1.9, 4.5, 0.02))

Z_string = sepal_decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
categories, Z_int = np.unique(Z_string, return_inverse=True)
Z_int = Z_int 
Z_int = Z_int.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
sns.scatterplot(data = test_iris_data, x = "sepal_length", y="sepal_width", hue="species", legend=False);
# fig = plt.gcf()
# fig.savefig("iris_sepal_decision_boundaries_model_test_only.png", dpi=300, bbox_inches = "tight")
#fig = plt.gcf()
#fig.savefig("iris_decision_boundaries_model_train_test_split.png", dpi=300, bbox_inches = "tight")
In [32]:
dot_data = tree.export_graphviz(sepal_decision_tree_model, out_file=None, 
                      feature_names=["sepal_length", "sepal_width"],  
                      class_names=["setosa", "versicolor", "virginica"],  
                      filled=True, rounded=True,  
                      special_characters=True)  
graph = graphviz.Source(dot_data)
# graph.render(format="png", filename="sepal_tree")
graph
Out[32]:
Tree 0 sepal_length ≤ 5.45 gini = 0.665 samples = 110 value = [34, 36, 40] class = virginica 1 sepal_width ≤ 2.85 gini = 0.245 samples = 36 value = [31, 4, 1] class = setosa 0->1 True 12 sepal_length ≤ 6.15 gini = 0.534 samples = 74 value = [3, 32, 39] class = virginica 0->12 False 2 sepal_length ≤ 4.95 gini = 0.56 samples = 5 value = [1, 3, 1] class = versicolor 1->2 7 sepal_length ≤ 5.35 gini = 0.062 samples = 31 value = [30, 1, 0] class = setosa 1->7 3 sepal_length ≤ 4.7 gini = 0.5 samples = 2 value = [1, 0, 1] class = setosa 2->3 6 gini = 0.0 samples = 3 value = [0, 3, 0] class = versicolor 2->6 4 gini = 0.0 samples = 1 value = [1, 0, 0] class = setosa 3->4 5 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 3->5 8 gini = 0.0 samples = 27 value = [27, 0, 0] class = setosa 7->8 9 sepal_width ≤ 3.2 gini = 0.375 samples = 4 value = [3, 1, 0] class = setosa 7->9 10 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 9->10 11 gini = 0.0 samples = 3 value = [3, 0, 0] class = setosa 9->11 13 sepal_width ≤ 3.6 gini = 0.471 samples = 32 value = [3, 22, 7] class = versicolor 12->13 44 sepal_width ≤ 2.4 gini = 0.363 samples = 42 value = [0, 10, 32] class = virginica 12->44 14 sepal_length ≤ 5.75 gini = 0.366 samples = 29 value = [0, 22, 7] class = versicolor 13->14 43 gini = 0.0 samples = 3 value = [3, 0, 0] class = setosa 13->43 15 sepal_width ≤ 2.85 gini = 0.18 samples = 10 value = [0, 9, 1] class = versicolor 14->15 22 sepal_width ≤ 3.1 gini = 0.432 samples = 19 value = [0, 13, 6] class = versicolor 14->22 16 sepal_width ≤ 2.75 gini = 0.278 samples = 6 value = [0, 5, 1] class = versicolor 15->16 21 gini = 0.0 samples = 4 value = [0, 4, 0] class = versicolor 15->21 17 gini = 0.0 samples = 3 value = [0, 3, 0] class = versicolor 16->17 18 sepal_length ≤ 5.65 gini = 0.444 samples = 3 value = [0, 2, 1] class = versicolor 16->18 19 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 18->19 20 gini = 0.0 samples = 2 value = [0, 2, 0] class = versicolor 18->20 23 sepal_length ≤ 6.05 gini = 0.457 samples = 17 value = [0, 11, 6] class = versicolor 22->23 42 gini = 0.0 samples = 2 value = [0, 2, 0] class = versicolor 22->42 24 sepal_width ≤ 2.75 gini = 0.486 samples = 12 value = [0, 7, 5] class = versicolor 23->24 39 sepal_width ≤ 2.7 gini = 0.32 samples = 5 value = [0, 4, 1] class = versicolor 23->39 25 sepal_width ≤ 2.4 gini = 0.408 samples = 7 value = [0, 5, 2] class = versicolor 24->25 32 sepal_length ≤ 5.85 gini = 0.48 samples = 5 value = [0, 2, 3] class = virginica 24->32 26 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 25->26 27 sepal_width ≤ 2.65 gini = 0.32 samples = 5 value = [0, 4, 1] class = versicolor 25->27 28 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 27->28 29 sepal_length ≤ 5.9 gini = 0.375 samples = 4 value = [0, 3, 1] class = versicolor 27->29 30 gini = 0.444 samples = 3 value = [0, 2, 1] class = versicolor 29->30 31 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 29->31 33 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 32->33 34 sepal_width ≤ 2.95 gini = 0.5 samples = 4 value = [0, 2, 2] class = versicolor 32->34 35 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 34->35 36 sepal_length ≤ 5.95 gini = 0.444 samples = 3 value = [0, 1, 2] class = virginica 34->36 37 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 36->37 38 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 36->38 40 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 39->40 41 gini = 0.0 samples = 4 value = [0, 4, 0] class = versicolor 39->41 45 gini = 0.0 samples = 2 value = [0, 2, 0] class = versicolor 44->45 46 sepal_length ≤ 7.05 gini = 0.32 samples = 40 value = [0, 8, 32] class = virginica 44->46 47 sepal_length ≤ 6.95 gini = 0.383 samples = 31 value = [0, 8, 23] class = virginica 46->47 68 gini = 0.0 samples = 9 value = [0, 0, 9] class = virginica 46->68 48 sepal_width ≤ 3.15 gini = 0.358 samples = 30 value = [0, 7, 23] class = virginica 47->48 67 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 47->67 49 sepal_length ≤ 6.45 gini = 0.42 samples = 20 value = [0, 6, 14] class = virginica 48->49 62 sepal_length ≤ 6.35 gini = 0.18 samples = 10 value = [0, 1, 9] class = virginica 48->62 50 sepal_width ≤ 2.6 gini = 0.198 samples = 9 value = [0, 1, 8] class = virginica 49->50 53 sepal_width ≤ 2.65 gini = 0.496 samples = 11 value = [0, 5, 6] class = virginica 49->53 51 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 50->51 52 gini = 0.0 samples = 7 value = [0, 0, 7] class = virginica 50->52 54 gini = 0.0 samples = 1 value = [0, 0, 1] class = virginica 53->54 55 sepal_width ≤ 2.9 gini = 0.5 samples = 10 value = [0, 5, 5] class = versicolor 53->55 56 gini = 0.0 samples = 2 value = [0, 2, 0] class = versicolor 55->56 57 sepal_width ≤ 3.05 gini = 0.469 samples = 8 value = [0, 3, 5] class = virginica 55->57 58 gini = 0.0 samples = 3 value = [0, 0, 3] class = virginica 57->58 59 sepal_length ≤ 6.8 gini = 0.48 samples = 5 value = [0, 3, 2] class = versicolor 57->59 60 gini = 0.444 samples = 3 value = [0, 2, 1] class = versicolor 59->60 61 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 59->61 63 sepal_width ≤ 3.35 gini = 0.375 samples = 4 value = [0, 1, 3] class = virginica 62->63 66 gini = 0.0 samples = 6 value = [0, 0, 6] class = virginica 62->66 64 gini = 0.5 samples = 2 value = [0, 1, 1] class = versicolor 63->64 65 gini = 0.0 samples = 2 value = [0, 0, 2] class = virginica 63->65
In [33]:
accuracy_score(sepal_decision_tree_model.predict(train_iris_data[["sepal_length", "sepal_width"]]), train_iris_data["species"])
Out[33]:
0.9363636363636364
In [34]:
accuracy_score(sepal_decision_tree_model.predict(test_iris_data[["sepal_length", "sepal_width"]]), test_iris_data["species"])
Out[34]:
0.7
In [35]:
decision_tree_model_4d = tree.DecisionTreeClassifier()
decision_tree_model_4d = decision_tree_model_4d.fit(train_iris_data[["petal_length", "petal_width", 
                                                                     "sepal_length", "sepal_width"]], train_iris_data["species"])
In [36]:
predictions = decision_tree_model_4d.predict(train_iris_data[["petal_length", "petal_width", "sepal_length", "sepal_width"]])
accuracy_score(predictions, train_iris_data["species"])
Out[36]:
1.0
In [37]:
predictions = decision_tree_model_4d.predict(test_iris_data[["petal_length", "petal_width", "sepal_length", "sepal_width"]])
accuracy_score(predictions, test_iris_data["species"])
Out[37]:
0.95
In [38]:
dot_data = tree.export_graphviz(decision_tree_model_4d, out_file=None, 
                      feature_names=["petal_length", "petal_width", "sepal_length", "sepal_width"],  
                      class_names=["setosa", "versicolor", "virginica"],  
                      filled=True, rounded=True,  
                      special_characters=True)  
graph = graphviz.Source(dot_data)
graph
Out[38]:
Tree 0 petal_length ≤ 2.45 gini = 0.665 samples = 110 value = [34, 36, 40] class = virginica 1 gini = 0.0 samples = 34 value = [34, 0, 0] class = setosa 0->1 True 2 petal_width ≤ 1.65 gini = 0.499 samples = 76 value = [0, 36, 40] class = virginica 0->2 False 3 petal_length ≤ 4.95 gini = 0.145 samples = 38 value = [0, 35, 3] class = versicolor 2->3 8 petal_length ≤ 4.85 gini = 0.051 samples = 38 value = [0, 1, 37] class = virginica 2->8 4 gini = 0.0 samples = 34 value = [0, 34, 0] class = versicolor 3->4 5 petal_width ≤ 1.55 gini = 0.375 samples = 4 value = [0, 1, 3] class = virginica 3->5 6 gini = 0.0 samples = 3 value = [0, 0, 3] class = virginica 5->6 7 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 5->7 9 sepal_width ≤ 3.1 gini = 0.375 samples = 4 value = [0, 1, 3] class = virginica 8->9 12 gini = 0.0 samples = 34 value = [0, 0, 34] class = virginica 8->12 10 gini = 0.0 samples = 3 value = [0, 0, 3] class = virginica 9->10 11 gini = 0.0 samples = 1 value = [0, 1, 0] class = versicolor 9->11
In [39]:
# graph.render(format="png", filename="iris_4d_tree")

Creating Decision Trees

In [40]:
def entropy(x):
    normalized_x = x / np.sum(x)
    return sum(-normalized_x * np.log2(normalized_x))
In [41]:
-np.log2(0.33)*0.33
Out[41]:
0.5278224832373695
In [42]:
-np.log2(0.36)*0.36
Out[42]:
0.5306152277996684
In [43]:
entropy([34, 36, 40])
Out[43]:
1.581649163979848
In [44]:
entropy([149, 1, 1])
Out[44]:
0.11485434496175385
In [45]:
entropy([50, 50])
Out[45]:
1.0
In [46]:
entropy([50, 50, 50])
Out[46]:
1.584962500721156
In [47]:
entropy([31, 4, 1])
Out[47]:
0.6815892897202809
In [48]:
#entropy([50, 46, 3])
#entropy([4, 47])
#entropy([41, 50])
#entropy([50, 50])
In [49]:
def weighted_average_entropy(x1, x2):
    N1 = sum(x1)
    N2 = sum(x2)
    N = N1/(N1 + N2)
    return (N1 * entropy(x1) + N2 * entropy(x2)) / (N1 + N2)
In [50]:
weighted_average_entropy([50, 46, 3], [4, 47])
Out[50]:
0.9033518322003758
In [51]:
weighted_average_entropy([50, 9], [41, 50])
Out[51]:
0.8447378399375686
In [52]:
weighted_average_entropy([2, 50, 50], [48])
Out[52]:
0.761345106024134
In [53]:
weighted_average_entropy([50, 50], [50])
Out[53]:
0.6666666666666666

Random Forests

In [54]:
ten_decision_tree_models = []
ten_training_sets = []
for i in range(10):
    current_model = tree.DecisionTreeClassifier()
    temp_iris_training_data, temp_iris_test_data = np.split(iris_data.sample(frac=1), [110])
    temp_iris_training_data = temp_iris_training_data.sort_values("species")
    current_model.fit(temp_iris_training_data[["sepal_length", "sepal_width"]], temp_iris_training_data["species"])
    ten_decision_tree_models.append(current_model)
    ten_training_sets.append(temp_iris_training_data)
In [55]:
def plot_decision_tree(decision_tree_model, data = None, disable_axes = False):
    from matplotlib.colors import ListedColormap
    sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

    xx, yy = np.meshgrid(np.arange(4, 8, 0.02),
                     np.arange(1.9, 4.5, 0.02))

    Z_string = decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
    categories, Z_int = np.unique(Z_string, return_inverse=True)
    Z_int = Z_int.reshape(xx.shape)
    cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
    if data is not None:
        sns.scatterplot(data = data, x = "sepal_length", y="sepal_width", hue="species", legend=False)

    if disable_axes:
        plt.axis("off")
#    if disable_axes:
#        
#        plt.gca().xaxis.label.set_visible(False)
#        plt.gca().yaxis.label.set_visible(False)        
In [56]:
m_num = 0
plot_decision_tree(ten_decision_tree_models[m_num], ten_training_sets[m_num])
# plt.savefig("random_forest_model_1_example.png", dpi = 300, bbox_inches = "tight")
In [57]:
m_num = 7
plot_decision_tree(ten_decision_tree_models[m_num], ten_training_sets[m_num])
# plt.savefig("random_forest_model_2_example.png", dpi = 300, bbox_inches = "tight")
In [58]:
import matplotlib.gridspec as gridspec
gs1 = gridspec.GridSpec(3, 3)
gs1.update(wspace=0.025, hspace=0.025) # set the spacing between axes. 

for i in range(0, 9):
    plt.subplot(gs1[i]) #3, 3, i)
    plot_decision_tree(ten_decision_tree_models[i], None, True)    
    
# plt.savefig("random_forest_model_9_examples.png", dpi = 300, bbox_inches = "tight")