Example 5.10#
Metropolis Hastings on the banana example in Example 5.10.
import numpy as np
import matplotlib.pyplot as plt
rng = np.random.default_rng(24)
# banana function for testing MCMC
def log_banana(x, y):
return -x**2 / 10 - y**2 / 10 - 2 * (y - x**2)**2
N = 1000000
samples_RW = np.zeros((2, N))
# initial values
x = 0
y = 0
samples_RW[:, 0] = np.array([x, y])
# parameters
gamma = 0.005
sigma_rw = 0.5
burnin = 200
for n in range(1, N):
# random walk
x_s = samples_RW[:, n-1] + sigma_rw * np.random.randn(2)
# metropolis
u = rng.uniform(0, 1)
if np.log(u) < log_banana(x_s[0], x_s[1]) - log_banana(samples_RW[0, n-1], samples_RW[1, n-1]):
samples_RW[:, n] = x_s
else:
samples_RW[:, n] = samples_RW[:, n-1]
# for surf plot banana 2d
x_bb = np.linspace(-4, 4, 100)
y_bb = np.linspace(-2, 6, 100)
X_bb, Y_bb = np.meshgrid(x_bb, y_bb)
Z_bb = np.exp(log_banana(X_bb, Y_bb))
plt.figure(figsize=(15, 5))
# make fonts bigger
plt.rcParams.update({'font.size': 20})
plt.subplot(1, 2, 1)
cnt = plt.contourf(X_bb, Y_bb, Z_bb, 100, cmap='RdBu')
for c in cnt.collections:
c.set_edgecolor("face")
plt.title('Target Distribution')
plt.subplot(1, 2, 2)
plt.hist2d(samples_RW[0, burnin:n], samples_RW[1, burnin:n], 100, cmap='RdBu', range=[[-4, 4], [-2, 6]], density=True)
# remove the white edges from plt.hist2d
for c in cnt.collections:
c.set_edgecolor("face")
plt.title('Random Walk Metropolis')
plt.show()
<ipython-input-1-dfcda187d96a>:47: MatplotlibDeprecationWarning: The collections attribute was deprecated in Matplotlib 3.8 and will be removed two minor releases later.
for c in cnt.collections:
<ipython-input-1-dfcda187d96a>:54: MatplotlibDeprecationWarning: The collections attribute was deprecated in Matplotlib 3.8 and will be removed two minor releases later.
for c in cnt.collections: