logo

기각 샘플링 📂머신러닝

기각 샘플링

개요 1

기각 샘플링은 몬테카를로 방법 중 하나로, 주어진 분포 $p$로 샘플링하기 어려울 때 샘플링하기 쉬운 제안분포 $q$를 이용하여 $p$를 따르는 샘플을 얻는 방법이다.

빌드업

확률변수 $X$의 확률밀도함수 $p$가 주어졌다고 하자. 우리는 $p$의 분포로 샘플링을 하고 싶지만, 이것이 어려운 상황이라고 하자. (가령 $X$의 누적분포함수의 역함수를 구하는 것이 어렵다면, 샘플링을 구현하기 힘들다.) 그리고 모든 $x$에 대해서, $kq(x) \ge p(x)$를 만족하는 상수 $k$와 다른 분포 $q$를 선택하자. 이때 $q$는 균등분포나 정규분포와 같이 샘플링하기 쉬운 분포여야하며, 타겟 분포와 비슷할 수록 좋다.

3518_RS.png

이제 확률변수 $X \sim q$에 대한 샘플링으로 $x$를 얻고, $Y \sim U(0,kq(x))$에 대한 샘플링으로 $y$를 얻었다고 하자. 여기서 $U$는 균등분포이다. $x, y$에 대해서 두 값 $p(x), y$를 비교하여,

  • $p(x) \gt y$이면, $x$를 채택accept하여 샘플에 추가한다.
  • $p(x) \le y$이면, $x$를 기각reject하여 샘플에 추가하지 않는다.

정의

$X$를 확률밀도함수가 $p$인 확률변수라고 하자. 다음을 만족하는 또 다른 확률밀도함수 $q$와 상수 $k$를 선택하자.

$$ p(x) \le k q(x),\quad \forall x $$

기각 샘플링rejection sampling이란 샘플 $\left\{ x_{i} \right\}$를 다음과 같이 추출하는 방법을 말한다.

$$ \left\{ x_{i} \Big| \dfrac{p(x_{i})}{kq(x_{i})} \gt y_{i},\quad X_{i} \sim q, Y_{i} \sim U(0,kq(x_{i})) \right\} $$

여기서 $p$를 타겟 분포target distribution, $q$를 제안 분포proposal distribution라고 한다.

설명

3518_RS.png

이 방법은 $p(x)$는 구체적으로 알고 있지만, 이로부터 샘플링을 하기 어려울 때 사용할 수 있다. 기각 샘플링이란, 풀어서 말하자면, $kq(x) \ge p(x)$를 만족하며 샘플링하기 쉬운 $q$를 선택하여, $q$로부터 $x_{i}$를 샘플링하고 $kq(x_{i})$와 $p(x_{i})$의 차이(비율)에 따라 $x_{i}$를 확률적으로 기각하는 것이다. 위의 그림을 보면 타겟 $p$와 제안 $q$의 차이가 클 수록 기각될 확률이 높다는 것을 알 수 있다. 그러니까 원래 $x_{i}$가 추출될 확률을 의미했던 $p(x_{i})$에 약간의 수식적인 조작을 통해 $x_{i}$가 채택될 확률 $\frac{p(x_{i})}{kq(x_{i})}$로 바꾼 것이다. 혹은 분포를 $p$에서 $kq$로 바꿨으니 그 차이 만큼 샘플링될 확률에 패널티를 준다고 이해할 수 있다.

제안 분포 $q$와 상수 $k$를 선택했을 때 샘플의 채택 비율은 다음과 같이 계산된다. $$ \text{acceptance rate} = \int \frac{p(x)}{kq(x)} q(x) dx = \dfrac{1}{k}\int p(x)dx = \dfrac{1}{k} $$ 따라서 충분한 양의 샘플을 모으는데 걸리는 시간은 $k$에 의존하고, 이를 최적화하기 위해선 $p(x) \le k q(x)$를 만족하는 가장 작은 $k$를 선택해야 한다.

기각 샘플링을 줄리아로 구현하여 산점도와 히스토그램을 그려보면 다음과 같다.

3518_RS.png

3518_RS.png

코드

using Distributions
using Plots
using LaTeXStrings

N = 20000
target(x) = 0.5*pdf.(Normal(-3, 0.8), x) + pdf.(Normal(3, 3), x)
proposal = Normal(0, 5)
samples = Float64[]ㅌ

accepted = Array{Float64}[]
rejected = Array{Float64}[]
for i ∈ 1:N
    x = rand(proposal)
    y = rand(Uniform(0,4.2pdf(proposal, x)))
    if y < target(x)
        push!(samples, x)
        push!(accepted, [x, y])
    else
        push!(rejected, [x, y])
    end
end
accepted = hcat(accepted...)
rejected = hcat(rejected...)

scatter(accepted[1,:], accepted[2,:], label="accepted", color=:blue, markerstrokewidth=0, markersize=2, dpi=300, size=(728,300), legend=:outertopright)
scatter!(rejected[1,:], rejected[2,:], label="rejected", color=:red, markerstrokewidth=0, markersize=2)
savefig("RS3.png")

x = range(-15, 15, length=1000)
plot(x, target(x), label="target "*L"p(x)", framestyle=:none)
histogram!(samples, label="accepted samples", color=:blue, alpha=0.5, bins=100, normed=true, dpi=300, size=(728,300), legend=:outertopright)
savefig("RS4.png")

환경

  • OS: Windows11
  • Version: Julia 1.8.3, Plots v1.38.6, Distributions v0.25.80, LaTeXStrings v1.3.0

  1. Christoper M. Bishop, Pattern Recognition annd Machine Learning (2006), p528-531 ↩︎