머신러닝에서 선형회귀모델의 경사하강법 학습
📂머신러닝머신러닝에서 선형회귀모델의 경사하강법 학습
개요
선형회귀모델의 학습 방법 중 하나인 경사하강법gradient descent을 이용한 방법을 소개한다.
설명
데이터 셋을 X={xi}i=1N, 레이블 셋을 Y={yi}i=1N라고 하자. 그리고 다음과 같은 선형회귀모델을 가정하자.
y^=j=0∑nwjxj=wTx
이때 x=[x0…xn]T, w=[w0…wn]T이다. 경사하강법이란 함수의 그래디언트가 의미하는 것이 함숫값이 가장 많이 증가하는 방향이라는 것을 이용한 학습법이다. 손실 함수 J를 다음과 같이 MSE로 두자.
J(w)=21i=1∑N(yi−wTxi)2
앞에 상수 21이 붙는 이유는 J를 미분할 때 2가 튀어나오는데 이를 약분해주기 위함이다. 어차피 J를 최소화하는 것이나 21J를 최소화하는 것이나 같다. 즉 목표는 주어진 X와 Y에 대해서 다음과 같은 최적해 w∗를 찾는 것이다.
w∗=wargmin(J(w)=21i=1∑N(yi−wTxi)2)
표기법
수학, 물리학, 공학 전반에서 스칼라 함수의 그래디언트를 다음과 같이 표기한다.
∇J(w)=[∂w0∂J(w)…∂wn∂J(w)]
머신러닝에서는 다음과 같은 노테이션도 많이 쓰인다.
∂w∂J(w)=[∂w0∂J(w)…∂wn∂J(w)]
다시말해 ∇J=∂w∂J이다. 본 글에서는 그래디언트 표기법으로 ∇J를 사용하겠다.
알고리즘
그래디언트의 성질에 의해, ∇J는 J가 가장 크게 증가하는 방향을 가리킨다. 그렇다면 반대로 −∇J는 J가 줄어드는 방향을 가리킨다. 우리의 목적은 J의 함숫값이 줄어드는 것이므로, 가중치 w를 −∇J 방향으로 이동시킨다. 다시말해 다음과 같이 업데이트한다.
w−α∇J→w
이때 α는 학습률learning rate이다. 실제로 ∇J를 계산해보면 다음과 같다. xi=[xi0…xin]라고 하자. ∂wj∂wTxi=∂wj∂∑kwkxik=xij이므로,
∇J=[21i=1∑N∂w1∂(yi−wTxi)2⋯21i=1∑N∂wn∂(yi−wTxi)2]=[i=1∑N(yi−wTxi)xi0⋯i=1∑N(yi−wTxi)xin]=i=1∑N(yi−wTxi)[xi0⋯xin]=i=1∑N(yi−wTxi)xi
따라서 구체적으로 w를 업데이트하는 수식은 다음과 같다.
w−αi∑(yi−wTxi)xi→w
이를 이를 위드로우-호프Widrow-Hoff 혹은 LMSLeast Mean Square 알고리즘이라고 한다. 위 식에서 괄호안은 오차 이므로, 오차가 크면 w가 많이 업데이트되고, 오차가 작으면 w가 조금 업데이트된다는 것을 알 수 있다.
업데이트 방법
가중치를 업데이트하는 방식에는 크게 두 종류가 있다. 두 방식 모두 적절한 러닝 레이트 α에 대해서, 최적해로 수렴함이 알려져있다.
배치 학습
배치 학습batch learning은 전체 데이터 셋 X의 오차에 대해서 한꺼번에 가중치를 수정하는 것을 말한다. 즉 위에서 설명한 것과 같다.
Repeat until convergence: w−αi∑(yi−wTxi)xi→w
온라인 학습
온라인 학습online learning은 각각의 데이터 xi의 오차에 대해서 가중치를 수정하는 것을 말한다.
Repeat until convergence: For i=1 to N:w−α(yi−wTxi)xi→w
같이보기