機器學(xué)習(xí)筆記 - 20. EM算法實踐(講師:鄒博)

主要內(nèi)容

2018-12-12 20_25_51-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

多維高斯混合分布聚類

EM算法的聚類效果或許比K均值聚類好一些。


2018-12-12 20_27_30-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

如圖,對于二維數(shù)據(jù)形成概率密度曲線,或者說等值線:


2018-12-12 20_28_09-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

這個圖也說明,身高一定符合高斯分布,不一定對。
下圖表明,男性符合幾個混合高斯分布,女性符合幾個混合高斯分布


2018-12-25 19_27_33-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

問答

問:歸一化的幾種優(yōu)劣之處?
答:比如做min-max,Scalar或標準版,如果數(shù)據(jù)服從均勻分布,可能做min-max好一些,但是如果數(shù)據(jù)服從高斯分布,可能標準化更好一些。
問:為什么鳶尾花沒有隱變量?鳶尾花也可能有某個未知的特征決定它的分類,任何分布不是都可能有隱變量么?
答:是的。也許花萼長寬與花瓣長寬并不是鳶尾花最重要特征。只是沒有提取
問:針對Kmeans多特征的情景,是不是用PCA處理以后,變成3維或2維,然后再用聚類的方式處理?
答:如果不需要做算法解釋的話,這么做是合理的;但是如果需要做算法解釋,建議不要用PCA,否則特征無法解釋。
比如如果數(shù)據(jù)是200維的,那么就針對這200維做K-Means聚類。

2018-12-25 19_29_38-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

高斯分布的公式:
f(x) = 1/((2*π)0.5*σ) * e-(x-μ)2/(2*σ2)
其中:μ是均值,σ2是方差
如果f(x)是多元的,則得到多元高斯分布的概率密度函數(shù):f(x) = (2*π)-n/2*(Σ-1)n*e-(1/2)*(x-μ)TT*(x-μ)
此處Σ為協(xié)方差矩陣,首先它是一個nxn的對稱方陣。
這個Σ矩陣,在做混合高斯模型的時候,就出問題了。
比如現(xiàn)在做一個二聚類:
GMM:
N(μ1,Σ1)以及N(μ2,Σ2)
μ1與μ2都是n元的,而Σ1與Σ2都是矩陣:

  1. 如果矩陣是單位矩陣,如:
    1 0 0 0
    0 1 0 0
    0 0 1 0
    0 0 0 1,
    則σ 乘以單位矩陣得到:σ · I
    則圖形為球面的,即圖中的Spherical柱狀圖。
    參數(shù)有1個

  2. 如果矩陣是對角矩陣,如:
    σ12 0 0 0
    0 σ22 0 0
    0 0 σ32 0
    0 0 0 σ42
    則得到diag柱狀圖
    參數(shù)有n個

  3. 如果Σ1 = Σ2,則形成tied,即相互關(guān)聯(lián)的
    即圖中的tied柱狀圖
    參數(shù)有nxn個,準確的說是nx(n+1)個參數(shù)

  4. Σ1 與 Σ2沒有任何關(guān)聯(lián),我們求正常的EM算法
    理論上有2倍的nx(n+1)個參數(shù)
    即圖中的full柱狀圖

只要是做混合高斯模型,基本都會涉及這四個參數(shù)

問答

問:協(xié)方差矩陣為什么是個對稱陣?
答:因為這是定義。協(xié)方差矩陣是對稱陣。
問:怎么看出來是球形?
答:如果隨機變量是三元的,協(xié)方差矩陣如果三個方差都相等,即主軸,副軸與短軸都相等,得到球形
問:這四種情況怎么來的?
答:就是參數(shù)的設(shè)置。不管是做EM, sklearn的設(shè)置,還有隱馬爾科夫模型,如果說隱馬爾科夫模型符合高斯分布,那么就是高斯隱馬爾科夫模型,那個模型里面,不同的隱變量,如果是方差也有方差是否相等,方差是不是對角陣等情況
問:參數(shù)設(shè)置時,這幾種情況怎么選?
答:如果不知道怎么選,我們選full,即參數(shù)有2倍的nx(n+1)個,如果知道是對角陣,選diagonal,當(dāng)然都試一下也無妨。
問:EM算法是無監(jiān)督學(xué)習(xí)么?
答:EM算法可以看成無監(jiān)督學(xué)習(xí),雖然它是一種算法,是描述how而不是what。比如EM,MLE(最大似然估計),SGD(隨機梯度下降),L-BFGS(擬牛頓)都是講的how,即解決what的具體的方法。

模型選擇的準則

2018-12-25 20_12_36-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

AIC解釋:
負的對數(shù)似然,就可以作為目標函數(shù);
但是我們不希望過擬合,所以需要在損失函數(shù)的前提下,加一個模型的復(fù)雜程度,比如模型的維度作為一個復(fù)雜標準。
哪個模型的這個值小,哪個模型最優(yōu)。
即2k就成為了正則項

