Decision Trees

Sandun Dayananda
17 min readApr 9, 2024

--

Mathematical and machine learning approach

decision trees
decision trees

The decision tree is a supervised machine-learning algorithm that can be used for classification and also can be used for regression tasks. A decision tree can be understood as a sequence of if-then statements that help classify records within a data set based on their characteristics. They provide a visual representation of decision-making processes and are widely used due to their interpretability and ability to handle both categorical and numerical data. In this article, we will explore the fundamentals of decision trees, their construction, and how they make predictions.

You should understand following key things before we go further

key terms
key terms

There are several algorithms to build a decision trees. Those algorithms have been introduced by several people. ID3, C4.5, C5.0,CART, CHAID, QUEST, CRUISE are well known and widely used algorithms for this.

  1. ID3 (Iterative Dichotomiser 3): Developed by Ross Quinlan, ID3 is one of the earliest and simplest decision tree algorithms. It builds the tree using a top-down, greedy approach and utilizes the concept of information gain to select the best attribute for splitting the data at each step.
  2. C4.5: Also developed by Ross Quinlan, C4.5 is an improved version of ID3. It uses information gain ratio instead of information gain, which helps to handle attributes with different numbers of distinct values more effectively. Additionally, C4.5 allows for handling continuous-valued attributes by converting them into discrete ranges.
  3. C5.0: This is the successor of C4.5, developed by Ross Quinlan. It improves the efficiency and accuracy of C4.5, using various optimization techniques. It can also handle larger datasets and offers better performance in terms of speed and memory usage.
  4. CART (Classification and Regression Trees): CART is a versatile algorithm introduced by Leo Breiman. Unlike ID3 and C4.5, which focus mainly on classification tasks, CART can be used for both classification and regression. It uses the Gini impurity for classification problems and the sum of squared errors for regression tasks to measure the quality of splits.
  5. CHAID (Chi-squared Automatic Interaction Detection): CHAID is designed to handle categorical target variables. It utilizes the chi-squared test to assess the relationship between attributes and the target, making it suitable for both binary and multi-class classification tasks.
  6. QUEST (Quick, Unbiased, and Efficient Statistical Trees): QUEST is another decision tree algorithm that utilizes statistical tests like the chi-squared test to identify significant attribute splits. It aims to build smaller trees by focusing on the most informative attributes.
  7. CRUISE (Classification Rule with Unbiased Interaction Search): CRUISE is an improvement over CHAID, which searches for more unbiased and accurate splits by incorporating a significance test for interaction detection between attributes.

Each of these algorithms has its strengths and weaknesses, and their performance may vary depending on the characteristics of the dataset and the specific problem at hand. Data scientists and machine learning practitioners often experiment with these algorithms to determine which one suits their task best.

Impurity in Decision Tree

Impurity, refers to the measure of disorder or uncertainty in a dataset. It is a concept used to evaluate the homogeneity( the degree of similarity or uniformity within a group or subset of data. It is the opposite of impurity. When a group is homogeneous, it means that the elements within that group share similar characteristics or belong to the same class or category.) of the data within a node of the decision tree. The main objective of decision tree algorithms is to create nodes with low impurity, leading to more accurate and informative splits.

In classification tasks, impurity is associated with how mixed or diverse the class labels are within a node. The more evenly distributed the class labels, the higher the impurity, while a node with only one class label is considered pure (impurity = 0).

In regression tasks, impurity refers to the variance of the target values within a node. A node with low impurity has target values that are tightly grouped together, while a node with high impurity has target values spread out more widely.

There are several metrics to measure impurity in decision trees:

for Classification: Entropy, Information Gain, and Gini Index

for regression: Mean Squared Error (MSE)

During the construction of a decision tree, the algorithm aims to find the best splits that reduce the impurity in child nodes compared to the impurity in the parent node. By repeatedly making such splits, the tree progressively becomes more homogeneous, leading to better and more accurate predictions. The impurity criteria (entropy, Information Gain, Gini index, or MSE) are used to guide the decision tree learning process, helping it find the most informative features and optimal split points.

In decision tree algorithms, nodes are split into subsets based on the values of different features. The goal of the algorithm is to reduce the impurity as much as possible through the splitting process.

If we build a decision tree using the mentioned algorithms(ID3, C4.5, C5.0,CART, CHAID, QUEST, CRUISE) for Classification then we use Entropy, Information Gain, and Gini Index as criterions. If we build the decision tree for Regression then we use Mean Squared Error (MSE).

