Example 2.1#

In what follows, we will implement inversion methods for sampling discrete and continuous distributions.

Discrete distributions#

Example 2.1. We think of our discrete distribution very practically as a vector defined on a set of states. Let us define a probability mass function on the set \(\mathsf{S} = \{1, 2, 3, 4, 5\}\) as follows:

\[\begin{align*} \mathsf{w} = \begin{bmatrix} 0.2 & 0.3 & 0.2 & 0.1 & 0.2 \end{bmatrix} \end{align*}\]

Let us plot the PMF and CDF of this distribution.

import numpy as np
import matplotlib.pyplot as plt

w = np.array([0.2, 0.3, 0.2, 0.1, 0.2])
s = np.array([1, 2, 3, 4, 5])

def discrete_cdf(w):
    return np.cumsum(w)

cw = discrete_cdf(w)

def plot_discrete_cdf(w, cw):
    fig, ax = plt.subplots(1, 2, figsize=(20, 5))
    ax[0].stem(s, w)
    ax[1].plot(s, cw, 'o-', drawstyle='steps-post')
    plt.show()

plot_discrete_cdf(w, cw)
../_images/0fc9023e2b1a0e9d66cd472e74cff1cf240ede7f37d8df5335ddbd7ed8fcd68a.png

Next, we can implement the sampling method using the inverse. Let us implement the method first, then look into some animations.

import numpy as np
import matplotlib.pyplot as plt

rng = np.random.default_rng(4)

w = np.array([0.2, 0.3, 0.2, 0.1, 0.2])
s = np.array([1, 2, 3, 4, 5])

def discrete_cdf(w):
    return np.cumsum(w)

def sample(u, s, w): # discrete sampler for a uniform random variable u, states s and weights w
    cdf = discrete_cdf(w)
    sample_ind = np.argmax(cdf > u)
    return s[sample_ind]

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
n = 2000
x = []
cw = discrete_cdf(w)
un = []

for n in range(n):
    u = rng.uniform(0, 1)
    un.append(u)
    sample_x = sample(u, s, w)
    x.append(sample_x)

    # the rest is for animation

# plot the pmf using the stem function and change the color of the markers and the line
markerline, stemlines, baseline = ax[0].stem(s, w, markerfmt='o', linefmt='r-')
ax[0].cla()
ax[0].hist(un, bins=10, density=True, color='k', alpha=1)
ax[0].set_title("Histogram of Uniform Random Variables")

ax[1].cla()
ax[1].stem(s, w, markerfmt='o', linefmt='r-')
ax[1].set_xlabel("s")
# plot a histogram centered on states s
ax[1].hist(x, bins=range(7), density=True, color='k', alpha=1, align='mid', width=0.1)
ax[1].set_xlim([0, 6])
ax[1].set_title("Histogram of Discrete Random Variables and PMF")
Text(0.5, 1.0, 'Histogram of Discrete Random Variables and PMF')
../_images/ba7f597c5ae578f3a13e2fdeaed8a3df43a4d39fb09293c5e37be3f179ac91ad.png

As we can see, with \(n = 2000\), the method samples from the correct distribution. Let us see below this process animated (the code is hidden as it has a lot of diversions from normal code for the sake of animation – but feel free to expand if you are curious!).

Animated discrete sampling#

It is important to visualise these processes to gain intuition. We will first below animate the discrete case, then the continuous case.

Hide code cell source
%matplotlib notebook

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display

rng = np.random.default_rng(4)

w = np.array([0.2, 0.3, 0.2, 0.1, 0.2])
s = np.array([1, 2, 3, 4, 5])

def discrete_cdf(w):
    return np.cumsum(w)

def sample(u, s, w): # discrete sampler for a uniform random variable u, states s and weights w
    cdf = discrete_cdf(w)
    sample_ind = np.argmax(cdf > u)
    return s[sample_ind]

fig, ax = plt.subplots(2, 2, figsize=(10, 8))
n = 200
x = []
cw = discrete_cdf(w)
un = []

