logo

줄리아에서 U-net 구현하기 📂Machine Learning

줄리아에서 U-net 구현하기

Overview

This document introduces how to implement the U-Net presented in the paper “U-Net: Convolutional networks for Biomedical Image Segmentation” using Julia.

Code

The structure of U-Net is divided into two main parts: a contracting path where the encoder compresses the input data, and an expansive path where the compressed data is restored. In the image below, the left half is the contracting path, and the right is the expansive path.

As shown in the image, there are skip connections, which cannot be implemented using the sequential API. We will summarize the method of implementing it using structs and a functional API. Since applying convolution twice is a repeated operation, we define the function as follows:

using Flux

function conv_block(in_ch, out_ch)
    return Chain(
        Conv((3,3), in_ch=>out_ch, relu, pad=0),
        Conv((3,3), out_ch=>out_ch, relu, pad=0),
    )
end

Next, we divide each part of the U-Net as shown in the image below. As there are skip connections in the middle, the forward computation is facilitated, and the code is made clean.

Now, we define a structure called UNet that includes each component in the above picture, and set this up as a trainable neural network.

struct UNet
    encoder1; encoder2; encoder3; encoder4
    bottleneck
    decoder4; decoder3; decoder2; decoder1
end

Flux.@functor UNet

The forward computation of UNet is defined as follows. There should be no difficulty in the contracting path and bottleneck part, while in the expansive path, the output of the previous layer and the output of the contracting path are combined for use as input for the next layer. Since the structure of the paper is implemented as-is, there is a part where the output of the encoder is trimmed; however, by providing padding to the convolutional layers of conv_block, the size is maintained, so it is not necessary to trim when using the cat function.

function (m::UNet)(x)
    # Contracting Path
    enc1 = m.encoder1(x)
    enc2 = m.encoder2(enc1)
    enc3 = m.encoder3(enc2)
    enc4 = m.encoder4(enc3)

    # Bottleneck
    bn = m.bottleneck(enc4)

    # Expansive Path
    dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
    dec4 = m.decoder4(dec4_input)

    dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
    dec3 = m.decoder3(dec3_input)
  
    dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
    dec2 = m.decoder2(dec2_input)

    dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
    dec1 = m.decoder1(dec1_input)

    return dec1
end

Now, define each encoder, bottleneck, and decoder, and create a U-Net from these to test the calculations by inputting data.

encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))

bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))

decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))

unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)

x = randn(Float32, 572, 572, 1, 1)
# 572×572×1×1 Array{Float32, 4}:

unet(x)
# 388×388×2×1 Array{Float32, 4}:

Full Code

using Flux

function conv_block(in_ch, out_ch)
    return Chain(
        Conv((3,3), in_ch=>out_ch, relu, pad=0),
        Conv((3,3), out_ch=>out_ch, relu, pad=0),
    )
end

struct UNet
    encoder1; encoder2; encoder3; encoder4
    bottleneck
    decoder4; decoder3; decoder2; decoder1
end

Flux.@functor UNet

function (m::UNet)(x)
    # Contracting Path
    enc1 = m.encoder1(x)
    enc2 = m.encoder2(enc1)
    enc3 = m.encoder3(enc2)
    enc4 = m.encoder4(enc3)

    # Bottleneck
    bn = m.bottleneck(enc4)

    # Expansive Path
    dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
    dec4 = m.decoder4(dec4_input)

    dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
    dec3 = m.decoder3(dec3_input)
    
    dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
    dec2 = m.decoder2(dec2_input)
    
    dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
    dec1 = m.decoder1(dec1_input)

    return dec1
end

encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))

bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))

decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))

unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)

x = randn(Float32, 572, 572, 1, 1)
unet(x)

Environment

  • OS: Windows11
  • Version: Julia 1.11.3, Flux v0.16.4