Following is the mathematical explanations of above mentioned criterions. This would be good for us to understand them further.

Entropy:

Entropy is a measure of uncertainty or randomness in a dataset for classification tasks. It quantifies the impurity of a node by calculating the weighted average of the probabilities of each class label within that node. The formula to calculate entropy is as follows:

Entropy = — Σ (P(c) * log2(P(c)))

where P(c) represents the probability of the data point belonging to class “c.” Entropy ranges from 0 (completely pure, all data points belong to a single class) to 1 (maximum impurity, data points evenly distributed across all classes).

Information Gain:

Information Gain is a measure of the effectiveness of a feature in reducing uncertainty or impurity in classification tasks. It represents the difference between the entropy of the parent node and the weighted average of the entropies of its child nodes after a split. The formula for Information Gain is given by:

Information Gain = Entropy(parent node) — Σ ( (Number of data points in child node / Total number of data points in parent node) * Entropy(child node) )

The higher the Information Gain, the more informative the feature is for the split, as it leads to a more significant reduction in entropy and better separation of classes.

Gini Index:

The Gini Index is another measure of impurity in classification tasks(Can be considered as the cost function). It measures the probability of misclassifying a randomly chosen data point in a node if it were randomly labeled according to the class distribution of that node. The formula to calculate the Gini Index is as follows:

Gini Index = 1 — Σ ( P(c)² )

where P(c) represents the probability of the data point belonging to class “c.” Similar to entropy, the Gini Index ranges from 0 (completely pure) to 1 (maximum impurity).

Mean Squared Error (MSE):

MSE is a measure of the quality of a split used in regression tasks. It calculates the average squared difference between the predicted target values and the actual target values within a node. The formula for Mean Squared Error is given by:

MSE = 1/n * Σ ( (y_i — ȳ)² )

where “n” is the number of data points in the node, “y_i” represents the actual target value for data point i, and “ȳ” is the predicted target value for the node (usually the mean of the actual target values). The goal in regression decision trees is to minimize the MSE, resulting in more accurate predictions and a more homogeneous subset of data within each node.

Hyperparameters in Decision Tree

The hyperparameters in decision trees are parameters that are set before the start of the learning process . These hyperparameters control the learning process and the structure of the tree. Following are some common hyperparameters in decision trees :

  1. max_depth: The maximum depth of the decision tree. Setting this can prevent the tree from growing too deep and overfitting the data.
  2. min_samples_split: The minimum number of samples required to split an internal node. It can prevent further splitting if the number of samples falls below this threshold.
  3. min_samples_leaf: The minimum number of samples required to be at a leaf node. It ensures that each leaf has enough samples, potentially avoiding overfitting.
  4. max_features: The maximum number of features to consider when looking for the best split. Limiting this can make the tree less sensitive to noise in the data.
  5. criterion(Splitting Criteria): The criterion used to measure the quality of a split. Splitting at the decision nodes occurs according to this.

Hyperparameters need to be set by the user before training the decision tree model, and their values can have a considerable impact on the performance and behaviour of the tree. Proper tuning of hyperparameters is critical to building a decision tree model which performs well and generalized.

In scikit-learn (we use it here) the algorithm used for decision tree-based models is an optimized version of CART (Classification and Regression Trees). Scikit-learn does not implement the exact versions of ID3, C4.5, C5.0, CHAID, QUEST, or CRUISE.

The decision tree implementation in scikit-learn is a variant of CART that supports binary and multi-class classification as well as regression tasks. It uses the Gini impurity criterion for classification problems and the mean squared error criterion for regression tasks to decide on the best splits for creating the tree. (However, if you want to use a different criterion for the tasks, you can specify it using the criterion hyperparameter when initializing the DecisionTree. For example, you can use 'mae' (mean absolute error) instead of MSE by setting criterion='mae'when you define decision tree regressor. But remember, MSE is the default and commonly used criterion for regression decision trees, and it often works well in practice.)

Here are the main classes in scikit-learn that utilize decision trees:

  1. DecisionTreeClassifier: This class is used for building decision tree models for classification tasks.
  2. DecisionTreeRegressor: This class is used for building decision tree models for regression tasks.

For example, to create a decision tree classifier(not regressor) in scikit-learn, you would typically do something like this:

from sklearn.tree import DecisionTreeClassifier

