logo

棄却サンプリング 📂機械学習

棄却サンプリング

概要 1

棄却サンプリングは、与えられた分布 ppからサンプリングするのが難しいときに、サンプリングしやすい提案分布 qqを利用して ppに従うサンプルを得る、モンテカルロ法の一つだ。

ビルドアップ

確率変数 XX確率密度関数 ppが与えられたとする。我々はppの分布からサンプリングしたいが、それが難しい場面だとする。(例えば、XXの累積分布関数の逆関数を求めるのが難しいならば、サンプリングを実装するのは難しい。) そして、全ての xxに対して、kq(x)p(x)kq(x) \ge p(x)を満たす定数kkと別の分布qqを選び出そう。この時、qqは一様分布や正規分布のようにサンプリングしやすい分布であるべきであり、ターゲット分布に似ているほど良い。

3518_RS.png

今、確率変数XqX \sim qによるサンプリングでxxを、YU(0,kq(x))Y \sim U(0,kq(x))によるサンプリングでyyを得たとしよう。ここで、UU一様分布だ。x,yx, yにおいて2つの値p(x),yp(x), yを比較し、

  • p(x)>yp(x) \gt yならば、xx採用acceptしてサンプルに加える。
  • p(x)yp(x) \le yならば、xx棄却rejectしてサンプルに加えない。

定義

XXを確率密度関数がppである確率変数とする。次を満たす別の確率密度関数qqと定数kkを選ぼう。

p(x)kq(x),x p(x) \le k q(x),\quad \forall x

棄却サンプリングは、サンプル{xi}\left\{ x_{i} \right\}を次のように抽出する方法である。

{xip(xi)kq(xi)>yi,Xiq,YiU(0,kq(xi))} \left\{ x_{i} \Big| \dfrac{p(x_{i})}{kq(x_{i})} \gt y_{i},\quad X_{i} \sim q, Y_{i} \sim U(0,kq(x_{i})) \right\}

ここで、ppターゲット分布target distributionqq提案分布proposal distributionと呼ぶ。

説明

3518_RS.png

この方法は、p(x)p(x)について具体的に知っているが、それからサンプリングするのが難しいときに使用できる。要するに、棄却サンプリングとは、kq(x)p(x)kq(x) \ge p(x)を満たしながらサンプリングしやすいqqを選び、qqからxix_{i}をサンプリングし、kq(xi)kq(x_{i})p(xi)p(x_{i})の差(比率)によってxix_{i}を確率的に棄却することを意味する。上の図を見れば、ターゲットppと提案qqの差が大きければ大きいほど、棄却される確率が高いことがわかる。つまり、本来p(xi)p(x_{i})で表されていたxix_{i}が抽出される確率を、少しの数式の操作を通じてxix_{i}が採用される確率p(xi)kq(xi)\frac{p(x_{i})}{kq(x_{i})}に変えたのである。または、分布がppからkqkqに変わったので、その差分ほどサンプリングされる確率にペナルティを与えると理解もできる。

提案分布 qqと定数kkを選んだときのサンプルの採用率は次のように計算される。 acceptance rate=p(x)kq(x)q(x)dx=1kp(x)dx=1k \text{acceptance rate} = \int \frac{p(x)}{kq(x)} q(x) dx = \dfrac{1}{k}\int p(x)dx = \dfrac{1}{k} したがって、十分な量のサンプルを集めるのにかかる時間は、kkに依存し、これを最適化するためにはp(x)kq(x)p(x) \le k q(x)を満たす最も小さいkkを選ばなければならない。

ジュリアで棄却サンプリングを実装し、散布図とヒストグラムを描けば、次のようになる。

3518_RS.png

3518_RS.png

コード

using Distributions
using Plots
using LaTeXStrings

N = 20000
target(x) = 0.5*pdf.(Normal(-3, 0.8), x) + pdf.(Normal(3, 3), x)
proposal = Normal(0, 5)
samples = Float64[]ㅌ

accepted = Array{Float64}[]
rejected = Array{Float64}[]
for i ∈ 1:N
    x = rand(proposal)
    y = rand(Uniform(0,4.2pdf(proposal, x)))
    if y < target(x)
        push!(samples, x)
        push!(accepted, [x, y])
    else
        push!(rejected, [x, y])
    end
end
accepted = hcat(accepted...)
rejected = hcat(rejected...)

scatter(accepted[1,:], accepted[2,:], label="accepted", color=:blue, markerstrokewidth=0, markersize=2, dpi=300, size=(728,300), legend=:outertopright)
scatter!(rejected[1,:], rejected[2,:], label="rejected", color=:red, markerstrokewidth=0, markersize=2)
savefig("RS3.png")

x = range(-15, 15, length=1000)
plot(x, target(x), label="target "*L"p(x)", framestyle=:none)
histogram!(samples, label="accepted samples", color=:blue, alpha=0.5, bins=100, normed=true, dpi=300, size=(728,300), legend=:outertopright)
savefig("RS4.png")

環境

  • OS: Windows11
  • Version: Julia 1.8.3, Plots v1.38.6, Distributions v0.25.80, LaTeXStrings v1.3.0

  1. Christoper M. Bishop, Pattern Recognition annd Machine Learning (2006), p528-531 ↩︎