BOBOBK

Decision Trees in Machine Learning and Python Examples

TECHNOLOGY

A tree in real life has many branches and leaves. In fact, the concept of trees is widely used in machine learning, covering classification and regression. In decision analysis, decision trees can be used to intuitively represent and make decisions. A decision tree, as the name implies, is a tree-shaped decision model. Although decision trees are commonly used in data mining and machine learning, this article will focus on decision trees and their implementation in Python.

How to Represent an Algorithm as a Tree

To this end, let’s consider a very basic example using the Titanic dataset (which can be directly obtained from sklearn). The model uses 3 features from the dataset: gender, age, and siblings/spouses aboard (number of siblings or spouses).

Wiki diagram…

Decision trees are upside down, with the root at the top. In the above image, the black bold text indicates conditions/internal nodes. The tree splits based on these conditions into branches/edges. The ends of the branches do not split further and represent decisions/leaves. In this case, whether the passenger died or survived is shown in red and green, respectively.

Although real datasets will have more features, this only shows a branch of a larger tree, but you cannot ignore the simplicity of this algorithm. The feature importance is obvious. This approach is often called learning decision trees from data, and the tree above is called a classification tree because the goal is to classify passengers as survived or dead. Regression trees are represented similarly but predict continuous values like house prices. Usually, decision tree algorithms are called CART (Classification and Regression Trees).

So, what actually happens in the background? Growing a tree involves deciding which features to select and what splitting criteria to use, as well as when to stop. Since trees can grow arbitrarily, you need to prune them to reduce the decision tree size, keeping nodes moderate to prevent overfitting.

Let’s start with common techniques used for splitting.

Splitting Attributes

First, features need to be divided into two categories, for example, age, split at 9.5 years old—one group greater than 9.5, the other less than 9.5. During this process, all features are considered, and different split points are tested using a cost function. The split with the highest (or lowest) cost is selected. Consider the early example of a tree learned from the Titanic dataset. At the first split or root, all attributes/features are considered, and the training data is divided into groups based on that split. We selected 3 features, so there are 3 candidate splits. Now, we use a function to calculate how much precision loss each split would cause. The split with the lowest cost is chosen—in our example, passenger gender. The algorithm is essentially recursive because the same strategy can be used to further subdivide the formed groups. This process is why the algorithm is called greedy, as it greedily tries to reduce the cost. This makes the root node the best predictor/classifier.

Split Cost

Let’s take a closer look at the cost functions used for classification and regression. In both cases, the cost function tries to find the most homogeneous branches or branches with groups of similar responses. This helps us be more confident that test data inputs will follow a specific path.

Regression: sum(y — prediction)²

Suppose we are predicting house prices. The decision tree starts splitting by considering each feature in the training data. The average response of the training data inputs in a particular group is regarded as that group’s prediction. The above function applies to all data points and calculates the cost for all candidate splits. Again, the split with the lowest cost is selected.

Classification: G = sum(pk * (1 — pk))

The Gini index measures the degree of mixture of response categories within the groups created by the split, giving a measure of split quality. Here, pk is the proportion of inputs of the same class in a particular group. Perfect class purity occurs when a group contains inputs from only one class, in which case pk is 1 or 0 and G = 0. A node split with a 50–50 class distribution has the worst purity. For binary classification, this corresponds to pk = 0.5 and G = 0.5.

When to Stop Splitting

You might ask when to stop growing the tree? Because problems often have many features, many splits lead to a huge tree. Such a tree is complex and very likely to overfit. Therefore, we need to know when to stop. One way is to set the minimum number of training inputs used on each leaf. For example, in this example, we can require at least 10 passengers to make a decision (dead or survived), ignoring leaves with fewer than 10 passengers. Another way is to set the maximum depth of the model. Maximum depth refers to the longest path length from root to leaf.

Pruning

Pruning can further improve tree performance. It involves removing branches using features of low importance. This reduces tree complexity and improves predictive ability by reducing overfitting.

Pruning can start from the root or leaves. The simplest pruning starts from the leaves and removes each node with the most popular class in the leaf if it does not reduce accuracy. This is called reduced error pruning. More complex pruning methods, such as cost complexity pruning, use a learning parameter (alpha) to trade off whether nodes can be removed based on subtree size. This is also called weakest link pruning.

Advantages of CART

Decision tree methods include ID3, C4.5, C5.0, CART. CART is used more frequently and the optimized version in sklearn is also based on CART. So here, only CART advantages are described:

  • Easy to understand, interpret, and visualize.
  • Decision trees implicitly perform variable screening or feature selection.
  • Can handle numerical and categorical data. Also handles multi-output problems.
  • Requires little data preparation from users.
  • Nonlinear relationships between parameters do not affect tree performance.

Disadvantages of CART

  • Decision tree learning may create overly complex trees that do not generalize well. This is called overfitting.
  • Decision trees can be unstable because small changes in data can lead to completely different trees. This is called variance and can be reduced with bagging and boosting methods.
  • Greedy algorithms do not guarantee the globally optimal decision tree. This can be mitigated by training multiple trees with random feature and sample replacement.
  • If some features dominate, the decision tree learner creates biased trees. Therefore, it is recommended to balance the dataset before fitting the decision tree.

Python Implementation

After introducing decision trees, here is a demonstration of implementing decision trees with sklearn.

import numpy as np
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import pandas as pd
import requests

text = requests.get("https://www.bobobk.com/wp-content/uploads/2020/01/train.csv").text
with open("train.csv",'w') as fw:
    fw.write(text)
titanic = pd.read_csv("train.csv")

titanic["Age"] = titanic["Age"].fillna(titanic["Age"].median())
titanic["Embarked"] = titanic["Embarked"].fillna("S")
candidate_train_predictors = titanic.drop(['PassengerId','Survived','Name','Ticket','Cabin'], axis=1)
categorical_cols = [cname for cname in candidate_train_predictors.columns if 
                                candidate_train_predictors[cname].nunique() < 10 and
                                candidate_train_predictors[cname].dtype == "object"]
numeric_cols = [cname for cname in candidate_train_predictors.columns if 
                                candidate_train_predictors[cname].dtype in ['int64', 'float64']]
my_cols = categorical_cols + numeric_cols
train_predictors = candidate_train_predictors[my_cols]
dummy_encoded_train_predictors = pd.get_dummies(train_predictors)
######
y_target = titanic["Survived"].values
x_features_one = dummy_encoded_train_predictors.values
x_train, x_validation, y_train, y_validation = train_test_split(x_features_one,y_target,test_size=.25,random_state=1)
print(x_features_one)

#####
tree_one = tree.DecisionTreeClassifier()
tree_one = tree_one.fit(x_features_one, y_target)
tree_one_accuracy = round(tree_one.score(x_features_one, y_target), 4)
print("Accuracy: %0.4f" % (tree_one_accuracy))

#Accuracy: 0.9798

Related