# create a decision tree classifier
clf = DecisionTreeClassifier()

# train the classifier on your training data
clf.fit(X_train, y_train)

# make predictions on new data
y_pred = clf.predict(X_test)

Similarly, for regression tasks, you would use DecisionTreeRegressor.

Keep in mind that scikit-learn provides a wide range of hyperparameters that you can tune to control the tree’s growth and prevent overfitting, among other things. This allows you to customize the decision tree model to suit your specific needs.

Components of the Decision Tree

A decision tree consists of several components that collectively form its structure and define how it makes decisions. Here are the main components of a decision tree:

  1. Root Node: The topmost node of the tree, representing the entire dataset. It is the starting point for decision making.
  2. Internal Nodes: Nodes in the tree that have child nodes. Each internal node represents a decision based on a specific feature and its split point.
  3. Branches: The edges connecting nodes in the tree. Each branch corresponds to a possible outcome of the decision made at its parent node.
  4. Leaf Nodes: Terminal nodes at the bottom of the tree where no further splits occur. They represent the final classifications or regression predictions.
tree structure
tree structure

Let’s consider a binary classification problem to distinguish between two types of fruits: Apples and Oranges. The decision tree aims to classify fruits based on their weight and color.

  • The root node represents the entire dataset, which contains both apples and oranges.
  • The first internal node might split the data based on weight. It could ask, “Is the fruit’s weight less than 150 grams?” The tree branches into two child nodes: one for fruits with weight <150 grams and another for fruits with weight ≥150 grams.
  • Next, for each child node, another internal node may split the data based on color. For instance, “Is the fruit’s color yellow?” The tree branches again based on the color condition.
  • Eventually, the process continues, and we reach leaf nodes where the final classifications are made. For example, if a leaf node contains only apples, it means the decision tree predicts apples for fruits that meet the specific conditions leading to that node.

This example demonstrates how the components of a decision tree work together to make decisions based on features and create a structure that efficiently separates fruits into their appropriate categories. Now, you should have the idea of the structure of the decision tree and how it works.

Underfitting and Overfitting

Underfitting

underfitting can occur in decision trees, just like in any other machine learning model. Underfitting happens when the model is too simple to capture the underlying patterns and relationships present in the data, leading to poor performance on both the training and test datasets.

In the context of decision trees, underfitting typically occurs when the tree is shallow and fails to capture the complexity of the data. Some of the common reasons for underfitting in decision trees are:

  1. Insufficient Splits: The decision tree may not be able to create enough splits to distinguish different classes or make accurate predictions, resulting in poor generalization.
  2. High Minimum Samples per Leaf: Setting a high value for the min_samples_leaf hyperparameter can cause the tree to form leaves with a large number of samples, leading to oversimplified predictions.
  3. High Maximum Tree Depth: Setting a low value for the max_depth hyperparameter can limit the depth of the tree, making it too shallow to capture intricate patterns in the data.
  4. Overgeneralization: When the tree is too generalized, it fails to adapt well to the training data, resulting in low accuracy on both the training and test sets.

To address underfitting in decision trees, you can try the following:

  1. Reduce min_samples_leaf: Decrease the minimum number of samples required to be at a leaf node, allowing the tree to split further and capture more information.
  2. Increase max_depth: Allow the tree to grow deeper by increasing the max_depth hyperparameter, allowing for more complex decision rules.
  3. Use Ensemble Techniques: Combine multiple decision trees using ensemble methods like Random Forests or Gradient Boosting, which can improve the model’s performance and reduce underfitting.
  4. Feature Engineering: Ensure that you have informative features and preprocess the data to provide the model with the necessary information to make accurate predictions.

By adjusting these hyperparameters and using appropriate techniques, you can mitigate underfitting in decision trees and achieve better performance on the training and test datasets.

Overfitting

We use two major techniques to deal with overfitting. They are Pruning(Post-Pruning) and Early Stopping (Pre-pruning).

Pruning

There are two main techniques to prune decision trees. As the word ‘Prune’ implies, we are going to cut the tree parts and make it simple.

Pruning(Post-Pruning)

In decision trees, pruning is a technique used to prevent overfitting, where the tree becomes too complex and may not generalize well to new data. After constructing the tree, pruning involves cutting back some branches to create a simpler and more generalized model. This is essential because a fully grown tree might memorize the training data exactly, but it can perform poorly on unseen data due to its complexity.

One of the commonly used pruning strategies is the “Minimum Error” approach. This strategy involves identifying the point at which the cross-validated error is at its minimum. Cross-validation is a technique that involves dividing the data into subsets. Most of the data is used for training the decision tree, while the remaining portion is used to evaluate the accuracy of the model. By finding the pruning point where the cross-validated error is minimized, we strike a balance between complexity and accuracy.

Let’s illustrate the Minimum Error pruning strategy with a sample Python code using scikit-learn:

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split, cross_val_score

# assuming you have your data loaded and features and labels prepared
# X = features, y = labels

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create an initial decision tree
initial_tree = DecisionTreeClassifier(random_state=42)
initial_tree.fit(X_train, y_train)

# perform cross-validation to find the optimal pruning point
cv_errors = []
for alpha in range(1, 21): # Vary pruning complexity parameter alpha from 1 to 20
pruned_tree = DecisionTreeClassifier(random_state=42, ccp_alpha=alpha)
cv_scores = cross_val_score(pruned_tree, X_train, y_train, cv=5)
cv_errors.append(1 - cv_scores.mean()) # Calculate cross-validation error

# Find the alpha with the minimum error
optimal_alpha = range(1, 21)[cv_errors.index(min(cv_errors))]

# Create the final pruned tree using the optimal alpha
final_pruned_tree = DecisionTreeClassifier(random_state=42, ccp_alpha=optimal_alpha)
final_pruned_tree.fit(X_train, y_train)

# evalate the pruned tree on the test set
accuracy = final_pruned_tree.score(X_test, y_test)
print("Pruned Tree Accuracy:", accuracy)

In this example, we used cross-validation to identify the optimal complexity parameter (alpha) that leads to the minimum error. The final pruned tree is then constructed using the chosen alpha, providing a more interpretable model with a slight increase in accuracy compared to the fully grown tree.

The second pruning strategy, “Smallest Tree,” goes slightly beyond the minimum error point. Technically, it prunes the tree to achieve cross-validation error within 1 standard error of the minimum error. The benefit is a more interpretable and less complex tree, while still maintaining reasonable accuracy. However, the increase in error compared to the minimum error point is minimal.

Early stopping or pre-pruning

In decision trees, early stopping (or pre-pruning) is an alternative approach to prevent overfitting. Instead of building the tree to its maximum depth and then pruning it back, early stopping interrupts the tree-building process before creating leaves with very small sample sizes. This helps create a simpler model that generalizes better to unseen data.

During the tree-building process, at each stage of splitting, we calculate the cross-validation error. If the error does not decrease significantly enough, we stop the tree-building process early. The intuition is that if a split does not contribute significantly to reducing the error, it may not be beneficial to include it in the final tree. By stopping early, we avoid creating complex branches that may memorize noise in the training data.

However, there is a trade-off with early stopping. It might underfit the data if we stop the tree-building process too soon. In some cases, a particular split might not be helpful on its own, but it may pave the way for more meaningful splits later, leading to a more accurate model overall.

Early stopping can be used either on its own or in combination with pruning techniques. Post-pruning is considered more mathematically rigorous, as it ensures finding a tree that is at least as good as the one obtained through early stopping. On the other hand, early stopping serves as a quick heuristic to prevent overfitting without requiring additional computation for pruning.

Let’s demonstrate the early stopping approach using a sample Python code with scikit-learn:

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

# assuming you have your data loaded and features and labels prepared
# X = features, y = labels

# split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create an initial decision tree with early stopping
early_stopping_tree = DecisionTreeClassifier(random_state=42, min_samples_split=10)
early_stopping_tree.fit(X_train, y_train)

# evaluate early stopping tree on the test set
accuracy = early_stopping_tree.score(X_test, y_test)
print("Early Stopping Tree Accuracy:", accuracy)

In this example, we specified the min_samples_split parameter as 10, which means the tree-building process will stop when the number of samples in a node falls below 10. This is a way to achieve early stopping, preventing the creation of branches with very small sample sizes. The resulting tree is simpler and more likely to generalize better to new data.

Now we have learned all the important things regarding decision trees. Let’s apply learned things on Famous Iris Dataset. Here I’m going to use Scikit-learn decision tree classifier.

(you can find the full code in my GitHub)

!pip install graphviz pydotplus #(if we want to visualize the decision tree)


import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

# load the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target

