logo

転置畳み込み 📂機械学習

転置畳み込み

定義

カーネルが$K$である畳み込み行列表現を$C_{K}$とする。この転置行列で定義される次のような行列変換転置畳み込みtransposed convolutionという。

$$ \mathbf{y} = C_{K}^{\mathsf{T}} \mathbf{x} $$

説明

一般に畳み込みはプーリング層と共に使用され、入力データの次元を削減し情報を圧縮するのに使われる。一方で転置畳み込みは圧縮データの次元を増加させ、アップサンプリングupsamplingを行うのに使用される。基本的にカーネルを移動させながら入力画像と計算結果を得ることは畳み込みと同じである。畳み込みは画像の一部領域とカーネルの内積を出力するのに対し、転置畳み込みは画像のピクセルひとつとカーネルのスカラー乗算scalar multiplicationである点が異なる。また、畳み込みでは入力データがカーネルと計算される領域が重なるが、転置畳み込みではカーネルと計算された出力データの領域が重なる。計算方式は以下のGIFに詳しく示されている。

同じカーネルについて、畳み込みと転置畳み込みの行列表現はその名の通り転置の関係にある。以下のような$3 \times 3$カーネルがあるとする。

$$ \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix} $$

このカーネルの$4 \times 4$画像に対する畳み込みと転置畳み込みの行列変換$A : \mathbb{R}^{4 \times 4} \to \mathbb{R}^{2 \times 2}$はそれぞれ以下の通りである。

$$ \text{convolution: } \begin{bmatrix} 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 \end{bmatrix} $$

これに対応する転置畳み込みは$\mathbb{R}^{2 \times 2}$行列を$\mathbb{R}^{4 \times 4}$行列に送る変換$B = A^{\mathsf{T}} : \mathbb{R}^{2 \times 2} \to \mathbb{R}^{4 \times 4}$で以下の通りである。

$$ \text{transposed convolution: } \begin{bmatrix} 1 & 0 & 0 & 0 \\ 2 & 1 & 0 & 0 \\ 3 & 2 & 0 & 0 \\ 0 & 3 & 0 & 0 \\ 4 & 0 & 1 & 0 \\ 5 & 4 & 2 & 1 \\ 6 & 5 & 3 & 2 \\ 0 & 6 & 0 & 3 \\ 7 & 0 & 4 & 0 \\ 8 & 7 & 5 & 4 \\ 9 & 8 & 6 & 5 \\ 0 & 9 & 0 & 6 \\ 0 & 0 & 7 & 0 \\ 0 & 0 & 8 & 7 \\ 0 & 0 & 9 & 8 \\ 0 & 0 & 0 & 9 \end{bmatrix} $$

二つの変換が相互に逆変換関係にないことに注意しよう。実際に二つの行列を掛け合わせても単位行列にはならない。

$$ AB = \begin{bmatrix} 285 & 186 & 154 & 94 \\ 186 & 285 & 106 & 154 \\ 154 & 106 & 285 & 186 \\ 94 & 154 & 186 & 285 \end{bmatrix} $$ $$ BA = \begin{bmatrix} 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 & 0 & 0 & 0 & 0 & 0 \\ 2 & 5 & 8 & 3 & 8 & 14 & 17 & 6 & 14 & 23 & 26 & 9 & 0 & 0 & 0 & 0 \\ 3 & 8 & 13 & 6 & 12 & 23 & 28 & 12 & 21 & 38 & 43 & 18 & 0 & 0 & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & 8 & 23 & 38 & 21 & 32 & 68 & 83 & 42 & 56 & 113 & 128 & 63 \\ 0 & 0 & 0 & 0 & 9 & 26 & 43 & 24 & 36 & 77 & 94 & 48 & 63 & 128 & 145 & 72 \\ 0 & 0 & 0 & 0 & 0 & 9 & 18 & 27 & 0 & 36 & 45 & 54 & 0 & 63 & 72 & 81 \end{bmatrix} $$