logo

ROCカーブを使用して最適なカットオフを見つける方法 📂機械学習

ROCカーブを使用して最適なカットオフを見つける方法

概要

ROC曲線を描くと、トレーニングデータで得たモデルがテストデータをどれくらいうまく説明しているか、一目でわかるから便利だ。しかし、この曲線はすべてのカットオフに対する分類率を計算して結んだものなので、結局「どのカットオフで0と1を分類するか」はわからない。これを解明するために、交差検証の方法論を応用してみよう。

20190208\_140726.png

検証データ

トレーニングデータで0と1を最もうまく分類する最適なカットオフを見つけても、それはトレーニングデータをうまく説明するカットオフにすぎない。当然、テストデータで分類率を計算してみても、その最適なカットオフはテストデータをうまく説明するカットオフに過ぎない。それゆえ、バリデーションデータvalidation Dataというものを別に作って、どのデータにも偏らないカットオフを探す。

20190208\_140734.png

  • ステップ 1.
    トレーニングデータを使ってモデルを学習させる。
  • ステップ 2.
    バリデーションデータに適用して最適なカットオフを探す。
  • ステップ 3.
    トレーニングデータで得たモデルをバリデーションデータで得たカットオフで分類し、テストデータでどれだけパフォーマンスが高いかを確認する。何種類かのモデルの候補があるなら、このパフォーマンスが最も高いモデルを最終モデルとして選ぶ。

この時、ステップ1で得たモデルは、ステップ2で得たカットオフを含んで初めてちゃんとしたモデルになる。

実践

(ROC曲線の描き方に続いて)

99C508455C4FA79537.png

上の図は、前のポストでデータをトレーニングデータとテストデータに分けた時のROC曲線だ。今度は最適なカットオフを求めるために、データを三分割する必要がある。この分割の比率はデータによって異なるが、特に問題がなければ3:1:1程度が無難だ。

DATANUM<-nrow(DATA)
numeric(DATANUM)
DATANUM*c(0.6,0.2,0.2)
  
slicing<-sample(1:DATANUM)
slicing[slicing>(DATANUM*0.8)]<-(-1)
slicing[slicing>(DATANUM*0.6)]<-(-2)
slicing[slicing>0]<-(-3)
slicing<-slicing+4
 
train<-DATA[slicing==1,]; head(train)
valid<-DATA[slicing==2,]; head(valid)
test<-DATA[slicing==3,]; head(test)

データの前処理を終えて上記のコードを実行すると、次のようにトレーニング、バリデーション、テストの三つのデータに分けられる。

20190208\_143314.png

ステップ 1.

トレーニングデータを学習させてモデルを作る。

out0<-glm(vote~.,family=binomial(),data=train); summary(out0)
vif(out0)
 
out1<-step(out0, direction="both"); summary(out1)
qchisq(0.95,df=1454)
vif(out1)

上記のコードを実行すると、次のように変数選択手順を経てモデルを作り、適合度検定多重共線性をチェックする。

20190208\_144026.png モデルを見ると、特に問題になるところはない。

ステップ 2.

バリデーションデータに適用して最適なカットオフを探す。

p <- predict(out1, newdata=valid, type="response")
 
pr <- prediction(p, valid$vote)
prf <- performance(pr, measure = "tpr", x.measure = "fpr")
win.graph(); plot(prf, main='ROC of Validation Data')

위의 코드를 실행하면 마치 밸리데이터 데이터를 테스트 데이터처럼 취급해서 다음과 같이 ROC 커브를 그려준다.

vp.png 최적의 컷오프는 데이터와 목적에 따라 다르게 결정될 수 있지만, 별도의 주안점이 없다면 왼쪽 위의 $(0,1)$ 에서 가장 가까운 점을 찾아 그 점의 컷오프를 최적의 컷오프로 삼는다. 거리를 계산해야하기 때문에 코드는 다소 복잡하다.

optid<-(1:length(prf@y.values[[1]][-1]))[((prf@x.values[[1]][-1])^2 + (1-prf@y.values[[1]][-11])^2)
                                         ==min((prf@x.values[[1]][-1])^2 + (1-prf@y.values[[1]][-1])^2)]
