How to Use Decision Trees in Julia
Overview
Introducing the DecisionTree.jl
package, which implements Decision Trees in Julia1.
Code
As an example, we use the iris
dataset, a classic built-in dataset in R. Our goal is to create a decision tree that uses four variables SepalLength
, SepalWidth
, PetalLength
, PetalWidth
to predict Species
and evaluate its performance.
julia> iris = dataset("datasets", "iris")
150×5 DataFrame
Row │ SepalLength SepalWidth PetalLength PetalWidth Speci ⋯
│ Float64 Float64 Float64 Float64 Cat… ⋯
─────┼──────────────────────────────────────────────────────────
1 │ 5.1 3.5 1.4 0.2 setos ⋯
2 │ 4.9 3.0 1.4 0.2 setos
3 │ 4.7 3.2 1.3 0.2 setos
4 │ 4.6 3.1 1.5 0.2 setos
5 │ 5.0 3.6 1.4 0.2 setos ⋯
6 │ 5.4 3.9 1.7 0.4 setos
7 │ 4.6 3.4 1.4 0.3 setos
8 │ 5.0 3.4 1.5 0.2 setos
9 │ 4.4 2.9 1.4 0.2 setos ⋯
10 │ 4.9 3.1 1.5 0.1 setos
11 │ 5.4 3.7 1.5 0.2 setos
12 │ 4.8 3.4 1.6 0.2 setos
13 │ 4.8 3.0 1.4 0.1 setos ⋯
14 │ 4.3 3.0 1.1 0.1 setos
15 │ 5.8 4.0 1.2 0.2 setos
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
136 │ 7.7 3.0 6.1 2.3 virgi
137 │ 6.3 3.4 5.6 2.4 virgi ⋯
138 │ 6.4 3.1 5.5 1.8 virgi
139 │ 6.0 3.0 4.8 1.8 virgi
140 │ 6.9 3.1 5.4 2.1 virgi
141 │ 6.7 3.1 5.6 2.4 virgi ⋯
142 │ 6.9 3.1 5.1 2.3 virgi
143 │ 5.8 2.7 5.1 1.9 virgi
144 │ 6.8 3.2 5.9 2.3 virgi
145 │ 6.7 3.3 5.7 2.5 virgi ⋯
146 │ 6.7 3.0 5.2 2.3 virgi
147 │ 6.3 2.5 5.0 1.9 virgi
148 │ 6.5 3.0 5.2 2.0 virgi
149 │ 6.2 3.4 5.4 2.3 virgi ⋯
150 │ 5.9 3.0 5.1 1.8 virgi
1 column and 120 rows omitted
Model Creation
julia> using DecisionTree
julia> model = DecisionTreeClassifier(max_depth=2)
DecisionTreeClassifier
max_depth: 2
min_samples_leaf: 1
min_samples_split: 2
min_purity_increase: 0.0
pruning_purity_threshold: 1.0
n_subfeatures: 0
classes: nothing
root: nothing
Create the model. Parameters for the decision tree can be given through DecisionTreeClassifier()
.
Model Fitting
julia> features = Matrix(iris[:, Not(:Species)]);
julia> labels = iris.Species;
julia> fit!(model, features, labels)
DecisionTreeClassifier
max_depth: 2
min_samples_leaf: 1
min_samples_split: 2
min_purity_increase: 0.0
pruning_purity_threshold: 1.0
n_subfeatures: 0
classes: ["setosa", "versicolor", "virginica"]
root: Decision Tree
Leaves: 3
Depth: 2
Divide the data into independent and dependent variables and train the model using the fit!()
function.
Performance Check
julia> print_tree(model)
Feature 3 < 2.45 ?
├─ setosa : 50/50
└─ Feature 4 < 1.75 ?
├─ versicolor : 49/54
└─ virginica : 45/46
After training, the structure of the model can be checked with the print_tree()
function.
julia> sum(labels .== predict(model, features)) / length(labels)
0.96
A quick check of the accuracy rate showed it was around 96%, which is quite decent.
Full Code
using RDatasets
iris = dataset("datasets", "iris")
using DecisionTree
model = DecisionTreeClassifier(max_depth=2)
features = Matrix(iris[:, Not(:Species)]);
labels = iris.Species;
fit!(model, features, labels)
print_tree(model)
sum(labels .== predict(model, features)) / length(labels)
Environment
- OS: Windows
- julia: v1.9.0