BIC解釋:
樣本多可能帶來模型復(fù)雜度變化,如果兩個模型,一個樣本多,一個樣本少,在結(jié)果相同的情況下,樣本少的模型,看起來要好一些。
所以乘以與樣本個數(shù)有關(guān)的項,是有道理的。
即(lnn)k
BIC也可以認為貝葉斯信息準則
BIC看相對大小才有意義,絕對大小沒有意義

2018-12-25 20_27_03-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

很顯然,當(dāng)參數(shù)選擇full的時候,錯誤率幾乎就是0,并且BIC最小。即選擇full參數(shù)的時候,模型是最優(yōu)的。

問題:為什么上圖右下角有一小塊紅色?
答:因為紅色方差大。

問答

問:上述的例子說明什么?
答:對于Σ1與Σ的選擇,引入更多的參數(shù)是否值得
問:平時計算EM的Σ不都是full類型么?
答:是的
問:樣本的個數(shù)n為什么越大,BIC就大呢?為什么和樣本個數(shù)n有關(guān)系?
答:能達到相同效果的時候,如果樣本比別人多,那么模型就沒有別人好。

2018-12-25 20_37_50-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png
2018-12-25 20_38_39-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

如圖,上圖中三分類效果遠遠比二分類要差,所以可以加入一些先驗知識。
如圖,如果模型中的參數(shù)θ是未知定值,則可以通過最大似然估計(MLE)以及期望最大化(EM)去求。
如果θ也是變化的,且符合概率分布,即P(θ|α),這個是先驗分布,
對于樣本y,只要給出x就能算出y的分布,且是對于θ的概率分布,這個是似然分布
P(θ|x, y),則是屬于后驗分布


2019-01-03 19_41_14-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

2019-01-03 19_48_16-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

接著進行計算,如果θ有無窮多個,那么哪一個θ是最大的,就是我們想要求的:


2019-01-03 19_51_04-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

θ這個值如何去求?
如圖,后驗分布,可以認為與似然分布 * 先驗分布成正比。

2019-01-03 19_53_45-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

如果θ是Dirichlet(狄利克雷)分布,可以演化為:

Dirichlet分布(參數(shù)為α+x) = 多項分布 * Dirichlet分布(參數(shù)為α)

如果α采樣的值,拍腦袋選擇1,10,或100,
假定α=1,得到θ1,θ2,θ3,。。。,θ100
從而分別得到(x1, y1),(x2, y2), (x3, y3), 。。。, (x100, y100)
這些是我們看到的樣本數(shù)據(jù)。
其中每一個θ都是根據(jù)α采樣得到的,即每一個θ都是一個隨機變量,構(gòu)成了一個隨機過程,或者說構(gòu)成了Dirichlet過程


2019-01-03 20_07_03-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

α取1的時候是最特殊的,即此時為均勻分布

再回頭來看這張圖,我們使用了符合高斯混合分布的模型,但是我們希望對參數(shù)做一個影響,那么使用Dirichlet過程+高斯混合分布模型,就得到DPGMM的模型,此時分類是合理的。
如圖左邊是正常的高斯混合模型,分類分錯了;
右邊是使用了貝葉斯的高斯混合模型,是特定的Dirichlet過程+高斯混合分布模型得到的結(jié)果,即使分類選擇3,但是得到的結(jié)果也是正確的

在sk-learn中, DPGMM的相關(guān)類為:BayesianGaussianMixture


2018-12-25 20_38_39-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

相關(guān)的代碼如下:

# !/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
import scipy as sp
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse


def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d


matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False


