logo

トレースの行列微分 📂多変数ベクトル解析

トレースの行列微分

公式

  • X\mathbf{X}n×nn \times n行列としよう。X=X\dfrac{\partial }{\partial \mathbf{X}} = \nabla_{\mathbf{X}}行列勾配としよう。すると次のような公式が成り立つ。

    Tr(X)X=I,Tr(aX)X=aI(1) \dfrac{\partial \Tr(\mathbf{X})}{\partial \mathbf{X}} = I, \qquad \dfrac{\partial \Tr(a\mathbf{X})}{\partial \mathbf{X}} = aI \tag{1}

    ここで aRa \in \mathbb{R}は定数(スカラー)であり、II恒等行列である。

  • ARn×p\mathbf{A} \in \mathbb{R}^{n \times p}として、XRp×n\mathbf{X} \in \mathbb{R}^{p \times n}としよう。次が成り立つ。 Tr(AX)X=Tr(XA)X=AT(2) \dfrac{\partial \Tr(\mathbf{A}\mathbf{X})}{\partial \mathbf{X}} = \dfrac{\partial \Tr( \mathbf{X}\mathbf{A})}{\partial \mathbf{X}} = \mathbf{A}^{\mathsf{T}} \tag{2} Tr(AXT)X=Tr(XTA)X=A \dfrac{\partial \Tr(\mathbf{A}\mathbf{X}^{\mathsf{T}})}{\partial \mathbf{X}} = \dfrac{\partial \Tr( \mathbf{X}^{\mathsf{T}}\mathbf{A})}{\partial \mathbf{X}} = \mathbf{A}

    • 帰結として次が成り立つ。ARn×p\mathbf{A} \in \mathbb{R}^{n \times p}XRp×q\mathbf{X} \in \mathbb{R}^{p \times q}BRq×n\mathbf{B} \in \mathbb{R}^{q \times n}について次が成り立つ。 Tr(AXB)X=ATBT(3) \dfrac{\partial \Tr(\mathbf{A}\mathbf{X}\mathbf{B})}{\partial \mathbf{X}} = \mathbf{A}^{\mathsf{T}}\mathbf{B}^{\mathsf{T}} \tag{3}
  • ARn×n\mathbf{A} \in \mathbb{R}^{n \times n}として、XRm×n\mathbf{X} \in \mathbb{R}^{m \times n}としよう。次が成り立つ。 Tr(AXTX)X=Tr(XTXA)X=Tr(XAXT)X=X(AT+A)(4) \dfrac{\partial \Tr(\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X})}{\partial \mathbf{X}} = \dfrac{\partial \Tr(\mathbf{X}^{\mathsf{T}}\mathbf{X}\mathbf{A})}{\partial \mathbf{X}} = \dfrac{\partial \Tr(\mathbf{X}\mathbf{A}\mathbf{X}^{\mathsf{T}})}{\partial \mathbf{X}} = \mathbf{X}(\mathbf{A}^{\mathsf{T}} + \mathbf{A}) \tag{4} A,XRn×n\mathbf{A}, \mathbf{X} \in \mathbb{R}^{n \times n}について、次が成り立つ。 Tr(AXX)X=XTAT+ATXT(5) \dfrac{\partial \Tr(\mathbf{A} \mathbf{X}\mathbf{X})}{\partial \mathbf{X}} = \mathbf{X}^{\mathsf{T}}\mathbf{A}^{\mathsf{T}} + \mathbf{A}^{\mathsf{T}}\mathbf{X}^{\mathsf{T}} \tag{5} A,B,XRn×n\mathbf{A}, \mathbf{B}, \mathbf{X} \in \mathbb{R}^{n \times n}について、次が成り立つ。 XTr(AXTBX)=Tr(AXTBX)X=BXA+BTXAT(6) \nabla_{\mathbf{X}} \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X}) = \dfrac{\partial \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X})}{\partial \mathbf{X}} = \mathbf{B} \mathbf{X} \mathbf{A} + \mathbf{B}^{\mathsf{T}} \mathbf{X} \mathbf{A}^{\mathsf{T}} \tag{6}

  • XRn×n\mathbf{X} \in \mathbb{R}^{n \times n}と自然数nnについて次が成り立つ。 Tr(Xn)X=n(Xn1)T(7) \dfrac{\partial \Tr(\mathbf{X}^{n})}{\partial \mathbf{X}} = n(\mathbf{X}^{n-1})^{\mathsf{T}} \tag{7}

