logo

전치 합성곱 📂머신러닝

전치 합성곱

정의

커널이 $K$인 합성곱행렬표현을 $C_{K}$라고 하자. 이의 전치행렬로 정의되는 아래와 같은 행렬변환전치 합성곱transposed convolution이라 한다.

$$ \mathbf{y} = C_{K}^{\mathsf{T}} \mathbf{x} $$

설명

일반적으로 합성곱은 풀링층과 같이 쓰여서 입력 데이터의 차원을 줄이고 정보를 압축하는데 쓰인다. 반면에 전치 합성곱은 압축된 데이터의 차원을 늘리고 업샘플링upsampling을 하는데 쓰인다. 기본적으로 커널을 이동시켜가면서 입력 이미지와의 계산 결과를 얻는 것은 합성곱과 같다. 합성곱은 이미지의 일부 영역과 커널의 내적을 출력하고, 전치합성곱은 이미지의 픽셀 하나와 커널의 상수곱scalar multiplication이라는 점이 다르다. 또한 합성곱은 입력 데이터가 커널과 계산되는 영역이 겹쳐지지만, 전치합성곱에서는 커널과 계산된 출력 데이터의 영역이 겹쳐진다. 계산 방식은 아래의 움짤에 잘 나와있다.

똑같은 커널에 대해서, 합성곱과 전치 합성곱의 행렬표현은 이름 그대로 전치 관계에 있다. 아래와 같은 $3 \times 3$ 커널이 있다고 하자.

$$ \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix} $$

이 커널의 $4 \times 4$ 이미지에 대한 합성곱과 전치합성곱의 행렬변환 $A : \mathbb{R}^{4 \times 4} \to \mathbb{R}^{2 \times 2}$은 각각 아래와 같다.

$$ \text{convolution: } \begin{bmatrix} 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 \end{bmatrix} $$

이의 대응되는 전치합성곱은 $\mathbb{R}^{2 \times 2}$ 행렬을 $\mathbb{R}^{4 \times 4}$ 행렬로 보내는 변환 $B = A^{\mathsf{T}} : \mathbb{R}^{2 \times 2} \to \mathbb{R}^{4 \times 4}$이며 아래와 같다.

$$ \text{transposed convolution: } \begin{bmatrix} 1 & 0 & 0 & 0 \\ 2 & 1 & 0 & 0 \\ 3 & 2 & 0 & 0 \\ 0 & 3 & 0 & 0 \\ 4 & 0 & 1 & 0 \\ 5 & 4 & 2 & 1 \\ 6 & 5 & 3 & 2 \\ 0 & 6 & 0 & 3 \\ 7 & 0 & 4 & 0 \\ 8 & 7 & 5 & 4 \\ 9 & 8 & 6 & 5 \\ 0 & 9 & 0 & 6 \\ 0 & 0 & 7 & 0 \\ 0 & 0 & 8 & 7 \\ 0 & 0 & 9 & 8 \\ 0 & 0 & 0 & 9 \end{bmatrix} $$

두 변환이 서로 역변환 관계는 아니라는 것에 유의하자. 실제로 두 행렬을 곱해보면 단위행렬이 나오지 않는다.

$$ AB = \begin{bmatrix} 285 & 186 & 154 & 94 \\ 186 & 285 & 106 & 154 \\ 154 & 106 & 285 & 186 \\ 94 & 154 & 186 & 285 \end{bmatrix} $$ $$ BA = \begin{bmatrix} 1 & 2 & 3 & 0 & 4 & 5 & 6 & 0 & 7 & 8 & 9 & 0 & 0 & 0 & 0 & 0 \\ 2 & 5 & 8 & 3 & 8 & 14 & 17 & 6 & 14 & 23 & 26 & 9 & 0 & 0 & 0 & 0 \\ 3 & 8 & 13 & 6 & 12 & 23 & 28 & 12 & 21 & 38 & 43 & 18 & 0 & 0 & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & 8 & 23 & 38 & 21 & 32 & 68 & 83 & 42 & 56 & 113 & 128 & 63 \\ 0 & 0 & 0 & 0 & 9 & 26 & 43 & 24 & 36 & 77 & 94 & 48 & 63 & 128 & 145 & 72 \\ 0 & 0 & 0 & 0 & 0 & 9 & 18 & 27 & 0 & 36 & 45 & 54 & 0 & 63 & 72 & 81 \end{bmatrix} $$