ジュリアで決定木を使う方法
概要
ジュリアで決定木decision Treeを実装したDecisionTree.jl
パッケージを紹介する1。
コード
例としては、Rの組み込みデータであるiris
データセットを使う。目標は、四つの変数SepalLength
, SepalWidth
, PetalLength
, PetalWidth
を用いてSpecies
を予測する決定木を作り、そのパフォーマンスを評価することである。
julia> iris = dataset("datasets", "iris")
150×5 DataFrame
Row │ SepalLength SepalWidth PetalLength PetalWidth Speci ⋯
│ Float64 Float64 Float64 Float64 Cat… ⋯
─────┼──────────────────────────────────────────────────────────
1 │ 5.1 3.5 1.4 0.2 setos ⋯
2 │ 4.9 3.0 1.4 0.2 setos
3 │ 4.7 3.2 1.3 0.2 setos
4 │ 4.6 3.1 1.5 0.2 setos
5 │ 5.0 3.6 1.4 0.2 setos ⋯
6 │ 5.4 3.9 1.7 0.4 setos
7 │ 4.6 3.4 1.4 0.3 setos
8 │ 5.0 3.4 1.5 0.2 setos
9 │ 4.4 2.9 1.4 0.2 setos ⋯
10 │ 4.9 3.1 1.5 0.1 setos
11 │ 5.4 3.7 1.5 0.2 setos
12 │ 4.8 3.4 1.6 0.2 setos
13 │ 4.8 3.0 1.4 0.1 setos ⋯
14 │ 4.3 3.0 1.1 0.1 setos
15 │ 5.8 4.0 1.2 0.2 setos
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
136 │ 7.7 3.0 6.1 2.3 virgi
137 │ 6.3 3.4 5.6 2.4 virgi ⋯
138 │ 6.4 3.1 5.5 1.8 virgi
139 │ 6.0 3.0 4.8 1.8 virgi
140 │ 6.9 3.1 5.4 2.1 virgi
141 │ 6.7 3.1 5.6 2.4 virgi ⋯
142 │ 6.9 3.1 5.1 2.3 virgi
143 │ 5.8 2.7 5.1 1.9 virgi
144 │ 6.8 3.2 5.9 2.3 virgi
145 │ 6.7 3.3 5.7 2.5 virgi ⋯
146 │ 6.7 3.0 5.2 2.3 virgi
147 │ 6.3 2.5 5.0 1.9 virgi
148 │ 6.5 3.0 5.2 2.0 virgi
149 │ 6.2 3.4 5.4 2.3 virgi ⋯
150 │ 5.9 3.0 5.1 1.8 virgi
1 column and 120 rows omitted
モデル作成
julia> using DecisionTree
julia> model = DecisionTreeClassifier(max_depth=2)
DecisionTreeClassifier
max_depth: 2
min_samples_leaf: 1
min_samples_split: 2
min_purity_increase: 0.0
pruning_purity_threshold: 1.0
n_subfeatures: 0
classes: nothing
root: nothing
モデルを作る。DecisionTreeClassifier()
を通じて、決定木で使われるパラメーターを指定できる。
モデルフィッティング
julia> features = Matrix(iris[:, Not(:Species)]);
julia> labels = iris.Species;
julia> fit!(model, features, labels)
DecisionTreeClassifier
max_depth: 2
min_samples_leaf: 1
min_samples_split: 2
min_purity_increase: 0.0
pruning_purity_threshold: 1.0
n_subfeatures: 0
classes: ["setosa", "versicolor", "virginica"]
root: Decision Tree
Leaves: 3
Depth: 2
データを独立変数と従属変数に分けて、fit!()
関数でモデルを学習させる。
パフォーマンス確認
julia> print_tree(model)
Feature 3 < 2.45 ?
├─ setosa : 50/50
└─ Feature 4 < 1.75 ?
├─ versicolor : 49/54
└─ virginica : 45/46
学習が終わったモデルは、print_tree()
関数を通じてどんな構造をしているか確認できる。
julia> sum(labels .== predict(model, features)) / length(labels)
0.96
簡単に正解率を確認した結果、96%程度でかなり良いことがわかった。
全コード
using RDatasets
iris = dataset("datasets", "iris")
using DecisionTree
model = DecisionTreeClassifier(max_depth=2)
features = Matrix(iris[:, Not(:Species)]);
labels = iris.Species;
fit!(model, features, labels)
print_tree(model)
sum(labels .== predict(model, features)) / length(labels)
環境
- OS: Windows
- julia: v1.9.0