logo

줄리아 플럭스에서 구조체를 이용하여 함수형 API로 신경망 정의하는 방법 📂머신러닝

줄리아 플럭스에서 구조체를 이용하여 함수형 API로 신경망 정의하는 방법

설명

간단한 구조의 신경망은 Flux.Chain을 사용하여 정의할 수 있지만, 복잡한 구조의 신경망은 Chain으로 정의하기 어렵다. 이 때는 @functor 매크로를 사용하여 함수형 API로 신경망을 정의할 수 있다. @functor는 구조체로 정의된 신경망의 파라미터를 추적하여 역전파를 수행할 수 있게 해준다.

코드

선형층 4개로 이루어진 신경망을 정의해보자. 우선 선형층 4개를 필드로 갖는 구조체를 정의하고, @functor 매크로를 사용하여 함수형 API로 신경망을 정의한다.

using Flux

struct CustomNetwork
    layer1::Dense
    layer2::Dense
    layer3::Dense
    layer4::Dense
end

Flux.@functor CustomNetwork

그리고 신경망의 forward pass를 다음과 같이 정의한다.

# forward pass 정의
function (m::CustomNetwork)(x)
    x = m.layer1(x)
    x = m.layer2(x)
    x = m.layer3(x)
    return m.layer4(x)
end

이제 CustomNetwork을 생성하고 forward pass를 수행해보자.

# CustomNetwork 생성
julia> network = CustomNetwork(Dense(2, 10, relu),
                               Dense(10, 10, relu),
                               Dense(10, 10, relu),
                               Dense(10, 1))
CustomNetwork(Dense(2 => 10, relu), Dense(10 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1))

julia> x = randn(Float32, 2)
2-element Vector{Float32}:
  1.5159738
 -2.2359543

julia> network(x)
1-element Vector{Float32}:
 -0.17960261

코드 전문

using Flux

struct CustomNetwork
    layer1::Dense
    layer2::Dense
    layer3::Dense
    layer4::Dense
end

Flux.@functor CustomNetwork

# forward pass 정의
function (m::CustomNetwork)(x)
    x = m.layer1(x)
    x = m.layer2(x)
    x = m.layer3(x)
    return m.layer4(x)
end

# CustomNetwork 생성
network = CustomNetwork(Dense(2, 10, relu),
                        Dense(10, 10, relu),
                        Dense(10, 10, relu),
                        Dense(10, 1))

# 
x = randn(Float32, 2)
network(x)

환경

  • OS: Windows11
  • Version: Julia 1.10.0, Flux v0.14.15