第19章 馬爾可夫鍊蒙特卡羅法
本文是李航老師的《統計學習方法》一書的代碼複現。作者:黃海廣
備注:代碼都可以在github中下載下傳。我将陸續将代碼釋出在公衆号“機器學習初學者”,可以在這個專輯線上閱讀。
- 蒙特卡羅法是通過基于機率模型的抽樣進行數值近似計算的方法,蒙特卡羅法可以用于機率分布的抽樣、機率分布數學期望的估計、定積分的近似計算。
随機抽樣是蒙特卡羅法的一種應用,有直接抽樣法、接受拒絕抽樣法等。接受拒絕法的基本想法是,找一個容易抽樣的建議分布,其密度函數的數倍大于等于想要抽樣的機率分布的密度函數。按照建議分布随機抽樣得到樣本,再按要抽樣的機率分布與建議分布的倍數的比例随機決定接受或拒絕該樣本,循環執行以上過程。
蒙特卡洛法(Monte Carlo method) , 也稱為統計模拟方法 (statistical simulation method) , 是通過從機率模型的随機抽樣進行近似數值計
算的方法。 馬爾可夫鍊陟特卡羅法 (Markov Chain Monte Carlo, MCMC), 則是以馬爾可夫鍊 (Markov chain)為機率模型的蒙特卡洛法。
馬爾可夫鍊蒙特卡羅法建構一個馬爾可夫鍊,使其平穩分布就是要進行抽樣的分布, 首先基于該馬爾可夫鍊進行随機遊走, 産生樣本的序列,
之後使用該平穩分布的樣本進行近似數值計算。
Metropolis-Hastings算法是最基本的馬爾可夫鍊蒙特卡羅法,Metropolis等人在 1953年提出原始的算法,Hastings在1970年對之加以推廣,
形成了現在的形式。吉布斯抽樣(Gibbs sampling)是更簡單、使用更廣泛的馬爾可夫鍊蒙特卡羅法,1984 年由S. Geman和D. Geman提出。
馬爾可夫鍊蒙特卡羅法被應用于機率分布的估計、定積分的近似計算、最優化問題的近似求解等問題,特别是被應用于統計學習中機率模型的學習
與推理,是重要的統計學習計算方法。
一般的蒙特卡羅法有直接抽樣法、接受-拒絕抽樣法、 重要性抽樣法等。
接受-拒絕抽樣法、重要性抽樣法适合于機率密度函數複雜 (如密度函數含有多個變量,各變量互相不獨立,密度函數形式複雜),不能直接抽樣的情況。
19.1.2 數學期望估計
一舣的蒙特卡羅法, 如直接抽樣法、接受·拒絕抽樣法、重要性抽樣法, 也可以用于數學期望估計 (estimation Of mathematical expectation)。
馬爾可夫鍊
平穩分布
引理19.1
吉布斯采樣
- 得到樣本集合
網絡資源:
LDA-math-MCMC 和 Gibbs Sampling: https://cosx.org/2013/01/lda-math-mcmc-and-gibbs-sampling
MCMC蒙特卡羅方法: https://www.cnblogs.com/pinard/p/6625739.html
import random
import math
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
transfer_matrix = np.array([[0.6, 0.2, 0.2], [0.3, 0.4, 0.3], [0, 0.3, 0.7]],
dtype='float32')
start_matrix = np.array([[0.5, 0.3, 0.2]], dtype='float32')
value1 = []
value2 = []
value3 = []
for i in range(30):
start_matrix = np.dot(start_matrix, transfer_matrix)
value1.append(start_matrix[0][0])
value2.append(start_matrix[0][1])
value3.append(start_matrix[0][2])
print(start_matrix)
複制
[[0.23076935 0.30769244 0.46153864]]
複制
#進行可視化
x = np.arange(30)
plt.plot(x,value1,label='cheerful')
plt.plot(x,value2,label='so-so')
plt.plot(x,value3,label='sad')
plt.legend()
plt.show()
複制
可以發現,從10輪左右開始,我們的狀态機率分布就不變了,一直保持在
[0.23076934,0.30769244,0.4615386]
參考:https://zhuanlan.zhihu.com/p/37121528
M-H采樣python實作
https://zhuanlan.zhihu.com/p/37121528
from scipy.stats import norm
def norm_dist_prob(theta):
y = norm.pdf(theta, loc=3, scale=2)
return y
T = 5000
pi = [0 for i in range(T)]
sigma = 1
t = 0
while t < T - 1:
t = t + 1
pi_star = norm.rvs(loc=pi[t - 1], scale=sigma, size=1,
random_state=None) #狀态轉移進行随機抽樣
alpha = min(
1, (norm_dist_prob(pi_star[0]) / norm_dist_prob(pi[t - 1]))) #alpha值
u = random.uniform(0, 1)
if u < alpha:
pi[t] = pi_star[0]
else:
pi[t] = pi[t - 1]
plt.scatter(pi, norm.pdf(pi, loc=3, scale=2), label='Target Distribution')
num_bins = 50
plt.hist(pi,
num_bins,
density=1,
facecolor='red',
alpha=0.7,
label='Samples Distribution')
plt.legend()
plt.show()
複制
二維Gibbs采樣執行個體python實作
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import multivariate_normal
samplesource = multivariate_normal(mean=[5,-1], cov=[[1,0.5],[0.5,2]])
def p_ygivenx(x, m1, m2, s1, s2):
return (random.normalvariate(m2 + rho * s2 / s1 * (x - m1), math.sqrt(1 - rho ** 2) * s2))
def p_xgiveny(y, m1, m2, s1, s2):
return (random.normalvariate(m1 + rho * s1 / s2 * (y - m2), math.sqrt(1 - rho ** 2) * s1))
N = 5000
K = 20
x_res = []
y_res = []
z_res = []
m1 = 5
m2 = -1
s1 = 1
s2 = 2
rho = 0.5
y = m2
for i in range(N):
for j in range(K):
x = p_xgiveny(y, m1, m2, s1, s2) #y給定得到x的采樣
y = p_ygivenx(x, m1, m2, s1, s2) #x給定得到y的采樣
z = samplesource.pdf([x,y])
x_res.append(x)
y_res.append(y)
z_res.append(z)
num_bins = 50
plt.hist(x_res, num_bins,density=1, facecolor='green', alpha=0.5,label='x')
plt.hist(y_res, num_bins, density=1, facecolor='red', alpha=0.5,label='y')
plt.title('Histogram')
plt.legend()
plt.show()
複制
本章代碼來源:https://github.com/hktxt/Learn-Statistical-Learning-Method
下載下傳位址
https://github.com/fengdu78/lihang-code
參考資料:
[1] 《統計學習方法》: https://baike.baidu.com/item/統計學習方法/10430179
[2] 黃海廣: https://github.com/fengdu78
[3] github: https://github.com/fengdu78/lihang-code
[4] wzyonggege: https://github.com/wzyonggege/statistical-learning-method
[5] WenDesi: https://github.com/WenDesi/lihang_book_algorithm
[6] 火燙火燙的: https://blog.csdn.net/tudaodiaozhale
[7] hktxt: https://github.com/hktxt/Learn-Statistical-Learning-Method