줄리아에서 U-net 구현하기
Overview
This document introduces how to implement the U-Net presented in the paper “U-Net: Convolutional networks for Biomedical Image Segmentation” using Julia.
Code
The structure of U-Net is divided into two main parts: a contracting path where the encoder compresses the input data, and an expansive path where the compressed data is restored. In the image below, the left half is the contracting path, and the right is the expansive path.
As shown in the image, there are skip connections, which cannot be implemented using the sequential API. We will summarize the method of implementing it using structs and a functional API. Since applying convolution twice is a repeated operation, we define the function as follows:
using Flux
function conv_block(in_ch, out_ch)
return Chain(
Conv((3,3), in_ch=>out_ch, relu, pad=0),
Conv((3,3), out_ch=>out_ch, relu, pad=0),
)
end
Next, we divide each part of the U-Net as shown in the image below. As there are skip connections in the middle, the forward computation is facilitated, and the code is made clean.
Now, we define a structure called UNet
that includes each component in the above picture, and set this up as a trainable neural network.
struct UNet
encoder1; encoder2; encoder3; encoder4
bottleneck
decoder4; decoder3; decoder2; decoder1
end
Flux.@functor UNet
The forward computation of UNet
is defined as follows. There should be no difficulty in the contracting path and bottleneck part, while in the expansive path, the output of the previous layer and the output of the contracting path are combined for use as input for the next layer. Since the structure of the paper is implemented as-is, there is a part where the output of the encoder is trimmed; however, by providing padding to the convolutional layers of conv_block
, the size is maintained, so it is not necessary to trim when using the cat
function.
function (m::UNet)(x)
# Contracting Path
enc1 = m.encoder1(x)
enc2 = m.encoder2(enc1)
enc3 = m.encoder3(enc2)
enc4 = m.encoder4(enc3)
# Bottleneck
bn = m.bottleneck(enc4)
# Expansive Path
dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
dec4 = m.decoder4(dec4_input)
dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
dec3 = m.decoder3(dec3_input)
dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
dec2 = m.decoder2(dec2_input)
dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
dec1 = m.decoder1(dec1_input)
return dec1
end
Now, define each encoder, bottleneck, and decoder, and create a U-Net from these to test the calculations by inputting data.
encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))
bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))
decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))
unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)
x = randn(Float32, 572, 572, 1, 1)
# 572×572×1×1 Array{Float32, 4}:
unet(x)
# 388×388×2×1 Array{Float32, 4}:
Full Code
using Flux
function conv_block(in_ch, out_ch)
return Chain(
Conv((3,3), in_ch=>out_ch, relu, pad=0),
Conv((3,3), out_ch=>out_ch, relu, pad=0),
)
end
struct UNet
encoder1; encoder2; encoder3; encoder4
bottleneck
decoder4; decoder3; decoder2; decoder1
end
Flux.@functor UNet
function (m::UNet)(x)
# Contracting Path
enc1 = m.encoder1(x)
enc2 = m.encoder2(enc1)
enc3 = m.encoder3(enc2)
enc4 = m.encoder4(enc3)
# Bottleneck
bn = m.bottleneck(enc4)
# Expansive Path
dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
dec4 = m.decoder4(dec4_input)
dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
dec3 = m.decoder3(dec3_input)
dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
dec2 = m.decoder2(dec2_input)
dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
dec1 = m.decoder1(dec1_input)
return dec1
end
encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))
bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))
decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))
unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)
x = randn(Float32, 572, 572, 1, 1)
unet(x)
Environment
- OS: Windows11
- Version: Julia 1.11.3, Flux v0.16.4