ジュリアの自動微分パッケージZygote.jl
概要
ジュリアでは、マシンラーニング、特にディープラーニングに関連した自動微分automatic DifferentiationのためにZygote.jl
というパッケージを使っている1。開発者たちは、このパッケージは次世代の自動微分システムとして、ジュリアで微分可能プログラミングdifferentiable Programmingができるようにすると宣伝していて、実際に使ってみると驚くほど直感的だと分かる。
自動微分ではなく、導関数に関連したパッケージ自体が気になるなら、Calculus.jl
パッケージを参照してほしい。
コード
単変数関数
信じられないくらい簡単だ。普段私たちが微分するのと同じように、関数名の後ろにプライム'
をつけると、まるで本当に導関数を使って計算しているように、微分係数が計算される。
julia> using Zygote
julia> p(x) = 2x^2 + 3x + 1
p (generic function with 1 method)
julia> p(2)
15
julia> p'(2)
11.0
julia> p''(2)
4.0
多変数関数
gradient()
関数を使う。
julia> g(x,y) = 3x^2 + 2y + x*y
g (generic function with 1 method)
julia> gradient(g, 2,-1)
(11.0, 4.0)
もう少し直感的にコードを書きたいなら、次のように\nabla
、すなわち∇
で再び関数を定義して試してみるのも良い。
julia> ∇(f, v...) = gradient(f, v...)
∇ (generic function with 1 method)
julia> ∇(g, 2, -1)
(11.0, 4.0)
全体のコード
using Zygote
p(x) = 2x^2 + 3x + 1
p(2)
p'(2)
p''(2)
g(x,y) = 3x^2 + 2y + x*y
gradient(g, 2,-1)
∇(f, v...) = gradient(f, v...)
∇(g, 2, -1)
環境
- OS: Windows
- julia: v1.9.0
- Zygote: v0.6.62