logo

머신러닝에서 선형회귀모델의 최소제곱법 학습 📂機械学習

머신러닝에서 선형회귀모델의 최소제곱법 학습

概要1

線形回帰モデルの学習方法の一つである最小二乗法least squaresを用いた方法を紹介する。

説明

データ集合をX={xi}i=1NX = \left\{ \mathbf{x}_{i} \right\}_{i=1}^{N}、ラベル集合をY={yi}i=1NY = \left\{ y_{i} \right\}_{i=1}^{N}とする。そして次のような線形回帰モデルを仮定しよう。

y^=j=0n1wjxj=wTx \hat{y} = \sum\limits_{j=0}^{n-1} w_{j}x_{j} = \mathbf{w}^{T} \mathbf{x}

このときx=[x0xn1]T\mathbf{x} = \begin{bmatrix} x_{0} & \dots & x_{n-1} \end{bmatrix}^{T}w=[w0wn1]T\mathbf{w} = \begin{bmatrix} w_{0} & \dots & w_{n-1} \end{bmatrix}^{T}である。最小二乗法とは、モデルの予測値y^i=wTxi\hat{y}_{i} = \mathbf{w}^{T} \mathbf{x}_{i}と実際の正解であるラベルyiy_{i}との誤差の二乗を最小化することによってモデルを学習させる方法だ。すなわち損失関数JJを次のようにMSEとする。

J(w)=12i=1N(yiwTxi)2=12yXw2 J(\mathbf{w}) = \dfrac{1}{2} \sum\limits_{i=1}^{N} (y_{i} - \mathbf{w}^{T}\mathbf{x}_{i})^{2} = \dfrac{1}{2} \left\| \mathbf{y} - \mathbf{X}\mathbf{w} \right\|^{2}

このときy=[y1yN]T\mathbf{y} = \begin{bmatrix} y_{1} & \dots & y_{N} \end{bmatrix}^{T}X=[x1xN]T\mathbf{X} = \begin{bmatrix} \mathbf{x}_{1} & \dots & \mathbf{x}_{N} \end{bmatrix}^{T}である。するとXw\mathbf{X}\mathbf{w}次のようになる

Xw=[x1TwxNTw] \mathbf{X}\mathbf{w} = \begin{bmatrix} \mathbf{x}_{1}^{T} \mathbf{w} \\ \vdots \\ \mathbf{x}_{N}^{T} \mathbf{w} \end{bmatrix}

先頭の定数12\frac{1}{2}がつく理由は、JJを微分する際に22が現れるが、これを約分するためだ。どうせJJを最小化するのも、12J\frac{1}{2}Jを最小化するのも同じだからだ。つまり、目的は与えられたXXYYについて、次のような最小二乗解wLS\mathbf{w}_{\text{LS}}を見つけることだ。

wLS=arg minw(J(w)=12yXw2) \mathbf{w}_{\text{LS}} = \argmin\limits_{\mathbf{w}} \left( J(\mathbf{w}) = \dfrac{1}{2} \left\| \mathbf{y} - \mathbf{X}\mathbf{w} \right\|^{2} \right)

勾配降下法を用いた学習では反復的なiterative approachアルゴリズムを通じて重みを最適解に収束させるのに対して、ここで紹介する方法は解析的に一発で最適解を見つける方法one-shot approachを紹介する。

学習

場合1: 正確で唯一の解が存在

データの数NN、重みの次元nn、およびX\mathbf{X}ランクがすべて等しいとしよう。つまりX\mathbf{X}フルランクfull rankの場合だ。

N=n=rank(X) N = n = \rank (\mathbf{X})

上記の状況で、X\mathbf{X}N×NN \times Nの正方行列で、ランクがNNであるため逆行列を持つ。したがって、連立方程式y=Xw\mathbf{y} = \mathbf{X} \mathbf{w}は次のように唯一の解を持つ。

w=X1y \mathbf{w} = \mathbf{X}^{-1}\mathbf{y}

