줄리아에서 의사결정나무 사용하는 법
개요
줄리아에서 의사결정나무decision Tree를 구현한 DecisionTree.jl
패키지를 소개한다1.
코드
예시로는 대표적인 R 내장데이터인 iris
데이터를 사용한다. 우리의 목표는 네가지 변수 SepalLength
, SepalWidth
, PetalLength
, PetalWidth
를 사용해 Species
를 예측하는 의사결정나무를 만들고 퍼포먼스를 평가하는 것이다.
julia> iris = dataset("datasets", "iris")
150×5 DataFrame
Row │ SepalLength SepalWidth PetalLength PetalWidth Speci ⋯
│ Float64 Float64 Float64 Float64 Cat… ⋯
─────┼──────────────────────────────────────────────────────────
1 │ 5.1 3.5 1.4 0.2 setos ⋯
2 │ 4.9 3.0 1.4 0.2 setos
3 │ 4.7 3.2 1.3 0.2 setos
4 │ 4.6 3.1 1.5 0.2 setos
5 │ 5.0 3.6 1.4 0.2 setos ⋯
6 │ 5.4 3.9 1.7 0.4 setos
7 │ 4.6 3.4 1.4 0.3 setos
8 │ 5.0 3.4 1.5 0.2 setos
9 │ 4.4 2.9 1.4 0.2 setos ⋯
10 │ 4.9 3.1 1.5 0.1 setos
11 │ 5.4 3.7 1.5 0.2 setos
12 │ 4.8 3.4 1.6 0.2 setos
13 │ 4.8 3.0 1.4 0.1 setos ⋯
14 │ 4.3 3.0 1.1 0.1 setos
15 │ 5.8 4.0 1.2 0.2 setos
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
136 │ 7.7 3.0 6.1 2.3 virgi
137 │ 6.3 3.4 5.6 2.4 virgi ⋯
138 │ 6.4 3.1 5.5 1.8 virgi
139 │ 6.0 3.0 4.8 1.8 virgi
140 │ 6.9 3.1 5.4 2.1 virgi
141 │ 6.7 3.1 5.6 2.4 virgi ⋯
142 │ 6.9 3.1 5.1 2.3 virgi
143 │ 5.8 2.7 5.1 1.9 virgi
144 │ 6.8 3.2 5.9 2.3 virgi
145 │ 6.7 3.3 5.7 2.5 virgi ⋯
146 │ 6.7 3.0 5.2 2.3 virgi
147 │ 6.3 2.5 5.0 1.9 virgi
148 │ 6.5 3.0 5.2 2.0 virgi
149 │ 6.2 3.4 5.4 2.3 virgi ⋯
150 │ 5.9 3.0 5.1 1.8 virgi
1 column and 120 rows omitted
모델 생성
julia> using DecisionTree
julia> model = DecisionTreeClassifier(max_depth=2)
DecisionTreeClassifier
max_depth: 2
min_samples_leaf: 1
min_samples_split: 2
min_purity_increase: 0.0
pruning_purity_threshold: 1.0
n_subfeatures: 0
classes: nothing
root: nothing
모델을 생성한다. DecisionTreeClassifier()
를 통해 의사결정나무에서 사용될 파라미터를 줄 수 있다.
모델 피팅
julia> features = Matrix(iris[:, Not(:Species)]);
julia> labels = iris.Species;
julia> fit!(model, features, labels)
DecisionTreeClassifier
max_depth: 2
min_samples_leaf: 1
min_samples_split: 2
min_purity_increase: 0.0
pruning_purity_threshold: 1.0
n_subfeatures: 0
classes: ["setosa", "versicolor", "virginica"]
root: Decision Tree
Leaves: 3
Depth: 2
데이터를 독립변수와 종속변수로 나누어서 fit!()
함수로 모델을 학습시킨다.
퍼포먼스 확인
julia> print_tree(model)
Feature 3 < 2.45 ?
├─ setosa : 50/50
└─ Feature 4 < 1.75 ?
├─ versicolor : 49/54
└─ virginica : 45/46
학습이 끝난 모델은 print_tree()
함수를 통해 어떤 구조를 가지고 있는지 확인할 수 있다.
julia> sum(labels .== predict(model, features)) / length(labels)
0.96
간단히 정분류율를 확인한 결과 96% 정도로 준수한 것을 볼 수 있었다.
전체코드
using RDatasets
iris = dataset("datasets", "iris")
using DecisionTree
model = DecisionTreeClassifier(max_depth=2)
features = Matrix(iris[:, Not(:Species)]);
labels = iris.Species;
fit!(model, features, labels)
print_tree(model)
sum(labels .== predict(model, features)) / length(labels)
환경
- OS: Windows
- julia: v1.9.0