天天看點

機器學習系列(三)——EM算法1、前言2、EM算法引入3、EM算法4、推導逼近5、證明收斂6、高斯混合分布7、代碼

1、前言

E M EM EM算法即期望最大化算法,是用于計算最大似然估計的疊代方法,常用在隐變量模型的參數學習中,所謂隐變量就是指資料的不完整性,也就是訓練資料并不能給出關于模型結果的全部資訊,是以隻能對模型中未知的狀态做出機率性的推測。

  • E E E步:求期望(expectation),用目前的參數來生成關于隐變量機率的期望函數
  • M M M步:求極大(maximization),尋找讓期望函數最大化的一組參數,并将這組參數應用到下一輪的期望步驟中。

    如此循環往複,算法就可以估計出隐變量的機率分布。

** E M EM EM算法雖然可以在不能直接求解方程時找到統計模型的最大似然參數,但它并不能保證收斂到全局最優。**一般來說,似然函數的最大化會涉及對所有未知參量求導,這在隐變量模型中是無法實作的。

E M EM EM算法的解決方法是将求解過程轉化為一組互鎖的方程,它們就像關聯的齒輪一樣,通過待求解參數和未知狀态變量的不斷疊代、交叉使用來求解最大似然。

2、EM算法引入

先看期刊《自然·生物技術》上的一個例子:

假定有兩枚不同的硬币 A A A和 B B B,它們的重量分布(跑一次出現正面機率) θ A \theta_A θA​和 θ B \theta_B θB​是未知的,其數值可通過抛擲後計算正反面各自出現的次數來估計。具體的估計方法是在每一輪中随機抽出一枚硬币抛擲10次,同樣的過程執行5輪,根據這50次投币的結果來計算 θ A \theta_A θA​和 θ B \theta_B θB​的最大似然估計。

機器學習系列(三)——EM算法1、前言2、EM算法引入3、EM算法4、推導逼近5、證明收斂6、高斯混合分布7、代碼

在上圖的單次試驗中,硬币 A A A被抽到3次,30次投擲中出現了24次正面;硬币 B B B被抽到2次,20次投擲中出現了9次正面。用最大似然估計可以計算出 θ ^ A = 24 / ( 24 + 6 ) = 0.8 , θ ^ B = 9 / ( 9 + 11 ) = 0.45 \hat{\theta}_A=24/(24+6)=0.8,\hat{\theta}_B=9/(9+11)=0.45 θ^A​=24/(24+6)=0.8,θ^B​=9/(9+11)=0.45

這個問題很簡單,但是如果我們隻能知道每一輪中出現的正反面結果,卻不能得知到底選取的硬币是 A A A還是 B B B,問題就沒那麼簡單了。這裡的硬币選擇就是不能直接觀測的隐變量。如果不管這個隐變量,就無法估計未知參數;要确定這一組隐變量,又得基于未知的硬币重量分布,問題進入了死胡同。

既然資料中資訊不完整,那就人為地補充完整。在這個問題中,隐藏的硬币選擇和待估計的重量分布,兩者确定一個就可以計算另一個。

由于觀測結果給出了關于重量分布的資訊,那就不放人為設定一組初始化的參數 θ ^ ( t ) = ( θ ^ A ( t ) , θ ^ B ( t ) ) \hat{\theta}^{(t)}=(\hat{\theta}_A^{(t)},\hat{\theta}_B^{(t)}) θ^(t)=(θ^A(t)​,θ^B(t)​),用這組猜測的重量分布去倒推到底每一輪使用的是哪個硬币。

計算出來的硬币選擇會被用來對原來随機産生的初始化參數進行更新。如果硬币選擇的結果是正确的,就可以利用最大似然估計計算出新的參數 θ ^ ( t + 1 ) \hat{\theta}^{(t+1)} θ^(t+1)。而更新後的參數有可以應用在觀測結果上,對硬币選擇的結果進行修正,進而形成了“批評-自我批評”的循環過程。這個過程會持續到隐變量和未知參數的取值都不再發生變化,其結果就是最終的輸出。

将思路應用到下圖的投擲結果中,就是 E M EM EM算法的雛形。

機器學習系列(三)——EM算法1、前言2、EM算法引入3、EM算法4、推導逼近5、證明收斂6、高斯混合分布7、代碼

将兩個初始的參數随機設定為 θ ^ A ( 0 ) = 0.6 , θ ^ B ( 0 ) = 0.5 \hat{\theta}_A^{(0)}=0.6,\hat{\theta}_B^{(0)}=0.5 θ^A(0)​=0.6,θ^B(0)​=0.5,在這兩個參數下出現第一輪結果,也就是5正5反的機率就可以表示為

P ( H 5 T 5 ∣ A ) = 0. 6 5 × 0. 4 5 , P ( H 5 T 5 ∣ B ) = 0. 5 10 P(H^5T^5|A)=0.6^5\times0.4^5,P(H^5T^5|B)=0.5^{10} P(H5T5∣A)=0.65×0.45,P(H5T5∣B)=0.510

對上面的兩個似然機率進行歸一化可以得出後驗機率,兩者分别是0.45和0.55,也就是下圖中的結果。這說明如果初始結果的随機參數是準确的,拿第一輪結果更可能由硬币 B B B産生(0.55>0.45)。同理也可以計算出其他四輪的結果來自不同硬币的後驗機率,結果已經在上圖中顯示。

在已知硬币的選擇時,所有正反面的結果都有明确的歸屬:要麼來自 A A A要麼來自 B B B。利用後驗機率可以直接對硬币的選擇作出判斷:1、4輪使用的是硬币 B B B,2、3、5輪使用的是硬币 A A A。

既然硬币的選擇已經确定,這時就可以使用最大似然估計,其結果和前文中的最大似然估計結果是相同的,也就是 θ ^ A ( 1 ) = 0.8 , θ ^ B ( 1 ) = 0.45 \hat{\theta}_A^{(1)}=0.8,\hat{\theta}_B^{(1)}=0.45 θ^A(1)​=0.8,θ^B(1)​=0.45。利用這組更新的參數又可以重新計算每一輪次抽取不同硬币的後驗機率。

雖然這種方法能夠實作隐變量和參數的動态更新,但它還不是真正的 E M EM EM算法,而是硬輸出的 k k k均值聚類。真正的 E M EM EM算法并不會将後驗機率最大的值賦給隐變量,而是考慮其所有可能的取值,在機率分布的架構下進行分析。

在前面的例子中,由于第一輪投擲硬币 A A A的可能性是0.45,那麼硬币 A A A對正反面出現次數的貢獻就是45%,在5次正面的結果中,來源于硬币 A A A的就是 5×0.45=2.25 次,來源于硬币 B B B的則是2.75次。同理可以計算出其他輪次中 A A A和 B B B各自的貢獻,貢獻的比例都和計算出的後驗機率相對應。

計算出 A A A和 B B B在不同輪次中的貢獻,就可以對未知參數做出更加精确的估計。在 50 次投擲中,硬币 A A A貢獻了21.3次正面和8.6次反面,其參數估計值 θ ^ A ( 1 ) = 0.71 \hat{\theta}_A^{(1)}=0.71 θ^A(1)​=0.71;硬币 B B B貢獻了11.7次正面和8.4次反面,其參數估計值 θ ^ B ( 1 ) = 0.58 \hat{\theta}_B^{(1)}=0.58 θ^B(1)​=0.58。利用這組參數繼續疊代更新,就可以計算出最終的估計值。

機率模型有時候既含有觀測變量,又含有隐變量或潛在變量,如果機率模型的變量都是觀測變量,那麼給定資料,可以直接用極大似然估計法,或貝葉斯估計方法估計模型參數,但是當模型含有隐變量時,就不能簡單的使用這些方法,EM算法就是含有隐變量的機率模型參數的極大似然估計法,或極大後驗機率估計法,我們讨論極大似然估計,極大後驗機率估計與其類似。

3、EM算法

輸入:觀測變量資料 Y Y Y,隐變量資料 Z Z Z,聯合分布 P ( Y , Z ∣ θ ) P(Y,Z|\theta) P(Y,Z∣θ),條件分布 P ( Z ∣ Y , θ ) P(Z|Y,\theta) P(Z∣Y,θ);

輸出:模型參數 θ \theta θ

  • (1)選擇參數的初值 θ 0 \theta^0 θ0,開始疊代
  • (2) E E E步:記 θ i \theta^i θi為第i次疊代參數 θ \theta θ的估計值,在第i+1次疊代的 E E E步,計算

    Q ( θ , θ i ) = E Z [ l o g P ( Y , Z ∣ θ ) ∣ Y , θ i ] = ∑ Z l o g P ( Y , Z ∣ θ ) P ( Z ∣ Y , θ i ) \begin{aligned} Q(\theta,\theta^i)&=E_{Z}[logP(Y,Z|\theta)|Y,\theta^i]\\ &=\sum_{Z}logP(Y,Z|\theta)P(Z|Y,\theta^i) \end{aligned} Q(θ,θi)​=EZ​[logP(Y,Z∣θ)∣Y,θi]=Z∑​logP(Y,Z∣θ)P(Z∣Y,θi)​

    這裡, P ( Z ∣ Y , θ i ) P(Z|Y,\theta^i) P(Z∣Y,θi)是在給定觀測資料Y和目前的參數估計 θ i \theta^i θi下隐變量資料 Z Z Z的條件機率分布;

  • (3) M M M步:求使 Q ( θ , θ i ) Q(\theta,\theta^i) Q(θ,θi)極大化的 θ \theta θ,确定第 i + 1 i+1 i+1次疊代的參數的估計值 θ i + 1 \theta^{i+1} θi+1,

    θ i + 1 = a r g max ⁡ θ Q ( θ , θ i ) \theta^{i+1}=arg \max \limits_{\theta}Q(\theta,\theta^{i}) θi+1=argθmax​Q(θ,θi)

    Q ( θ , θ i ) Q(\theta,\theta^{i}) Q(θ,θi)是 E M EM EM算法的核心,稱為 Q Q Q函數( Q Q Q function),這個是需要自己構造的。

  • (4) 重複第(2)步和第(3)步,直到收斂,收斂條件:

    ∣ ∣ θ i + 1 − θ i ∣ ∣ < ε 1 || \theta^{i+1}-\theta^{i} || < \varepsilon_1 ∣∣θi+1−θi∣∣<ε1​

    或者:

    ∣ ∣ Q ( θ i + 1 , θ i ) − Q ( θ i , θ i ) ∣ ∣ < ε 2 ||Q(\theta^{i+1},\theta^{i})-Q(\theta^{i},\theta^{i})|| <\varepsilon_2 ∣∣Q(θi+1,θi)−Q(θi,θi)∣∣<ε2​

    收斂疊代就結束了。我們來拆解一下這個 M M M步驟,

