줄리아에서 U-net 구현하기
概要
論文「U-Net: Convolutional networks for Biomedical Image Segmentation」で紹介されたU-Netをジュリアで実装する方法を紹介する。
コード
U-Netの構造は大きく二つの部分に分かれる。エンコーダが繰り返され、入力データを圧縮するcontracting pathと、圧縮されたデータを再び復元するexpansive pathである。下の図で左側の半分がcontracting path、右側がexpansive pathだ。
図で示されているようにスキップ接続が存在するため、シーケンシャルAPIでは実装できない。構造体と関数型APIで実装する方法をまとめる。まずは、畳み込みを二回適用することが繰り返されるため、以下のように関数を定義する。
using Flux
function conv_block(in_ch, out_ch)
return Chain(
Conv((3,3), in_ch=>out_ch, relu, pad=0),
Conv((3,3), out_ch=>out_ch, relu, pad=0),
)
end
次にU-Netの各部分を下の図のように分けよう。中間にスキップ接続が含まれているため、forward計算が容易でコードが整然とするように区分した。
今度は上の図の各要素を含むUNet
という構造体を定義し、これを学習可能なニューラルネットワークとして設定する。
struct UNet
encoder1; encoder2; encoder3; encoder4
bottleneck
decoder4; decoder3; decoder2; decoder1
end
Flux.@functor UNet
そして、UNet
のforward計算を以下のように定義する。contracting pathとbottleneck部分は難しいことはなく、expansive pathでは前の層の出力とcontracting pathの出力を結合して次の層の入力として使用する。論文の構造をそのまま実装したので、エンコーダの出力をトリミングする部分があるが、conv_block
の畳み込み層にパディングを与えればサイズが維持されるので、cat
関数でトリミングして入れなくてもよい。
function (m::UNet)(x)
# Contracting Path
enc1 = m.encoder1(x)
enc2 = m.encoder2(enc1)
enc3 = m.encoder3(enc2)
enc4 = m.encoder4(enc3)
# Bottleneck
bn = m.bottleneck(enc4)
# Expansive Path
dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
dec4 = m.decoder4(dec4_input)
dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
dec3 = m.decoder3(dec3_input)
dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
dec2 = m.decoder2(dec2_input)
dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
dec1 = m.decoder1(dec1_input)
return dec1
end
次に、それぞれのエンコーダ、ボトルネック、デコーダを定義し、これらからU-Netを作成して入力データを試してみると、計算がうまくいくことを確認できる。
encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))
bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))
decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))
unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)
x = randn(Float32, 572, 572, 1, 1)
# 572×572×1×1 Array{Float32, 4}:
unet(x)
# 388×388×2×1 Array{Float32, 4}:
コード全文
using Flux
function conv_block(in_ch, out_ch)
return Chain(
Conv((3,3), in_ch=>out_ch, relu, pad=0),
Conv((3,3), out_ch=>out_ch, relu, pad=0),
)
end
struct UNet
encoder1; encoder2; encoder3; encoder4
bottleneck
decoder4; decoder3; decoder2; decoder1
end
Flux.@functor UNet
function (m::UNet)(x)
# Contracting Path
enc1 = m.encoder1(x)
enc2 = m.encoder2(enc1)
enc3 = m.encoder3(enc2)
enc4 = m.encoder4(enc3)
# Bottleneck
bn = m.bottleneck(enc4)
# Expansive Path
dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
dec4 = m.decoder4(dec4_input)
dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
dec3 = m.decoder3(dec3_input)
dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
dec2 = m.decoder2(dec2_input)
dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
dec1 = m.decoder1(dec1_input)
return dec1
end
encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))
bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))
decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))
unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)
x = randn(Float32, 572, 572, 1, 1)
unet(x)
環境
- OS: Windows11
- Version: Julia 1.11.3, Flux v0.16.4