logo

머신러닝에서 선형회귀모델의 최소제곱법 학습 📂머신러닝

머신러닝에서 선형회귀모델의 최소제곱법 학습

개요1

선형회귀모델의 학습 방법 중 하나인 최소제곱법least squares을 이용한 방법을 소개한다.

설명

데이터 집합을 X={xi}i=1NX = \left\{ \mathbf{x}_{i} \right\}_{i=1}^{N}, 레이블 집합을 Y={yi}i=1NY = \left\{ y_{i} \right\}_{i=1}^{N}라고 하자. 그리고 다음과 같은 선형회귀모델을 가정하자.

y^=j=0n1wjxj=wTx \hat{y} = \sum\limits_{j=0}^{n-1} w_{j}x_{j} = \mathbf{w}^{T} \mathbf{x}

이때 x=[x0xn1]T\mathbf{x} = \begin{bmatrix} x_{0} & \dots & x_{n-1} \end{bmatrix}^{T}, w=[w0wn1]T\mathbf{w} = \begin{bmatrix} w_{0} & \dots & w_{n-1} \end{bmatrix}^{T}이다. 최소제곱법이란 모델의 예측치 y^i=wTxi\hat{y}_{i} = \mathbf{w}^{T} \mathbf{x}_{i}와 실제 정답인 레이블 yiy_{i}사이의 오차의 제곱을 최소화하여 모델을 학습시키는 방법이다. 즉 손실 함수 JJ를 다음과 같이 MSE로 둔다.

J(w)=12i=1N(yiwTxi)2=12yXw2 J(\mathbf{w}) = \dfrac{1}{2} \sum\limits_{i=1}^{N} (y_{i} - \mathbf{w}^{T}\mathbf{x}_{i})^{2} = \dfrac{1}{2} \left\| \mathbf{y} - \mathbf{X}\mathbf{w} \right\|^{2}

이때 y=[y1yN]T\mathbf{y} = \begin{bmatrix} y_{1} & \dots & y_{N} \end{bmatrix}^{T}, X=[x1xN]T\mathbf{X} = \begin{bmatrix} \mathbf{x}_{1} & \dots & \mathbf{x}_{N} \end{bmatrix}^{T}이다. 그러면 Xw\mathbf{X}\mathbf{w}다음과 같다.

Xw=[x1TwxNTw] \mathbf{X}\mathbf{w} = \begin{bmatrix} \mathbf{x}_{1}^{T} \mathbf{w} \\ \vdots \\ \mathbf{x}_{N}^{T} \mathbf{w} \end{bmatrix}

앞에 상수 12\frac{1}{2}이 붙는 이유는 JJ를 미분할 때 22가 튀어나오는데 이를 약분해주기 위함이다. 어차피 JJ를 최소화하는 것이나 12J\frac{1}{2}J를 최소화하는 것이나 같다. 즉 목표는 주어진 XXYY에 대해서 다음과 같은 최소제곱해 wLS\mathbf{w}_{\text{LS}}를 찾는 것이다.

wLS=arg minw(J(w)=12yXw2) \mathbf{w}_{\text{LS}} = \argmin\limits_{\mathbf{w}} \left( J(\mathbf{w}) = \dfrac{1}{2} \left\| \mathbf{y} - \mathbf{X}\mathbf{w} \right\|^{2} \right)

경사하강법을 이용한 학습에서는 반복적인iterative approach 알고리즘을 통해 가중치를 최적해에 수렴시키는 것임에 반해, 본 글에서 소개할 방법은 해석적으로 한 방에 최적해를 찾는 방법one-shot approach을 소개한다.

학습

경우 1: 정확하고 유일한 해가 존재

데이터의 수 NN과 가중치의 차원 nn, 그리고 X\mathbf{X}랭크가 모두 같다고 하자. 즉 X\mathbf{X}풀 랭크full rank인 경우다.

N=n=rank(X) N = n = \rank (\mathbf{X})

위와 같은 상황에서, X\mathbf{X}N×NN \times N 정사각 행렬이고 랭크가 NN이므로 역행렬을 갖는다. 따라서 연립 방정식 y=Xw\mathbf{y} = \mathbf{X} \mathbf{w}는 다음과 같은 유일한 해를 갖는다.

w=X1y \mathbf{w} = \mathbf{X}^{-1}\mathbf{y}