4、推導逼近

主要講解 J e n s e n Jensen Jensen不等式,這個公式在推導和收斂都用到,主要是如下的結論:

  • f ( x ) f(x) f(x)是凸函數

    f ( E ( X ) ) ≤ E ( f ( x ) ) f(E(X)) \le E(f(x)) f(E(X))≤E(f(x))

  • f ( x ) f(x) f(x)是凹函數

    f ( E ( X ) ) ≥ E ( f ( x ) ) f(E(X)) \ge E(f(x)) f(E(X))≥E(f(x))

推導出 E M EM EM算法可以近似實作對觀測資料的極大似然估計的辦法是找到E步驟的下界,讓下界最大,通過逼近的方式實作對觀測資料的最大似然估計。統計學習基礎中采用的是相減方式,我們來看下具體的步驟。

  • 增加隐藏變量

    L ( θ ) = ∑ Z l o g P ( Y ∣ Z , θ ) P ( Z , θ ) L(\theta)=\sum_{Z}logP(Y|Z,\theta)P(Z,\theta) L(θ)=Z∑​logP(Y∣Z,θ)P(Z,θ)

    則 L ( θ ) − L ( θ i ) L(\theta)-L(\theta^{i}) L(θ)−L(θi)為:

    L ( θ ) − L ( θ i ) = l o g ( ∑ Z P ( Y ∣ Z , θ i ) P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Y ∣ Z , θ i ) ) − L ( θ i ) ≥ ∑ Z P ( Y ∣ Z , θ i ) l o g ( P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Y ∣ Z , θ i ) ) − L ( θ i ) \begin{aligned} L(\theta)-L(\theta^{i})=log(\sum_{Z} P(Y|Z,\theta^i)\frac{P(Y|Z,\theta)P(Z,\theta)}{P(Y|Z,\theta^i)})-L(\theta^{i})\\ \ge \sum_{Z} P(Y|Z,\theta^i)log(\frac{P(Y|Z,\theta)P(Z,\theta)}{P(Y|Z,\theta^i)})-L(\theta^{i}) \end{aligned} L(θ)−L(θi)=log(Z∑​P(Y∣Z,θi)P(Y∣Z,θi)P(Y∣Z,θ)P(Z,θ)​)−L(θi)≥Z∑​P(Y∣Z,θi)log(P(Y∣Z,θi)P(Y∣Z,θ)P(Z,θ)​)−L(θi)​

    ≥ \ge ≥這一個步驟就是采用了凹函數的 J e n s e n Jensen Jensen不等式做轉換。因為 Z Z Z是隐藏變量,是以有 ∑ Z P ( Y ∣ Z , θ i ) = = 1 , P ( Y ∣ Z , θ i ) > 0 \sum_{Z} P(Y|Z,\theta^i)==1,P(Y|Z,\theta^i)>0 ∑Z​P(Y∣Z,θi)==1,P(Y∣Z,θi)>0,于是繼續變:

L ( θ ) − L ( θ i ) = l o g ( ∑ Z P ( Y ∣ Z , θ i ) P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Y ∣ Z , θ i ) ) − L ( θ i ) ≥ ∑ Z P ( Z ∣ Y , θ i ) l o g ( P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Z ∣ Y , θ i ) ) − L ( θ i ) = ∑ Z P ( Z ∣ Y , θ i ) l o g ( P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Z ∣ Y , θ i ) ) − ∑ Z P ( Z ∣ Y , θ i ) L ( θ i ) = ∑ Z P ( Z ∣ Y , θ i ) l o g ( P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Z ∣ Y , θ i ) ( P ( Y ∣ θ i ) ) ≥ 0 \begin{aligned} L(\theta)-L(\theta^{i})&=log(\sum_{Z} P(Y|Z,\theta^i)\frac{P(Y|Z,\theta)P(Z,\theta)}{P(Y|Z,\theta^i)})-L(\theta^{i})\\ &\ge \sum_{Z} P(Z|Y,\theta^i)log(\frac{P(Y|Z,\theta)P(Z,\theta)}{P(Z|Y,\theta^i)})-L(\theta^{i})\\ &=\sum_{Z} P(Z|Y,\theta^i)log(\frac{P(Y|Z,\theta)P(Z,\theta)}{P(Z|Y,\theta^i)})-\sum_{Z} P(Z|Y,\theta^i)L(\theta^{i})\\ &= \sum_{Z} P(Z|Y,\theta^i)log(\frac{P(Y|Z,\theta)P(Z,\theta)}{P(Z|Y,\theta^i) (P(Y|\theta^{i})}) \\ & \ge0 \end{aligned} L(θ)−L(θi)​=log(Z∑​P(Y∣Z,θi)P(Y∣Z,θi)P(Y∣Z,θ)P(Z,θ)​)−L(θi)≥Z∑​P(Z∣Y,θi)log(P(Z∣Y,θi)P(Y∣Z,θ)P(Z,θ)​)−L(θi)=Z∑​P(Z∣Y,θi)log(P(Z∣Y,θi)P(Y∣Z,θ)P(Z,θ)​)−Z∑​P(Z∣Y,θi)L(θi)=Z∑​P(Z∣Y,θi)log(P(Z∣Y,θi)(P(Y∣θi)P(Y∣Z,θ)P(Z,θ)​)≥0​

也就是: L ( θ ) ≥ L ( θ i ) + ∑ Z P ( Z ∣ Y , θ i ) l o g ( P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Y ∣ Z , θ i ) L ( θ i ) ) L(\theta)\ge L(\theta^{i})+ \sum_{Z} P(Z|Y,\theta^i)log(\frac{P(Y|Z,\theta)P(Z,\theta)}{P(Y|Z,\theta^i) L(\theta^{i})}) L(θ)≥L(θi)+∑Z​P(Z∣Y,θi)log(P(Y∣Z,θi)L(θi)P(Y∣Z,θ)P(Z,θ)​),有下界,最大化下界,來得到近似值。這裡有一個細節: P ( Y ∣ Z , θ i ) P(Y|Z,\theta^i) P(Y∣Z,θi) 變為 P ( Z ∣ Y , θ i ) P(Z|Y,\theta^i) P(Z∣Y,θi)?如果要滿足 J e n s e n Jensen Jensen不等式的等号,則有:

P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Y ∣ Z , θ i ) = c \frac{P(Y|Z,\theta)P(Z,\theta)}{P(Y|Z,\theta^i)} = c P(Y∣Z,θi)P(Y∣Z,θ)P(Z,θ)​=c

c c c為一個常數,而 ∑ Z P ( Y ∣ Z , θ i ) = 1 \sum_{Z}P(Y|Z,\theta^i)=1 ∑Z​P(Y∣Z,θi)=1則:

∑ Z P ( Y ∣ Z , θ ) P ( Z , θ ) = c ∑ Z P ( Y ∣ Z , θ i ) = c = P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Y ∣ Z , θ i ) P ( Y ∣ Z , θ ) = P ( Y ∣ Z , θ ) P ( Z , θ ) ∑ Z P ( Y ∣ Z , θ ) P ( Z , θ ) = P ( Y , Z , θ ) P ( Y , θ ) = P ( Z ∣ Y , θ ) \begin{aligned} \sum_{Z}P(Y|Z,\theta)P(Z,\theta)= c\sum_{Z}P(Y|Z,\theta^i)&=c\\ &=\frac{P(Y|Z,\theta)P(Z,\theta)}{P(Y|Z,\theta^i)}\\ P(Y|Z,\theta)=\frac{P(Y|Z,\theta)P(Z,\theta)}{\sum_{Z}P(Y|Z,\theta)P(Z,\theta)}=\frac{P(Y,Z,\theta)}{P(Y,\theta)}=P(Z|Y,\theta) \end{aligned} Z∑​P(Y∣Z,θ)P(Z,θ)=cZ∑​P(Y∣Z,θi)P(Y∣Z,θ)=∑Z​P(Y∣Z,θ)P(Z,θ)P(Y∣Z,θ)P(Z,θ)​=P(Y,θ)P(Y,Z,θ)​=P(Z∣Y,θ)​=c=P(Y∣Z,θi)P(Y∣Z,θ)P(Z,θ)​​

大家是不是很奇怪 P ( Y ∣ Z , θ ) P ( Z , θ ) P(Y|Z,\theta)P(Z,\theta) P(Y∣Z,θ)P(Z,θ)加上 ∑ \sum ∑之後等于什麼,其實有的部落格這裡使用 P ( Z , θ ) = P ( Y i , Z i , θ i ) P(Z,\theta) = P(Y^i,Z^i,\theta^i) P(Z,θ)=P(Yi,Zi,θi)來替代 P ( Y ∣ Z , θ ) P(Y|Z,\theta) P(Y∣Z,θ)參與計算,這樣 ∑ Z P ( Y i , Z i , θ i ) \sum_{Z}P(Y^i,Z^i,\theta^i) ∑Z​P(Yi,Zi,θi),這樣就友善了解來了。

于是最大化如下:

θ i + 1 = a r g max ⁡ θ ∑ Z P ( Z ∣ Y , θ i ) l o g ( P ( Y ∣ Z , θ ) P ( Z , θ ) P ( Z ∣ Y , θ i ) ) = a r g max ⁡ θ ∑ Z P ( Z ∣ Y , θ i ) l o g ( P ( Y ∣ Z , θ ) P ( Z , θ ) ) = a r g max ⁡ θ ∑ Z P ( Z ∣ Y , θ i ) l o g ( P ( Y , Z ∣ θ ) ) = a r g max ⁡ θ Q ( θ , θ i ) \begin{aligned} \theta^{i+1}&=arg \max_{\theta}\sum_{Z} P(Z|Y,\theta^i)log(\frac{P(Y|Z,\theta)P(Z,\theta)}{P(Z|Y,\theta^i)})\\ &=arg \max_{\theta}\sum_{Z} P(Z|Y,\theta^i)log(P(Y|Z,\theta)P(Z,\theta))\\ & =arg \max_{\theta}\sum_{Z} P(Z|Y,\theta^i)log(P(Y,Z|\theta))\\ &=arg \max_{\theta}Q(\theta,\theta^i) \end{aligned} θi+1​=argθmax​Z∑​P(Z∣Y,θi)log(P(Z∣Y,θi)P(Y∣Z,θ)P(Z,θ)​)=argθmax​Z∑​P(Z∣Y,θi)log(P(Y∣Z,θ)P(Z,θ))=argθmax​Z∑​P(Z∣Y,θi)log(P(Y,Z∣θ))=argθmax​Q(θ,θi)​

