logo

How to Use Decision Trees in Julia 📂Julia

How to Use Decision Trees in Julia

Overview

Introducing the DecisionTree.jl package, which implements Decision Trees in Julia1.

Code

As an example, we use the iris dataset, a classic built-in dataset in R. Our goal is to create a decision tree that uses four variables SepalLength, SepalWidth, PetalLength, PetalWidth to predict Species and evaluate its performance.

julia> iris = dataset("datasets", "iris")
150×5 DataFrame
 Row │ SepalLength  SepalWidth  PetalLength  PetalWidth  Speci ⋯
     │ Float64      Float64     Float64      Float64     Cat…  ⋯
─────┼──────────────────────────────────────────────────────────
   1 │         5.1         3.5          1.4         0.2  setos ⋯
   2 │         4.9         3.0          1.4         0.2  setos  
   3 │         4.7         3.2          1.3         0.2  setos  
   4 │         4.6         3.1          1.5         0.2  setos  
   5 │         5.0         3.6          1.4         0.2  setos ⋯
   6 │         5.4         3.9          1.7         0.4  setos  
   7 │         4.6         3.4          1.4         0.3  setos  
   8 │         5.0         3.4          1.5         0.2  setos  
   9 │         4.4         2.9          1.4         0.2  setos ⋯
  10 │         4.9         3.1          1.5         0.1  setos  
  11 │         5.4         3.7          1.5         0.2  setos  
  12 │         4.8         3.4          1.6         0.2  setos  
  13 │         4.8         3.0          1.4         0.1  setos ⋯
  14 │         4.3         3.0          1.1         0.1  setos  
  15 │         5.8         4.0          1.2         0.2  setos  
  ⋮  │      ⋮           ⋮            ⋮           ⋮           ⋮ ⋱
 136 │         7.7         3.0          6.1         2.3  virgi  
 137 │         6.3         3.4          5.6         2.4  virgi ⋯
 138 │         6.4         3.1          5.5         1.8  virgi  
 139 │         6.0         3.0          4.8         1.8  virgi  
 140 │         6.9         3.1          5.4         2.1  virgi  
 141 │         6.7         3.1          5.6         2.4  virgi ⋯
 142 │         6.9         3.1          5.1         2.3  virgi  
 143 │         5.8         2.7          5.1         1.9  virgi  
 144 │         6.8         3.2          5.9         2.3  virgi  
 145 │         6.7         3.3          5.7         2.5  virgi ⋯
 146 │         6.7         3.0          5.2         2.3  virgi  
 147 │         6.3         2.5          5.0         1.9  virgi  
 148 │         6.5         3.0          5.2         2.0  virgi  
 149 │         6.2         3.4          5.4         2.3  virgi ⋯
 150 │         5.9         3.0          5.1         1.8  virgi  
                                   1 column and 120 rows omitted

Model Creation

julia> using DecisionTree

julia> model = DecisionTreeClassifier(max_depth=2)
DecisionTreeClassifier
max_depth:                2
min_samples_leaf:         1
min_samples_split:        2
min_purity_increase:      0.0
pruning_purity_threshold: 1.0
n_subfeatures:            0
classes:                  nothing
root:                     nothing

Create the model. Parameters for the decision tree can be given through DecisionTreeClassifier().

Model Fitting

julia> features = Matrix(iris[:, Not(:Species)]);

julia> labels = iris.Species;

julia> fit!(model, features, labels)        
DecisionTreeClassifier
max_depth:                2
min_samples_leaf:         1
min_samples_split:        2
min_purity_increase:      0.0
pruning_purity_threshold: 1.0
n_subfeatures:            0
classes:                  ["setosa", "versicolor", "virginica"]
root:                     Decision Tree     
Leaves: 3
Depth:  2

Divide the data into independent and dependent variables and train the model using the fit!() function.

Performance Check

julia> print_tree(model)
Feature 3 < 2.45 ?
├─ setosa : 50/50
└─ Feature 4 < 1.75 ?
    ├─ versicolor : 49/54
    └─ virginica : 45/46

After training, the structure of the model can be checked with the print_tree() function.

julia> sum(labels .== predict(model, features)) / length(labels)
0.96

A quick check of the accuracy rate showed it was around 96%, which is quite decent.

Full Code

using RDatasets
iris = dataset("datasets", "iris")

using DecisionTree
model = DecisionTreeClassifier(max_depth=2)

features = Matrix(iris[:, Not(:Species)]);
labels = iris.Species;
fit!(model, features, labels)

print_tree(model)

sum(labels .== predict(model, features)) / length(labels)

Environment

  • OS: Windows
  • julia: v1.9.0