def update(i):
    global un, x
    u = rng.uniform(0, 1)
    un.append(u)
    sample_x = sample(u, s, w)
    x.append(sample_x)

    # the rest is for animation

    if i % 1 == 0:
        ax[0, 0].cla()
        # plot the pmf using the stem function and change the color of the markers and the line
        markerline, stemlines, baseline = ax[0, 0].stem(s, w, markerfmt='o', linefmt='r-')
        ax[0, 0].set_title("PMF")
        ax[0, 0].set_xlabel("s")
        # plot u in the y axis of ax[0, 1]
        ax[0, 1].cla()
        ax[0, 1].plot(s, cw, 'ro-', drawstyle='steps-post')
        ax[0, 1].set_title("Cumulative Distribution Function")
        ax[0, 1].set_xlabel("s")
        ax[0, 1].set_xlim([0, 6])
        ax[0, 1].set_ylim([0, 1])
        ax[0, 1].plot(0, u, c='k', marker='o', linestyle='none', markersize=10)
        ax[0, 1].plot(sample_x, u, c='k', marker='o', linestyle='none', markersize=10)
        ax[0, 1].plot([0, sample_x], [u, u], c=[0.8, 0, 0], linestyle='--')
        ax[0, 1].plot(sample_x, 0, c='k', marker='o', linestyle='none', markersize=10)

        ax[1, 0].cla()
        ax[1, 0].hist(un, bins=10, density=True, color='k', alpha=1)
        ax[1, 0].set_title("Histogram of Uniform Random Variables")

        ax[1, 1].cla()
        ax[1, 1].stem(s, w, markerfmt='o', linefmt='r-')
        ax[1, 1].set_xlabel("s")
        # plot a histogram centered on states s
        ax[1, 1].hist(x, bins=range(7), density=True, color='k', alpha=1, align='mid', width=0.1)
        ax[1, 1].set_xlim([0, 6])
        ax[1, 1].set_title("Histogram of Discrete Random Variables and PMF")


ani = FuncAnimation(fig, update, frames=n, repeat=False)
HTML(ani.to_jshtml())
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/ipykernel/pylab/backend_inline.py:49, in show(close, block)
     46 # only call close('all') if any to close
     47 # close triggers gc.collect, which can be slow
     48 if close and Gcf.get_all_fig_managers():
---> 49     matplotlib.pyplot.close('all')

File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/matplotlib/pyplot.py:1070, in close(fig)
   1068         _pylab_helpers.Gcf.destroy(manager)
   1069 elif fig == 'all':
-> 1070     _pylab_helpers.Gcf.destroy_all()
   1071 elif isinstance(fig, int):
   1072     _pylab_helpers.Gcf.destroy(fig)

File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/matplotlib/_pylab_helpers.py:82, in Gcf.destroy_all(cls)
     80 for manager in list(cls.figs.values()):
     81     manager.canvas.mpl_disconnect(manager._cidgcf)
---> 82     manager.destroy()
     83 cls.figs.clear()

File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/matplotlib/backends/backend_nbagg.py:144, in FigureManagerNbAgg.destroy(self)
    142 for comm in list(self.web_sockets):
    143     comm.on_close()
--> 144 self.clearup_closed()

File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/matplotlib/backends/backend_nbagg.py:152, in FigureManagerNbAgg.clearup_closed(self)
    148 self.web_sockets = {socket for socket in self.web_sockets
    149                     if socket.is_open()}
    151 if len(self.web_sockets) == 0:
--> 152     CloseEvent("close_event", self.canvas)._process()

File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/matplotlib/backend_bases.py:1271, in Event._process(self)
   1269 def _process(self):
   1270     """Process this event on ``self.canvas``, then unset ``guiEvent``."""
-> 1271     self.canvas.callbacks.process(self.name, self)
   1272     self._guiEvent_deleted = True

File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/matplotlib/cbook.py:303, in CallbackRegistry.process(self, s, *args, **kwargs)
    301 except Exception as exc:
    302     if self.exception_handler is not None:
--> 303         self.exception_handler(exc)
    304     else:
    305         raise

File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/matplotlib/cbook.py:87, in _exception_printer(exc)
     85 def _exception_printer(exc):
     86     if _get_running_interactive_framework() in ["headless", None]:
---> 87         raise exc
     88     else:
     89         traceback.print_exc()

File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/matplotlib/cbook.py:298, in CallbackRegistry.process(self, s, *args, **kwargs)
    296 if func is not None:
    297     try:
--> 298         func(*args, **kwargs)
    299     # this does not capture KeyboardInterrupt, SystemExit,
    300     # and GeneratorExit
    301     except Exception as exc:

File ~/miniconda3/envs/StochSim/lib/python3.9/site-packages/matplotlib/animation.py:924, in Animation._stop(self, *args)
    922     self._fig.canvas.mpl_disconnect(self._resize_id)
    923 self._fig.canvas.mpl_disconnect(self._close_id)
--> 924 self.event_source.remove_callback(self._step)
    925 self.event_source = None

AttributeError: 'NoneType' object has no attribute 'remove_callback'

One can see above the animation of uniform sample being sampled and going through the inverse of the CDF.