其中 l o g log log分母提出來是關于 Z Z Z的 ∑ Z P ( Z ∣ Y , θ i ) l o g P ( Z ∣ Y , θ i ) \sum_{Z} P(Z|Y,\theta^i)logP(Z|Y,\theta^i) ∑Z​P(Z∣Y,θi)logP(Z∣Y,θi),可以去掉。當然也有部落格寫的形式是:

a r g max ⁡ θ ∑ i = 1 M ∑ Z i P ( Z i ∣ Y i , θ i ) l o g ( P ( Y i , Z i ; θ ) ) arg \max_{\theta}\sum_{i=1}^{M}\sum_{Z^{i}} P(Z^{i}|Y^{i},\theta^i)log(P(Y^{i},Z^{i};\theta))\\ argθmax​i=1∑M​Zi∑​P(Zi∣Yi,θi)log(P(Yi,Zi;θ))

形式其實一樣,表示的不一樣而已。

5、證明收斂

我們知道已知觀測資料的似然函數是 P ( Y , θ ) P(Y,\theta) P(Y,θ),對數似然函數為:

L ( ) = ∑ i = 1 M l o g P ( y i , θ ) = ∑ i = 1 M l o g ( P ( y i , Z ∣ θ ) P ( Z ∣ y i , θ ) ) = ∑ i = 1 M l o g P ( y i , Z ∣ θ ) − ∑ i = 1 M l o g P ( Z ∣ y i , θ ) \begin{aligned} L()=\sum_{i=1}^{M}logP(y^{i},\theta) &=\sum_{i=1}^{M}log(\frac{P(y^i,Z|\theta)}{P(Z|y^i,\theta)})\\ &=\sum_{i=1}^{M}logP(y^i,Z|\theta) - \sum_{i=1}^{M}logP(Z|y^i,\theta) \end{aligned} L()=i=1∑M​logP(yi,θ)​=i=1∑M​log(P(Z∣yi,θ)P(yi,Z∣θ)​)=i=1∑M​logP(yi,Z∣θ)−i=1∑M​logP(Z∣yi,θ)​

要證明收斂,就證明單調遞增, ∑ i = 1 M l o g P ( y i , θ j + 1 ) > ∑ i = 1 M l o g P ( y i , θ j ) \sum_{i=1}^{M}logP(y^{i},\theta^{j+1})>\sum_{i=1}^{M}logP(y^{i},\theta^{j}) ∑i=1M​logP(yi,θj+1)>∑i=1M​logP(yi,θj)

由上文知道:

Q ( θ , θ i ) = ∑ Z l o g P ( Y , Z ∣ θ ) P ( Z ∣ Y , θ i ) = ∑ i = 1 M ∑ Z j l o g P ( y i , Z j ∣ θ ) P ( Z j ∣ y i , θ i ) \begin{aligned} Q(\theta,\theta^i)&=\sum_{Z}logP(Y,Z|\theta)P(Z|Y,\theta^i)\\ &=\sum_{i=1}^{M}\sum_{Z^j}logP(y^i,Z^j|\theta)P(Z^j|y^i,\theta^i) \end{aligned} Q(θ,θi)​=Z∑​logP(Y,Z∣θ)P(Z∣Y,θi)=i=1∑M​Zj∑​logP(yi,Zj∣θ)P(Zj∣yi,θi)​

我們構造一個函數 H H H,讓他等于:

H ( θ , θ i ) = ∑ i = 1 M ∑ Z j l o g ( P ( Z ∣ y i , θ ) P ( Z ∣ y i , θ i ) ) H(\theta,\theta^{i})=\sum_{i=1}^{M}\sum_{Z^j}log(P(Z|y^i,\theta)P(Z|y^i,\theta^i)) H(θ,θi)=i=1∑M​Zj∑​log(P(Z∣yi,θ)P(Z∣yi,θi))

讓 Q ( θ , θ i ) − H ( θ , θ i ) Q(\theta,\theta^i)-H(\theta,\theta^{i}) Q(θ,θi)−H(θ,θi):

Q ( θ , θ i ) − H ( θ , θ i ) = ∑ i = 1 M ∑ Z j l o g P ( y i , Z j ∣ θ ) P ( Z j ∣ y i , θ i ) − ∑ i = 1 M ∑ Z j l o g ( P ( Z j ∣ y i , θ ) P ( Z j ∣ y i , θ i ) ) = ∑ i = 1 M ∑ Z j l o g ( P ( y i , Z j ∣ θ ) − P ( Z j ∣ y i , θ ) ) = ∑ i = 1 M l o g P ( y i , θ ) \begin{aligned} Q(\theta,\theta^i)-H(\theta,\theta^{i})&=\sum_{i=1}^{M}\sum_{Z^j}logP(y^i,Z^j|\theta)P(Z^j|y^i,\theta^i) - \sum_{i=1}^{M}\sum_{Z^j}log(P(Z^j|y^i,\theta)P(Z^j|y^i,\theta^i)) \\ &=\sum_{i=1}^{M}\sum_{Z^j}log\bigg(P(y^i,Z^j|\theta)-P(Z^j|y^i,\theta)\bigg) \\ &=\sum_{i=1}^{M}logP(y^{i},\theta) \end{aligned} Q(θ,θi)−H(θ,θi)​=i=1∑M​Zj∑​logP(yi,Zj∣θ)P(Zj∣yi,θi)−i=1∑M​Zj∑​log(P(Zj∣yi,θ)P(Zj∣yi,θi))=i=1∑M​Zj∑​log(P(yi,Zj∣θ)−P(Zj∣yi,θ))=i=1∑M​logP(yi,θ)​

是以:

∑ i = 1 M l o g P ( y i , θ j + 1 ) − ∑ i = 1 M l o g P ( y i , θ j ) = Q ( θ i + 1 , θ i ) − H ( θ i + 1 , θ i ) − ( Q ( θ i , θ i ) − H ( θ i , θ i ) ) = Q ( θ i + 1 , θ i ) − Q ( θ i , θ i ) − ( H ( θ i + 1 , θ i ) − H ( θ i , θ i ) ) \sum_{i=1}^{M}logP(y^{i},\theta^{j+1})-\sum_{i=1}^{M}logP(y^{i},\theta^{j}) \\ = Q(\theta^{i+1},\theta^i)-H(\theta^{i+1},\theta^{i}) - (Q(\theta^{i},\theta^{i})-H(\theta^{i},\theta^{i}))\\ = Q(\theta^{i+1},\theta^i)- Q(\theta^{i},\theta^{i}) -( H(\theta^{i+1},\theta^{i}) - H(\theta^{i},\theta^{i})) i=1∑M​logP(yi,θj+1)−i=1∑M​logP(yi,θj)=Q(θi+1,θi)−H(θi+1,θi)−(Q(θi,θi)−H(θi,θi))=Q(θi+1,θi)−Q(θi,θi)−(H(θi+1,θi)−H(θi,θi))

該公式左邊已經被證明是大于0,證明右邊: H ( θ i + 1 , θ i ) − H ( θ i , θ i ) < 0 H(\theta^{i+1},\theta^{i}) - H(\theta^{i},\theta^{i})<0 H(θi+1,θi)−H(θi,θi)<0:

H ( θ i + 1 , θ i ) − H ( θ i , θ i ) = ∑ Z j ( l o g ( P ( Z j ∣ Y , θ i + 1 ) P ( Z j ∣ Y , θ i ) ) ) P ( Z j ∣ Y , θ i ) = l o g ( ∑ Z j P ( Z j ∣ Y , θ i + 1 ) P ( Z j ∣ Y , θ i ) P ( Z j ∣ Y , θ i ) ) = l o g P ( Z ∣ Y , θ i + 1 ) = l o g 1 = 0 \begin{aligned} H(\theta^{i+1},\theta^{i}) - H(\theta^{i},\theta^{i}) &=\sum_{Z^j}\bigg(log(\frac{P(Z^j|Y,\theta^{i+1})}{P(Z^j|Y,\theta^i)}) \bigg)P(Z^j|Y,\theta^i) \\ &=log\bigg(\sum_{Z^j}\frac{P(Z^j|Y,\theta^{i+1})}{P(Z^j|Y,\theta^i)}P(Z^j|Y,\theta^i) \bigg)\\ &=logP(Z|Y,\theta^{i+1})=log1=0 \end{aligned} H(θi+1,θi)−H(θi,θi)​=Zj∑​(log(P(Zj∣Y,θi)P(Zj∣Y,θi+1)​))P(Zj∣Y,θi)=log(Zj∑​P(Zj∣Y,θi)P(Zj∣Y,θi+1)​P(Zj∣Y,θi))=logP(Z∣Y,θi+1)=log1=0​

其中不等式是由于 J e n s e n Jensen Jensen不等式,由此證明了 ∑ i = 1 M l o g P ( y i , θ j + 1 ) > ∑ i = 1 M l o g P ( y i , θ j ) \sum_{i=1}^{M}logP(y^{i},\theta^{j+1})>\sum_{i=1}^{M}logP(y^{i},\theta^{j}) ∑i=1M​logP(yi,θj+1)>∑i=1M​logP(yi,θj),證明了 E M EM EM算法的收斂性。但不能保證是全局最優,隻能保證局部最優。

6、高斯混合分布

EM算法的一個重要應用場景就是高斯混合模型的參數估計。高斯混合模型就是由多個高斯模型組合在一起的混合模型(可以了解為多個高斯分布函數的線性組合,理論上高斯混合模型是可以拟合任意類型的分布),例如對于下圖中的資料集如果用一個高斯模型來描述的話顯然是不合理的:

機器學習系列(三)——EM算法1、前言2、EM算法引入3、EM算法4、推導逼近5、證明收斂6、高斯混合分布7、代碼

兩個高斯模型可以拟合資料集,如圖所示:

機器學習系列(三)——EM算法1、前言2、EM算法引入3、EM算法4、推導逼近5、證明收斂6、高斯混合分布7、代碼

如果有多個高斯模型,公式表示為:

P ( y ∣ θ ) = ∑ k = 1 K a k ϕ ( y ∣ θ k ) ϕ ( y ∣ θ k ) = 1 2 π δ k e x p ( − ( y − μ k ) 2 2 δ k 2 ) a k > 0 , ∑ a k = 1 P(y|\theta)=\sum_{k=1}^{K}a_k\phi(y|\theta_{k}) \\ \phi(y|\theta_{k})=\frac{1}{\sqrt{2\pi}\delta_{k}}exp(-\frac{(y-\mu_{k})^2}{2 \delta_{k}^{2}}) \\ a_k>0,\sum a_k =1 P(y∣θ)=k=1∑K​ak​ϕ(y∣θk​)ϕ(y∣θk​)=2π

​δk​1​exp(−2δk2​(y−μk​)2​)ak​>0,∑ak​=1

ϕ ( y ∣ θ k ) \phi(y|\theta_{k}) ϕ(y∣θk​)表示為第k個高斯分布密度模型,定義如上,其中 a k a_k ak​表示被選中的機率。在本次模型 P ( y ∣ θ ) P(y|\theta) P(y∣θ)中,觀測資料是已知的,而觀測資料具體來自哪個模型是未知的,有點像之前提過的三硬币模型,我們來對比一下, A A A硬币就像是機率 a k a_k ak​,用來表明具體的模型,而 B 、 C B、C B、C硬币就是具體的模型,隻不過這裡有很多個模型,不僅僅是 B 、 C B、C B、C這兩個模型。我們用 γ j k \gamma_{jk} γjk​來表示,則:

γ j k = { 1 第j個觀測資料來源于第k個模型 0 否則 \gamma_{jk} = \begin{cases} 1& \text{第j個觀測資料來源于第k個模型}\\ 0& \text{否則} \end{cases} γjk​={10​第j個觀測資料來源于第k個模型否則​

是以一個觀測資料 y j y_j yj​的隐藏資料 ( γ j 1 , γ j 2 , . . . , γ j k ) (\gamma_{j1},\gamma_{j2},...,\gamma_{jk}) (γj1​,γj2​,...,γjk​),那麼完全似然函數就是:

P ( y , γ ∣ θ ) = ∏ k = 1 K ∏ j = 1 N [ a k ϕ ( y ∣ θ k ) ] γ j k P(y,\gamma|\theta)= \prod_{k=1}^{K}\prod_{j=1}^{N}[a_{k}\phi(y|\theta_{k})]^{\gamma_{jk}} P(y,γ∣θ)=k=1∏K​j=1∏N​[ak​ϕ(y∣θk​)]γjk​

取對數之後等于:

l o g ( P ( y , γ ∣ θ ) ) = l o g ( ∏ k = 1 K ∏ j = 1 N [ a k ϕ ( y ∣ θ k ) ] γ j k ) = ∑ K k = 1 ( ∑ j = 1 N ( γ j k ) l o g ( a k ) + ∑ j = 1 N ( γ j k ) [ l o g ( 1 2 π ) − l o g ( δ k ) − ( y i − μ k ) 2 2 δ k 2 ] ) \begin{aligned} log(P(y,\gamma|\theta))&=log( \prod_{k=1}^{K}\prod_{j=1}^{N}[a_{k}\phi(y|\theta_{k})]^{\gamma_{jk}})\\ &=\sum_{K}^{k=1}\bigg(\sum_{j=1}^{N}(\gamma_{jk}) log(a_k)+\sum_{j=1}^{N}( \gamma_{jk})\bigg[log(\frac{1}{\sqrt{2\pi}})-log(\delta_{k})-\frac{(y_i-\mu_{k})^2}{2 \delta_{k}^{2}}\bigg]\bigg) \end{aligned} log(P(y,γ∣θ))​=log(k=1∏K​j=1∏N​[ak​ϕ(y∣θk​)]γjk​)=K∑k=1​(j=1∑N​(γjk​)log(ak​)+j=1∑N​(γjk​)[log(2π

​1​)−log(δk​)−2δk2​(yi​−μk​)2​])​

  • E E E 步 :

    Q ( θ . θ i ) = E [ l o g ( P ( y , γ ∣ θ ) ) ] = ∑ K k = 1 ( ∑ j = 1 N ( E γ j k ) l o g ( a k ) + ∑ j = 1 N ( E γ j k ) [ l o g ( 1 2 π ) − l o g ( δ k ) − ( y i − μ k ) 2 2 δ k 2 ] ) \begin{aligned} Q(\theta.\theta^i) &= E[log(P(y,\gamma|\theta))]\\ &=\sum_{K}^{k=1}\bigg(\sum_{j=1}^{N}(E\gamma_{jk}) log(a_k)+\sum_{j=1}^{N}(E\gamma_{jk})\bigg[log(\frac{1}{\sqrt{2\pi}})-log(\delta_{k})-\frac{(y_i-\mu_{k})^2}{2 \delta_{k}^{2}}\bigg]\bigg) \end{aligned} Q(θ.θi)​=E[log(P(y,γ∣θ))]=K∑k=1​(j=1∑N​(Eγjk​)log(ak​)+j=1∑N​(Eγjk​)[log(2π

    ​1​)−log(δk​)−2δk2​(yi​−μk​)2​])​

    其中我們定義 γ j k ^ \hat{\gamma_{jk}} γjk​^​:

    γ j k ^ = E ( γ j k ∣ y , θ ) = a k ϕ ( y i ∣ θ k ) ∑ k = 1 K a k ϕ ( y i ∣ θ k ) j = 1 , 2 , . . , N ; k = 1 , 2 , . . . , K n k = ∑ j = i N E γ j k \hat{\gamma_{jk}} = E(\gamma_{jk}|y,\theta)=\frac{a_k\phi(y_i|\theta_{k})}{\sum_{k=1}^{K}a_k\phi(y_i|\theta_{k}) }\\ j=1,2,..,N;k=1,2,...,K\\ n_k=\sum_{j=i}^{N}E\gamma_{jk} γjk​^​=E(γjk​∣y,θ)=∑k=1K​ak​ϕ(yi​∣θk​)ak​ϕ(yi​∣θk​)​j=1,2,..,N;k=1,2,...,Knk​=j=i∑N​Eγjk​

    于是化簡得到:

    Q ( θ . θ i ) = ∑ K k = 1 ( n k l o g ( a k ) + ∑ j = 1 N ( E γ j k ) [ l o g ( 1 2 π ) − l o g ( δ k ) − ( y i − μ k ) 2 2 δ k 2 ] ) \begin{aligned} Q(\theta.\theta^i) &= \sum_{K}^{k=1}\bigg(n_k log(a_k)+\sum_{j=1}^{N}(E\gamma_{jk})\bigg[log(\frac{1}{\sqrt{2\pi}})-log(\delta_{k})-\frac{(y_i-\mu_{k})^2}{2 \delta_{k}^{2}}\bigg]\bigg) \end{aligned} Q(θ.θi)​=K∑k=1​(nk​log(ak​)+j=1∑N​(Eγjk​)[log(2π

    ​1​)−log(δk​)−2δk2​(yi​−μk​)2​])​

E E E步在代碼設計上隻有 γ j k ^ \hat{\gamma_{jk}} γjk​^​有用,用于M步的計算。

  • M M M步,

    θ i + 1 = a r g max ⁡ θ Q ( θ , θ i ) \theta^{i+1}=arg \max_{\theta}Q(\theta,\theta^i) θi+1=argθmax​Q(θ,θi)

    對 Q ( θ , θ i ) Q(\theta,\theta^i) Q(θ,θi)求導,得到每個未知量的偏導,使其偏導等于0,求解得到:

μ k ^ = ∑ j = 1 N γ j k ^ y i ∑ j = 1 N γ j k ^ δ k ^ = ∑ j = 1 N γ j k ^ ( y i − μ k ) 2 ∑ j = 1 N γ j k ^ a k ^ = ∑ j = 1 N γ j k ^ N \hat{\mu_k}=\frac{\sum_{j=1}^{N}\hat{\gamma_{jk}}y_i}{\sum_{j=1}^{N}\hat{\gamma_{jk}}} \\ \\ \hat{\delta_k}=\frac{\sum_{j=1}^{N}\hat{\gamma_{jk}}(y_i-\mu_k)^2}{\sum_{j=1}^{N}\hat{\gamma_{jk}}} \hat{a_k}=\frac{\sum_{j=1}^{N}\hat{\gamma_{jk}} }{N} μk​^​=∑j=1N​γjk​^​∑j=1N​γjk​^​yi​​δk​^​=∑j=1N​γjk​^​∑j=1N​γjk​^​(yi​−μk​)2​ak​^​=N∑j=1N​γjk​^​​

給一個初始值,來回疊代就可以求得值内容。這一塊主要用到了 Q ( θ . θ i ) Q(\theta.\theta^i) Q(θ.θi)的導

數,并且用到了E步的 γ j k ^ \hat{\gamma_{jk}} γjk​^​。

7、代碼

import numpy as np
import random
import math
import time
           
'''
資料集:僞造資料集(兩個高斯分布混合)
資料集長度:1000
------------------------------
運作結果:
----------------------------
the Parameters set is:
alpha0:0.3, mu0:0.7, sigmod0:-2.0, alpha1:0.5, mu1:0.5, sigmod1:1.0
----------------------------
the Parameters predict is:
alpha0:0.4, mu0:0.6, sigmod0:-1.7, alpha1:0.7, mu1:0.7, sigmod1:0.9
----------------------------
'''

def loadData(mu0, sigma0, mu1, sigma1, alpha0, alpha1):
    '''
    初始化資料集
    這裡通過服從高斯分布的随機函數來僞造資料集
    :param mu0: 高斯0的均值
    :param sigma0: 高斯0的方差
    :param mu1: 高斯1的均值
    :param sigma1: 高斯1的方差
    :param alpha0: 高斯0的系數
    :param alpha1: 高斯1的系數
    :return: 混合了兩個高斯分布的資料
    '''
    # 定義資料集長度為1000
    length = 1000

    # 初始化第一個高斯分布,生成資料,資料長度為length * alpha系數,以此來
    # 滿足alpha的作用
    data0 = np.random.normal(mu0, sigma0, int(length * alpha0))
    # 第二個高斯分布的資料
    data1 = np.random.normal(mu1, sigma1, int(length * alpha1))

    # 初始化總資料集
    # 兩個高斯分布的資料混合後會放在該資料集中傳回
    dataSet = []
    # 将第一個資料集的内容添加進去
    dataSet.extend(data0)
    # 添加第二個資料集的資料
    dataSet.extend(data1)
    # 對總的資料集進行打亂(其實不打亂也沒事,隻不過打亂一下直覺上讓人感覺已經混合了
    # 讀者可以将下面這句話屏蔽以後看看效果是否有差别)
    random.shuffle(dataSet)

    #傳回僞造好的資料集
    return dataSet
           
# 高斯分布公式
def calcGauss(dataSetArr, mu, sigmod):
    '''
    根據高斯密度函數計算值
    依據:“9.3.1 高斯混合模型” 式9.25
    注:在公式中y是一個實數,但是在EM算法中(見算法9.2的E步),需要對每個j
    都求一次yjk,在本執行個體中有1000個可觀測資料,是以需要計算1000次。考慮到
    在E步時進行1000次高斯計算,程式上比較不簡潔,是以這裡的y是向量,在numpy
    的exp中如果exp内部值為向量,則對向量中每個值進行exp,輸出仍是向量的形式。
    是以使用向量的形式1次計算即可将所有計算結果得出,程式上較為簡潔
    
    :param dataSetArr: 可觀測資料集
    :param mu: 均值
    :param sigmod: 方差
    :return: 整個可觀測資料集的高斯分布密度(向量形式)
    '''
    # 計算過程就是依據式9.25寫的,沒有别的花樣
    result = (1 / (math.sqrt(2*math.pi)*sigmod**2)) * np.exp(-1 * (dataSetArr-mu) * (dataSetArr-mu) / (2*sigmod**2))
    # 傳回結果
    return result


def E_step(dataSetArr, alpha0, mu0, sigmod0, alpha1, mu1, sigmod1):
    '''
    EM算法中的E步
    依據目前模型參數,計算分模型k對觀資料y的響應度
    :param dataSetArr: 可觀測資料y
    :param alpha0: 高斯模型0的系數
    :param mu0: 高斯模型0的均值
    :param sigmod0: 高斯模型0的方差
    :param alpha1: 高斯模型1的系數
    :param mu1: 高斯模型1的均值
    :param sigmod1: 高斯模型1的方差
    :return: 兩個模型各自的響應度
    '''
    # 計算y0的響應度
    # 先計算模型0的響應度的分子
    gamma0 = alpha0 * calcGauss(dataSetArr, mu0, sigmod0)
    #print("gamma0=",gamma0.shape) # 1000, 維向量
    # 模型1響應度的分子
    gamma1 = alpha1 * calcGauss(dataSetArr, mu1, sigmod1)

    # 兩者相加為E步中的分布
    sum = gamma0 + gamma1
    # 各自相除,得到兩個模型的響應度
    gamma0 = gamma0 / sum
    gamma1 = gamma1 / sum

    # 傳回兩個模型響應度
    return gamma0, gamma1

def M_step(muo, mu1, gamma0, gamma1, dataSetArr):
    # 依據算法9.2計算各個值
    # 這裡沒什麼花樣,對照書本公式看看這裡就好了
    
    # np.dot 點積:[1,2] [2,3] = [2,6]
    mu0_new = np.dot(gamma0, dataSetArr) / np.sum(gamma0)
    mu1_new = np.dot(gamma1, dataSetArr) / np.sum(gamma1)

    # math.sqrt  平方根 
    sigmod0_new = math.sqrt(np.dot(gamma0, (dataSetArr - muo)**2) / np.sum(gamma0))
    sigmod1_new = math.sqrt(np.dot(gamma1, (dataSetArr - mu1)**2) / np.sum(gamma1))

    alpha0_new = np.sum(gamma0) / len(gamma0)
    alpha1_new = np.sum(gamma1) / len(gamma1)

    # 将更新的值傳回
    return mu0_new, mu1_new, sigmod0_new, sigmod1_new, alpha0_new, alpha1_new


## 訓練主函數
def EM_Train(dataSetList, iter=500):
    '''
    根據EM算法進行參數估計
    算法依據“9.3.2 高斯混合模型參數估計的EM算法” 算法9.2
    :param dataSetList:資料集(可觀測資料)
    :param iter: 疊代次數
    :return: 估計的參數
    '''
    # 将可觀測資料y轉換為數組形式,主要是為了友善後續運算
    dataSetArr = np.array(dataSetList)

    # 步驟1:對參數取初值,開始疊代
    alpha0 = 0.5
    mu0 = 0
    sigmod0 = 1
    alpha1 = 0.5
    mu1 = 1
    sigmod1 = 1

    # 開始疊代
    step = 0
    while (step < iter):
        # 每次進入一次疊代後疊代次數加1
        step += 1
        # 步驟2:E步:依據目前模型參數,計算分模型k對觀測資料y的響應度
        gamma0, gamma1 = E_step(dataSetArr, alpha0, mu0, sigmod0, alpha1, mu1, sigmod1)
        # 步驟3:M步
        mu0, mu1, sigmod0, sigmod1, alpha0, alpha1 = M_step(mu0, mu1, gamma0, gamma1, dataSetArr)

    # 疊代結束後将更新後的各參數傳回
    return alpha0, mu0, sigmod0, alpha1, mu1, sigmod1
           
if __name__ == '__main__':
    start = time.time()

    # 設定兩個高斯模型進行混合,這裡是初始化兩個模型各自的參數
    # 見“9.3 EM算法在高斯混合模型學習中的應用”
    # alpha是“9.3.1 高斯混合模型” 定義9.2中的系數α
    # mu0是均值μ
    # sigmod是方差σ
    # 在設定上兩個alpha的和必須為1,其他沒有什麼具體要求,符合高斯定義就可以
    
    alpha0 = 0.3  # 系數α
    mu0 = -2  # 均值μ
    sigmod0 = 0.5  # 方差σ

    alpha1 = 0.7  # 系數α
    mu1 = 0.5  # 均值μ
    sigmod1 = 1  # 方差σ

    # 初始化資料集
    dataSetList = loadData(mu0, sigmod0, mu1, sigmod1, alpha0, alpha1)

    #列印設定的參數
    print('---------------------------')
    print('the Parameters set is:')
    print('alpha0:%.1f, mu0:%.1f, sigmod0:%.1f, alpha1:%.1f, mu1:%.1f, sigmod1:%.1f' % (
        alpha0, alpha1, mu0, mu1, sigmod0, sigmod1
    ))

    # 開始EM算法,進行參數估計
    alpha0, mu0, sigmod0, alpha1, mu1, sigmod1 = EM_Train(dataSetList)

    # 列印參數預測結果
    print('----------------------------')
    print('the Parameters predict is:')
    print('alpha0:%.1f, mu0:%.1f, sigmod0:%.1f, alpha1:%.1f, mu1:%.1f, sigmod1:%.1f' % (
        alpha0, alpha1, mu0, mu1, sigmod0, sigmod1
    ))

    # 列印時間
    print('----------------------------')
    print('time span:', time.time() - start)
           
---------------------------
the Parameters set is:
alpha0:0.3, mu0:0.7, sigmod0:-2.0, alpha1:0.5, mu1:0.5, sigmod1:1.0
----------------------------
the Parameters predict is:
alpha0:0.4, mu0:0.6, sigmod0:-1.8, alpha1:0.7, mu1:0.7, sigmod1:0.9
----------------------------
time span: 0.28402137756347656
           
import math
import copy
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
 
#生成随機資料,4個高斯模型
def generate_data(sigma,N,mu1,mu2,mu3,mu4,alpha):
    global X                  #可觀測資料集
    X = np.zeros((N, 2))       # 初始化X,2行N列。2維資料,N個樣本
    X=np.matrix(X)
    global mu                 #随機初始化mu1,mu2,mu3,mu4
    mu = np.random.random((4,2))
    mu=np.matrix(mu)
    global excep              #期望第i個樣本屬于第j個模型的機率的期望
    excep=np.zeros((N,4))
    global alpha_             #初始化混合項系數
    alpha_=[0.25,0.25,0.25,0.25]
    for i in range(N):
        if np.random.random(1) < 0.1:  # 生成0-1之間随機數
            X[i,:]  = np.random.multivariate_normal(mu1, sigma, 1)     #用第一個高斯模型生成2維資料
        elif 0.1 <= np.random.random(1) < 0.3:
            X[i,:] = np.random.multivariate_normal(mu2, sigma, 1)      #用第二個高斯模型生成2維資料
        elif 0.3 <= np.random.random(1) < 0.6:
            X[i,:] = np.random.multivariate_normal(mu3, sigma, 1)      #用第三個高斯模型生成2維資料
        else:
            X[i,:] = np.random.multivariate_normal(mu4, sigma, 1)      #用第四個高斯模型生成2維資料
 
    print("可觀測資料:\n",X)       #輸出可觀測樣本
    print("初始化的mu1,mu2,mu3,mu4:",mu)      #輸出初始化的mu


# E 期望
#  \hat{\gamma_{jk}}
def e_step(sigma,k,N):
    global X
    global mu
    global excep
    global alpha_
    for i in range(N):
        denom=0
        for j in range(0,k):
            #  sigma.I 表示矩陣的逆矩陣
            # np.transpose :矩陣轉置   np.linalg.det():矩陣求行列式
            denom += alpha_[j]*  math.exp(-(X[i,:]-mu[j,:])*sigma.I*np.transpose(X[i,:]-mu[j,:]))  /np.sqrt(np.linalg.det(sigma))       #分母
        for j in range(0,k):
            numer = math.exp(-(X[i,:]-mu[j,:])*sigma.I*np.transpose(X[i,:]-mu[j,:]))/np.sqrt(np.linalg.det(sigma))        #分子
            excep[i,j]=alpha_[j]*numer/denom      #求期望
    print("隐藏變量:\n",excep)

    
def m_step(k,N):
    global excep
    global X
    global alpha_
    for j in range(0,k):
        denom=0   #分母
        numer=0   #分子
        for i in range(N):
            numer += excep[i,j]*X[i,:]
            denom += excep[i,j]
        mu[j,:] = numer/denom    #求均值
        alpha_[j]=denom/N        #求混合項系數

        #     #可視化結果
def plotShow():
    # 畫生成的原始資料
    plt.subplot(221)
    plt.scatter(X[:,0].tolist(), X[:,1].tolist(),c='b',s=25,alpha=0.4,marker='o')    #T散點顔色,s散點大小,alpha透明度,marker散點形狀
    plt.title('random generated data')
    #畫分類好的資料
    plt.subplot(222)
    plt.title('classified data through EM')
    order=np.zeros(N)
    color=['b','r','k','y']
    for i in range(N):
        for j in range(k):
            if excep[i,j]==max(excep[i,:]):
                order[i]=j     #選出X[i,:]屬于第幾個高斯模型
            probility[i] += alpha_[int(order[i])]*math.exp(-(X[i,:]-mu[j,:])*sigma.I*np.transpose(X[i,:]-mu[j,:]))/(np.sqrt(np.linalg.det(sigma))*2*np.pi)    #計算混合高斯分布
        plt.scatter(X[i, 0], X[i, 1], c=color[int(order[i])], s=25, alpha=0.4, marker='o')      #繪制分類後的散點圖
    #繪制三維圖像
    ax = plt.subplot(223, projection='3d')
    plt.title('3d view')
    for i in range(N):
        ax.scatter(X[i, 0], X[i, 1], probility[i], c=color[int(order[i])])
    plt.show()
           
if __name__ == '__main__':
    iter_num=1000  #疊代次數
    N=500         #樣本數目
    k=4            #高斯模型數
    probility = np.zeros(N)    #混合高斯分布
    u1=[5,35]
    u2=[30,40]
    u3=[20,20]
    u4=[45,15]
    sigma=np.matrix([[30, 0], [0, 30]])               #協方差矩陣
    alpha=[0.1,0.2,0.3,0.4]         #混合項系數
    generate_data(sigma,N,u1,u2,u3,u4,alpha)     #生成資料
    #疊代計算
    for i in range(iter_num):
        err=0     #均值誤差
        err_alpha=0    #混合項系數誤差
        Old_mu = copy.deepcopy(mu)
        Old_alpha = copy.deepcopy(alpha_)
        
        e_step(sigma,k,N)     # E步
        m_step(k,N)           # M步
        
        print("疊代次數:",i+1)
        print("估計的均值:",mu)
        print("估計的混合項系數:",alpha_)
        for z in range(k):
            err += (abs(Old_mu[z,0]-mu[z,0])+abs(Old_mu[z,1]-mu[z,1]))      #計算誤差
            err_alpha += abs(Old_alpha[z]-alpha_[z])
        if (err<=0.001) and (err_alpha<0.001):     #達到精度退出疊代
            print(err,err_alpha)
            break
           
109e-25 9.99488022e-01 1.29388453e-22]
 ...
 [1.99757796e-05 4.17687378e-28 9.99980024e-01 4.95350725e-25]
 [2.45358493e-10 2.50870039e-01 2.11657235e-12 7.49129961e-01]
 [1.48152801e-01 5.42344357e-09 8.51846979e-01 2.14816543e-07]]
