logo

棄却サンプリング 📂機械学習

棄却サンプリング

概要 1

棄却サンプリングは、与えられた分布 $p$からサンプリングするのが難しいときに、サンプリングしやすい提案分布 $q$を利用して $p$に従うサンプルを得る、モンテカルロ法の一つだ。

ビルドアップ

確率変数 $X$の確率密度関数 $p$が与えられたとする。我々は$p$の分布からサンプリングしたいが、それが難しい場面だとする。(例えば、$X$の累積分布関数の逆関数を求めるのが難しいならば、サンプリングを実装するのは難しい。) そして、全ての $x$に対して、$kq(x) \ge p(x)$を満たす定数$k$と別の分布$q$を選び出そう。この時、$q$は一様分布や正規分布のようにサンプリングしやすい分布であるべきであり、ターゲット分布に似ているほど良い。

3518_RS.png

今、確率変数$X \sim q$によるサンプリングで$x$を、$Y \sim U(0,kq(x))$によるサンプリングで$y$を得たとしよう。ここで、$U$は一様分布だ。$x, y$において2つの値$p(x), y$を比較し、

  • $p(x) \gt y$ならば、$x$を採用acceptしてサンプルに加える。
  • $p(x) \le y$ならば、$x$を棄却rejectしてサンプルに加えない。

定義

$X$を確率密度関数が$p$である確率変数とする。次を満たす別の確率密度関数$q$と定数$k$を選ぼう。

$$ p(x) \le k q(x),\quad \forall x $$

棄却サンプリングは、サンプル$\left\{ x_{i} \right\}$を次のように抽出する方法である。

$$ \left\{ x_{i} \Big| \dfrac{p(x_{i})}{kq(x_{i})} \gt y_{i},\quad X_{i} \sim q, Y_{i} \sim U(0,kq(x_{i})) \right\} $$

ここで、$p$をターゲット分布target distribution、$q$を提案分布proposal distributionと呼ぶ。

説明

3518_RS.png

この方法は、$p(x)$について具体的に知っているが、それからサンプリングするのが難しいときに使用できる。要するに、棄却サンプリングとは、$kq(x) \ge p(x)$を満たしながらサンプリングしやすい$q$を選び、$q$から$x_{i}$をサンプリングし、$kq(x_{i})$と$p(x_{i})$の差(比率)によって$x_{i}$を確率的に棄却することを意味する。上の図を見れば、ターゲット$p$と提案$q$の差が大きければ大きいほど、棄却される確率が高いことがわかる。つまり、本来$p(x_{i})$で表されていた$x_{i}$が抽出される確率を、少しの数式の操作を通じて$x_{i}$が採用される確率$\frac{p(x_{i})}{kq(x_{i})}$に変えたのである。または、分布が$p$から$kq$に変わったので、その差分ほどサンプリングされる確率にペナルティを与えると理解もできる。

提案分布 $q$と定数$k$を選んだときのサンプルの採用率は次のように計算される。 $$ \text{acceptance rate} = \int \frac{p(x)}{kq(x)} q(x) dx = \dfrac{1}{k}\int p(x)dx = \dfrac{1}{k} $$ したがって、十分な量のサンプルを集めるのにかかる時間は、$k$に依存し、これを最適化するためには$p(x) \le k q(x)$を満たす最も小さい$k$を選ばなければならない。

ジュリアで棄却サンプリングを実装し、散布図とヒストグラムを描けば、次のようになる。

3518_RS.png

3518_RS.png

コード

using Distributions
using Plots
using LaTeXStrings

N = 20000
target(x) = 0.5*pdf.(Normal(-3, 0.8), x) + pdf.(Normal(3, 3), x)
proposal = Normal(0, 5)
samples = Float64[]ㅌ

accepted = Array{Float64}[]
rejected = Array{Float64}[]
for i ∈ 1:N
    x = rand(proposal)
    y = rand(Uniform(0,4.2pdf(proposal, x)))
    if y < target(x)
        push!(samples, x)
        push!(accepted, [x, y])
    else
        push!(rejected, [x, y])
    end
end
accepted = hcat(accepted...)
rejected = hcat(rejected...)

scatter(accepted[1,:], accepted[2,:], label="accepted", color=:blue, markerstrokewidth=0, markersize=2, dpi=300, size=(728,300), legend=:outertopright)
scatter!(rejected[1,:], rejected[2,:], label="rejected", color=:red, markerstrokewidth=0, markersize=2)
savefig("RS3.png")

x = range(-15, 15, length=1000)
plot(x, target(x), label="target "*L"p(x)", framestyle=:none)
histogram!(samples, label="accepted samples", color=:blue, alpha=0.5, bins=100, normed=true, dpi=300, size=(728,300), legend=:outertopright)
savefig("RS4.png")

環境

  • OS: Windows11
  • Version: Julia 1.8.3, Plots v1.38.6, Distributions v0.25.80, LaTeXStrings v1.3.0

  1. Christoper M. Bishop, Pattern Recognition annd Machine Learning (2006), p528-531 ↩︎