題目
假設(shè)有如下八個(gè)點(diǎn):(3,1)(3,2)(4,1)(4,2)(1,3)(1,4)(2,3)(2,4),使用KMeans算法對(duì)其進(jìn)行聚類(lèi)。假設(shè)初始聚類(lèi)中心點(diǎn)分別為(0,4)和(3,3),則最終的聚類(lèi)中為(x , y)和(x,y)。
實(shí)現(xiàn)代碼
import numpy as np
from sklearn.cluster import KMeans
from matplotlib import pyplot
points = np.array([[3,1],[3,2],[4,1],[4,2],[1,3],[1,4],[2,3],[2,4]])
# pyplot.scatter(points[:,0],points[:,1])
# pyplot.show()
# 把數(shù)據(jù)點(diǎn)分組
clf = KMeans(n_clusters = 2)
clf.fit(points)
# 數(shù)據(jù)點(diǎn)的中心點(diǎn)
centers = clf.cluster_centers_
print(centers)
# 每個(gè)數(shù)據(jù)點(diǎn)所屬分組
labels = clf.labels_
# print(labels)
for i in range(len(labels)):
pyplot.scatter(points[i][0], points[i][1], c=('r' if labels[i] == 0 else 'b'))
pyplot.scatter(centers[:,0],centers[:,1], marker='*', s=100)
# 預(yù)測(cè)
predict = [[1.5,1.5], [3.5,3.5]]
label = clf.predict(predict)
for i in range(len(label)):
pyplot.scatter(predict[i][0], predict[i][1], c=('r' if label[i] == 0 else 'b'), marker='x')
pyplot.show()
參考資料:博客
可視化顯示

show()輸出的圖像@2x.png
*為兩組數(shù)據(jù)的中心點(diǎn),x為預(yù)測(cè)的中心點(diǎn)