logo

Rejection Sampling 📂Machine Learning

Rejection Sampling

Overview 1

Rejection sampling is one of the Monte Carlo methods, where a proposal distribution qq, easy to sample from, is used to obtain samples following a given distribution pp, especially when it’s difficult to sample from pp directly.

Build-up

Let’s assume we are given a random variable XX with a probability density function pp. We want to sample from the distribution of pp, but it’s challenging. (For instance, if it’s difficult to find the inverse function of the cumulative distribution function of XX, implementing sampling is challenging.) And for all xx, choose another distribution qq and a constant kk that satisfies kq(x)p(x)kq(x) \ge p(x). Here, qq should be an easily samplable distribution like uniform or normal distribution and resemble the target distribution as much as possible.

3518_RS.png

Now, suppose we’ve obtained xx through sampling the random variable XqX \sim q and yy through sampling for YU(0,kq(x))Y \sim U(0,kq(x)), where UU is a uniform distribution. Upon comparing two values p(x),yp(x), y for x,yx, y,

  • If p(x)>yp(x) \gt y, then xx is accepted into the sample.
  • If p(x)yp(x) \le y, then xx is rejected from the sample.

Definition

Let XX be a random variable with the probability density function of pp. Choose another probability density function qq and a constant kk that satisfy the following.

p(x)kq(x),x p(x) \le k q(x),\quad \forall x

Rejection sampling is the method of extracting sample {xi}\left\{ x_{i} \right\} as follows.

{xip(xi)kq(xi)>yi,Xiq,YiU(0,kq(xi))} \left\{ x_{i} \Big| \dfrac{p(x_{i})}{kq(x_{i})} \gt y_{i},\quad X_{i} \sim q, Y_{i} \sim U(0,kq(x_{i})) \right\}

Here, pp is called the target distribution, and qq is referred to as the proposal distribution.

Explanation

3518_RS.png

This method can be used when it is specifically known about p(x)p(x), but difficult to sample from it. In essence, rejection sampling means selecting a qq that is easy to sample from and satisfies kq(x)p(x)kq(x) \ge p(x), sampling xix_{i} from qq, and probabilistically rejecting xix_{i} based on the difference (ratio) between kq(xi)kq(x_{i}) and p(xi)p(x_{i}). The figure above shows that the bigger the difference between the target pp and the proposal qq, the higher the probability of rejection. In other words, the probability originally meant to represent p(xi)p(x_{i}) has been mathematically manipulated to represent the probability of xix_{i} being accepted, p(xi)kq(xi)\frac{p(x_{i})}{kq(x_{i})}. Or it can be understood as giving a penalty to the probability of sampling based on the difference because the distribution was changed from pp to kqkq.

When the proposal distribution qq and the constant kk are chosen, the acceptance rate of the sample can be calculated as follows. acceptance rate=p(x)kq(x)q(x)dx=1kp(x)dx=1k \text{acceptance rate} = \int \frac{p(x)}{kq(x)} q(x) dx = \dfrac{1}{k}\int p(x)dx = \dfrac{1}{k} Therefore, the time it takes to collect a sufficient amount of samples depends on kk, and for optimization, the smallest possible kk satisfying p(x)kq(x)p(x) \le k q(x) should be chosen.

Implementing rejection sampling in Julia and creating scatter plots and histograms yields the following.

3518_RS.png

3518_RS.png

Code

using Distributions
using Plots
using LaTeXStrings

N = 20000
target(x) = 0.5*pdf.(Normal(-3, 0.8), x) + pdf.(Normal(3, 3), x)
proposal = Normal(0, 5)
samples = Float64[]ㅌ

accepted = Array{Float64}[]
rejected = Array{Float64}[]
for i ∈ 1:N
    x = rand(proposal)
    y = rand(Uniform(0,4.2pdf(proposal, x)))
    if y < target(x)
        push!(samples, x)
        push!(accepted, [x, y])
    else
        push!(rejected, [x, y])
    end
end
accepted = hcat(accepted...)
rejected = hcat(rejected...)

scatter(accepted[1,:], accepted[2,:], label="accepted", color=:blue, markerstrokewidth=0, markersize=2, dpi=300, size=(728,300), legend=:outertopright)
scatter!(rejected[1,:], rejected[2,:], label="rejected", color=:red, markerstrokewidth=0, markersize=2)
savefig("RS3.png")

x = range(-15, 15, length=1000)
plot(x, target(x), label="target "*L"p(x)", framestyle=:none)
histogram!(samples, label="accepted samples", color=:blue, alpha=0.5, bins=100, normed=true, dpi=300, size=(728,300), legend=:outertopright)
savefig("RS4.png")

Environment

  • OS: Windows11
  • Version: Julia 1.8.3, Plots v1.38.6, Distributions v0.25.80, LaTeXStrings v1.3.0

  1. Christoper M. Bishop, Pattern Recognition annd Machine Learning (2006), p528-531 ↩︎