上記のような重みw\mathbf{w}に対して、損失関数の関数値は00である。

J(X1y)=12yXX1y2=0 J(\mathbf{X}^{-1}\mathbf{y}) = \dfrac{1}{2} \left\| \mathbf{y} - \mathbf{X}\mathbf{X}^{-1}\mathbf{y}\right\|^{2} = 0

これにより、勾配降下法などを用いずに一発で最適な重みを見つけることができるが、実際の状況ではこのような仮定はほとんど成立しない。つまり理想的な場合と言える。


場合2: フルランクと過剰決定

データの数NNと重みの次元nn、およびrank(X)\rank (\mathbf{X})について次のようにしよう。実際には多くの場合でこの仮定が満たされる。

N>n=rank(X) N \gt n = \rank (\mathbf{X})

線形システムXw=y\mathbf{X} \mathbf{w} = \mathbf{y}過剰決定である場合、最小二乗問題の解が無数に存在する。まずJJを展開して計算するとベクトルのノルムは、

J(w)=12yXX1y2=12(yXw)T(yXw)=12(yTwTXT)(yXw)=12(yTyyTXwwTXTy+wTXTXw) \begin{align*} J(\mathbf{w}) &= \dfrac{1}{2} \left\| \mathbf{y} - \mathbf{X}\mathbf{X}^{-1}\mathbf{y}\right\|^{2} = \dfrac{1}{2} \left( \mathbf{y} - \mathbf{X}\mathbf{w} \right)^{T} \left( \mathbf{y} - \mathbf{X}\mathbf{w} \right) \\ &= \dfrac{1}{2} \left( \mathbf{y}^{T} - \mathbf{w}^{T}\mathbf{X}^{T} \right) \left( \mathbf{y} - \mathbf{X}\mathbf{w} \right) \\ &= \dfrac{1}{2} \left( \mathbf{y}^{T}\mathbf{y} - \mathbf{y}^{T}\mathbf{X}\mathbf{w} - \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{y} + \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{X}\mathbf{w} \right) \end{align*}

このときwTXTy\mathbf{w}^{T}\mathbf{X}^{T}\mathbf{y}はスカラーなので、

wTXTy=(wTXTy)T=yTXw \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{y} = (\mathbf{w}^{T}\mathbf{X}^{T}\mathbf{y})^{T} = \mathbf{y}^{T}\mathbf{X}\mathbf{w}

したがって次を得る。

J(w)=12(yTy2yTXw+wTXTXw) J(\mathbf{w}) = \dfrac{1}{2} \left( \mathbf{y}^{T}\mathbf{y} - 2\mathbf{y}^{T}\mathbf{X}\mathbf{w} + \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{X}\mathbf{w} \right)

数学式でベクトルと行列しか見えないが、上の値はスカラーである。混乱しないように注意しよう。この値が最小になるときを見つけるために勾配を計算すると、

J(w)=J(w)w=12w(yTy2yTXw+wTXTXw)=12(2XTy+2XTXw)=XTXwXTy \begin{align*} \nabla J(\mathbf{w}) = \dfrac{\partial J(\mathbf{w})}{\partial \mathbf{w}} &= \dfrac{1}{2} \dfrac{\partial }{\partial \mathbf{w}}\left( \mathbf{y}^{T}\mathbf{y} - 2\mathbf{y}^{T}\mathbf{X}\mathbf{w} + \mathbf{w}^{T}\mathbf{X}^{T}\mathbf{X}\mathbf{w} \right) \\ &= \dfrac{1}{2} \left( -2\mathbf{X}^{T}\mathbf{y} + 2\mathbf{X}^{T}\mathbf{X}\mathbf{w} \right) \\ &= \mathbf{X}^{T}\mathbf{X}\mathbf{w} - \mathbf{X}^{T}\mathbf{y} \end{align*}

微分計算の結果はここを参照しよう。すると次の式を得る。

