logo

기각 샘플링 📂머신러닝

기각 샘플링

개요 1

기각 샘플링은 몬테카를로 방법 중 하나로, 주어진 분포 pp로 샘플링하기 어려울 때 샘플링하기 쉬운 제안분포 qq를 이용하여 pp를 따르는 샘플을 얻는 방법이다.

빌드업

확률변수 XX확률밀도함수 pp가 주어졌다고 하자. 우리는 pp의 분포로 샘플링을 하고 싶지만, 이것이 어려운 상황이라고 하자. (가령 XX의 누적분포함수의 역함수를 구하는 것이 어렵다면, 샘플링을 구현하기 힘들다.) 그리고 모든 xx에 대해서, kq(x)p(x)kq(x) \ge p(x)를 만족하는 상수 kk와 다른 분포 qq를 선택하자. 이때 qq는 균등분포나 정규분포와 같이 샘플링하기 쉬운 분포여야하며, 타겟 분포와 비슷할 수록 좋다.

3518_RS.png

이제 확률변수 XqX \sim q에 대한 샘플링으로 xx를 얻고, YU(0,kq(x))Y \sim U(0,kq(x))에 대한 샘플링으로 yy를 얻었다고 하자. 여기서 UU균등분포이다. x,yx, y에 대해서 두 값 p(x),yp(x), y를 비교하여,

  • p(x)>yp(x) \gt y이면, xx채택accept하여 샘플에 추가한다.
  • p(x)yp(x) \le y이면, xx기각reject하여 샘플에 추가하지 않는다.

정의

XX를 확률밀도함수가 pp인 확률변수라고 하자. 다음을 만족하는 또 다른 확률밀도함수 qq와 상수 kk를 선택하자.

p(x)kq(x),x p(x) \le k q(x),\quad \forall x

기각 샘플링rejection sampling이란 샘플 {xi}\left\{ x_{i} \right\}를 다음과 같이 추출하는 방법을 말한다.

{xip(xi)kq(xi)>yi,Xiq,YiU(0,kq(xi))} \left\{ x_{i} \Big| \dfrac{p(x_{i})}{kq(x_{i})} \gt y_{i},\quad X_{i} \sim q, Y_{i} \sim U(0,kq(x_{i})) \right\}

여기서 pp타겟 분포target distribution, qq제안 분포proposal distribution라고 한다.

설명

3518_RS.png

이 방법은 p(x)p(x)는 구체적으로 알고 있지만, 이로부터 샘플링을 하기 어려울 때 사용할 수 있다. 기각 샘플링이란, 풀어서 말하자면, kq(x)p(x)kq(x) \ge p(x)를 만족하며 샘플링하기 쉬운 qq를 선택하여, qq로부터 xix_{i}를 샘플링하고 kq(xi)kq(x_{i})p(xi)p(x_{i})의 차이(비율)에 따라 xix_{i}를 확률적으로 기각하는 것이다. 위의 그림을 보면 타겟 pp와 제안 qq의 차이가 클 수록 기각될 확률이 높다는 것을 알 수 있다. 그러니까 원래 xix_{i}가 추출될 확률을 의미했던 p(xi)p(x_{i})에 약간의 수식적인 조작을 통해 xix_{i}가 채택될 확률 p(xi)kq(xi)\frac{p(x_{i})}{kq(x_{i})}로 바꾼 것이다. 혹은 분포를 pp에서 kqkq로 바꿨으니 그 차이 만큼 샘플링될 확률에 패널티를 준다고 이해할 수 있다.

제안 분포 qq와 상수 kk를 선택했을 때 샘플의 채택 비율은 다음과 같이 계산된다. acceptance rate=p(x)kq(x)q(x)dx=1kp(x)dx=1k \text{acceptance rate} = \int \frac{p(x)}{kq(x)} q(x) dx = \dfrac{1}{k}\int p(x)dx = \dfrac{1}{k} 따라서 충분한 양의 샘플을 모으는데 걸리는 시간은 kk에 의존하고, 이를 최적화하기 위해선 p(x)kq(x)p(x) \le k q(x)를 만족하는 가장 작은 kk를 선택해야 한다.

기각 샘플링을 줄리아로 구현하여 산점도와 히스토그램을 그려보면 다음과 같다.

3518_RS.png

3518_RS.png

코드

using Distributions
using Plots
using LaTeXStrings

N = 20000
target(x) = 0.5*pdf.(Normal(-3, 0.8), x) + pdf.(Normal(3, 3), x)
proposal = Normal(0, 5)
samples = Float64[]ㅌ

accepted = Array{Float64}[]
rejected = Array{Float64}[]
for i ∈ 1:N
    x = rand(proposal)
    y = rand(Uniform(0,4.2pdf(proposal, x)))
    if y < target(x)
        push!(samples, x)
        push!(accepted, [x, y])
    else
        push!(rejected, [x, y])
    end
end
accepted = hcat(accepted...)
rejected = hcat(rejected...)

scatter(accepted[1,:], accepted[2,:], label="accepted", color=:blue, markerstrokewidth=0, markersize=2, dpi=300, size=(728,300), legend=:outertopright)
scatter!(rejected[1,:], rejected[2,:], label="rejected", color=:red, markerstrokewidth=0, markersize=2)
savefig("RS3.png")

x = range(-15, 15, length=1000)
plot(x, target(x), label="target "*L"p(x)", framestyle=:none)
histogram!(samples, label="accepted samples", color=:blue, alpha=0.5, bins=100, normed=true, dpi=300, size=(728,300), legend=:outertopright)
savefig("RS4.png")

환경

  • OS: Windows11
  • Version: Julia 1.8.3, Plots v1.38.6, Distributions v0.25.80, LaTeXStrings v1.3.0

  1. Christoper M. Bishop, Pattern Recognition annd Machine Learning (2006), p528-531 ↩︎