points(prf@x.values[[1]][-1][optid],prf@y.values[[1]][-1][optid], col='red', pch=15)
optcut<-prf@alpha.values[[1]][-1][optid]; optcut

위의 코드를 실행하면 위의 설명대로 $(0,1)$ 에서 가장 가까운 점을 빨간색으로 표시해주고 그 지점의 컷오프를 출력해준다.

20190208\_144843.png vp2.png

코드가 많이 복잡하지만 이해하려고 노력할 필요는 없다. 복잡한 것과 어려운 것은 다른 일이다. 위의 코드는 길기만하지 개념적으로는 전혀 어렵지 않다. 그냥 곡선 위의 모든 점과 $(0,1)$ 사이의 거리를 잰 후 그 거리가 가장 짧은 점을 선택한 것 뿐이다. 그 점에서의 `$alpha値を参照すれば、カットオフがわかる。こうして得られたカットオフは、データを最も適切にうまく分類するカットオフとして受け入れられる。(再度強調するが、これは絶対的な基準ではない。利用者の目的によって、「最適なカットオフ」そのものが完全に新しく定義されうる。)

この例で得られた最適なカットオフは$0.4564142$であり、これより高ければ1と判断し、低ければ0と判断するのが無難と受け取っても構わない。(これが三度目の強調だが、受け取ってもいいとされることだが、最もいいというわけではない。適切な解釈を出すことは、完全に分析者にかかっている。)

ステップ 3.

テストデータでどれだけうまく当てはまるかを確認する。

p <- predict(out1, newdata=test, type="response"); head(p,48)
table(test$vote, p>optcut)

위의 코드를 실행하면 테스트 데이터에서 확률을 계산해주고 최적 컷오프에 따른 오류행렬을 출력해준다.

20190208\_150337.png

위 오류행렬의 정분류율은 약 $81 \%$ 로써 꽤 쓸만하고, 분석자가 만족할만하다면 최종모형으로 받아들여봄직하다.눈치챘겠지만 엄밀히 말해 최적 컷오프를 구하는데 있어서 꼭 ROC 곡선을 그릴 필요는 없다. 어차피 계산을 위한 데이터는 데이터 프레임으로써 다 구해놨기 때문에 코드만 잘 돌려서 값만 얻어내도 전혀 상관 없다.

코드

아래는 예제 코드다.

install.packages("car")
install.packages("ResourceSelection")
install.packages("ROCR")
 
library(car)
library(ResourceSelection)
library(ROCR)
 
set.seed(150421)
 
?Chile
str(Chile)
nrow(Chile)
 
DATA<-na.omit(Chile)
DATA$vote[DATA$vote!='Y']<-'N'
DATA$vote<-factor(DATA$vote)
 
DATANUM<-nrow(DATA)
numeric(DATANUM)
DATANUM*c(0.6,0.2,0.2)
  
slicing<-sample(1:DATANUM)
slicing[slicing>(DATANUM*0.8)]<-(-1)
slicing[slicing>(DATANUM*0.6)]<-(-2)
slicing[slicing>0]<-(-3)
slicing<-slicing+4
 
train<-DATA[slicing==1,]; head(train)
valid<-DATA[slicing==2,]; head(valid)
test<-DATA[slicing==3,]; head(test)
 
out0<-glm(vote~.,family=binomial(),data=train); summary(out0)
vif(out0)
 
out1<-step(out0, direction="both"); summary(out1)
qchisq(0.95,df=1454)
vif(out1)
 
p <- predict(out1, newdata=valid, type="response")
 
pr <- prediction(p, valid$vote)
prf <- performance(pr, measure = "tpr", x.measure = "fpr")
win.graph(); plot(prf, main='ROC of Validation Data')
 
optid<-(1:length(prf@y.values[[1]][-1]))[((prf@x.values[[1]][-1])^2 + (1-prf@y.values[[1]][-11])^2)
                                         ==min((prf@x.values[[1]][-1])^2 + (1-prf@y.values[[1]][-1])^2)]
points(prf@x.values[[1]][-1][optid],prf@y.values[[1]][-1][optid], col='red', pch=15)
optcut<-prf@alpha.values[[1]][-1][optid]; optcut
 
p <- predict(out1, newdata=test, type="response"); head(p,48)
table(test$vote, p>optcut)