줄리아에서 이원수를 이용하여 자동미분 전진모드 구현하기
Overview
The forward mode of automatic differentiation can be easily implemented using dual numbers. This document explains how to implement the forward mode in Julia. For background knowledge on dual numbers and automatic differentiation, the following articles are recommended:
Code 1
An example of calculating the automatic differentiation for function $y(x) = \ln (x^{2} + \sin x)$.
Defining the Dual Number Structure
First, let’s inform the computer about what a dual number is. In Julia, dual numbers can be defined as a struct as follows. Keeping in mind its application in automatic differentiation, let the first component v
be called the function value and the second component d
be called the derivative.
struct Dual
v::Float64 # (function) value
d::Float64 # derivative
end
Now, we can define dual numbers $x = (3, 1)$ and $y = (2, 0)$.
julia> x = Dual(3, 1)
Dual(3.0, 1.0)
julia> x.v, x.d
(3.0, 1.0)
julia> y = Dual(2, 0)
Dual(2.0, 0.0)
Defining Addition for Dual Numbers
Let’s try to add x
and y
. Doing so will naturally throw an error since the binary operation +
predefined in Julia is not defined for Dual
and Dual
.
julia> x + y
ERROR: MethodError: no method matching +(::Dual, ::Dual)
Currently, 189 methods are defined for the addition +
.
julia> methods(+)
# 189 methods for generic function "+" from Base:
By defining +
for Dual
and Dual
as shown below, we can see that the number of methods has increased by one.
julia> Base.:+(x::Dual, y::Dual) = Dual(x.v + y.v, x.d + y.d)
julia> methods(+)
# 190 methods for generic function "+" from Base:
We can now calculate x + y
.
julia> x + y
Dual(5.0, 1.0)
Defining Multiplication for Dual Numbers
Similarly to addition, we can define and compute multiplication as follows.
julia> Base.:*(x::Dual, y::Dual) = Dual(x.v*y.v, x.v*y.d + x.d*y.v)
julia> x * y
Dual(6.0, 2.0)
Defining Functions on Dual Numbers
Since we need to compute function $y(x) = \ln (x^{2} + \sin x)$, let’s define the logarithm and sine functions for dual numbers.
Base.sin(x::Dual) = Dual(sin(x.v), x.d * cos(x.v) )
Base.log(x::Dual) = Dual(log(x.v), x.d / x.v)
Forward Mode Calculation
Now, by calculating as follows, we can obtain both the function value and the derivative simultaneously.
julia> y₁ = x*x
Dual(9.0, 6.0)
julia> y₂ = sin(x)
Dual(0.1411200080598672, -0.9899924966004454)
julia> y₃ = y₁ + y₂
Dual(9.141120008059866, 5.010007503399555)
julia> y₄ = log(y₃)
Dual(2.2127829171337874, 0.548073704204972)
julia> log(3^2 + sin(3))
2.2127829171337874
julia> (2*3 + cos(3)) / (3^2 + sin(3))
0.548073704204972
Full Code
# 이원수 구조체 정의
struct Dual
v::Float64 # (function) value
d::Float64 # derivative
end
x = Dual(3, 1)
y = Dual(2, 0)
# 이원수의 덧셈 정의
x + y
methods(+)
Base.:+(x::Dual, y::Dual) = Dual(x.v + y.v, x.d + y.d)
methods(+)
x + y
# 이원수의 곱셈 정의
Base.:*(x::Dual, y::Dual) = Dual(x.v*y.v, x.v*y.d + x.d*y.v)
x * y
# 이원수 위의 함수 정의
Base.sin(x::Dual) = Dual(sin(x.v), x.d * cos(x.v) )
Base.log(x::Dual) = Dual(log(x.v), x.d / x.v)
# 전진모드 계산
y₁ = x*x
y₂ = sin(x)
y₃ = y₁ + y₂
y₄ = log(y₃)
log(3^2 + sin(3))
(2*3 + cos(3)) / (3^2 + sin(3))
Environment
- OS: Windows11
- Version: Julia 1.7.1
See Also
Mykel J. Kochenderfer, Algorithms for Optimization (2019), p27-32 ↩︎