logo

트레이스의 행렬 미분 📂다변수벡터해석

트레이스의 행렬 미분

공식

  • X\mathbf{X}n×nn \times n 행렬이라 하자. X=X\dfrac{\partial }{\partial \mathbf{X}} = \nabla_{\mathbf{X}}행렬 그래디언트라고 하자. 그러면 다음과 같은 공식이 성립한다.

    Tr(X)X=I,Tr(aX)X=aI(1) \dfrac{\partial \Tr(\mathbf{X})}{\partial \mathbf{X}} = I, \qquad \dfrac{\partial \Tr(a\mathbf{X})}{\partial \mathbf{X}} = aI \tag{1}

    여기서 aRa \in \mathbb{R}는 상수(스칼라)이고, II항등행렬이다.

  • ARn×p\mathbf{A} \in \mathbb{R}^{n \times p}이고 XRp×n\mathbf{X} \in \mathbb{R}^{p \times n}라고 하자. 다음이 성립한다. Tr(AX)X=Tr(XA)X=AT(2) \dfrac{\partial \Tr(\mathbf{A}\mathbf{X})}{\partial \mathbf{X}} = \dfrac{\partial \Tr( \mathbf{X}\mathbf{A})}{\partial \mathbf{X}} = \mathbf{A}^{\mathsf{T}} \tag{2} Tr(AXT)X=Tr(XTA)X=A \dfrac{\partial \Tr(\mathbf{A}\mathbf{X}^{\mathsf{T}})}{\partial \mathbf{X}} = \dfrac{\partial \Tr( \mathbf{X}^{\mathsf{T}}\mathbf{A})}{\partial \mathbf{X}} = \mathbf{A}

    • 따름정리로서 다음이 성립한다. ARn×p\mathbf{A} \in \mathbb{R}^{n \times p}, XRp×q\mathbf{X} \in \mathbb{R}^{p \times q}, BRq×n\mathbf{B} \in \mathbb{R}^{q \times n}에 대해 다음이 성립한다. Tr(AXB)X=ATBT(3) \dfrac{\partial \Tr(\mathbf{A}\mathbf{X}\mathbf{B})}{\partial \mathbf{X}} = \mathbf{A}^{\mathsf{T}}\mathbf{B}^{\mathsf{T}} \tag{3}
  • ARn×n\mathbf{A} \in \mathbb{R}^{n \times n}이고 XRm×n\mathbf{X} \in \mathbb{R}^{m \times n}라고 하자. 다음이 성립한다. Tr(AXTX)X=Tr(XTXA)X=Tr(XAXT)X=X(AT+A)(4) \dfrac{\partial \Tr(\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X})}{\partial \mathbf{X}} = \dfrac{\partial \Tr(\mathbf{X}^{\mathsf{T}}\mathbf{X}\mathbf{A})}{\partial \mathbf{X}} = \dfrac{\partial \Tr(\mathbf{X}\mathbf{A}\mathbf{X}^{\mathsf{T}})}{\partial \mathbf{X}} = \mathbf{X}(\mathbf{A}^{\mathsf{T}} + \mathbf{A}) \tag{4} A,XRn×n\mathbf{A}, \mathbf{X} \in \mathbb{R}^{n \times n}에 대해서, 다음이 성립한다. Tr(AXX)X=XTAT+ATXT(5) \dfrac{\partial \Tr(\mathbf{A} \mathbf{X}\mathbf{X})}{\partial \mathbf{X}} = \mathbf{X}^{\mathsf{T}}\mathbf{A}^{\mathsf{T}} + \mathbf{A}^{\mathsf{T}}\mathbf{X}^{\mathsf{T}} \tag{5} A,B,XRn×n\mathbf{A}, \mathbf{B}, \mathbf{X} \in \mathbb{R}^{n \times n}에 대해서, 다음이 성립한다. XTr(AXTBX)=Tr(AXTBX)X=BXA+BTXAT(6) \nabla_{\mathbf{X}} \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X}) = \dfrac{\partial \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X})}{\partial \mathbf{X}} = \mathbf{B} \mathbf{X} \mathbf{A} + \mathbf{B}^{\mathsf{T}} \mathbf{X} \mathbf{A}^{\mathsf{T}} \tag{6}

  • XRn×n\mathbf{X} \in \mathbb{R}^{n \times n}와 자연수 nn에 대해서 다음이 성립한다. Tr(Xn)X=n(Xn1)T(7) \dfrac{\partial \Tr(\mathbf{X}^{n})}{\partial \mathbf{X}} = n(\mathbf{X}^{n-1})^{\mathsf{T}} \tag{7}

