logo

How to Define Neural Networks Using Functional API in Julia Flux 📂Machine Learning

How to Define Neural Networks Using Functional API in Julia Flux

Description

A simple neural network structure can be defined using Flux.Chain, but it is difficult to define a neural network with a complex structure using Chain. In such cases, you can define the neural network using a functional API with the @functor macro. The @functor allows the parameters of a struct-defined neural network to be tracked for performing backpropagation.

Code

Let’s define a neural network consisting of four linear layers. First, define a struct that has four linear layers as fields, and use the @functor macro to define the neural network with a functional API.

using Flux

struct CustomNetwork
    layer1::Dense
    layer2::Dense
    layer3::Dense
    layer4::Dense
end

Flux.@functor CustomNetwork

Then, define the forward pass of the neural network as follows.

# forward pass 정의
function (m::CustomNetwork)(x)
    x = m.layer1(x)
    x = m.layer2(x)
    x = m.layer3(x)
    return m.layer4(x)
end

Now, let’s create the CustomNetwork and perform the forward pass.

# CustomNetwork 생성
julia> network = CustomNetwork(Dense(2, 10, relu),
                               Dense(10, 10, relu),
                               Dense(10, 10, relu),
                               Dense(10, 1))
CustomNetwork(Dense(2 => 10, relu), Dense(10 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1))

julia> x = randn(Float32, 2)
2-element Vector{Float32}:
  1.5159738
 -2.2359543

julia> network(x)
1-element Vector{Float32}:
 -0.17960261

Full Code

using Flux

struct CustomNetwork
    layer1::Dense
    layer2::Dense
    layer3::Dense
    layer4::Dense
end

Flux.@functor CustomNetwork

# forward pass 정의
function (m::CustomNetwork)(x)
    x = m.layer1(x)
    x = m.layer2(x)
    x = m.layer3(x)
    return m.layer4(x)
end

# CustomNetwork 생성
network = CustomNetwork(Dense(2, 10, relu),
                        Dense(10, 10, relu),
                        Dense(10, 10, relu),
                        Dense(10, 1))

# 
x = randn(Float32, 2)
network(x)

Environment

  • OS: Windows11
  • Version: Julia 1.10.0, Flux v0.14.15