# convet the dataset into a pandas DataFrame for better visualization (this is optional)
df = pd.DataFrame(data=np.c_[iris['data'], iris['target']], columns=iris['feature_names'] + ['target'])
display(df)

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create the Decision Tree classifier with Gini index as the criterion
decision_tree = DecisionTreeClassifier(criterion='gini', random_state=42)

# Train the model using the training data
decision_tree.fit(X_train, y_train)

# Evaluate the model on the test set
accuracy = decision_tree.score(X_test, y_test)
print("Model Accuracy:", accuracy)

# Perform cross-validation to find the optimal pruning point
cv_errors = []
for alpha in range(1, 21): # Vary pruning complexity parameter alpha from 1 to 20
pruned_tree = DecisionTreeClassifier(criterion='gini', ccp_alpha=alpha, random_state=42)
cv_scores = cross_val_score(pruned_tree, X_train, y_train, cv=5)
cv_errors.append(1 - cv_scores.mean()) # Calculate cross-validation error

# find the alpha with the minimum error
optimal_alpha = range(1, 21)[cv_errors.index(min(cv_errors))]

# create final pruned tree using the optimal alpha
final_pruned_tree = DecisionTreeClassifier(criterion='gini', ccp_alpha=optimal_alpha, random_state=42)
final_pruned_tree.fit(X_train, y_train)

# Evaluate the pruned tree on the test set
pruned_accuracy = final_pruned_tree.score(X_test, y_test)
print("Pruned Tree Accuracy:", pruned_accuracy)

# Visualize the Decision Tree
plt.figure(figsize=(12, 8))
plot_tree(final_pruned_tree, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.show()

# making predictions on new data
new_data = np.array([[5.1, 3.5, 1.4, 0.2], # sample 1
[6.2, 2.9, 4.3, 1.3], # sample 2
[7.7, 3.0, 6.1, 2.3]]) # sample 3

predictions = final_pruned_tree.predict(new_data)
print("Predictions:", predictions)

# Plot the decision tree
plt.figure(figsize=(12, 8))
plot_tree(decision_tree, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.show()
dataset and other metrics
dataset and other metrics
prediction
prediction
generated decision tree
generated decision tree

Let’s walk through the code step-by-step and provide simple explanations for each part:

Import necessary libraries and load the Iris dataset:

  • We import the required libraries such as NumPy, Pandas, scikit-learn modules for decision trees, and Matplotlib for visualization.
  • We load the Iris dataset using load_iris() from scikit-learn, which provides features (X) and target labels (y).

Split the dataset into training and testing sets:

  • We use train_test_split() from scikit-learn to split the dataset into training and testing subsets. We use 80% of the data for training and 20% for testing.

Create and train the Decision Tree model:

  • We create a Decision Tree classifier using DecisionTreeClassifier() from scikit-learn. By default, it uses the Gini index as the criterion to measure impurity at each split.
  • We train the decision tree on the training data using the fit() method.

Evaluate the model’s performance on the test set:

  • We calculate the accuracy of the model on the test data using the score() method of the decision tree.

Perform cross-validation to find the optimal pruning point:

  • We use cross-validation to find the optimal pruning complexity parameter (alpha) that prevents overfitting.
  • For different values of alpha (ranging from 1 to 20), we create pruned decision trees and calculate the cross-validation error using 5-fold cross-validation.
  • We store the cross-validation errors in a list.

Find the alpha with the minimum error and create the final pruned tree:

  • We identify the alpha value that minimizes the cross-validation error and choose it as the optimal pruning parameter.
  • We create a new decision tree with this optimal alpha, which will be the final pruned tree.

Evaluate the pruned tree’s accuracy on the test set:

  • We calculate the accuracy of the pruned tree on the test data using the score() method.

Visualize the Decision Tree:

  • We use Matplotlib’s plot_tree() function to visualize the final pruned decision tree, showing the splits and leaf nodes.

Make predictions on new data:

  • We provide three new data points (samples) and use the final pruned tree to predict their target classes.

Overall, the code demonstrates the complete process of applying the Decision Tree algorithm to the Iris dataset, including data loading, model training, evaluation, pruning, and visualization. The final pruned tree is expected to have a more balanced trade-off between complexity and accuracy, making it a better model for generalization to new data.

--

--

Sandun Dayananda

Big Data Engineer with passion for Machine Learning and DevOps | MSc Industrial Analytics at Uppsala University