if __name__ == '__main__':
    np.random.seed(0)
    cov1 = np.diag((1, 2))
    N1 = 500
    N2 = 300
    N = N1 + N2
    x1 = np.random.multivariate_normal(mean=(3, 2), cov=cov1, size=N1)
    m = np.array(((1, 1), (1, 3)))
    x1 = x1.dot(m)
    x2 = np.random.multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
    x = np.vstack((x1, x2))
    y = np.array([0]*N1 + [1]*N2)
    n_components = 3

    # 繪圖使用
    colors = '#A0FFA0', '#2090E0', '#FF8080'
    cm = mpl.colors.ListedColormap(colors)
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    grid_test = np.stack((x1.flat, x2.flat), axis=1)

    plt.figure(figsize=(6, 6), facecolor='w')
    plt.suptitle('GMM/DPGMM比較', fontsize=15)

    ax = plt.subplot(211)
    gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
    gmm.fit(x)
    centers = gmm.means_
    covs = gmm.covariances_
    print('GMM均值 = \n', centers)
    print('GMM方差 = \n', covs)
    y_hat = gmm.predict(x)

    grid_hat = gmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
    plt.scatter(x[:, 0], x[:, 1], s=20, c=y, cmap=cm, marker='o', edgecolors='#202020')

    clrs = list('rgbmy')
    for i, (center, cov) in enumerate(zip(centers, covs)):
        value, vector = sp.linalg.eigh(cov)
        width, height = value[0], value[1]
        v = vector[0] / sp.linalg.norm(vector[0])
        angle = 180* np.arctan(v[1] / v[0]) / np.pi
        e = Ellipse(xy=center, width=width, height=height,
                    angle=angle, color=clrs[i], alpha=0.5, clip_box = ax.bbox)
        ax.add_artist(e)

    ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.title('GMM', fontsize=15)
    plt.grid(b=True, ls=':', color='#606060')

    # DPGMM
    dpgmm = BayesianGaussianMixture(n_components=n_components, covariance_type='full', max_iter=1000, n_init=5,
                                    weight_concentration_prior_type='dirichlet_process', weight_concentration_prior=0.1)
    dpgmm.fit(x)
    centers = dpgmm.means_
    covs = dpgmm.covariances_
    print('DPGMM均值 = \n', centers)
    print('DPGMM方差 = \n', covs)
    y_hat = dpgmm.predict(x)
    print(y_hat)

    ax = plt.subplot(212)
    grid_hat = dpgmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
    plt.scatter(x[:, 0], x[:, 1], s=20, c=y, cmap=cm, marker='o', edgecolors='#202020')

    for i, cc in enumerate(zip(centers, covs)):
        if i not in y_hat:
            continue
        center, cov = cc
        value, vector = sp.linalg.eigh(cov)
        width, height = value[0], value[1]
        v = vector[0] / sp.linalg.norm(vector[0])
        angle = 180* np.arctan(v[1] / v[0]) / np.pi
        e = Ellipse(xy=center, width=width, height=height,
                    angle=angle, color='m', alpha=0.5, clip_box = ax.bbox)
        ax.add_artist(e)
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.title('DPGMM', fontsize=15)
    plt.grid(b=True, ls=':', color='#606060')
    plt.tight_layout(2, rect=(0, 0, 1, 0.95))
    plt.show()

得到的結(jié)果如圖:


2019-01-03 20_15_36_20.EM_20.png

問:看來要講貝葉斯啊
答:是的,要為下次LDA做鋪墊
問:DPGMM其實就是一個主題和樣本的分布作為權(quán)重,乘以主題和樣本的高斯混合分布?
答:一定程度可以這樣解釋。
問:P(,;|)這三個符號一般是指什么呢?
答:P(x,y) <=> P(y, x),這個代表x與y的聯(lián)合分布
P(y; x)與P(y|x)是等價的,即x是條件,y是x的因變量
但是如果代入θ,就不一樣了:
但是P(y;θ),屬于頻率學(xué)派;其中θ為參數(shù),θ是未知的定值;
而P(y|θ),屬于貝葉斯學(xué)派,則樣本θ是未知的隨機變量
問:都是同樣阿爾法,怎么能采樣出多個θ呢?
答:高斯分布中,如果均值為170,方差為10,取樣可以為226,175,168多個值;同理,同樣的阿爾法,也可以采樣出多個θ
問:模型和里面的實現(xiàn)可以組合么?比如這種混合高斯模型,能用隨機或批量梯度下降達到目的么?
答:其實是達不到的。因為混合高斯分布,其目標函數(shù)是有隱變量的存在,所以沒辦法對其直接求取梯度,只能固定隱變量求梯度;固定梯度求隱變量,二者不斷迭代,最后才得到EM
問:均值為0和不為0有什么區(qū)別?效果會變化么?
答:均值是否為0只是一個解釋,因為不為0的時候,我們總是會將其調(diào)整為0附近的。比如事先減均值。

2019-02-04 15_07_45-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

求導(dǎo)的過程很簡單:
?h(p)/?p = n*pn-1*(1-p)(N-n) - pn*(N-n)*(1-p)(N-n-1)
假定導(dǎo)數(shù)為0,即:
?h(p)/?p = n*pn-1*(1-p)(N-n) - pn*(N-n)*(1-p)(N-n-1) = 0
則等式兩邊除以pn*(1-p)(N-n),得到:
?h(p)/?p = n*p-1 - (N-n)*(1-p)-1 = n/p - (N-n)/(1-p) = 0,
可得:p = n/N
下面看二項分布與先驗舉例:
2019-02-04 15_47_50-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

可以觀察到,修正公式的分子各加了一個5,而這個5是Dirichlet(狄利克雷)分布的超參數(shù)α。

2019-02-04 15_52_03-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png