설명

행렬 AA트레이스 TrA\Tr AAA의 모든 대각성분의 합을 말한다. 이를 단순히 값으로 생각할 수도 있지만, 행렬을 스칼라로 변환하는 함수로 생각할 수도 있다. 그러면면 트레이스는 다음과 같이 정의되는 함수이다.

Tr:Rn×nR \Tr : \mathbb{R}^{n \times n} \to \mathbb{R}

함수라면 미분에 대해서 얘기하지 않을 수 없다. 함숫값을 생각해보면, Tr(A)=i=1naii\Tr(A) = \sum\limits_{i=1}^{n} a_{ii}라서 단순한 일차함수에 불과하므로 미분가능성에 대해서는 걱정할 필요가 없다. 다만, 변수가 행렬이라는 점이 직관적이지 않을 수 있다. 자세한 것은 그래디언트 행렬 문서를 읽어보자.

위의 결과들을 보면 스칼라 미분과 상당히 비슷한 것을 알 수 있는데, 이로부터 트레이스란 행렬의 미분을 직관적으로 다룰 수 있게 해주는 도구라고 받아들일 수 있다. 특히나 A,B,X\mathbf{A}, \mathbf{B}, \mathbf{X} 등 모든 행렬이 대칭행렬이면 (행렬이라는 특수성을 고려할 필요가 없으면) 다항함수의 미분과 거의 일치하는 결과를 보여준다.

  • (1)(1): 항등행렬 II는 행렬의 곱셈에 대한 항등원이므로, 스칼라 미분 dxdx=1\dfrac{d x}{d x} = 1과 대응되는 결과이다.
  • (2)(2), (3)(3): 일차함수의 미분과 대응되는 직관적인 결과이다.
  • (4)(4), (5)(5): 이차함수의 미분과 대응되는 결과이다.
  • (7)(7): 다항함수의 미분과 대응되는 결과이다.

아래의 증명에서는 직접 계산으로 보였지만, 트레이스 트릭이라고 불리는 방법을 사용하면 더 간단하게 계산할 수 있다. X\mathbf{X}에 대한 임의의 꼴이나, X\mathbf{X}가 많이 포함된 식의 경우 사실상 직접 계산하는 것은 너무 힘들고 트레이스 트릭을 사용하여야 한다.

더 많은 공식은 스칼라 함수의 행렬 미분 표에서 확인할 수 있다.

증명

(1)(1)

Tr(X)=i=1nxii\Tr (\mathbf{X}) = \sum\limits_{i=1}^{n} x_{ii}이므로 다음이 성립한다.

Tr(X)xij={1,i=j0,ij \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{ij}} = \begin{cases} 1 &, i=j \\ 0 &, i \neq j \end{cases}

따라서 아래의 결과를 얻는다.

Tr(X)X=[Tr(X)x11Tr(X)x1nTr(X)xn1Tr(X)xnn]=[100010001]=I \dfrac{\partial \Tr (\mathbf{X})}{\partial \mathbf{X}} = \begin{bmatrix} \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{11}} & \cdots & \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{1n}} \\ \vdots & \ddots & \vdots \\ \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{n1}} & \cdots & \dfrac{\partial \Tr (\mathbf{X})}{\partial x_{nn}} \end{bmatrix} = \begin{bmatrix} 1 & 0 & \cdots & 0 \\ 0 & 1 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 1 \end{bmatrix} = I

(2)(2)

ARn×p\mathbf{A} \in \mathbb{R}^{n \times p}, XRp×n\mathbf{X} \in \mathbb{R}^{p \times n}라고 하자. 우선 트레이스는 순환성질이 있으므로 Tr(AX)=Tr(XA)\Tr (\mathbf{A}\mathbf{X}) = \Tr(\mathbf{X}\mathbf{A})이다.

트레이스의 순환 성질

Tr(AB)=Tr(BA) \Tr(AB) = \Tr(BA)

행렬 AX\mathbf{A} \mathbf{X}ijij 성분은 k=1paikxkj\sum\limits_{k=1}^{p} a_{ik} x_{kj}이므로, Tr(AX)=i=1nk=1paikxki\Tr (\mathbf{A} \mathbf{X}) = \sum\limits_{i=1}^{n}\sum\limits_{k=1}^{p} a_{ik} x_{ki}이다. 따라서 Tr(AX)xij=aji\dfrac{\partial \Tr (\mathbf{A}\mathbf{X})}{\partial x_{ij}} = a_{ji}이고 다음이 성립한다.

