logo

줄리아에서 U-net 구현하기 📂機械学習

줄리아에서 U-net 구현하기

概要

論文「U-Net: Convolutional networks for Biomedical Image Segmentation」で紹介されたU-Netをジュリアで実装する方法を紹介する。

コード

U-Netの構造は大きく二つの部分に分かれる。エンコーダが繰り返され、入力データを圧縮するcontracting pathと、圧縮されたデータを再び復元するexpansive pathである。下の図で左側の半分がcontracting path、右側がexpansive pathだ。

図で示されているようにスキップ接続が存在するため、シーケンシャルAPIでは実装できない。構造体と関数型APIで実装する方法をまとめる。まずは、畳み込みを二回適用することが繰り返されるため、以下のように関数を定義する。

using Flux

function conv_block(in_ch, out_ch)
    return Chain(
        Conv((3,3), in_ch=>out_ch, relu, pad=0),
        Conv((3,3), out_ch=>out_ch, relu, pad=0),
    )
end

次にU-Netの各部分を下の図のように分けよう。中間にスキップ接続が含まれているため、forward計算が容易でコードが整然とするように区分した。

今度は上の図の各要素を含むUNetという構造体を定義し、これを学習可能なニューラルネットワークとして設定する。

struct UNet
    encoder1; encoder2; encoder3; encoder4
    bottleneck
    decoder4; decoder3; decoder2; decoder1
end

Flux.@functor UNet

そして、UNetのforward計算を以下のように定義する。contracting pathとbottleneck部分は難しいことはなく、expansive pathでは前の層の出力とcontracting pathの出力を結合して次の層の入力として使用する。論文の構造をそのまま実装したので、エンコーダの出力をトリミングする部分があるが、conv_block畳み込み層にパディングを与えればサイズが維持されるので、cat関数でトリミングして入れなくてもよい。

function (m::UNet)(x)
    # Contracting Path
    enc1 = m.encoder1(x)
    enc2 = m.encoder2(enc1)
    enc3 = m.encoder3(enc2)
    enc4 = m.encoder4(enc3)

    # Bottleneck
    bn = m.bottleneck(enc4)

    # Expansive Path
    dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
    dec4 = m.decoder4(dec4_input)

    dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
    dec3 = m.decoder3(dec3_input)
  
    dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
    dec2 = m.decoder2(dec2_input)

    dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
    dec1 = m.decoder1(dec1_input)

    return dec1
end

次に、それぞれのエンコーダ、ボトルネック、デコーダを定義し、これらからU-Netを作成して入力データを試してみると、計算がうまくいくことを確認できる。

encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))

bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))

decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))

unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)

x = randn(Float32, 572, 572, 1, 1)
# 572×572×1×1 Array{Float32, 4}:

unet(x)
# 388×388×2×1 Array{Float32, 4}:

コード全文

using Flux

function conv_block(in_ch, out_ch)
    return Chain(
        Conv((3,3), in_ch=>out_ch, relu, pad=0),
        Conv((3,3), out_ch=>out_ch, relu, pad=0),
    )
end

struct UNet
    encoder1; encoder2; encoder3; encoder4
    bottleneck
    decoder4; decoder3; decoder2; decoder1
end

Flux.@functor UNet

function (m::UNet)(x)
    # Contracting Path
    enc1 = m.encoder1(x)
    enc2 = m.encoder2(enc1)
    enc3 = m.encoder3(enc2)
    enc4 = m.encoder4(enc3)

    # Bottleneck
    bn = m.bottleneck(enc4)

    # Expansive Path
    dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
    dec4 = m.decoder4(dec4_input)

    dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
    dec3 = m.decoder3(dec3_input)
    
    dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
    dec2 = m.decoder2(dec2_input)
    
    dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
    dec1 = m.decoder1(dec1_input)

    return dec1
end

encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))

bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))

decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))

unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)

x = randn(Float32, 572, 572, 1, 1)
unet(x)

環境

  • OS: Windows11
  • Version: Julia 1.11.3, Flux v0.16.4