logo

줄리아에서 의사결정나무 사용하는 법 📂줄리아

줄리아에서 의사결정나무 사용하는 법

개요

줄리아에서 의사결정나무decision Tree를 구현한 DecisionTree.jl 패키지를 소개한다1.

코드

예시로는 대표적인 R 내장데이터iris 데이터를 사용한다. 우리의 목표는 네가지 변수 SepalLength, SepalWidth, PetalLength, PetalWidth를 사용해 Species를 예측하는 의사결정나무를 만들고 퍼포먼스를 평가하는 것이다.

julia> iris = dataset("datasets", "iris")
150×5 DataFrame
 Row │ SepalLength  SepalWidth  PetalLength  PetalWidth  Speci ⋯
     │ Float64      Float64     Float64      Float64     Cat…  ⋯
─────┼──────────────────────────────────────────────────────────
   1 │         5.1         3.5          1.4         0.2  setos ⋯
   2 │         4.9         3.0          1.4         0.2  setos  
   3 │         4.7         3.2          1.3         0.2  setos  
   4 │         4.6         3.1          1.5         0.2  setos  
   5 │         5.0         3.6          1.4         0.2  setos ⋯
   6 │         5.4         3.9          1.7         0.4  setos  
   7 │         4.6         3.4          1.4         0.3  setos  
   8 │         5.0         3.4          1.5         0.2  setos  
   9 │         4.4         2.9          1.4         0.2  setos ⋯
  10 │         4.9         3.1          1.5         0.1  setos  
  11 │         5.4         3.7          1.5         0.2  setos  
  12 │         4.8         3.4          1.6         0.2  setos  
  13 │         4.8         3.0          1.4         0.1  setos ⋯
  14 │         4.3         3.0          1.1         0.1  setos  
  15 │         5.8         4.0          1.2         0.2  setos  
  ⋮  │      ⋮           ⋮            ⋮           ⋮           ⋮ ⋱
 136 │         7.7         3.0          6.1         2.3  virgi  
 137 │         6.3         3.4          5.6         2.4  virgi ⋯
 138 │         6.4         3.1          5.5         1.8  virgi  
 139 │         6.0         3.0          4.8         1.8  virgi  
 140 │         6.9         3.1          5.4         2.1  virgi  
 141 │         6.7         3.1          5.6         2.4  virgi ⋯
 142 │         6.9         3.1          5.1         2.3  virgi  
 143 │         5.8         2.7          5.1         1.9  virgi  
 144 │         6.8         3.2          5.9         2.3  virgi  
 145 │         6.7         3.3          5.7         2.5  virgi ⋯
 146 │         6.7         3.0          5.2         2.3  virgi  
 147 │         6.3         2.5          5.0         1.9  virgi  
 148 │         6.5         3.0          5.2         2.0  virgi  
 149 │         6.2         3.4          5.4         2.3  virgi ⋯
 150 │         5.9         3.0          5.1         1.8  virgi  
                                   1 column and 120 rows omitted

모델 생성

julia> using DecisionTree

julia> model = DecisionTreeClassifier(max_depth=2)
DecisionTreeClassifier
max_depth:                2
min_samples_leaf:         1
min_samples_split:        2
min_purity_increase:      0.0
pruning_purity_threshold: 1.0
n_subfeatures:            0
classes:                  nothing
root:                     nothing

모델을 생성한다. DecisionTreeClassifier()를 통해 의사결정나무에서 사용될 파라미터를 줄 수 있다.

모델 피팅

julia> features = Matrix(iris[:, Not(:Species)]);

julia> labels = iris.Species;

julia> fit!(model, features, labels)        
DecisionTreeClassifier
max_depth:                2
min_samples_leaf:         1
min_samples_split:        2
min_purity_increase:      0.0
pruning_purity_threshold: 1.0
n_subfeatures:            0
classes:                  ["setosa", "versicolor", "virginica"]
root:                     Decision Tree     
Leaves: 3
Depth:  2

데이터를 독립변수와 종속변수로 나누어서 fit!() 함수로 모델을 학습시킨다.

퍼포먼스 확인

julia> print_tree(model)
Feature 3 < 2.45 ?
├─ setosa : 50/50
└─ Feature 4 < 1.75 ?
    ├─ versicolor : 49/54
    └─ virginica : 45/46

학습이 끝난 모델은 print_tree() 함수를 통해 어떤 구조를 가지고 있는지 확인할 수 있다.

julia> sum(labels .== predict(model, features)) / length(labels)
0.96

간단히 정분류율를 확인한 결과 96% 정도로 준수한 것을 볼 수 있었다.

전체코드

using RDatasets
iris = dataset("datasets", "iris")

using DecisionTree
model = DecisionTreeClassifier(max_depth=2)

features = Matrix(iris[:, Not(:Species)]);
labels = iris.Species;
fit!(model, features, labels)

print_tree(model)

sum(labels .== predict(model, features)) / length(labels)

환경

  • OS: Windows
  • julia: v1.9.0