위와 같은 가중치 w\mathbf{w}에 대해서, 손실함수의 함숫값은 00이다.

J(X1y)=12yXX1y2=0 J(\mathbf{X}^{-1}\mathbf{y}) = \dfrac{1}{2} \left\| \mathbf{y} - \mathbf{X}\mathbf{X}^{-1}\mathbf{y}\right\|^{2} = 0

이렇게되면 경사하강법이고 뭐고 한 방에 최적의 가중치를 찾을 수 있지만, 실제 상황에서 이러한 가정은 거의 성립하지 않는다. 즉 이상적인 경우라고 할 수 있다.


경우 2: 풀 랭크와 과대결정

데이터의 수 NN과 가중치의 차원 nn, 그리고 rank(X)\rank (\mathbf{X})에 대해서 다음과 같다고 하자. 실제로 많은 경우에서 이러한 가정을 만족한다.

N>n=rank(X) N \gt n = \rank (\mathbf{X})

선형시스템 Xw=y\mathbf{X} \mathbf{w} = \mathbf{y}과도결정일 때는, 최소제곱문제의 해가 무수히 많이 존재한다. 우선 JJ를 풀어서 계산해보면 벡터의 놈은,

J(w)=12yXX1y2=12(yXw)T(yXw)=12(yTwTXT)(yXw)=12(yTyyTXwwTXTy+wTXTXw) \begin{align*} J(\mathbf{w}) &= \dfrac{1}{2} \left\| \mathbf{y} - \mathbf{X}\mathbf{X}^{-1}\mathbf{y}\right\|^{2} = \dfrac{1}{2} \left( \mathbf{y} - \mathbf{X}\mathbf{w} \right)^{T} \left( \mathbf{y} - \mathbf{X}\mathbf{w} \right) \\ &= \dfrac{1}{2} \left( \mathbf{y}^{T} - \mathbf{w}^{T}\mathbf{X}^{T} \right) \left( \mathbf{y} - \mathbf{X}\mathbf{w} \right) \\ &= \dfrac{1}{2} \left( \mathbf{y}^{T}\mathbf{y} - \mathbf{y}^{T}\mathbf{X}\mathbf{w} - \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{y} + \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{X}\mathbf{w} \right) \end{align*}

이때 wTXTy\mathbf{w}^{T}\mathbf{X}^{T}\mathbf{y}는 스칼라이므로,

wTXTy=(wTXTy)T=yTXw \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{y} = (\mathbf{w}^{T}\mathbf{X}^{T}\mathbf{y})^{T} = \mathbf{y}^{T}\mathbf{X}\mathbf{w}

따라서 다음을 얻는다.

J(w)=12(yTy2yTXw+wTXTXw) J(\mathbf{w}) = \dfrac{1}{2} \left( \mathbf{y}^{T}\mathbf{y} - 2\mathbf{y}^{T}\mathbf{X}\mathbf{w} + \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{X}\mathbf{w} \right)

수식에서 벡터와 행렬 밖에 보이지 않으나 위의 값은 스칼라이다. 헷갈리지 않게 주의하자. 위의 값이 언제 최소가 되는지를 찾기위해 그래디언트를 계산하면,

J(w)=J(w)w=12w(yTy2yTXw+wTXTXw)=12(2XTy+2XTXw)=XTXwXTy \begin{align*} \nabla J(\mathbf{w}) = \dfrac{\partial J(\mathbf{w})}{\partial \mathbf{w}} &= \dfrac{1}{2} \dfrac{\partial }{\partial \mathbf{w}}\left( \mathbf{y}^{T}\mathbf{y} - 2\mathbf{y}^{T}\mathbf{X}\mathbf{w} + \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{X}\mathbf{w} \right) \\ &= \dfrac{1}{2} \left( -2\mathbf{X}^{T}\mathbf{y} + 2\mathbf{X}^{T}\mathbf{X}\mathbf{w} \right) \\ &= \mathbf{X}^{T}\mathbf{X}\mathbf{w} - \mathbf{X}^{T}\mathbf{y} \end{align*}

미분 계산 결과는 여기를 참고하자. 그러면 다음의 식을 얻는다.

J=0    XTXw=XTy \nabla J = \mathbf{0} \implies \mathbf{X}^{T}\mathbf{X}\mathbf{w} = \mathbf{X}^{T}\mathbf{y}

