k-近鄰算法

機(jī)器學(xué)習(xí)實(shí)戰(zhàn) http://www.ituring.com.cn/book/1021

簡介

工作原理:存在一個(gè)樣本數(shù)據(jù)集,并且樣本集中每個(gè)數(shù)據(jù)都存在標(biāo)簽(樣本集中每一數(shù)據(jù)與所屬分類的對應(yīng)關(guān)系)。輸入沒有標(biāo)簽的新數(shù)據(jù)后,將新數(shù)據(jù)的每個(gè)特征與樣本集中數(shù)據(jù)對應(yīng)的特征進(jìn)行比較,然后算法提取樣本集中特征最相似數(shù)據(jù)(最鄰近)的分類標(biāo)簽。一般來說,選擇樣本數(shù)據(jù)集中前k個(gè)最相似的數(shù)據(jù),最后選擇k個(gè)最相似數(shù)據(jù)中出現(xiàn)次數(shù)最多的分類,作為新數(shù)據(jù)的分類。

例如:

電影 打斗次數(shù) 接吻次數(shù) 電影類型
California Man 3 104 Romance
He's Not Really into Dudes 2 100 Romance
Beautiful Woman 1 81 Romance
Kevin Longblade 101 10 Action
Robo Slayer 3000 99 5 Action
Amped II 98 2 Action
未知 18 90 Unknown

已知電影與未知電影的距離:

電影名稱 與未知電影的距離
California Man 20.5
He's Not Really into Dudes 18.7
Beautiful Woman 19.2
Kevin Longblade 115.3
Robo Slayer 3000 117.4
Amped II 118.9

按照距離遞增排序可以找到k個(gè)距離最近的電影,假設(shè)k=3,取最近三部電影的類型可知未知電影也是一部愛情片。

算法步驟

1)計(jì)算測試數(shù)據(jù)與各個(gè)訓(xùn)練數(shù)據(jù)之間的距離;
2)按照距離的遞增關(guān)系進(jìn)行排序;
3)選取距離最小的K個(gè)點(diǎn);
4)確定前K個(gè)點(diǎn)所在類別的出現(xiàn)頻率;
5)返回前K個(gè)點(diǎn)中出現(xiàn)頻率最高的類別作為測試數(shù)據(jù)的預(yù)測分類

import numpy as np
import operator
##給出訓(xùn)練數(shù)據(jù)以及對應(yīng)的類別
def createDataSet():
    group = np.array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group,labels

###通過KNN進(jìn)行分類
def classify0(inX,dataSet,labels,k):
    dataSetSize = dataSet.shape[0]
    # print(dataSetSize)
    # print(np.tile(inX,(dataSetSize,1)))
    ###計(jì)算距離
    diffMat = np.tile(inX,(dataSetSize,1)) - dataSet
    # print(diffMat)
    sqDiffMat = diffMat**2
    # print(sqDiffMat)
    sqDistances = sqDiffMat.sum(axis=1)
    # print(sqDistances)
    distances = sqDistances**0.5
    print('distances:',distances)
    #根據(jù)元素的值從大到小對元素進(jìn)行排序,返回下標(biāo)
    sortedDistIndicies = distances.argsort()
    # print('sortedDistIndicies:',sortedDistIndicies)
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
    #選取出最多的類別
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

if __name__ == '__main__':
    group, labels = createDataSet()
    input = [0.5,0.5]
    output = classify0(input,group,labels,3)
    print("測試數(shù)據(jù)為:",input,"分類結(jié)果為:",output)
>>>distances: [ 0.78102497  0.70710678  0.70710678  0.64031242]
>>>測試數(shù)據(jù)為: [0.5, 0.5] 分類結(jié)果為: B
歸一化數(shù)值

在計(jì)算測試樣本和樣本集中數(shù)據(jù)距離的時(shí)候,有些特征的差值較大,對計(jì)算結(jié)果有較大影響。在處理這種不同取值范圍的特征值時(shí),我們通常采用的方法就是將數(shù)值歸一化,如將取值范圍處理為0到1或者-1到1之間。

newValue = (oldValue-min)/(max-min)
def autoNum(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normDataSet = np.zeros(np.shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - np.tile(minVals,(m,1))
    normDataSet = normDataSet/np.tile(ranges,(m,1))
    return normDataSet,ranges,minVals

測試算法

'''
將文本記錄裝換為NumPy的解析程序
'''
from collections import Counter
def file2matrix(filename):
    fr = open(filename)
    arrayOLines = fr.readlines()
    numberOfLines = len(arrayOLines)
    returnMat = np.zeros((numberOfLines,3))
    classLabelVector = []
    index = 0
    for line in arrayOLines:
        line = line.strip()
        listFormLine = line.split('\t')
        returnMat[index,:] = listFormLine[0:3]
        classLabelVector.append(listFormLine[-1])
        index += 1
    dictClassLabel = Counter(classLabelVector)
    classLabel = []
    kind = list(dictClassLabel)
    for item in classLabelVector:
        if item == kind[0]:
            item = 1
        elif item == kind[1]:
            item = 2
        else:
            item = 3
        classLabel.append(item)
    return returnMat,classLabel

def datingClassTest():
    hoRatio = 0.1  #選取10%測試,剩下的是已知數(shù)據(jù)集
    datingDataMat,datingLabels = file2matrix('datingTestSet.txt')
    normMat,ranges,minVals = autoNum(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print("the classifier came back with %d, the real answer is %d"%(classifierResult,datingLabels[i]))
        if(classifierResult != datingLabels[i]):errorCount += 1.0
    print("the total error rate is %f" %(errorCount/float(numTestVecs)))

算法的scikit-learn實(shí)現(xiàn)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容