logo

머신러닝에서 선형회귀모델의 경사하강법 학습 📂머신러닝

머신러닝에서 선형회귀모델의 경사하강법 학습

개요1

선형회귀모델의 학습 방법 중 하나인 경사하강법gradient descent을 이용한 방법을 소개한다.

설명

데이터 셋을 X={xi}i=1NX = \left\{ \mathbf{x}_{i} \right\}_{i=1}^{N}, 레이블 셋을 Y={yi}i=1NY = \left\{ y_{i} \right\}_{i=1}^{N}라고 하자. 그리고 다음과 같은 선형회귀모델을 가정하자.

y^=j=0nwjxj=wTx \hat{y} = \sum\limits_{j=0}^{n} w_{j}x_{j} = \mathbf{w}^{T} \mathbf{x}

이때 x=[x0xn]T\mathbf{x} = \begin{bmatrix} x_{0} & \dots & x_{n} \end{bmatrix}^{T}, w=[w0wn]T\mathbf{w} = \begin{bmatrix} w_{0} & \dots & w_{n} \end{bmatrix}^{T}이다. 경사하강법이란 함수의 그래디언트가 의미하는 것이 함숫값이 가장 많이 증가하는 방향이라는 것을 이용한 학습법이다. 손실 함수 JJ를 다음과 같이 MSE로 두자.

J(w)=12i=1N(yiwTxi)2 J(\mathbf{w}) = \dfrac{1}{2} \sum\limits_{i=1}^{N} (y_{i} - \mathbf{w}^{T}\mathbf{x}_{i})^{2}

앞에 상수 12\frac{1}{2}이 붙는 이유는 JJ를 미분할 때 22가 튀어나오는데 이를 약분해주기 위함이다. 어차피 JJ를 최소화하는 것이나 12J\frac{1}{2}J를 최소화하는 것이나 같다. 즉 목표는 주어진 XXYY에 대해서 다음과 같은 최적해 w\mathbf{w}_{\ast}를 찾는 것이다.

w=arg minw(J(w)=12i=1N(yiwTxi)2) \mathbf{w}_{\ast} = \argmin\limits_{\mathbf{w}} \left( J(\mathbf{w}) = \dfrac{1}{2} \sum\limits_{i=1}^{N} (y_{i} - \mathbf{w}^{T}\mathbf{x}_{i})^{2} \right)

표기법

수학, 물리학, 공학 전반에서 스칼라 함수의 그래디언트를 다음과 같이 표기한다.

J(w)=[J(w)w0J(w)wn] \nabla J (\mathbf{w}) = \begin{bmatrix} \dfrac{\partial J(\mathbf{w})}{\partial w_{0}} & \dots & \dfrac{\partial J(\mathbf{w})}{\partial w_{n}} \end{bmatrix}

머신러닝에서는 다음과 같은 노테이션도 많이 쓰인다.

J(w)w=[J(w)w0J(w)wn] \dfrac{\partial J(\mathbf{w})}{\partial \mathbf{w}} = \begin{bmatrix} \dfrac{\partial J(\mathbf{w})}{\partial w_{0}} & \dots & \dfrac{\partial J(\mathbf{w})}{\partial w_{n}} \end{bmatrix}

다시말해 J=Jw\nabla J = \dfrac{\partial J}{\partial \mathbf{w}}이다. 본 글에서는 그래디언트 표기법으로 J\nabla J를 사용하겠다.

알고리즘

그래디언트의 성질에 의해, J\nabla JJJ가 가장 크게 증가하는 방향을 가리킨다. 그렇다면 반대로 J-\nabla JJJ가 줄어드는 방향을 가리킨다. 우리의 목적은 JJ의 함숫값이 줄어드는 것이므로, 가중치 w\mathbf{w}J-\nabla J 방향으로 이동시킨다. 다시말해 다음과 같이 업데이트한다.

wαJw \mathbf{w} - \alpha \nabla J \to \mathbf{w}

