머신러닝에서 선형회귀모델의 최소제곱법 학습
📂머신러닝머신러닝에서 선형회귀모델의 최소제곱법 학습
개요
선형회귀모델의 학습 방법 중 하나인 최소제곱법least squares을 이용한 방법을 소개한다.
설명
데이터 집합을 X={xi}i=1N, 레이블 집합을 Y={yi}i=1N라고 하자. 그리고 다음과 같은 선형회귀모델을 가정하자.
y^=j=0∑n−1wjxj=wTx
이때 x=[x0…xn−1]T, w=[w0…wn−1]T이다. 최소제곱법이란 모델의 예측치 y^i=wTxi와 실제 정답인 레이블 yi사이의 오차의 제곱을 최소화하여 모델을 학습시키는 방법이다. 즉 손실 함수 J를 다음과 같이 MSE로 둔다.
J(w)=21i=1∑N(yi−wTxi)2=21∥y−Xw∥2
이때 y=[y1…yN]T, X=[x1…xN]T이다. 그러면 Xw는 다음과 같다.
Xw=x1Tw⋮xNTw
앞에 상수 21이 붙는 이유는 J를 미분할 때 2가 튀어나오는데 이를 약분해주기 위함이다. 어차피 J를 최소화하는 것이나 21J를 최소화하는 것이나 같다. 즉 목표는 주어진 X와 Y에 대해서 다음과 같은 최소제곱해 wLS를 찾는 것이다.
wLS=wargmin(J(w)=21∥y−Xw∥2)
경사하강법을 이용한 학습에서는 반복적인iterative approach 알고리즘을 통해 가중치를 최적해에 수렴시키는 것임에 반해, 본 글에서 소개할 방법은 해석적으로 한 방에 최적해를 찾는 방법one-shot approach을 소개한다.
학습
경우 1: 정확하고 유일한 해가 존재
데이터의 수 N과 가중치의 차원 n, 그리고 X의 랭크가 모두 같다고 하자. 즉 X가 풀 랭크full rank인 경우다.
N=n=rank(X)
위와 같은 상황에서, X는 N×N 정사각 행렬이고 랭크가 N이므로 역행렬을 갖는다. 따라서 연립 방정식 y=Xw는 다음과 같은 유일한 해를 갖는다.
w=X−1y
위와 같은 가중치 w에 대해서, 손실함수의 함숫값은 0이다.
J(X−1y)=21y−XX−1y2=0
이렇게되면 경사하강법이고 뭐고 한 방에 최적의 가중치를 찾을 수 있지만, 실제 상황에서 이러한 가정은 거의 성립하지 않는다. 즉 이상적인 경우라고 할 수 있다.
경우 2: 풀 랭크와 과대결정
데이터의 수 N과 가중치의 차원 n, 그리고 rank(X)에 대해서 다음과 같다고 하자. 실제로 많은 경우에서 이러한 가정을 만족한다.
N>n=rank(X)
선형시스템 Xw=y가 과도결정일 때는, 최소제곱문제의 해가 무수히 많이 존재한다. 우선 J를 풀어서 계산해보면 벡터의 놈은,
J(w)=21y−XX−1y2=21(y−Xw)T(y−Xw)=21(yT−wTXT)(y−Xw)=21(yTy−yTXw−wTXTy+wTXTXw)
이때 wTXTy는 스칼라이므로,
wTXTy=(wTXTy)T=yTXw
따라서 다음을 얻는다.
J(w)=21(yTy−2yTXw+wTXTXw)
수식에서 벡터와 행렬 밖에 보이지 않으나 위의 값은 스칼라이다. 헷갈리지 않게 주의하자. 위의 값이 언제 최소가 되는지를 찾기위해 그래디언트를 계산하면,
∇J(w)=∂w∂J(w)=21∂w∂(yTy−2yTXw+wTXTXw)=21(−2XTy+2XTXw)=XTXw−XTy
미분 계산 결과는 여기를 참고하자. 그러면 다음의 식을 얻는다.
∇J=0⟹XTXw=XTy
여기서 식 XTXw=XTy를 정규 방정식normal equation이라 한다. 풀 랭크인 경우에 이 문제는 유일한 솔루션을 갖는다. XTX는 n×n 행렬이고, X와 같은 랭크를 가진다.
rank(XTX)=n
따라서 역행렬이 존재하고, 다음과 같은 최소제곱해를 얻는다.
w=(XTX)−1XTy
경우 3: 풀 랭크와 과소결정
데이터의 수 N과 가중치의 차원 n, 그리고 rank(X)에 대해서 다음과 같다고 하자.
n>N=rank(X)
그러면 연립 방정식 Xw=y는 과소결정계이다. 이 때도 경우 2와 마찬가지로 유일한 최소제곱해가 존재하지 않고, 무수히 많다. 해가 무수히 많으므로, 우리는 그 중에서도 놈이 가장 작은 w를 구하는 것을 목표로 하자.
wargmin21∥w∥2
이 문제의 풀이를 라그랑주 승수법으로 접근하자. 그러면 우리가 가지고 있는 제약조건 y−Xw에 승수 λ=[λ1…λN]T을 곱한 것을 더하여 다음과 같은 함수를 최소화하는 문제가 된다.
L(w,λ)=21∥w∥2+λT(y−Xw)
L의 그래디언트는 다음과 같다. 미분 계산은 여기를 참고하라.
∇L=∂w∂L=w−XTλ
따라서 미분해서 0이 되는 가중치는 w=XTλ이다. 그러면 다음의 식을 얻는다.
y=Xw=XXTλ
따라서 λ=(XXT)−1y이고, 우리가 원하는 솔루션은 다음과 같다.
w=XT(XXT)−1y
경우 4: 랭크 디피션트
랭크 디피션트라고 가정하자.
overdetermined: N>n>rank(X)underdetermined: n>N>rank(X)
이 때도 마찬가지로 최소제곱해는 유일하지않고 무수히 많이 존재한다. 또한 X가 풀 랭크가 아니기 때문에, XTX의 역행렬이 존재하지 않는다. 따라서 경우 2와 경우 3에서처럼 풀 수 없다.