説明

行列AAトレースTrA\Tr AとはAAのすべての対角成分の和を指す。これは単に値として考えることもできるが、行列をスカラーに変換する関数として考えることもできる。すると、トレースは次のように定義される関数である。

Tr:Rn×nR \Tr : \mathbb{R}^{n \times n} \to \mathbb{R}

関数であれば微分について話さないわけにはいかない。関数値を考えてみると、Tr(A)=i=1naii\Tr(A) = \sum\limits_{i=1}^{n} a_{ii}で単純な一次関数に過ぎないため、微分可能性については心配する必要がない。ただし、変数が行列である点が直感的でないかもしれない。詳しくは勾配行列の文書を読んでみよう。

上記の結果を見ると、スカラー微分とかなり似ていることが分かり、ここからトレースとは行列の微分を直感的に扱えるようにしてくれる道具と受け取ることができる。特に、A,B,X\mathbf{A}, \mathbf{B}, \mathbf{X}などすべての行列が対称行列であれば(行列という特異性を考慮する必要がなければ)多項式の微分とほぼ一致する結果を示している。

  • (1)(1): 恒等行列IIは行列の掛け算に対する単位元なので、スカラー微分dxdx=1\dfrac{d x}{d x} = 1と対応する結果である。
  • (2)(2)(3)(3): 一次関数の微分と対応する直感的な結果である。
  • (4)(4)(5)(5): 二次関数の微分と対応する結果である。
  • (7)(7): 多項式の微分と対応する結果である。

下記の証明では直接計算で示したが、🔒(25/07/17)トレーストリックと呼ばれる方法を使用するとより簡単に計算できる。X\mathbf{X}に対する任意の形や、X\mathbf{X}が多く含まれる式の場合、実際に直接計算するのは非常に困難でトレーストリックを使用する必要がある。

より多くの公式は🔒(25/07/19)スカラー関数の行列微分表で確認できる。

証明

(1)(1)

Tr(X)=i=1nxii\Tr (\mathbf{X}) = \sum\limits_{i=1}^{n} x_{ii}であるため次が成り立つ。

Tr(X)xij={1,i=j0,ij \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{ij}} = \begin{cases} 1 &, i=j \\ 0 &, i \neq j \end{cases}

そのため、以下の結果を得る。

Tr(X)X=[Tr(X)x11Tr(X)x1nTr(X)xn1Tr(X)xnn]=[100010001]=I \dfrac{\partial \Tr (\mathbf{X})}{\partial \mathbf{X}} = \begin{bmatrix} \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{11}} & \cdots & \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{1n}} \\ \vdots & \ddots & \vdots \\ \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{n1}} & \cdots & \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{nn}} \end{bmatrix} = \begin{bmatrix} 1 & 0 & \cdots & 0 \\ 0 & 1 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 1 \end{bmatrix} = I

(2)(2)

ARn×p\mathbf{A} \in \mathbb{R}^{n \times p}XRp×n\mathbf{X} \in \mathbb{R}^{p \times n}としよう。 まずトレースは循環性質を持つため、Tr(AX)=Tr(XA)\Tr (\mathbf{A}\mathbf{X}) = \Tr(\mathbf{X}\mathbf{A})である。

トレースの循環性質

Tr(AB)=Tr(BA) \Tr(AB) = \Tr(BA)

行列AX\mathbf{A} \mathbf{X}ijij成分はk=1paikxkj\sum\limits_{k=1}^{p} a_{ik} x_{kj}であるため、Tr(AX)=i=1nk=1paikxki\Tr (\mathbf{A} \mathbf{X}) = \sum\limits_{i=1}^{n}\sum\limits_{k=1}^{p} a_{ik} x_{ki}である。したがってTr(AX)xij=aji\dfrac{\partial \Tr (\mathbf{A}\mathbf{X})}{\partial x_{ij}} = a_{ji}であり、次が成り立つ。