이때 α\alpha학습률learning rate이다. 실제로 J\nabla J를 계산해보면 다음과 같다. xi=[xi0xin]\mathbf{x}_{i} = \begin{bmatrix} x_{i0} & \dots & x_{in} \end{bmatrix}라고 하자. wTxiwj=kwkxikwj=xij\dfrac{\partial \mathbf{w}^{T}\mathbf{x}_{i}}{\partial w_{j}} = \dfrac{\partial \sum_{k}w_{k}x_{ik}}{\partial w_{j}} = x_{ij}이므로,

J=[12i=1Nw1(yiwTxi)212i=1Nwn(yiwTxi)2]=[i=1N(yiwTxi)xi0i=1N(yiwTxi)xin]=i=1N(yiwTxi)[xi0xin]=i=1N(yiwTxi)xi \begin{align*} \nabla J &= \begin{bmatrix} \dfrac{1}{2}\sum\limits_{i=1}^{N} \dfrac{\partial}{\partial w_{1}}(y_{i} - \mathbf{w}^{T}\mathbf{x}_{i})^{2} & \cdots & \dfrac{1}{2}\sum\limits_{i=1}^{N} \dfrac{\partial}{\partial w_{n}}(y_{i} - \mathbf{w}^{T}\mathbf{x}_{i})^{2} \end{bmatrix} \\ &= \begin{bmatrix} \sum\limits_{i=1}^{N} (y_{i} - \mathbf{w}^{T}\mathbf{x}_{i})x_{i0} & \cdots & \sum\limits_{i=1}^{N} (y_{i} - \mathbf{w}^{T}\mathbf{x}_{i})x_{in} \end{bmatrix} \\ &= \sum\limits_{i=1}^{N} (y_{i} - \mathbf{w}^{T}\mathbf{x}_{i}) \begin{bmatrix} x_{i0} & \cdots & x_{in} \end{bmatrix} \\ &= \sum\limits_{i=1}^{N} (y_{i} - \mathbf{w}^{T}\mathbf{x}_{i}) \mathbf{x}_{i} \end{align*}

따라서 구체적으로 w\mathbf{w}를 업데이트하는 수식은 다음과 같다.

wαi(yiwTxi)xiw \mathbf{w} - \alpha \sum\limits_{i}(y_{i} - \mathbf{w}^{T}\mathbf{x}_{i}) \mathbf{x}_{i} \to \mathbf{w}

이를 이를 위드로우-호프Widrow-Hoff 혹은 LMSLeast Mean Square 알고리즘이라고 한다. 위 식에서 괄호안은 오차 이므로, 오차가 크면 w\mathbf{w}가 많이 업데이트되고, 오차가 작으면 w\mathbf{w}가 조금 업데이트된다는 것을 알 수 있다.

업데이트 방법2

가중치를 업데이트하는 방식에는 크게 두 종류가 있다. 두 방식 모두 적절한 러닝 레이트 α\alpha에 대해서, 최적해로 수렴함이 알려져있다.

배치 학습

배치 학습batch learning은 전체 데이터 셋 XX의 오차에 대해서 한꺼번에 가중치를 수정하는 것을 말한다. 즉 위에서 설명한 것과 같다.

Repeat until convergence: wαi(yiwTxi)xiw \begin{align*} &\text{Repeat until convergence: }\\ &\quad \mathbf{w} - \alpha \sum\limits_{i}(y_{i} - \mathbf{w}^{T}\mathbf{x}_{i}) \mathbf{x}_{i} \to \mathbf{w} \end{align*}

온라인 학습

온라인 학습online learning은 각각의 데이터 xi\mathbf{x}_{i}의 오차에 대해서 가중치를 수정하는 것을 말한다.

Repeat until convergence: For i=1 to N:wα(yiwTxi)xiw \begin{align*} &\text{Repeat until convergence: } \\ &\quad \text{For } i = 1 \text{ to } N: \\ &\qquad \mathbf{w} - \alpha (y_{i} - \mathbf{w}^{T}\mathbf{x}_{i}) \mathbf{x}_{i} \to \mathbf{w} \end{align*}

같이보기


  1. Simon Haykin, Neural Networks and Learning Machines (3rd Edition, 2009), p91-96 ↩︎

  2. Simon Haykin, Neural Networks and Learning Machines (3rd Edition, 2009), p127 ↩︎