logo

줄리아에서 U-net 구현하기 📂머신러닝

줄리아에서 U-net 구현하기

개요

논문 "U-Net: Convolutional networks for Biomedical Image Segmentation"에서 소개된 U-Net을 줄리아로 구현하는 방법을 소개한다.

코드

U-Net의 구조는 크게 두 부분으로 나뉜다. 인코더가 반복되어 입력 데이터를 압축하는 contracting path와 압축된 데이터를 다시 복원하는 expansive path이다. 아래 그림에서 왼쪽의 절반이 contracting path, 오른쪽이 expansive path이다.

그림에서 보이듯이 스킵커넥션이 존재하기 때문에 sequantial API로는 구현할 수 없다. 구조체와 함수형 API로 구현하는 방법을 정리한다. 우선 컨볼루션을 두 번 적용하는 것이 반복되므로 아래와 같이 함수를 정의한다.

using Flux

function conv_block(in_ch, out_ch)
    return Chain(
        Conv((3,3), in_ch=>out_ch, relu, pad=0),
        Conv((3,3), out_ch=>out_ch, relu, pad=0),
    )
end

그리고 U-Net의 각 부분을 아래 그림과 같이 나누자. 중간에 스킵커넥션이 포함되어있기 때문에에, forward 계산이 편하고 코드가 깔끔하도록 구분하였다.

이제 위 사진의 각 요소를 포함하는 UNet이라는 구조체를 정의하고, 이를 학습가능한 신경망으로 설정한다.

struct UNet
    encoder1; encoder2; encoder3; encoder4
    bottleneck
    decoder4; decoder3; decoder2; decoder1
end

Flux.@functor UNet

그리고 UNet의 forward 계산을 아래와 같이 정의한다. contracting path와 bottleneck 부분은 어려울 것이 없고, expansive path에서는 이전 층의 출력과 contracting path의 출력을 결합해서 다음 층의 입력으로 사용한다. 논문의 구조를 그대로 구현했기 때문에 인코더의 출력을 잘라내는 부분이 있는데, conv_block합성곱층에 패딩을 주면 크기가 보존되므로 cat 함수에 잘라서 넣지 않아도 된다.

function (m::UNet)(x)
    # Contracting Path
    enc1 = m.encoder1(x)
    enc2 = m.encoder2(enc1)
    enc3 = m.encoder3(enc2)
    enc4 = m.encoder4(enc3)

    # Bottleneck
    bn = m.bottleneck(enc4)

    # Expansive Path
    dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
    dec4 = m.decoder4(dec4_input)

    dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
    dec3 = m.decoder3(dec3_input)
  
    dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
    dec2 = m.decoder2(dec2_input)

    dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
    dec1 = m.decoder1(dec1_input)

    return dec1
end

이제 각각의 encoder, bottleneck, decoder를 정의하고, 이들로부터 U-Net을 만들어서 입력 데이터를 넣어보면 계산이 잘 되는 것을 확인할 수 있다.

encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))

bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))

decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))

unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)

x = randn(Float32, 572, 572, 1, 1)
# 572×572×1×1 Array{Float32, 4}:

unet(x)
# 388×388×2×1 Array{Float32, 4}:

코드 전문

using Flux

function conv_block(in_ch, out_ch)
    return Chain(
        Conv((3,3), in_ch=>out_ch, relu, pad=0),
        Conv((3,3), out_ch=>out_ch, relu, pad=0),
    )
end

struct UNet
    encoder1; encoder2; encoder3; encoder4
    bottleneck
    decoder4; decoder3; decoder2; decoder1
end

Flux.@functor UNet

function (m::UNet)(x)
    # Contracting Path
    enc1 = m.encoder1(x)
    enc2 = m.encoder2(enc1)
    enc3 = m.encoder3(enc2)
    enc4 = m.encoder4(enc3)

    # Bottleneck
    bn = m.bottleneck(enc4)

    # Expansive Path
    dec4_input = cat(bn, enc4[1+4:end-4, 1+4:end-4, :, :], dims=3)
    dec4 = m.decoder4(dec4_input)

    dec3_input = cat(dec4, enc3[1+16:end-16, 1+16:end-16, :, :], dims=3)
    dec3 = m.decoder3(dec3_input)
    
    dec2_input = cat(dec3, enc2[1+40:end-40, 1+40:end-40, :, :], dims=3)
    dec2 = m.decoder2(dec2_input)
    
    dec1_input = cat(dec2, enc1[1+88:end-88, 1+88:end-88, :, :], dims=3)
    dec1 = m.decoder1(dec1_input)

    return dec1
end

encoder1 = conv_block(1, 64)
encoder2 = Chain(MaxPool((2,2)), conv_block(64, 128))
encoder3 = Chain(MaxPool((2,2)), conv_block(128, 256))
encoder4 = Chain(MaxPool((2,2)), conv_block(256, 512))

bottleneck = Chain(MaxPool((2,2)), conv_block(512, 1024), ConvTranspose((2,2), 1024=>512, relu, stride=2))

decoder4 = Chain(conv_block(1024, 512), ConvTranspose((2,2), 512=>256, relu, stride=2))
decoder3 = Chain(conv_block(512, 256), ConvTranspose((2,2), 256=>128, relu, stride=2))
decoder2 = Chain(conv_block(256, 128), ConvTranspose((2,2), 128=>64, relu, stride=2))
decoder1 = Chain(conv_block(128, 64), Conv((1,1), 64=>2, relu, pad=0))

unet = UNet(encoder1, encoder2, encoder3, encoder4, bottleneck, decoder4, decoder3, decoder2, decoder1)

x = randn(Float32, 572, 572, 1, 1)
unet(x)

환경

  • OS: Windows11
  • Version: Julia 1.11.3, Flux v0.16.4