Tr(AX)X=[Tr(AX)x11Tr(AX)x1nTr(AX)xp1Tr(AX)xpn]=[a11a21an1a12a22an2a1pa2panp]=AT \begin{align*} \dfrac{\partial \Tr (\mathbf{A}\mathbf{X})}{\partial \mathbf{X}} &=\begin{bmatrix} \dfrac{\partial \Tr (\mathbf{A} \mathbf{X})}{\partial x_{11}} & \cdots & \dfrac{\partial \Tr (\mathbf{A} \mathbf{X})}{\partial x_{1n}} \\ \vdots & \ddots & \vdots \\ \dfrac{\partial \Tr (\mathbf{A} \mathbf{X})}{\partial x_{p1}} & \cdots & \dfrac{\partial \Tr (\mathbf{A} \mathbf{X})}{\partial x_{pn}} \end{bmatrix} \\ &= \begin{bmatrix} a_{11} & a_{21} & \cdots & a_{n1} \\ a_{12} & a_{22} & \cdots & a_{n2} \\ \vdots & \vdots & \ddots & \vdots \\ a_{1p} & a_{2p} & \cdots & a_{np} \end{bmatrix} \\ &= \mathbf{A}^{\mathsf{T}} \end{align*}

この結果とトレースの循環性質を利用すれば、(3)(3)をすぐに得ることができる。

(4)(4)

行列の累乗形式の公式により、AXTX\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X}ijij成分とトレース、偏微分は次の通りである。

[AXTX]ij=k=1n=1maikxkxj,Tr(AXTX)=s=1nk=1n=1maskxkxs [\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X}]_{ij} = \sum\limits_{k=1}^{n}\sum\limits_{\ell=1}^{m} a_{ik} x_{\ell k}x_{\ell j}, \quad \Tr (\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X}) = \sum\limits_{s=1}^{n}\sum\limits_{k=1}^{n}\sum\limits_{\ell=1}^{m} a_{sk} x_{\ell k}x_{\ell s}

Tr(AXTX)xij=k=1najkxik+s=1nasjxis=k=1nxikajk+k=1nxikakj=[XAT]ij+[XA]ij \begin{align*} \dfrac{\partial \Tr (\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X})}{\partial x_{ij}} &= \sum\limits_{k=1}^{n}a_{jk}x_{ik} + \sum\limits_{s=1}^{n}a_{sj}x_{is} \\ &= \sum\limits_{k=1}^{n}x_{ik}a_{jk} + \sum\limits_{k=1}^{n}x_{ik}a_{kj} &= [\mathbf{X}\mathbf{A}^{\mathsf{T}}]_{ij} + [\mathbf{X}\mathbf{A}]_{ij} \end{align*}

したがって、

Tr(AXTX)X=XAT+XA=X(AT+A) \dfrac{\partial \Tr (\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X})}{\partial \mathbf{X}} = \mathbf{X}\mathbf{A}^{\mathsf{T}} + \mathbf{X}\mathbf{A} = \mathbf{X}(\mathbf{A}^{\mathsf{T}} + \mathbf{A})

(5)(5)

[AXX]ij=k,s=1naikxksxsj[\mathbf{A} \mathbf{X} \mathbf{X}]_{ij} = \sum\limits_{k, s=1}^{n} a_{ik}x_{ks}x_{sj}であるため、

Tr(AXX)==1nk,s=1nakxksxs \Tr (\mathbf{A}\mathbf{X} \mathbf{X}) = \sum\limits_{\ell= 1}^{n} \sum\limits_{k, s=1}^{n} a_{\ell k}x_{ks}x_{s\ell}

したがって偏微分は以下の通りである。

Tr(AXX)xij==1naixj+k=1najkxki \dfrac{\partial \Tr (\mathbf{A}\mathbf{X} \mathbf{X})}{\partial x_{ij}} = \sum\limits_{\ell=1}^{n} a_{\ell i}x_{j\ell} + \sum\limits_{k=1}^{n} a_{jk}x_{ki}

よって次が得られる。