Tr(AX)X=[Tr(AX)x11Tr(AX)x1nTr(AX)xp1Tr(AX)xpn]=[a11a21an1a12a22an2a1pa2panp]=AT \begin{align*} \dfrac{\partial \Tr (\mathbf{A}\mathbf{X})}{\partial \mathbf{X}} &=\begin{bmatrix} \dfrac{\partial \Tr (\mathbf{A} \mathbf{X})}{\partial x_{11}} & \cdots & \dfrac{\partial \Tr (\mathbf{A} \mathbf{X})}{\partial x_{1n}} \\ \vdots & \ddots & \vdots \\ \dfrac{\partial \Tr (\mathbf{A} \mathbf{X})}{\partial x_{p1}} & \cdots & \dfrac{\partial \Tr (\mathbf{A} \mathbf{X})}{\partial x_{pn}} \end{bmatrix} \\ &= \begin{bmatrix} a_{11} & a_{21} & \cdots & a_{n1} \\ a_{12} & a_{22} & \cdots & a_{n2} \\ \vdots & \vdots & \ddots & \vdots \\ a_{1p} & a_{2p} & \cdots & a_{np} \end{bmatrix} \\ &= \mathbf{A}^{\mathsf{T}} \end{align*}

이 결과와 트레이스의 순환 성질을 이용하면 (3)(3)을 바로 얻는다.

(4)(4)

행렬의 거듭제곱꼴 공식에 의해 AXTX\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X}ijij성분과 트레이스, 편미분은 다음과 같다.

[AXTX]ij=k=1n=1maikxkxj,Tr(AXTX)=s=1nk=1n=1maskxkxs [\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X}]_{ij} = \sum\limits_{k=1}^{n}\sum\limits_{\ell=1}^{m} a_{ik} x_{\ell k}x_{\ell j}, \quad \Tr (\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X}) = \sum\limits_{s=1}^{n}\sum\limits_{k=1}^{n}\sum\limits_{\ell=1}^{m} a_{sk} x_{\ell k}x_{\ell s}

Tr(AXTX)xij=k=1najkxik+s=1nasjxis=k=1nxikajk+k=1nxikakj=[XAT]ij+[XA]ij \begin{align*} \dfrac{\partial \Tr (\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X})}{\partial x_{ij}} &= \sum\limits_{k=1}^{n}a_{jk}x_{ik} + \sum\limits_{s=1}^{n}a_{sj}x_{is} \\ &= \sum\limits_{k=1}^{n}x_{ik}a_{jk} + \sum\limits_{k=1}^{n}x_{ik}a_{kj} &= [\mathbf{X}\mathbf{A}^{\mathsf{T}}]_{ij} + [\mathbf{X}\mathbf{A}]_{ij} \end{align*}

따라서,

Tr(AXTX)X=XAT+XA=X(AT+A) \dfrac{\partial \Tr (\mathbf{A}\mathbf{X}^{\mathsf{T}}\mathbf{X})}{\partial \mathbf{X}} = \mathbf{X}\mathbf{A}^{\mathsf{T}} + \mathbf{X}\mathbf{A} = \mathbf{X}(\mathbf{A}^{\mathsf{T}} + \mathbf{A})

(5)(5)

[AXX]ij=k,s=1naikxksxsj[\mathbf{A} \mathbf{X} \mathbf{X}]_{ij} = \sum\limits_{k, s=1}^{n} a_{ik}x_{ks}x_{sj}이므로,

Tr(AXX)==1nk,s=1nakxksxs \Tr (\mathbf{A}\mathbf{X} \mathbf{X}) = \sum\limits_{\ell= 1}^{n} \sum\limits_{k, s=1}^{n} a_{\ell k}x_{ks}x_{s\ell}

따라서 편미분은 아래와 같다.

Tr(AXX)xij==1naixj+k=1najkxki \dfrac{\partial \Tr (\mathbf{A}\mathbf{X} \mathbf{X})}{\partial x_{ij}} = \sum\limits_{\ell=1}^{n} a_{\ell i}x_{j\ell} + \sum\limits_{k=1}^{n} a_{jk}x_{ki}

그러므로 다음을 얻는다.