여기서 식 XTXw=XTy\mathbf{X}^{T}\mathbf{X}\mathbf{w} = \mathbf{X}^{T}\mathbf{y}정규 방정식normal equation이라 한다. 풀 랭크인 경우에 이 문제는 유일한 솔루션을 갖는다. XTX\mathbf{X}^{T} \mathbf{X}n×nn\times n 행렬이고, X\mathbf{X}와 같은 랭크를 가진다.

rank(XTX)=n \rank (\mathbf{X}^{T} \mathbf{X}) = n

따라서 역행렬이 존재하고, 다음과 같은 최소제곱해를 얻는다.

w=(XTX)1XTy \mathbf{w} = (\mathbf{X}^{T}\mathbf{X})^{-1}\mathbf{X}^{T}\mathbf{y}


경우 3: 풀 랭크와 과소결정

데이터의 수 NN과 가중치의 차원 nn, 그리고 rank(X)\rank (\mathbf{X})에 대해서 다음과 같다고 하자.

n>N=rank(X) n \gt N = \rank (\mathbf{X})

그러면 연립 방정식 Xw=y\mathbf{X}\mathbf{w} = \mathbf{y}과소결정계이다. 이 때도 경우 2와 마찬가지로 유일한 최소제곱해가 존재하지 않고, 무수히 많다. 해가 무수히 많으므로, 우리는 그 중에서도 놈이 가장 작은 w\mathbf{w}를 구하는 것을 목표로 하자.

arg minw12w2 \argmin\limits_{\mathbf{w}} \dfrac{1}{2} \left\| \mathbf{w} \right\|^{2}

이 문제의 풀이를 라그랑주 승수법으로 접근하자. 그러면 우리가 가지고 있는 제약조건 yXw\mathbf{y} - \mathbf{X}\mathbf{w}에 승수 λ=[λ1λN]T\boldsymbol{\lambda} = \begin{bmatrix} \lambda_{1} & \dots & \lambda_{N} \end{bmatrix}^{T}을 곱한 것을 더하여 다음과 같은 함수를 최소화하는 문제가 된다.

L(w,λ)=12w2+λT(yXw) L(\mathbf{w}, \boldsymbol{\lambda}) = \dfrac{1}{2} \left\| \mathbf{w} \right\|^{2} + \boldsymbol{\lambda}^{T}(\mathbf{y} - \mathbf{X}\mathbf{w})

LL의 그래디언트는 다음과 같다. 미분 계산은 여기를 참고하라.

L=Lw=wXTλ \nabla L = \dfrac{\partial L}{\partial \mathbf{w}} = \mathbf{w} - \mathbf{X}^{T}\boldsymbol{\lambda}

따라서 미분해서 0\mathbf{0}이 되는 가중치는 w=XTλ\mathbf{w} = \mathbf{X}^{T}\boldsymbol{\lambda}이다. 그러면 다음의 식을 얻는다.

y=Xw=XXTλ \mathbf{y} = \mathbf{X}\mathbf{w} = \mathbf{X}\mathbf{X}^{T}\boldsymbol{\lambda}

따라서 λ=(XXT)1y\boldsymbol{\lambda} = (\mathbf{X}\mathbf{X}^{T})^{-1}\mathbf{y}이고, 우리가 원하는 솔루션은 다음과 같다.

w=XT(XXT)1y \mathbf{w} = \mathbf{X}^{T}(\mathbf{X}\mathbf{X}^{T})^{-1}\mathbf{y}


경우 4: 랭크 디피션트

랭크 디피션트라고 가정하자.

overdetermined: N>n>rank(X)underdetermined: n>N>rank(X) \text{overdetermined: }N \gt n \gt \rank(\mathbf{X}) \\ \text{underdetermined: } n \gt N \gt \rank(\mathbf{X})

이 때도 마찬가지로 최소제곱해는 유일하지않고 무수히 많이 존재한다. 또한 X\mathbf{X}가 풀 랭크가 아니기 때문에, XTX\mathbf{X}^{T}\mathbf{X}의 역행렬이 존재하지 않는다. 따라서 경우 2경우 3에서처럼 풀 수 없다.


  1. Christoper M. Bishop, Pattern Recognition annd Machine Learning (2006), p140-147 ↩︎