問答
問:是不是這個課程代碼都敲會,加上一個項目經(jīng)驗就OK?
答:這個看需要。比如現(xiàn)在的情況,機器學(xué)習(xí)是研究整個這套方式的一個根基。不用基礎(chǔ)這個詞,是因為以為其很簡單,所以用根基這個詞。有了這個根基之后,大家再去往上做其他應(yīng)用,不會感覺困難。比如用卷積網(wǎng)絡(luò),最后一層我們使用SoftMax的全連接,還是用SVM,本質(zhì)上是換損失函數(shù)。然后我們解釋模型是否有效,都是可以用上的。
基礎(chǔ)是夠的,但是如果大家沒有深度學(xué)習(xí)的應(yīng)用實踐,或者只有一個項目經(jīng)驗,還是不夠的。所以需要實際項目進行反復(fù)驗證與活學(xué)活用,或者參加競賽,比賽也可以。
問:強化學(xué)習(xí)和機器學(xué)習(xí)的關(guān)聯(lián)大么?強化學(xué)習(xí)未來的應(yīng)用前景如何?
答:強化學(xué)習(xí)可能是近兩三年的爆發(fā)點,可能是大公司玩的。需要的算力,比數(shù)據(jù)要強。比如飛翔的小鳥,或者行走的人,需要對當(dāng)前的動作進行反饋,然后根據(jù)反饋的結(jié)果,去更正動作,并不斷學(xué)習(xí)。算力要求非常高,前景要求是有的,但目前只能進行簡單游戲、博弈、對抗這種內(nèi)容,沒法成為最主力的算法應(yīng)用。也許不對,但最主流的還是在有監(jiān)督應(yīng)用。
問:算力是什么?可以簡單理解為硬件速度么?
答:這個理解沒問題的。

EM算法代碼

下面代碼有自己實現(xiàn)的高斯混合模型,以及通過sk-learn庫的高斯混合模型類直接實現(xiàn)的兩種方式。
自己實現(xiàn)的高斯混合模型,其實就是實現(xiàn)期望最大化,即EM算法,公式如下:


2019-02-04 17_19_35-【鄒博_chinahadoop】機器學(xué)習(xí)升級版VII(七).png
import numpy as np
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin


mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False


if __name__ == '__main__':
    #style = 'sklearn'
    style = 'myself'
    np.random.seed(0)
    mu1_fact = (0, 0, 0)
    cov1_fact = np.diag((1, 2, 3))
    # 根據(jù)實際情況生成一個多元正態(tài)分布矩陣,np.random.multivariate_normal
    # 參數(shù)就是高斯分布所需的均值與方差
    # 第一個參數(shù): mean:mean是多維分布的均值維度為1;
    # 第二個參數(shù):cov:協(xié)方差矩陣,注意:協(xié)方差矩陣必須是對稱的且需為半正定矩陣;
    # 第三個參數(shù):size:指定生成的正態(tài)分布矩陣的維度
    data1 = np.random.multivariate_normal(mu1_fact, cov1_fact, 400)
    print('data1 shape: {0}'.format(data1.shape))
    mu2_fact = (2, 2, 1)
    # 方差對稱且正定(positive-semidefinite): (4, 1, 3), (1, 2, 1), (3, 1, 4)
    cov2_fact = np.array(((4, 1, 3), (1, 2, 1), (3, 1, 4)))
    data2 = np.random.multivariate_normal(mu2_fact, cov2_fact, 100)
    print('data2 shape: {0}'.format(data2.shape))

    data = np.vstack((data1, data2))
    print('data shape: {0}'.format(data.shape))
    y = np.array([True] * 400 + [False] * 100)

    if style == 'sklearn':
        g = GaussianMixture(n_components=2, covariance_type='full', tol=1e-6, max_iter=1000)
        g.fit(data)
        print('類別概率:\t', g.weights_[0])
        print('均值:\n', g.means_, '\n')
        print('方差:\n', g.covariances_, '\n')
        mu1, mu2 = g.means_
        sigma1, sigma2 = g.covariances_
    else:
        num_iter = 100
        n, d = data.shape
        # 隨機指定
        # mu1 = np.random.standard_normal(d)
        # print mu1
        # mu2 = np.random.standard_normal(d)
        # print mu2
        mu1 = data.min(axis=0)
        mu2 = data.max(axis=0)
        # 創(chuàng)建d行d列的單位矩陣(對角線為1,其余為0)
        sigma1 = np.identity(d)
        sigma2 = np.identity(d)
        pi = 0.5
        # EM
        for i in range(num_iter):
            # E Step
            # 通過初始化的均值與方差,做多元的正態(tài)分布
            norm1 = multivariate_normal(mu1, sigma1)
            norm2 = multivariate_normal(mu2, sigma2)
            # 概率密度 * pi
            tau1 = pi * norm1.pdf(data)
            tau2 = (1 - pi) * norm2.pdf(data)
            gamma = tau1 / (tau1 + tau2)

            # M Step
            mu1 = np.dot(gamma, data) / np.sum(gamma)
            mu2 = np.dot((1 - gamma), data) / np.sum((1 - gamma))
            sigma1 = np.dot(gamma * (data - mu1).T, data - mu1) / np.sum(gamma)
            sigma2 = np.dot((1 - gamma) * (data - mu2).T, data - mu2) / np.sum(1 - gamma)
            pi = np.sum(gamma) / n
            print(i, ":\t", mu1, mu2)
        print('類別概率:\t', pi)
        print('均值:\t', mu1, mu2)
        print('方差:\n', sigma1, '\n\n', sigma2, '\n')

    # 預(yù)測分類
    # multivariate_normal獲得多元正態(tài)分布
    norm1 = multivariate_normal(mu1, sigma1)
    norm2 = multivariate_normal(mu2, sigma2)
    # pdf: Probability density function,連續(xù)性概率分布函數(shù)
    tau1 = norm1.pdf(data)
    tau2 = norm2.pdf(data)

    fig = plt.figure(figsize=(10, 5), facecolor='w')
    ax = fig.add_subplot(121, projection='3d')
    ax.scatter(data[:, 0], data[:, 1], data[:, 2], c='b', s=30, marker='o', edgecolors='k', depthshade=True)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('原始數(shù)據(jù)', fontsize=15)
    ax = fig.add_subplot(122, projection='3d')
    # 求取點距離
    order = pairwise_distances_argmin([mu1_fact, mu2_fact], [mu1, mu2], metric='euclidean')
    # order = pairwise_distances_argmin([mu1_fact, mu2_fact], [mu1, mu2], metric='cosine')

    # 通過歐式距離,將點分為兩類
    print(order)
    if order[0] == 0:
        c1 = tau1 > tau2
    else:
        c1 = tau1 < tau2
    c2 = ~c1
    # 機器學(xué)習(xí)計算準確率的常用做法
    # 原理:真實值是y,預(yù)測值是c1,相等則為True,否則為False。True為1,F(xiàn)alse為0
    # 求均值則為:預(yù)測準確的數(shù)目/總數(shù)目,這不就是準確率么
    acc = np.mean(y == c1)
    print('準確率:%.2f%%' % (100*acc))
    ax.scatter(data[c1, 0], data[c1, 1], data[c1, 2], c='r', s=30, marker='o', edgecolors='k', depthshade=True)
    ax.scatter(data[c2, 0], data[c2, 1], data[c2, 2], c='g', s=30, marker='^', edgecolors='k', depthshade=True)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('EM算法分類', fontsize=15)
    plt.suptitle('EM算法的實現(xiàn)', fontsize=18)
    plt.subplots_adjust(top=0.90)
    # plt.tight_layout()
    plt.show()

得到的圖形界面:

emdraw.png

如果將data2,那100個數(shù)據(jù),均值設(shè)置為(5, 5, 5),則分類效果更明顯,如圖:
emdraw.png

當(dāng)均值為5, 5, 5時,且自己實現(xiàn)高斯分布,輸出如下。
可以發(fā)現(xiàn),迭代24次之后,均值就不再變化了,可以稱為模型收斂了。
真實值為[0, 0, 0]與[5, 5, 5],計算的均值為:
[-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ],有差別,但是靠譜。
當(dāng)然,也可以增加樣本,使得結(jié)果更接近真實值的情況。比如將400, 100替換為4000, 1000

data1 shape: (400, 3)
data2 shape: (100, 3)
data shape: (500, 3)

0 :  [-0.02992749  0.09146815  0.03351835] [5.43632719 5.10518101 5.44044355]
1 :  [-0.0577343   0.02743837  0.00975419] [5.23010882 5.07679997 5.22085292]
2 :  [-0.064707    0.00613346 -0.0011657 ] [5.13849255 5.04871819 5.14698533]
3 :  [-0.0673045  -0.00120726 -0.00558504] [5.0995056  5.03026551 5.1158224 ]
4 :  [-0.068465   -0.00420981 -0.00743412] [5.08222193 5.02089901 5.10147628]
5 :  [-0.06899479 -0.00554239 -0.00824704] [5.07428887 5.01640324 5.09475058]
6 :  [-0.06923923 -0.00615431 -0.00861565] [5.07059253 5.01427675 5.09158405]
7 :  [-0.06935293 -0.00643926 -0.00878579] [5.06886018 5.01327497 5.09009261]
8 :  [-0.06940609 -0.00657271 -0.00886508] [5.06804645 5.01280356 5.0893904 ]
9 :  [-0.06943103 -0.00663537 -0.00890221] [5.06766389 5.01258178 5.0890599 ]
10 :     [-0.06944274 -0.00666482 -0.00891964] [5.06748396 5.01247744 5.08890438]
11 :     [-0.06944825 -0.00667867 -0.00892783] [5.06739933 5.01242836 5.08883121]
12 :     [-0.06945084 -0.00668519 -0.00893168] [5.06735951 5.01240527 5.08879678]
13 :     [-0.06945206 -0.00668825 -0.00893349] [5.06734079 5.01239441 5.08878059]
14 :     [-0.06945263 -0.00668969 -0.00893434] [5.06733197 5.0123893  5.08877297]
15 :     [-0.0694529  -0.00669037 -0.00893474] [5.06732783 5.0123869  5.08876938]
16 :     [-0.06945302 -0.00669069 -0.00893493] [5.06732588 5.01238577 5.0887677 ]
17 :     [-0.06945308 -0.00669084 -0.00893502] [5.06732496 5.01238523 5.0887669 ]
18 :     [-0.06945311 -0.00669091 -0.00893506] [5.06732453 5.01238498 5.08876653]
19 :     [-0.06945313 -0.00669095 -0.00893508] [5.06732433 5.01238487 5.08876635]
20 :     [-0.06945313 -0.00669096 -0.00893509] [5.06732423 5.01238481 5.08876627]
21 :     [-0.06945314 -0.00669097 -0.00893509] [5.06732419 5.01238478 5.08876623]
22 :     [-0.06945314 -0.00669097 -0.00893509] [5.06732417 5.01238477 5.08876621]
23 :     [-0.06945314 -0.00669097 -0.0089351 ] [5.06732416 5.01238477 5.08876621]
24 :     [-0.06945314 -0.00669097 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
25 :     [-0.06945314 -0.00669097 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
26 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
27 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
28 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
29 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
30 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
31 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
32 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
33 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
34 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
35 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
36 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
37 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
38 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
39 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
40 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
41 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
42 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
43 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
44 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
45 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
46 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
47 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
48 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
49 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
50 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
51 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
52 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
53 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
54 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
55 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
56 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
57 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
58 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
59 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
60 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
61 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
62 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
63 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
64 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
65 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
66 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
67 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
68 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
69 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
70 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
71 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
72 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
73 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
74 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
75 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
76 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
77 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
78 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
79 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
80 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
81 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
82 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
83 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
84 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
85 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
86 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
87 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
88 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
89 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
90 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
91 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
92 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
93 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
94 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
95 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
96 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
97 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
98 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
99 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
類別概率:    0.7987220297951044
均值:  [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
方差:
 [[ 0.87148101 -0.05642494  0.03198856]
 [-0.05642494  2.09700921 -0.12547629]
 [ 0.03198856 -0.12547629  2.745459  ]] 

 [[4.08142083 0.79087313 3.107469  ]
 [0.79087313 1.79995257 0.75954681]
 [3.107469   0.75954681 4.04331614]] 

[0 1]
準確率:98.60%

問答
問:協(xié)方差一定是對稱的么?
答:是的,協(xié)方差一定是對稱的
問:np.identity是什么意思?
答:創(chuàng)建單位矩陣,即對角線為1,其余值為0

GMM代碼實現(xiàn)

對應(yīng)業(yè)務(wù)為性別-身高-體重數(shù)據(jù)
通過高斯混合模型,預(yù)測身高與體重所屬的性別

# !/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt

mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False
# from matplotlib.font_manager import FontProperties
# font_set = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=15)
# fontproperties=font_set


def expand(a, b):
    d = (b - a) * 0.05
    return a-d, b+d


if __name__ == '__main__':
    data = np.loadtxt('./HeightWeight.csv', dtype=np.float, delimiter=',', skiprows=1)
    print(data.shape)
    y, x = np.split(data, [1, ], axis=1)
    x, x_test, y, y_test = train_test_split(x, y, train_size=0.6, random_state=0)
    gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
    x_min = np.min(x, axis=0)
    x_max = np.max(x, axis=0)
    gmm.fit(x)
    print('均值 = \n', gmm.means_)
    print('方差 = \n', gmm.covariances_)
    y_hat = gmm.predict(x)
    y_test_hat = gmm.predict(x_test)
    change = (gmm.means_[0][0] > gmm.means_[1][0])
    if change:
        z = y_hat == 0
        y_hat[z] = 1
        y_hat[~z] = 0
        z = y_test_hat == 0
        y_test_hat[z] = 1
        y_test_hat[~z] = 0
    acc = np.mean(y_hat.ravel() == y.ravel())
    acc_test = np.mean(y_test_hat.ravel() == y_test.ravel())
    acc_str = '訓(xùn)練集準確率:%.2f%%' % (acc * 100)
    acc_test_str = '測試集準確率:%.2f%%' % (acc_test * 100)
    print(acc_str)
    print(acc_test_str)

    cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0'])
    cm_dark = mpl.colors.ListedColormap(['r', 'g'])
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    grid_test = np.stack((x1.flat, x2.flat), axis=1)
    grid_hat = gmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    if change:
        z = grid_hat == 0
        grid_hat[z] = 1
        grid_hat[~z] = 0
    plt.figure(figsize=(7, 6), facecolor='w')
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
    plt.scatter(x[:, 0], x[:, 1], s=50, c=y.ravel(), marker='o', cmap=cm_dark, edgecolors='k')
    plt.scatter(x_test[:, 0], x_test[:, 1], s=60, c=y_test.ravel(), marker='^', cmap=cm_dark, edgecolors='k')

    p = gmm.predict_proba(grid_test)
    print(p)
    p = p[:, 0].reshape(x1.shape)
    CS = plt.contour(x1, x2, p, levels=(0.1, 0.5, 0.8), colors=list('rgb'), linewidths=2)
    plt.clabel(CS, fontsize=12, fmt='%.1f', inline=True)
    ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
    xx = 0.95*ax1_min + 0.05*ax1_max
    yy = 0.05*ax2_min + 0.95*ax2_max
    plt.text(xx, yy, acc_str, fontsize=12)
    yy = 0.1*ax2_min + 0.9*ax2_max
    plt.text(xx, yy, acc_test_str, fontsize=12)
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.xlabel('身高(cm)', fontsize=13)
    plt.ylabel('體重(kg)', fontsize=13)
    plt.title('EM算法估算GMM的參數(shù)', fontsize=15)
    plt.grid(b=True, ls=':', color='#606060')
    plt.tight_layout(2)
    plt.show()

其中HeightWeight.csv的數(shù)據(jù)如下,直接將其拷貝到文本文件,然后保存為文件名為HeightWeight.csv的文件即可

Sex,Height(cm),Weight(kg)
0,156,50
0,160,60
0,162,54
0,162,55
0,160.5,56
0,160,53
0,158,55
0,164,60
0,165,50
0,166,55
0,158,47.5
0,161,49
0,169,55
0,161,46
0,160,45
0,167,44
0,155,49
0,154,57
0,172,52
0,155,56
0,157,55
0,165,65
0,156,52
0,155,50
0,156,56
0,160,55
0,158,55
0,162,70
0,162,65
0,155,57
0,163,70
0,160,60
0,162,55
0,165,65
0,159,60
0,147,47
0,163,53
0,157,54
0,160,55
0,162,48
0,158,60
0,155,48
0,165,60
0,161,58
0,159,45
0,163,50
0,158,49
0,155,50
0,162,55
0,157,63
0,159,49
0,152,47
0,156,51
0,165,49
0,154,47
0,156,52
0,162,48
1,162,60
1,164,62
1,168,86
1,187,75
1,167,75
1,174,64
1,175,62
1,170,65
1,176,73
1,169,58
1,178,54
1,165,66
1,183,68
1,171,61
1,179,64
1,172,60
1,173,59
1,172,58
1,175,62
1,160,60
1,160,58
1,160,60
1,175,75
1,163,60
1,181,77
1,172,80
1,175,73
1,175,60
1,167,65
1,172,60
1,169,75
1,172,65
1,175,72
1,172,60
1,170,65
1,158,59
1,167,63
1,164,61
1,176,65
1,182,95
1,173,75
1,176,67
1,163,58
1,166,67
1,162,59
1,169,56
1,163,59
1,163,56
1,176,62
1,169,57
1,173,61
1,163,59
1,167,57
1,176,63
1,168,61
1,167,60
1,170,69

圖形示例如下:


2019-02-04 17_41_49-Start.png

問:from sklearn.metrics.pairwise import pairwise_distances_argmin,這個是干嘛的?
答:是用于計算任意的兩個值里面,誰和誰是最小的。比如:order = pairwise_distances_argmin([mu1_fact, mu2_fact], [mu1, mu2], metric='euclidean'),返回的值是[0,1],表明mu1與mu1_fact最近,mu2與mu2_fact最近。換句話說,我們做的順序是做對了。

通過GMM實現(xiàn)鳶尾花分類

通過高斯混合模型,對鳶尾花數(shù)據(jù)做分類

# !/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin

mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False

iris_feature = '花萼長度', '花萼寬度', '花瓣長度', '花瓣寬度'


def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d


if __name__ == '__main__':
    path = '..\9.Regression\iris.data'
    data = pd.read_csv(path, header=None)
    x_prime = data[np.arange(4)]
    y = pd.Categorical(data[4]).codes

    n_components = 3
    feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
    plt.figure(figsize=(8, 6), facecolor='w')
    for k, pair in enumerate(feature_pairs, start=1):
        x = x_prime[pair]
        m = np.array([np.mean(x[y == i], axis=0) for i in range(3)])  # 均值的實際值
        print('實際均值 = \n', m)

        gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
        gmm.fit(x)
        print('預(yù)測均值 = \n', gmm.means_)
        print('預(yù)測方差 = \n', gmm.covariances_)
        y_hat = gmm.predict(x)
        print(y_hat)
        order = pairwise_distances_argmin(m, gmm.means_, axis=1, metric='euclidean')
        print(order)
        print('順序:\t', order)

        n_sample = y.size
        n_types = 3
        change = np.empty((n_types, n_sample), dtype=np.bool)
        for i in range(n_types):
            change[i] = y_hat == order[i]
        for i in range(n_types):
            y_hat[change[i]] = i
        acc = '準確率:%.2f%%' % (100*np.mean(y_hat == y))
        print(acc)

        cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0', '#A0A0FF'])
        cm_dark = mpl.colors.ListedColormap(['r', 'g', '#6060FF'])
        x1_min, x2_min = x.min()
        x1_max, x2_max = x.max()
        x1_min, x1_max = expand(x1_min, x1_max)
        x2_min, x2_max = expand(x2_min, x2_max)
        x1, x2 = np.mgrid[x1_min:x1_max:200j, x2_min:x2_max:200j]
        grid_test = np.stack((x1.flat, x2.flat), axis=1)
        grid_hat = gmm.predict(grid_test)

        change = np.empty((n_types, grid_hat.size), dtype=np.bool)
        for i in range(n_types):
            change[i] = grid_hat == order[i]
        for i in range(n_types):
            grid_hat[change[i]] = i

        grid_hat = grid_hat.reshape(x1.shape)
        plt.subplot(2, 3, k)
        plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
        plt.scatter(x[pair[0]], x[pair[1]], s=20, c=y, marker='o', cmap=cm_dark, edgecolors='k')
        xx = 0.95 * x1_min + 0.05 * x1_max
        yy = 0.1 * x2_min + 0.9 * x2_max
        plt.text(xx, yy, acc, fontsize=10)
        plt.xlim((x1_min, x1_max))
        plt.ylim((x2_min, x2_max))
        plt.xlabel(iris_feature[pair[0]], fontsize=11)
        plt.ylabel(iris_feature[pair[1]], fontsize=11)
        plt.grid(b=True, ls=':', color='#606060')
    plt.suptitle('EM算法無監(jiān)督分類鳶尾花數(shù)據(jù)', fontsize=14)
    plt.tight_layout(1, rect=(0, 0, 1, 0.95))
    plt.show()

圖例如下:


2019-02-04 18_17_05-Figure 1.png

繪制高斯混合模型的等值線

# !/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
from sklearn.mixture import GaussianMixture
import scipy as sp
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import warnings


def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d


if __name__ == '__main__':
    warnings.filterwarnings(action='ignore', category=RuntimeWarning)
    np.random.seed(0)
    cov1 = np.diag((1, 2))
    N1 = 500
    N2 = 300
    N = N1 + N2
    x1 = np.random.multivariate_normal(mean=(3, 2), cov=cov1, size=N1)
    m = np.array(((1, 1), (1, 3)))
    x1 = x1.dot(m)
    x2 = np.random.multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
    x = np.vstack((x1, x2))
    y = np.array([0]*N1 + [1]*N2)

    gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
    gmm.fit(x)
    centers = gmm.means_
    covs = gmm.covariances_
    print('GMM均值 = \n', centers)
    print('GMM方差 = \n', covs)
    y_hat = gmm.predict(x)

    colors = '#A0FFA0', '#E080A0',
    levels = 10
    cm = mpl.colors.ListedColormap(colors)
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    grid_test = np.stack((x1.flat, x2.flat), axis=1)
    print(gmm.score_samples(grid_test))
    grid_hat = -gmm.score_samples(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    plt.figure(figsize=(7, 6), facecolor='w')
    ax = plt.subplot(111)
    cmesh = plt.pcolormesh(x1, x2, grid_hat, cmap=plt.cm.Spectral)
    plt.colorbar(cmesh, shrink=0.9)
    CS = plt.contour(x1, x2, grid_hat, levels=np.logspace(0, 2, num=levels, base=10), colors='w', linewidths=1)
    plt.clabel(CS, fontsize=9, inline=True, fmt='%.1f')
    plt.scatter(x[:, 0], x[:, 1], s=30, c=y, cmap=cm, marker='o', edgecolors='#202020')

    for i, cc in enumerate(zip(centers, covs)):
        center, cov = cc
        value, vector = sp.linalg.eigh(cov)
        width, height = value[0], value[1]
        v = vector[0] / sp.linalg.norm(vector[0])
        angle = 180* np.arctan(v[1] / v[0]) / np.pi
        e = Ellipse(xy=center, width=width, height=height,
                    angle=angle, color='m', alpha=0.5, clip_box = ax.bbox)
        ax.add_artist(e)

    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    mpl.rcParams['font.sans-serif'] = ['SimHei']
    mpl.rcParams['axes.unicode_minus'] = False
    plt.title('GMM似然函數(shù)值', fontsize=15)
    plt.grid(b=True, ls=':', color='#606060')
    plt.tight_layout(2)
    plt.show()

圖例如下:


2019-02-04 18_29_19-Start.png

問答
問:DPGMM選的k是不是要盡量???
答:不一定,與k值選擇是否小沒關(guān)系。
問:矩陣運算不是不能用交換律么?怎么直接交換了?
答:是對這個代碼:sigma1 = np.dot(gamma * (data - mu1).T, data - mu1) / np.sum(gamma),矩陣不能交換,但是標量值就可以進行交換。

EM算法內(nèi)容完結(jié)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容