logo

줄리아에서 이원수를 이용하여 자동미분 전진모드 구현하기 📂機械学習

줄리아에서 이원수를 이용하여 자동미분 전진모드 구현하기

概要

自動微分の前進モードは二元数を利用すれば簡単に実装できる。ジュリアで前進モードを実装する方法を説明する。二元数と自動微分に関する背景知識のために以下の記事をお勧めする。

コード 1

関数$y(x) = \ln (x^{2} + \sin x)$に対する自動微分を計算する例である。

二元数の構造体定義

まず、コンピュータに二元数が何かを知らせよう。ジュリアでは二元数を以下のように構造体structとして定義できる。自動微分での応用を念頭に置き、最初の成分vを関数値valueとし、第二の成分d微分係数derivativeとしよう。

struct Dual 
    v::Float64 # (function) value
    d::Float64 # derivative
end

今度は二元数$x = (3, 1)$と$y = (2, 0)$を定義することができる。

julia> x = Dual(3, 1)
Dual(3.0, 1.0)

julia> x.v, x.d
(3.0, 1.0)

julia> y = Dual(2, 0)
Dual(2.0, 0.0)

二元数の加算定義

今度はxyを加えてみよう。これによりエラーが発生する。ジュリアに基本的に定義されている二項演算 +にはDualDualについて定義されていないためである。

julia> x + y
ERROR: MethodError: no method matching +(::Dual, ::Dual)

現在、加算+に対するメソッドは189個定義されている。

julia> methods(+)
# 189 methods for generic function "+" from Base:

以下のようにDualDualに対する+を定義すると、メソッドが一つ増加したのが分かる。

julia> Base.:+(x::Dual, y::Dual) = Dual(x.v + y.v, x.d + y.d)

julia> methods(+)
# 190 methods for generic function "+" from Base:

今度はx + yを計算することができる。

julia> x + y
Dual(5.0, 1.0)

二元数の乗算定義

加算と同じ方法で以下のように乗算を定義して計算できる。

julia> Base.:*(x::Dual, y::Dual) = Dual(x.v*y.v, x.v*y.d + x.d*y.v)

julia> x * y
Dual(6.0, 2.0)

二元数上の関数定義

関数$y(x) = \ln (x^{2} + \sin x)$を計算しなければならないので、対数関数とサイン関数を二元数に対して定義しよう。

Base.sin(x::Dual) = Dual(sin(x.v), x.d * cos(x.v) )
Base.log(x::Dual) = Dual(log(x.v), x.d / x.v)

前進モード計算

次のように計算すれば関数値と微分係数を同時に得ることができる。

julia> y₁ = x*x
Dual(9.0, 6.0)

julia> y₂ = sin(x)
Dual(0.1411200080598672, -0.9899924966004454)

julia> y₃ = y₁ + y₂
Dual(9.141120008059866, 5.010007503399555)

julia> y₄ = log(y₃)
Dual(2.2127829171337874, 0.548073704204972)

julia> log(3^2 + sin(3))
2.2127829171337874

julia> (2*3 + cos(3)) / (3^2 + sin(3))
0.548073704204972

コード全文

# 이원수 구조체 정의
struct Dual 
    v::Float64 # (function) value
    d::Float64 # derivative
end

x = Dual(3, 1)
y = Dual(2, 0)

# 이원수의 덧셈 정의
x + y

methods(+)

Base.:+(x::Dual, y::Dual) = Dual(x.v + y.v, x.d + y.d)
methods(+)

x + y

# 이원수의 곱셈 정의
Base.:*(x::Dual, y::Dual) = Dual(x.v*y.v, x.v*y.d + x.d*y.v)
x * y

# 이원수 위의 함수 정의
Base.sin(x::Dual) = Dual(sin(x.v), x.d * cos(x.v) )
Base.log(x::Dual) = Dual(log(x.v), x.d / x.v)

# 전진모드 계산
y₁ = x*x
y₂ = sin(x)
y₃ = y₁ + y₂
y₄ = log(y₃)

log(3^2 + sin(3))
(2*3 + cos(3)) / (3^2 + sin(3))

環境

  • OS: Windows11
  • Version: Julia 1.7.1

参考資料


  1. Mykel J. Kochenderfer, Algorithms for Optimization (2019), p27-32 ↩︎