疊代次數: 3
估計的均值: [[18.16352738 18.51789671]
 [44.0188727  11.57156255]
 [18.79097885 37.11812726]
 [43.94919498 19.236304  ]]
估計的混合項系數: [0.1701072194309895, 0.22034483674648622, 0.31885414724744104, 0.2906937965750829]
隐藏變量:
 [[1.26315958e-09 6.52549567e-05 9.48508268e-06 9.99925259e-01]
 [4.33200608e-20 6.66643753e-01 8.84281679e-27 3.33356247e-01]
 [4.38259698e-04 3.97669452e-27 9.99561740e-01 8.38008244e-23]
 ...
 [8.01040377e-06 1.76520329e-30 9.99991990e-01 1.99052502e-25]
 [4.75000771e-09 1.81904192e-01 9.33028509e-14 8.18095803e-01]
 [7.25834510e-01 5.94754597e-10 2.74165248e-01 2.41114905e-07]]
疊代次數: 4
估計的均值: [[18.77195835 19.4084763 ]
 [44.35394848 11.16002574]
 [18.52344821 38.16632375]
 [43.8465529  19.80397859]]
估計的混合項系數: [0.19991031621232774, 0.22656512119250755, 0.29192402223384517, 0.2816005403613193]
隐藏變量:
 [[5.57753418e-09 2.07735437e-05 2.58393783e-06 9.99976637e-01]
 [1.58486276e-19 7.80152724e-01 7.16119250e-28 2.19847276e-01]
 [8.15875236e-04 1.02172345e-27 9.99184125e-01 2.02877035e-22]
 ...
 [1.33772058e-05 3.13492540e-31 9.99986623e-01 4.28060885e-25]
 [1.48788710e-08 1.64964916e-01 1.56538613e-14 8.35035069e-01]
 [9.33691079e-01 1.50659883e-10 6.63087051e-02 2.16111631e-07]]
疊代次數: 5
估計的均值: [[18.86897121 19.80983479]
 [44.56169652 11.17640642]
 [18.43531761 38.55845461]
 [43.69853006 20.09846373]]
估計的混合項系數: [0.21166414491602825, 0.23428522558935938, 0.2805981836543999, 0.27345244584021244]
隐藏變量:
 [[7.75949130e-09 1.65360448e-05 1.48091767e-06 9.99981975e-01]
 [1.62755833e-19 8.50872210e-01 2.58251229e-28 1.49127790e-01]
 [1.25405536e-03 6.78781564e-28 9.98745945e-01 4.09396084e-22]
 ...
 [2.02893540e-05 1.87187802e-31 9.99979711e-01 8.50301544e-25]
 [1.81718076e-08 1.72899671e-01 7.99722795e-15 8.27100310e-01]
 [9.63576417e-01 9.22834832e-11 3.64233473e-02 2.35627153e-07]]
疊代次數: 6
估計的均值: [[18.82268462 19.96689266]
 [44.70364998 11.30798903]
 [18.43338711 38.7100299 ]
 [43.52686588 20.31945405]]
估計的混合項系數: [0.21576276515752263, 0.24333853386050489, 0.27603036091323985, 0.2648683400687325]
隐藏變量:
 [[7.08522242e-09 1.70377434e-05 1.15124494e-06 9.99981804e-01]
 [1.23603423e-19 8.98667049e-01 1.75283637e-28 1.01332951e-01]
 [1.64283218e-03 6.30379751e-28 9.98357168e-01 8.01202925e-22]
 ...
 [2.67221677e-05 1.71040145e-31 9.99973278e-01 1.71471963e-24]
 [1.73252986e-08 1.90047650e-01 6.52073031e-15 8.09952332e-01]
 [9.70516115e-01 8.22114407e-11 2.94835862e-02 2.98779708e-07]]
疊代次數: 7
估計的均值: [[18.74583197 20.02884824]
 [44.81488522 11.46884818]
 [18.4450482  38.77773924]
 [43.33653548 20.53323269]]
估計的混合項系數: [0.21702977097992085, 0.2530421753812168, 0.2738802688037561, 0.25604778483510604]
隐藏變量:
 [[5.63996305e-09 1.86320775e-05 9.59269127e-07 9.99980403e-01]
 [8.94749080e-20 9.32329906e-01 1.43098252e-28 6.76700941e-02]
 [1.95155889e-03 6.48111666e-28 9.98048441e-01 1.58647526e-21]
 ...
 [3.19538273e-05 1.77997659e-31 9.99968046e-01 3.56717194e-24]
 [1.57111530e-08 2.11986251e-01 6.13598326e-15 7.88013733e-01]
 [9.72652796e-01 8.35351570e-11 2.73467963e-02 4.07685738e-07]]
疊代次數: 8
估計的均值: [[18.67126246 20.05487171]
 [44.90781318 11.6377031 ]
 [18.44899629 38.81463704]
 [43.13182143 20.75514003]]
估計的混合項系數: [0.21727589377491396, 0.2630864447889952, 0.2725762790823613, 0.24706138235372946]
隐藏變量:
 [[4.33512973e-09 2.06162523e-05 7.99681053e-07 9.99978580e-01]
 [6.60274572e-20 9.55716869e-01 1.21413821e-28 4.42831312e-02]
 [2.18299241e-03 6.92306613e-28 9.97817008e-01 3.19465252e-21]
 ...
 [3.59061046e-05 1.94252105e-31 9.99964094e-01 7.61475711e-24]
 [1.43233910e-08 2.37888885e-01 5.98344110e-15 7.62111100e-01]
 [9.73461915e-01 8.94107703e-11 2.65375090e-02 5.76022190e-07]]
疊代次數: 9
估計的均值: [[18.60662308 20.0659511 ]
 [44.98532723 11.80879848]
 [18.44104192 38.83873899]
 [42.91603272 20.98804571]]
估計的混合項系數: [0.21712896438460902, 0.27331295159086055, 0.2715793638895945, 0.23797872013493607]
隐藏變量:
 [[3.32494915e-09 2.28300068e-05 6.56384756e-07 9.99976510e-01]
 [5.04703944e-20 9.71506505e-01 1.02881938e-28 2.84934947e-02]
 [2.34497721e-03 7.58484178e-28 9.97655023e-01 6.52363658e-21]
 ...
 [3.86534675e-05 2.18193001e-31 9.99961347e-01 1.65674976e-23]
 [1.32954744e-08 2.67889357e-01 5.85243377e-15 7.32110630e-01]
 [9.73860916e-01 9.84812890e-11 2.61382558e-02 8.28393459e-07]]
疊代次數: 10
估計的均值: [[18.55234695 20.06961945]
 [45.04793103 11.98000757]
 [18.42142179 38.85657654]
 [42.69151146 21.23228625]]
估計的混合項系數: [0.21680857794760375, 0.28362239820277213, 0.2706973336071755, 0.22887169024244847]
隐藏變量:
 [[2.56846009e-09 2.52331624e-05 5.29466128e-07 9.99974235e-01]
 [4.00283252e-20 9.81905060e-01 8.63386584e-29 1.80949405e-02]
 [2.44517897e-03 8.50905794e-28 9.97554821e-01 1.34632159e-20]
 ...
 [4.03036179e-05 2.51362498e-31 9.99959696e-01 3.65580001e-23]
 [1.25713047e-08 3.02326073e-01 5.68627712e-15 6.97673915e-01]
 [9.74132493e-01 1.10830451e-10 2.58663046e-02 1.20278737e-06]]
疊代次數: 11
估計的均值: [[18.50724485 20.06880426]
 [45.09606103 12.15021938]
 [18.39139765 38.8707027 ]
 [42.45986149 21.48793057]]
估計的混合項系數: [0.21639697380634126, 0.2939500015477026, 0.2698585292298478, 0.21979449541610818]
隐藏變量:
 [[2.00246431e-09 2.78122912e-05 4.20405230e-07 9.99971765e-01]
 [3.28711691e-20 9.88632587e-01 7.17769313e-29 1.13674132e-02]
 [2.49104055e-03 9.76948417e-28 9.97508959e-01 2.80090380e-20]
 ...
 [4.09804624e-05 2.96619542e-31 9.99959020e-01 8.15615131e-23]
 [1.20778107e-08 3.41522297e-01 5.47384527e-15 6.58477691e-01]
 [9.74369932e-01 1.27049046e-10 2.56283121e-02 1.75546700e-06]]
疊代次數: 12
估計的均值: [[18.47004434 20.06480445]
 [45.13050115 12.31885085]
 [18.35250595 38.88214567]
 [42.22214886 21.75517746]]
估計的混合項系數: [0.21592942953357389, 0.3042578628245726, 0.26903800854787174, 0.2107746990939817]
隐藏變量:
 [[1.57588390e-09 3.05670651e-05 3.29470049e-07 9.99969102e-01]
 [2.78704171e-20 9.92931244e-01 5.92702541e-29 7.06875608e-03]
 [2.49030546e-03 1.14680916e-27 9.97509695e-01 5.86494436e-20]
 ...
 [4.08247349e-05 3.58089326e-31 9.99959175e-01 1.83668047e-22]
 [1.17558837e-08 3.85694954e-01 5.21976434e-15 6.14305035e-01]
 [9.74602361e-01 1.48038953e-10 2.53950706e-02 2.56832101e-06]]
疊代次數: 13
估計的均值: [[18.43968835 20.05830737]
 [45.15240169 12.48582094]
 [18.30630292 38.89130406]
 [41.97867645 22.03456432]]
估計的混合項系數: [0.21542454899548702, 0.31453315962782175, 0.2682292940918509, 0.20181299728484045]
隐藏變量:
 [[1.25145040e-09 3.35134894e-05 2.55576942e-07 9.99966230e-01]
 [2.43257557e-20 9.95654010e-01 4.87612671e-29 4.34598950e-03]
 [2.45102120e-03 1.37443578e-27 9.97548979e-01 1.23636368e-19]
 ...
 [3.99882159e-05 4.41603778e-31 9.99960012e-01 4.17629061e-22]
 [1.15597030e-08 4.34886362e-01 4.93338841e-15 5.65113626e-01]
 [9.74835658e-01 1.75012439e-10 2.51605808e-02 3.76120952e-06]]
疊代次數: 14
估計的均值: [[18.41529176 20.04973825]
 [45.16329282 12.65161672]
 [18.25398287 38.89835415]
 [41.72817629 22.32757161]]
估計的混合項系數: [0.2148937254777755, 0.32479172342617646, 0.2674282625187707, 0.19288628857727688]
隐藏變量:
 [[1.00203924e-09 3.66828580e-05 1.96611872e-07 9.99963120e-01]
 [2.17881924e-20 9.97369028e-01 4.00399568e-29 2.63097155e-03]
 [2.38077150e-03 1.67847641e-27 9.97619228e-01 2.63491133e-19]
 ...
 [3.86153668e-05 5.55249188e-31 9.99961385e-01 9.63374719e-22]
 [1.14519241e-08 4.88931587e-01 4.62224150e-15 5.11068401e-01]
 [9.75069710e-01 2.09550122e-10 2.49247721e-02 5.51796340e-06]]
疊代次數: 15
估計的均值: [[18.39604475 20.03935155]
 [45.1650092  12.81738281]
 [18.19585438 38.9034572 ]
 [41.46665894 22.63769003]]
估計的混合項系數: [0.21434245651887535, 0.33508661705868736, 0.26662147175989015, 0.18394945466254706]
隐藏變量:
 [[8.07510900e-10 4.01140545e-05 1.49923668e-07 9.99959735e-01]
 [1.99581484e-20 9.98445606e-01 3.27860810e-29 1.55439438e-03]
 [2.28550256e-03 2.08332261e-27 9.97714497e-01 5.74066890e-19]
 ...
 [3.68194830e-05 7.10050154e-31 9.99963181e-01 2.28225102e-21]
 [1.14003203e-08 5.47500270e-01 4.28792338e-15 4.52499719e-01]
 [9.75307553e-01 2.53738445e-10 2.46843001e-02 8.14681041e-06]]
疊代次數: 16
估計的均值: [[18.3811661  20.02720964]
 [45.15943944 12.98503939]
 [18.13083281 38.90682006]
 [41.18618677 22.9718833 ]]
估計的混合項系數: [0.21376860251851398, 0.34552043888529943, 0.2657782243895824, 0.17493273420660424]
隐藏變量:
 [[6.52756515e-10 4.38455031e-05 1.12848348e-07 9.99956041e-01]
 [1.86299137e-20 9.99118094e-01 2.66409166e-29 8.81906490e-04]
 [2.16853172e-03 2.62113303e-27 9.97831468e-01 1.30598917e-18]
 ...
 [3.46647301e-05 9.21150233e-31 9.99965335e-01 5.68321357e-21]
 [1.13758528e-08 6.10202450e-01 3.92493835e-15 3.89797538e-01]
 [9.75559746e-01 3.10497833e-10 2.44280307e-02 1.22230673e-05]]
疊代次數: 17
估計的均值: [[18.36993381 20.01310735]
 [45.14815075 13.15744683]
 [18.05577656 38.9086547 ]
 [40.87328978 23.34293604]]
估計的混合項系數: [0.2131600168729003, 0.35626081901739354, 0.2648413573866065, 0.16573780672309965]
隐藏變量:
 [[5.26436161e-10 4.79022094e-05 8.30565521e-08 9.99952014e-01]
 [1.76628844e-20 9.99532355e-01 2.12646771e-29 4.67645053e-04]
 [2.02972827e-03 3.33586410e-27 9.97970272e-01 3.21349033e-18]
 ...
 [3.21535893e-05 1.20999687e-30 9.99967846e-01 1.54646299e-20]
 [1.13511042e-08 6.76686747e-01 3.52082494e-15 3.23313241e-01]
 [9.75847530e-01 3.84299282e-10 2.41335249e-02 1.89451288e-05]]
疊代次數: 18
估計的均值: [[18.36181366 19.99642836]
 [45.13194653 13.33864512]
 [17.96377638 38.90912525]
 [40.50582685 23.77454953]]
估計的混合項系數: [0.2124913897948912, 0.3675621613044594, 0.26370196128983275, 0.15624448761081694]
隐藏變量:
 [[4.19922618e-10 5.22420603e-05 5.86516632e-08 9.99947699e-01]
 [1.69701495e-20 9.99778714e-01 1.63508976e-29 2.21286364e-04]
 [1.86392256e-03 4.28888969e-27 9.98136077e-01 9.05348630e-18]
 ...
 [2.92025212e-05 1.60753444e-30 9.99970797e-01 4.89927310e-20]
 [1.12987231e-08 7.46632049e-01 3.05415071e-15 2.53367940e-01]
 [9.76212904e-01 4.82680476e-10 2.37559538e-02 3.11418008e-05]]
疊代次數: 19
估計的均值: [[18.35681327 19.9758038 ]
 [45.11022195 13.5342415 ]
 [17.839336   38.90841955]
 [40.04574027 24.31372585]]
估計的混合項系數: [0.21171708081122612, 0.3798029774350422, 0.26212757230080824, 0.14635236945292315]
隐藏變量:
 [[3.26448531e-10 5.65879332e-05 3.81892011e-08 9.99943374e-01]
 [1.65272622e-20 9.99914184e-01 1.16221597e-29 8.58156831e-05]
 [1.65724596e-03 5.56001153e-27 9.98342754e-01 3.21584021e-17]
 ...
 [2.55849565e-05 2.15571751e-30 9.99974415e-01 2.01405379e-19]
 [1.11898790e-08 8.19411738e-01 2.48916311e-15 1.80588250e-01]
 [9.76750970e-01 6.19685613e-10 2.31923497e-02 5.66799532e-05]]
疊代次數: 20
估計的均值: [[18.35648627 19.94820734]
 [45.07963445 13.75206894]
 [17.64657942 38.90744821]
 [39.4227667  25.05902192]]
估計的混合項系數: [0.21075285007135164, 0.39355359432195897, 0.2595886176932576, 0.13610493791343223]
隐藏變量:
 [[2.41495326e-10 6.01870692e-05 2.10556847e-08 9.99939792e-01]
 [1.64250170e-20 9.99977041e-01 6.94683064e-30 2.29592219e-05]
 [1.38295928e-03 7.22172148e-27 9.98617041e-01 1.71024949e-16]
 ...
 [2.08759976e-05 2.89573662e-30 9.99979124e-01 1.30711219e-18]
 [1.09897017e-08 8.92485967e-01 1.78134017e-15 1.07514022e-01]
 [9.77690313e-01 8.24734585e-10 2.21868903e-02 1.22795675e-04]]
疊代次數: 21
估計的均值: [[18.36660536 19.90656541]
 [45.03127824 14.00281454]
 [17.30995131 38.91116537]
 [38.50432922 26.20382356]]
估計的混合項系數: [0.20943400814020083, 0.40965464376871735, 0.25498184749745717, 0.1259295005936244]
隐藏變量:
 [[1.67529498e-10 6.28517695e-05 8.45500068e-09 9.99937140e-01]
 [1.70430798e-20 9.99997042e-01 2.84498981e-30 2.95805573e-06]
 [1.01525828e-03 9.27070051e-27 9.98984742e-01 1.82905382e-15]
 ...
 [1.47275211e-05 3.82648370e-30 9.99985272e-01 1.88328782e-17]
 [1.06494710e-08 9.56184743e-01 9.62044668e-16 4.38152461e-02]
 [9.79482232e-01 1.16821644e-09 2.01657836e-02 3.51983195e-04]]
疊代次數: 22
估計的均值: [[18.4030349  19.83466574]
 [44.94745748 14.2959934 ]
 [16.67817916 38.92306069]
 [37.07629283 28.08159159]]
估計的混合項系數: [0.20745407744439648, 0.42893462046643477, 0.24621129654889992, 0.11740000554026857]
隐藏變量:
 [[1.18683012e-10 7.06304507e-05 2.02387225e-09 9.99929367e-01]
 [1.95905735e-20 9.99999910e-01 5.30842348e-31 9.04070426e-08]
 [5.77411050e-04 1.13305828e-26 9.99422589e-01 5.64752611e-14]
 ...
 [7.76582346e-06 4.74365050e-30 9.99992234e-01 9.34395031e-16]
 [1.02268962e-08 9.92169380e-01 2.91298881e-16 7.83060967e-03]
 [9.82312078e-01 1.83257018e-09 1.62738856e-02 1.41403471e-03]]
疊代次數: 23
估計的均值: [[18.51849592 19.70480461]
 [44.80921449 14.61909758]
 [15.31980639 38.8245099 ]
 [35.06740991 31.22293553]]
估計的混合項系數: [0.2044725229955513, 0.45043667820476124, 0.22864847692404722, 0.11644232187563988]
隐藏變量:
 [[1.22875086e-10 1.13514021e-04 1.78061917e-10 9.99886486e-01]
 [2.93031047e-20 1.00000000e+00 1.64505565e-32 2.26990691e-10]
 [1.79158200e-04 1.02927982e-26 9.99820842e-01 3.39957023e-12]
 ...
 [2.10990012e-06 4.25600590e-30 9.99997890e-01 1.20454824e-13]
 [1.06564197e-08 9.99739434e-01 2.60625914e-17 2.60554917e-04]
 [9.84858990e-01 3.23083042e-09 1.01287725e-02 5.01223398e-03]]
疊代次數: 24
估計的均值: [[18.89941808 19.51408609]
 [44.6501912  14.92607734]
 [12.20919393 38.14395448]
 [32.68794651 35.38534057]]
估計的混合項系數: [0.20100301487178368, 0.4694081577166903, 0.1941118050316388, 0.13547702237988657]
隐藏變量:
 [[5.32493546e-10 5.18971706e-04 1.48110112e-12 9.99481028e-01]
 [8.93404004e-20 1.00000000e+00 7.23462179e-36 3.60726983e-14]
 [1.71356611e-05 3.97189252e-27 9.99982864e-01 7.38459465e-11]
 ...
 [1.66738385e-07 1.58376389e-30 9.99999833e-01 7.06706860e-12]
 [1.68644884e-08 9.99998967e-01 1.47637826e-19 1.01623174e-06]
 [9.91162982e-01 5.61103836e-09 2.86118085e-03 5.97583164e-03]]
疊代次數: 25
估計的均值: [[19.53357786 19.44803077]
 [44.59648418 15.06122594]
 [ 8.14381564 36.57630249]
 [30.97046753 38.06410028]]
估計的混合項系數: [0.20137882718845665, 0.4761886334701343, 0.15768655112596666, 0.16474598821544198]
隐藏變量:
 [[6.15468016e-09 2.91501402e-03 1.37170650e-15 9.97084980e-01]
 [4.56095129e-19 1.00000000e+00 2.80647341e-40 4.52626023e-17]
 [2.03336103e-06 1.25449533e-27 9.99997966e-01 2.63468251e-10]
 ...
 [1.83922001e-08 5.15177030e-31 9.99999982e-01 5.21064624e-11]
 [4.21000553e-08 9.99999945e-01 1.67057218e-22 1.27430535e-08]
 [9.96375872e-01 6.33279993e-09 2.84839736e-04 3.33928191e-03]]
疊代次數: 26
估計的均值: [[20.06702997 19.57904722]
 [44.58833841 15.1369444 ]
 [ 6.20404097 35.8843099 ]
 [29.90970494 38.76115659]]
估計的混合項系數: [0.20463623267757325, 0.4783664423947606, 0.14177558372240673, 0.17522174120525877]
隐藏變量:
 [[4.21893755e-08 8.63797905e-03 3.98201036e-17 9.91361979e-01]
 [1.62059405e-18 1.00000000e+00 1.27726426e-42 2.48479698e-18]
 [1.09336082e-06 1.10795049e-27 9.99998906e-01 1.05996856e-09]
 ...
 [9.58511976e-09 4.56276883e-31 9.99999990e-01 2.63313275e-10]
 [9.03775278e-08 9.99999907e-01 3.91423156e-24 2.28086452e-09]
 [9.96733641e-01 5.87482884e-09 5.19687206e-05 3.21438489e-03]]
疊代次數: 27
估計的均值: [[20.24201439 19.65972979]
 [44.59402075 15.16442444]
 [ 5.44492946 35.6822641 ]
 [29.4015747  38.9346748 ]]
估計的混合項系數: [0.20699103328358764, 0.47876992747693614, 0.13453034708589415, 0.17970869215358223]
隐藏變量:
 [[9.39242105e-08 1.40899513e-02 9.69688957e-18 9.85909955e-01]
 [2.41060616e-18 1.00000000e+00 1.27406692e-43 7.72877288e-19]
 [1.00347358e-06 1.13007855e-27 9.99998994e-01 2.34761172e-09]
 ...
 [8.64066754e-09 4.58406686e-31 9.99999991e-01 6.19562055e-10]
 [1.15994535e-07 9.99999883e-01 7.62623208e-25 1.20916815e-09]
 [9.96503210e-01 5.55928510e-09 2.32193051e-05 3.47356483e-03]]
疊代次數: 28
估計的均值: [[20.28395623 19.69167323]
 [44.59539886 15.17325418]
 [ 5.15460195 35.62222322]
 [29.1938406  38.98645001]]
估計的混合項系數: [0.20800631658867907, 0.478890306022767, 0.13144205711770116, 0.18166132027085297]
隐藏變量:
 [[1.24260513e-07 1.71166272e-02 5.61151975e-18 9.82883249e-01]
 [2.64300584e-18 1.00000000e+00 5.07545786e-44 4.93062316e-19]
 [1.01259051e-06 1.15678189e-27 9.99998984e-01 3.30593408e-09]
 ...
 [8.64749254e-09 4.64644697e-31 9.99999990e-01 8.87739360e-10]
 [1.23114277e-07 9.99999876e-01 3.94225204e-25 9.53963689e-10]
 [9.96343736e-01 5.45424802e-09 1.66835669e-05 3.63957543e-03]]
疊代次數: 29
估計的均值: [[20.29152008 19.70432467]
 [44.59495397 15.17664184]
 [ 5.03992837 35.60667956]
 [29.11184324 39.00052166]]
估計的混合項系數: [0.2083820860351442, 0.4789678579789313, 0.1301657952040437, 0.18248426078188087]
隐藏變量:
 [[1.36695438e-07 1.84515331e-02 4.49824981e-18 9.81548330e-01]
 [2.68378846e-18 1.00000000e+00 3.48128468e-44 4.17419906e-19]
 [1.02843990e-06 1.17406538e-27 9.99998968e-01 3.80174179e-09]
 ...
 [8.74544315e-09 4.68950954e-31 9.99999990e-01 1.02463530e-09]
 [1.24377477e-07 9.99999875e-01 3.00114835e-25 8.75890489e-10]
 [9.96256515e-01 5.42719439e-09 1.45520292e-05 3.72892800e-03]]
疊代次數: 30
估計的均值: [[20.29138883 19.70923731]
 [44.59443504 15.17794424]
 [ 4.99447461 35.60322454]
 [29.07997332 39.00372942]]
估計的混合項系數: [0.20851417991791724, 0.47900989912456504, 0.12964611099550127, 0.18282980996201692]
隐藏變量:
 [[1.41094496e-07 1.89835566e-02 4.11215258e-18 9.81016302e-01]
 [2.68049958e-18 1.00000000e+00 2.98575516e-44 3.92855553e-19]
 [1.03854907e-06 1.18280935e-27 9.99998957e-01 4.01999187e-09]
 ...
 [8.81343366e-09 4.71126521e-31 9.99999990e-01 1.08377759e-09]
 [1.24303524e-07 9.99999875e-01 2.68372661e-25 8.49950448e-10]
 [9.96213342e-01 5.42171816e-09 1.37626647e-05 3.77288980e-03]]
疊代次數: 31
估計的均值: [[20.29025119 19.71109645]
 [44.59412071 15.17843231]
 [ 4.97644294 35.60263679]
 [29.06758752 39.00417033]]
估計的混合項系數: [0.2085589231411389, 0.4790299865583697, 0.12943562485330548, 0.18297546544718607]
隐藏變量:
 [[1.42557397e-07 1.91882513e-02 3.96549337e-18 9.80811606e-01]
 [2.67243953e-18 1.00000000e+00 2.80608265e-44 3.84273603e-19]
 [1.04376148e-06 1.18678967e-27 9.99998952e-01 4.11023769e-09]
 ...
 [8.84964646e-09 4.72109551e-31 9.99999990e-01 1.10782309e-09]
 [1.24074876e-07 9.99999875e-01 2.56467286e-25 8.41050295e-10]
 [9.96193073e-01 5.42137741e-09 1.34565582e-05 3.79346482e-03]]
疊代次數: 32
估計的均值: [[20.28943145 19.71178251]
 [44.59396075 15.17861368]
 [ 4.96927478 35.60262912]
 [29.06274622 39.00405128]]
估計的混合項系數: [0.20857349179188478, 0.4790390253864386, 0.12935059448885797, 0.18303688833281867]
隐藏變量:
 [[1.43029763e-07 1.92664327e-02 3.90773588e-18 9.80733424e-01]
 [2.66697434e-18 1.00000000e+00 2.73680834e-44 3.81176089e-19]
 [1.04620060e-06 1.18851784e-27 9.99998950e-01 4.14677779e-09]
 ...
 [8.86687094e-09 4.72534249e-31 9.99999990e-01 1.11742294e-09]
 [1.23916868e-07 9.99999875e-01 2.51808978e-25 8.37953772e-10]
 [9.96183843e-01 5.42190448e-09 1.33357138e-05 3.80281627e-03]]
疊代次數: 33
估計的均值: [[20.2889821  19.71202963]
 [44.59388579 15.17868121]
 [ 4.96641783 35.60269248]
 [29.06084043 39.00389895]]
估計的混合項系數: [0.2085780168428415, 0.47904295402306146, 0.1293162845510407, 0.18306274458305627]
隐藏變量:
 [[1.43179227e-07 1.92963844e-02 3.88465402e-18 9.80703472e-01]
 [2.66404928e-18 1.00000000e+00 2.70942211e-44 3.80037819e-19]
 [1.04728237e-06 1.18924960e-27 9.99998949e-01 4.16149743e-09]
 ...
 [8.87457408e-09 4.72713541e-31 9.99999990e-01 1.12124579e-09]
 [1.23831579e-07 9.99999875e-01 2.49954741e-25 8.36868134e-10]
 [9.96179723e-01 5.42235530e-09 1.32876573e-05 3.80698370e-03]]
疊代次數: 34
估計的均值: [[20.28876184 19.71211652]
 [44.5938523  15.1787065 ]
 [ 4.9652762  35.60273756]
 [29.0600852  39.00380061]]
估計的混合項系數: [0.2085793312657214, 0.47904462357835054, 0.12930244486837542, 0.1830736002875526]
隐藏變量:
 [[1.43225507e-07 1.93079320e-02 3.87537253e-18 9.80691925e-01]
 [2.66263474e-18 1.00000000e+00 2.69848042e-44 3.79614284e-19]
 [1.04774684e-06 1.18955481e-27 9.99998948e-01 4.16742964e-09]
 ...
 [8.87789388e-09 4.72788173e-31 9.99999990e-01 1.12277204e-09]
 [1.23790110e-07 9.99999875e-01 2.49211329e-25 8.36485393e-10]
 [9.96177912e-01 5.42262253e-09 1.32684823e-05 3.80881389e-03]]
疊代次數: 35
估計的均值: [[20.28866011 19.7121463 ]
 [44.59383778 15.17871603]
 [ 4.9648189  35.60276153]
 [29.05978421 39.00374782]]
估計的混合項系數: [0.20857967121648013, 0.4790453223652429, 0.12929686139065455, 0.18307814502762204]
0.0009906091599516387 1.1166955441399562e-05
           
# 畫圖
plotShow()
           
機器學習系列(三)——EM算法1、前言2、EM算法引入3、EM算法4、推導逼近5、證明收斂6、高斯混合分布7、代碼

參考部落格

統計學習基礎

EM算法 - 期望極大算法