logo

ジュリアフラックスで関数型APIを使用してニューラルネットワークを定義する方法 📂機械学習

ジュリアフラックスで関数型APIを使用してニューラルネットワークを定義する方法

説明

簡単な構造のニューラルネットワークは Flux.Chain を使用して定義できるが、複雑な構造のニューラルネットワークは Chain で定義するのが難しい。この場合、@functor マクロを使用して関数型APIでニューラルネットワークを定義できる。@functor は構造体で定義されたニューラルネットワークのパラメータを追跡し、バックプロパゲーションを実行できるようにする。

コード

線形層 4つからなるニューラルネットワークを定義してみよう。まず、線形層4つをフィールドに持つ構造体を定義し、@functor マクロを使って関数型APIでニューラルネットワークを定義する。

using Flux

struct CustomNetwork
    layer1::Dense
    layer2::Dense
    layer3::Dense
    layer4::Dense
end

Flux.@functor CustomNetwork

そしてニューラルネットワークの順方向伝播を次のように定義する。

# forward pass 정의
function (m::CustomNetwork)(x)
    x = m.layer1(x)
    x = m.layer2(x)
    x = m.layer3(x)
    return m.layer4(x)
end

では CustomNetwork を生成して順方向伝播を実行してみよう。

# CustomNetwork 생성
julia> network = CustomNetwork(Dense(2, 10, relu),
                               Dense(10, 10, relu),
                               Dense(10, 10, relu),
                               Dense(10, 1))
CustomNetwork(Dense(2 => 10, relu), Dense(10 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1))

julia> x = randn(Float32, 2)
2-element Vector{Float32}:
  1.5159738
 -2.2359543

julia> network(x)
1-element Vector{Float32}:
 -0.17960261

コード全文

using Flux

struct CustomNetwork
    layer1::Dense
    layer2::Dense
    layer3::Dense
    layer4::Dense
end

Flux.@functor CustomNetwork

# forward pass 정의
function (m::CustomNetwork)(x)
    x = m.layer1(x)
    x = m.layer2(x)
    x = m.layer3(x)
    return m.layer4(x)
end

# CustomNetwork 생성
network = CustomNetwork(Dense(2, 10, relu),
                        Dense(10, 10, relu),
                        Dense(10, 10, relu),
                        Dense(10, 1))

# 
x = randn(Float32, 2)
network(x)

環境

  • OS: Windows11
  • Version: Julia 1.10.0, Flux v0.14.15