J=0    XTXw=XTy \nabla J = \mathbf{0} \implies \mathbf{X}^{T}\mathbf{X}\mathbf{w} = \mathbf{X}^{T}\mathbf{y}

ここで式XTXw=XTy\mathbf{X}^{T}\mathbf{X}\mathbf{w} = \mathbf{X}^{T}\mathbf{y}正規方程式normal equationと言う。フルランクの場合、この問題は唯一のソリューションを持つ。XTX\mathbf{X}^{T} \mathbf{X}n×nn\times n行列で、X\mathbf{X}と同じランクを持つ。

rank(XTX)=n \rank (\mathbf{X}^{T} \mathbf{X}) = n

したがって逆行列が存在し、次のような最小二乗解を得る。

w=(XTX)1XTy \mathbf{w} = (\mathbf{X}^{T}\mathbf{X})^{-1}\mathbf{X}^{T}\mathbf{y}


場合3: フルランクと不足決定

データの数NNと重みの次元nn、およびrank(X)\rank (\mathbf{X})について次のようにしよう。

n>N=rank(X) n \gt N = \rank (\mathbf{X})

すると連立方程式Xw=y\mathbf{X}\mathbf{w} = \mathbf{y}不足決定系である。この場合も場合2と同様に唯一の最小二乗解が存在せず、無数にある。解が無数にあるので、その中でもノルムが最小のw\mathbf{w}を求めることを目標としよう。

arg minw12w2 \argmin\limits_{\mathbf{w}} \dfrac{1}{2} \left\| \mathbf{w} \right\|^{2}

この問題の解法をラグランジュ乗数法でアプローチしよう。すると、我々が持っている制約条件yXw\mathbf{y} - \mathbf{X}\mathbf{w}に乗数λ=[λ1λN]T\boldsymbol{\lambda} = \begin{bmatrix} \lambda_{1} & \dots & \lambda_{N} \end{bmatrix}^{T}を掛けたものを足して、次のような関数を最小化する問題となる。

L(w,λ)=12w2+λT(yXw) L(\mathbf{w}, \boldsymbol{\lambda}) = \dfrac{1}{2} \left\| \mathbf{w} \right\|^{2} + \boldsymbol{\lambda}^{T}(\mathbf{y} - \mathbf{X}\mathbf{w})

LLの勾配は次のようになる。微分計算はここを参照しよう。

L=Lw=wXTλ \nabla L = \dfrac{\partial L}{\partial \mathbf{w}} = \mathbf{w} - \mathbf{X}^{T}\boldsymbol{\lambda}

したがって微分して0\mathbf{0}となる重みはw=XTλ\mathbf{w} = \mathbf{X}^{T}\boldsymbol{\lambda}である。すると次の式を得る。

y=Xw=XXTλ \mathbf{y} = \mathbf{X}\mathbf{w} = \mathbf{X}\mathbf{X}^{T}\boldsymbol{\lambda}

したがってλ=(XXT)1y\boldsymbol{\lambda} = (\mathbf{X}\mathbf{X}^{T})^{-1}\mathbf{y}であり、我々が求めるソリューションは次のとおりだ。

w=XT(XXT)1y \mathbf{w} = \mathbf{X}^{T}(\mathbf{X}\mathbf{X}^{T})^{-1}\mathbf{y}


場合4: ランクディフィシェント

ランクディフィシェントであると仮定しよう。

overdetermined: N>n>rank(X)underdetermined: n>N>rank(X) \text{overdetermined: }N \gt n \gt \rank(\mathbf{X}) \\ \text{underdetermined: } n \gt N \gt \rank(\mathbf{X})

この場合も同様に最小二乗解は唯一でなく無数に存在する。また、X\mathbf{X}がフルランクでないため、XTX\mathbf{X}^{T}\mathbf{X}の逆行列は存在しない。したがって場合2場合3のようには解けない。


  1. Christoper M. Bishop, Pattern Recognition annd Machine Learning (2006), p140-147 ↩︎