[Tr(AXX)X]ij==1naixj+k=1najkxki=[ATXT]ij+[XTAT] \left[ \dfrac{\partial \Tr (\mathbf{A}\mathbf{X} \mathbf{X})}{\partial \mathbf{X}} \right]_{ij} = \sum\limits_{\ell=1}^{n} a_{\ell i}x_{j\ell} + \sum\limits_{k=1}^{n} a_{jk}x_{ki} = [\mathbf{A}^{\mathsf{T}}\mathbf{X}^{\mathsf{T}}]_{ij} + [\mathbf{X}^{\mathsf{T}}\mathbf{A}^{\mathsf{T}}]

    Tr(AXX)X=ATXT+XTAT \implies \dfrac{\partial \Tr (\mathbf{A}\mathbf{X} \mathbf{X})}{\partial \mathbf{X}} = \mathbf{A}^{\mathsf{T}}\mathbf{X}^{\mathsf{T}} + \mathbf{X}^{\mathsf{T}}\mathbf{A}^{\mathsf{T}}

(6)(6)

上記の証明過程を無理なく追ってきたという仮定のもと、簡略に記述する。

[AXTBX]ij=k,s,=1naikxskbsxj \left[ \mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X} \right]_{ij} = \sum_{k,s,\ell = 1}^{n} a_{ik} x_{sk} b_{s\ell} x_{\ell j}

    Tr(AXTBX)=r=1nk,s,=1narkxskbsxr \implies \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X}) = \sum\limits_{r=1}^{n}\sum_{k,s,\ell = 1}^{n} a_{rk} x_{sk} b_{s\ell} x_{\ell r}

    [XTr(AXTBX)]ij=r,arjbixr+k,sajkxskbsi=r,bixrarj+k,sbsixskajk=[BXA]ij+[BTXAT]ij \begin{align*} \implies \left[ \nabla_{\mathbf{X}} \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X}) \right]_{ij} &= \sum_{r,\ell} a_{rj}b_{i\ell}x_{\ell r} + \sum_{k,s} a_{jk}x_{sk}b_{si} \\ &= \sum_{r,\ell} b_{i\ell}x_{\ell r}a_{rj} + \sum_{k,s} b_{si}x_{sk}a_{jk} \\ &= [\mathbf{B} \mathbf{X} \mathbf{A}]_{ij} + [\mathbf{B}^{\mathsf{T}} \mathbf{X} \mathbf{A}^{\mathsf{T}}]_{ij} \end{align*}

    XTr(AXTBX)=BXA+BTXAT \implies \nabla_{\mathbf{X}} \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X}) = \mathbf{B} \mathbf{X} \mathbf{A} + \mathbf{B}^{\mathsf{T}} \mathbf{X} \mathbf{A}^{\mathsf{T}}

(7)(7)

[Xn]ij=k(2),,k(n)=1nxik(2)xk(2)k(3)xk(n)j [\mathbf{X}^{n}]_{ij} = \sum\limits_{k_{(2)}, \dots, k_{(n)}=1}^{n} x_{ik_{(2)}} x_{k_{(2)}k_{(3)}} \cdots x_{k_{(n)}j}

    Tr(Xn)==1nk(2),,k(n)=1nxk(2)xk(2)k(3)xk(n) \implies \Tr(\mathbf{X}^{n}) = \sum\limits_{\ell=1}^{n} \sum\limits_{k_{(2)}, \dots, k_{(n)}=1}^{n} x_{\ell k_{(2)}} x_{k_{(2)}k_{(3)}} \cdots x_{k_{(n)}\ell}

    [XTr(Xn)]ij=nk(3),,k(n)=1nxjk(3)xk(3)k(4)xk(n)i=n[(Xn1)T]ij \implies [\nabla_{\mathbf{X}} \Tr(\mathbf{X}^{n})]_{ij} = n \sum\limits_{k_{(3)}, \dots, k_{(n)}=1}^{n} x_{j k_{(3)}} x_{k_{(3)}k_{(4)}} \cdots x_{k_{(n)}i} = n[(\mathbf{X}^{n-1})^{\mathsf{T}}]_{ij}

    XTr(Xn)=n(Xn1)T \implies \nabla_{\mathbf{X}} \Tr(\mathbf{X}^{n}) = n(\mathbf{X}^{n-1})^{\mathsf{T}}