줄리아에서 이원수를 이용하여 자동미분 전진모드 구현하기
개요
자동미분의 전진모드는 이원수를 이용하면 쉽게 구현할 수 있다. 줄리아에서 전진모드를 구현하는 방법을 설명한다. 이원수와 자동미분에 대한 배경지식을 위해 아래의 글을 추천한다.
코드 1
함수 $y(x) = \ln (x^{2} + \sin x)$에 대한 자동 미분을 계산하는 예제이다.
이원수 구조체 정의
우선 컴퓨터에게 이원수가 뭔지 알려주자. 줄리아에서는 이원수를 다음과 같이 구조체struct로 정의하면 된다. 자동미분에서의 응용을 염두해 첫번째 성분 v
를 함숫값value이라 하고, 두번째 성분 d
를 미분계수derivative라 하자.
struct Dual
v::Float64 # (function) value
d::Float64 # derivative
end
이제 이원수 $x = (3, 1)$과 $y = (2, 0)$을 정의할 수 있다.
julia> x = Dual(3, 1)
Dual(3.0, 1.0)
julia> x.v, x.d
(3.0, 1.0)
julia> y = Dual(2, 0)
Dual(2.0, 0.0)
이원수의 덧셈 정의
이제 x
와 y
를 더해보자. 그러면 당연히 에러가 난다. 줄리아에 기본적으로 정의된 이항연산 +
에는 Dual
과 Dual
에 대해서 정의되어있지 않기 때문이다.
julia> x + y
ERROR: MethodError: no method matching +(::Dual, ::Dual)
현재 더하기 +
에 대한 메소드는 189개가 정의되어있다.
julia> methods(+)
# 189 methods for generic function "+" from Base:
아래와 같이 Dual
과 Dual
에 대한 +
를 정의해주면, 메소드가 하나 더 늘어난 것을 볼 수 있다.
julia> Base.:+(x::Dual, y::Dual) = Dual(x.v + y.v, x.d + y.d)
julia> methods(+)
# 190 methods for generic function "+" from Base:
이제 x + y
를 계산할 수 있다.
julia> x + y
Dual(5.0, 1.0)
이원수의 곱셈 정의
덧셈과 같은 방법으로 아래와 같이 곱셈을 정의하고 계산할 수 있다.
julia> Base.:*(x::Dual, y::Dual) = Dual(x.v*y.v, x.v*y.d + x.d*y.v)
julia> x * y
Dual(6.0, 2.0)
이원수 위의 함수 정의
함수 $y(x) = \ln (x^{2} + \sin x)$를 계산해야하므로, 로그함수와 사인함수를 이원수에 대해서 정의해주자.
Base.sin(x::Dual) = Dual(sin(x.v), x.d * cos(x.v) )
Base.log(x::Dual) = Dual(log(x.v), x.d / x.v)
전진모드 계산
이제 다음과 같이 계산하면 함숫값과 미분계수를 동시에 얻을 수 있다.
julia> y₁ = x*x
Dual(9.0, 6.0)
julia> y₂ = sin(x)
Dual(0.1411200080598672, -0.9899924966004454)
julia> y₃ = y₁ + y₂
Dual(9.141120008059866, 5.010007503399555)
julia> y₄ = log(y₃)
Dual(2.2127829171337874, 0.548073704204972)
julia> log(3^2 + sin(3))
2.2127829171337874
julia> (2*3 + cos(3)) / (3^2 + sin(3))
0.548073704204972
코드 전문
# 이원수 구조체 정의
struct Dual
v::Float64 # (function) value
d::Float64 # derivative
end
x = Dual(3, 1)
y = Dual(2, 0)
# 이원수의 덧셈 정의
x + y
methods(+)
Base.:+(x::Dual, y::Dual) = Dual(x.v + y.v, x.d + y.d)
methods(+)
x + y
# 이원수의 곱셈 정의
Base.:*(x::Dual, y::Dual) = Dual(x.v*y.v, x.v*y.d + x.d*y.v)
x * y
# 이원수 위의 함수 정의
Base.sin(x::Dual) = Dual(sin(x.v), x.d * cos(x.v) )
Base.log(x::Dual) = Dual(log(x.v), x.d / x.v)
# 전진모드 계산
y₁ = x*x
y₂ = sin(x)
y₃ = y₁ + y₂
y₄ = log(y₃)
log(3^2 + sin(3))
(2*3 + cos(3)) / (3^2 + sin(3))
환경
- OS: Windows11
- Version: Julia 1.7.1
같이보기
Mykel J. Kochenderfer, Algorithms for Optimization (2019), p27-32 ↩︎