[Tr(AXX)X]ij==1naixj+k=1najkxki=[ATXT]ij+[XTAT] \left[ \dfrac{\partial \Tr (\mathbf{A}\mathbf{X} \mathbf{X})}{\partial \mathbf{X}} \right]_{ij} = \sum\limits_{\ell=1}^{n} a_{\ell i}x_{j\ell} + \sum\limits_{k=1}^{n} a_{jk}x_{ki} = [\mathbf{A}^{\mathsf{T}}\mathbf{X}^{\mathsf{T}}]_{ij} + [\mathbf{X}^{\mathsf{T}}\mathbf{A}^{\mathsf{T}}]

    Tr(AXX)X=ATXT+XTAT \implies \dfrac{\partial \Tr (\mathbf{A}\mathbf{X} \mathbf{X})}{\partial \mathbf{X}} = \mathbf{A}^{\mathsf{T}}\mathbf{X}^{\mathsf{T}} + \mathbf{X}^{\mathsf{T}}\mathbf{A}^{\mathsf{T}}

(6)(6)

위의 증명 과정을 무리없이 따라왔다는 가정하에, 간략히 서술한다.

[AXTBX]ij=k,s,=1naikxskbsxj \left[ \mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X} \right]_{ij} = \sum_{k,s,\ell = 1}^{n} a_{ik} x_{sk} b_{s\ell} x_{\ell j}

    Tr(AXTBX)=r=1nk,s,=1narkxskbsxr \implies \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X}) = \sum\limits_{r=1}^{n}\sum_{k,s,\ell = 1}^{n} a_{rk} x_{sk} b_{s\ell} x_{\ell r}

    [XTr(AXTBX)]ij=r,arjbixr+k,sajkxskbsi=r,bixrarj+k,sbsixskajk=[BXA]ij+[BTXAT]ij \begin{align*} \implies \left[ \nabla_{\mathbf{X}} \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X}) \right]_{ij} &= \sum_{r,\ell} a_{rj}b_{i\ell}x_{\ell r} + \sum_{k,s} a_{jk}x_{sk}b_{si} \\ &= \sum_{r,\ell} b_{i\ell}x_{\ell r}a_{rj} + \sum_{k,s} b_{si}x_{sk}a_{jk} \\ &= [\mathbf{B} \mathbf{X} \mathbf{A}]_{ij} + [\mathbf{B}^{\mathsf{T}} \mathbf{X} \mathbf{A}^{\mathsf{T}}]_{ij} \end{align*}

    XTr(AXTBX)=BXA+BTXAT \implies \nabla_{\mathbf{X}} \Tr (\mathbf{A} \mathbf{X}^{\mathsf{T}} \mathbf{B} \mathbf{X}) = \mathbf{B} \mathbf{X} \mathbf{A} + \mathbf{B}^{\mathsf{T}} \mathbf{X} \mathbf{A}^{\mathsf{T}}

(7)(7)

[Xn]ij=k(2),,k(n)=1nxik(2)xk(2)k(3)xk(n)j [\mathbf{X}^{n}]_{ij} = \sum\limits_{k_{(2)}, \dots, k_{(n)}=1}^{n} x_{ik_{(2)}} x_{k_{(2)}k_{(3)}} \cdots x_{k_{(n)}j}

    Tr(Xn)==1nk(2),,k(n)=1nxk(2)xk(2)k(3)xk(n) \implies \Tr(\mathbf{X}^{n}) = \sum\limits_{\ell=1}^{n} \sum\limits_{k_{(2)}, \dots, k_{(n)}=1}^{n} x_{\ell k_{(2)}} x_{k_{(2)}k_{(3)}} \cdots x_{k_{(n)}\ell}

    [XTr(Xn)]ij=nk(3),,k(n)=1nxjk(3)xk(3)k(4)xk(n)i=n[(Xn1)T]ij \implies [\nabla_{\mathbf{X}} \Tr(\mathbf{X}^{n})]_{ij} = n \sum\limits_{k_{(3)}, \dots, k_{(n)}=1}^{n} x_{j k_{(3)}} x_{k_{(3)}k_{(4)}} \cdots x_{k_{(n)}i} = n[(\mathbf{X}^{n-1})^{\mathsf{T}}]_{ij}

    XTr(Xn)=n(Xn1)T \implies \nabla_{\mathbf{X}} \Tr(\mathbf{X}^{n}) = n(\mathbf{X}^{n-1})^{\mathsf{T}}