logo

ジュリアで決定木を使う方法 📂ジュリア

ジュリアで決定木を使う方法

概要

ジュリアで決定木decision Treeを実装したDecisionTree.jlパッケージを紹介する1

コード

例としては、Rの組み込みデータであるirisデータセットを使う。目標は、四つの変数SepalLength, SepalWidth, PetalLength, PetalWidthを用いてSpeciesを予測する決定木を作り、そのパフォーマンスを評価することである。

julia> iris = dataset("datasets", "iris")
150×5 DataFrame
 Row │ SepalLength  SepalWidth  PetalLength  PetalWidth  Speci ⋯
     │ Float64      Float64     Float64      Float64     Cat…  ⋯
─────┼──────────────────────────────────────────────────────────
   1 │         5.1         3.5          1.4         0.2  setos ⋯
   2 │         4.9         3.0          1.4         0.2  setos  
   3 │         4.7         3.2          1.3         0.2  setos  
   4 │         4.6         3.1          1.5         0.2  setos  
   5 │         5.0         3.6          1.4         0.2  setos ⋯
   6 │         5.4         3.9          1.7         0.4  setos  
   7 │         4.6         3.4          1.4         0.3  setos  
   8 │         5.0         3.4          1.5         0.2  setos  
   9 │         4.4         2.9          1.4         0.2  setos ⋯
  10 │         4.9         3.1          1.5         0.1  setos  
  11 │         5.4         3.7          1.5         0.2  setos  
  12 │         4.8         3.4          1.6         0.2  setos  
  13 │         4.8         3.0          1.4         0.1  setos ⋯
  14 │         4.3         3.0          1.1         0.1  setos  
  15 │         5.8         4.0          1.2         0.2  setos  
  ⋮  │      ⋮           ⋮            ⋮           ⋮           ⋮ ⋱
 136 │         7.7         3.0          6.1         2.3  virgi  
 137 │         6.3         3.4          5.6         2.4  virgi ⋯
 138 │         6.4         3.1          5.5         1.8  virgi  
 139 │         6.0         3.0          4.8         1.8  virgi  
 140 │         6.9         3.1          5.4         2.1  virgi  
 141 │         6.7         3.1          5.6         2.4  virgi ⋯
 142 │         6.9         3.1          5.1         2.3  virgi  
 143 │         5.8         2.7          5.1         1.9  virgi  
 144 │         6.8         3.2          5.9         2.3  virgi  
 145 │         6.7         3.3          5.7         2.5  virgi ⋯
 146 │         6.7         3.0          5.2         2.3  virgi  
 147 │         6.3         2.5          5.0         1.9  virgi  
 148 │         6.5         3.0          5.2         2.0  virgi  
 149 │         6.2         3.4          5.4         2.3  virgi ⋯
 150 │         5.9         3.0          5.1         1.8  virgi  
                                   1 column and 120 rows omitted

モデル作成

julia> using DecisionTree

julia> model = DecisionTreeClassifier(max_depth=2)
DecisionTreeClassifier
max_depth:                2
min_samples_leaf:         1
min_samples_split:        2
min_purity_increase:      0.0
pruning_purity_threshold: 1.0
n_subfeatures:            0
classes:                  nothing
root:                     nothing

モデルを作る。DecisionTreeClassifier()を通じて、決定木で使われるパラメーターを指定できる。

モデルフィッティング

julia> features = Matrix(iris[:, Not(:Species)]);

julia> labels = iris.Species;

julia> fit!(model, features, labels)        
DecisionTreeClassifier
max_depth:                2
min_samples_leaf:         1
min_samples_split:        2
min_purity_increase:      0.0
pruning_purity_threshold: 1.0
n_subfeatures:            0
classes:                  ["setosa", "versicolor", "virginica"]
root:                     Decision Tree     
Leaves: 3
Depth:  2

データを独立変数と従属変数に分けて、fit!()関数でモデルを学習させる。

パフォーマンス確認

julia> print_tree(model)
Feature 3 < 2.45 ?
├─ setosa : 50/50
└─ Feature 4 < 1.75 ?
    ├─ versicolor : 49/54
    └─ virginica : 45/46

学習が終わったモデルは、print_tree()関数を通じてどんな構造をしているか確認できる。

julia> sum(labels .== predict(model, features)) / length(labels)
0.96

簡単に正解率を確認した結果、96%程度でかなり良いことがわかった。

全コード

using RDatasets
iris = dataset("datasets", "iris")

using DecisionTree
model = DecisionTreeClassifier(max_depth=2)

features = Matrix(iris[:, Not(:Species)]);
labels = iris.Species;
fit!(model, features, labels)

print_tree(model)

sum(labels .== predict(model, features)) / length(labels)

環境

  • OS: Windows
  • julia: v1.9.0