logo

줄리아에서 이원수를 이용하여 자동미분 전진모드 구현하기 📂Machine Learning

줄리아에서 이원수를 이용하여 자동미분 전진모드 구현하기

Overview

The forward mode of automatic differentiation can be easily implemented using dual numbers. This document explains how to implement the forward mode in Julia. For background knowledge on dual numbers and automatic differentiation, the following articles are recommended:

Code 1

An example of calculating the automatic differentiation for function $y(x) = \ln (x^{2} + \sin x)$.

Defining the Dual Number Structure

First, let’s inform the computer about what a dual number is. In Julia, dual numbers can be defined as a struct as follows. Keeping in mind its application in automatic differentiation, let the first component v be called the function value and the second component d be called the derivative.

struct Dual 
    v::Float64 # (function) value
    d::Float64 # derivative
end

Now, we can define dual numbers $x = (3, 1)$ and $y = (2, 0)$.

julia> x = Dual(3, 1)
Dual(3.0, 1.0)

julia> x.v, x.d
(3.0, 1.0)

julia> y = Dual(2, 0)
Dual(2.0, 0.0)

Defining Addition for Dual Numbers

Let’s try to add x and y. Doing so will naturally throw an error since the binary operation + predefined in Julia is not defined for Dual and Dual.

julia> x + y
ERROR: MethodError: no method matching +(::Dual, ::Dual)

Currently, 189 methods are defined for the addition +.

julia> methods(+)
# 189 methods for generic function "+" from Base:

By defining + for Dual and Dual as shown below, we can see that the number of methods has increased by one.

julia> Base.:+(x::Dual, y::Dual) = Dual(x.v + y.v, x.d + y.d)

julia> methods(+)
# 190 methods for generic function "+" from Base:

We can now calculate x + y.

julia> x + y
Dual(5.0, 1.0)

Defining Multiplication for Dual Numbers

Similarly to addition, we can define and compute multiplication as follows.

julia> Base.:*(x::Dual, y::Dual) = Dual(x.v*y.v, x.v*y.d + x.d*y.v)

julia> x * y
Dual(6.0, 2.0)

Defining Functions on Dual Numbers

Since we need to compute function $y(x) = \ln (x^{2} + \sin x)$, let’s define the logarithm and sine functions for dual numbers.

Base.sin(x::Dual) = Dual(sin(x.v), x.d * cos(x.v) )
Base.log(x::Dual) = Dual(log(x.v), x.d / x.v)

Forward Mode Calculation

Now, by calculating as follows, we can obtain both the function value and the derivative simultaneously.

julia> y₁ = x*x
Dual(9.0, 6.0)

julia> y₂ = sin(x)
Dual(0.1411200080598672, -0.9899924966004454)

julia> y₃ = y₁ + y₂
Dual(9.141120008059866, 5.010007503399555)

julia> y₄ = log(y₃)
Dual(2.2127829171337874, 0.548073704204972)

julia> log(3^2 + sin(3))
2.2127829171337874

julia> (2*3 + cos(3)) / (3^2 + sin(3))
0.548073704204972

Full Code

# 이원수 구조체 정의
struct Dual 
    v::Float64 # (function) value
    d::Float64 # derivative
end

x = Dual(3, 1)
y = Dual(2, 0)

# 이원수의 덧셈 정의
x + y

methods(+)

Base.:+(x::Dual, y::Dual) = Dual(x.v + y.v, x.d + y.d)
methods(+)

x + y

# 이원수의 곱셈 정의
Base.:*(x::Dual, y::Dual) = Dual(x.v*y.v, x.v*y.d + x.d*y.v)
x * y

# 이원수 위의 함수 정의
Base.sin(x::Dual) = Dual(sin(x.v), x.d * cos(x.v) )
Base.log(x::Dual) = Dual(log(x.v), x.d / x.v)

# 전진모드 계산
y₁ = x*x
y₂ = sin(x)
y₃ = y₁ + y₂
y₄ = log(y₃)

log(3^2 + sin(3))
(2*3 + cos(3)) / (3^2 + sin(3))

Environment

  • OS: Windows11
  • Version: Julia 1.7.1

See Also


  1. Mykel J. Kochenderfer, Algorithms for Optimization (2019), p27-32 ↩︎