k-近鄰算法構(gòu)建手寫識別系統(tǒng)

姓名:劉強(qiáng)
【嵌牛導(dǎo)讀】
手寫識別是計(jì)算機(jī)視覺的一個研究方向,可以看成是一個分類問題。機(jī)器學(xué)習(xí)的任務(wù),便是解決分類(有監(jiān)督學(xué)習(xí))、聚類(無監(jiān)督學(xué)習(xí))和回歸(強(qiáng)化學(xué)習(xí))問題。k-近鄰算法(簡稱kNN)是最簡單的有監(jiān)督學(xué)習(xí)算法,本文介紹了如何用k-近鄰算法構(gòu)建一個手寫識別系統(tǒng),并附上其python實(shí)現(xiàn)。
【嵌牛鼻子】
k-近鄰算法 機(jī)器學(xué)習(xí) 分類 手寫識別
【嵌牛提問】
k-近鄰算法是什么? 如何構(gòu)建一個手寫識別系統(tǒng)?
【嵌牛正文】

k近鄰算法基本思想

存在一個樣本數(shù)據(jù)集,稱為訓(xùn)練集,訓(xùn)練集中每個數(shù)據(jù)都存在標(biāo)簽(標(biāo)簽即數(shù)據(jù)所屬的類別,從這一點(diǎn)可以看出,k近鄰算法屬于有監(jiān)督學(xué)習(xí))。對于不知道標(biāo)簽的新數(shù)據(jù),將新數(shù)據(jù)的每個特征與訓(xùn)練集中數(shù)據(jù)對應(yīng)的特征相比較,選出訓(xùn)練集中前k個最相似的數(shù)據(jù)(這就是k-近鄰算法名稱中k的出處),然后對這k個數(shù)據(jù)做統(tǒng)計(jì),選擇出現(xiàn)次數(shù)最多的標(biāo)簽作為新數(shù)據(jù)的標(biāo)簽(即k-近鄰算法的輸出)。
從其基本思想可以看出,k-近鄰算法用于解決分類問題。所謂近鄰,其實(shí)是用數(shù)據(jù)之間的歐氏距離來衡量它們的相似程度,距離越短,表示兩個數(shù)據(jù)越相似。

圖片來源于知乎

構(gòu)建手寫識別系統(tǒng)

需求分析

很多輸入法都支持手寫輸入,實(shí)現(xiàn)手寫輸入通常的做法是把手寫的結(jié)果生成圖片,進(jìn)行圖像識別。我們知道,圖片可以用矩陣表示,對于單通道的灰度圖像,假如分辨率為32X32,則可以用一個32X32的矩陣表示,矩陣中的每個元素表示圖片中該位置的像素,元素的值為0~255之間的灰度值。
而對于手寫圖片,表示方法則更加簡單,因?yàn)槭謱憟D片是只有黑白兩色的二值圖像,利用圖像處理軟件,黑色的位置寫1,白色背景寫0,將其轉(zhuǎn)成文本文件,如下圖所示:

手寫圖片轉(zhuǎn)成的文本文件

雖然這樣表示不能有效利用內(nèi)存空間(本來0/1只需占據(jù)1bit的空間,但是變成字符“0”,“1”之后需要用char類型所占的字節(jié)數(shù)),但是對于圖像到矩陣的轉(zhuǎn)換這一過程非常直觀,方便演示。
我們的目標(biāo)是:將這樣的一幅“圖像”輸入我們的系統(tǒng),我們能夠輸出“圖像”中所顯示的數(shù)字(只做數(shù)字0~9的識別)。

系統(tǒng)組成

我們的手寫識別系統(tǒng)由以下部分組成:

  • 已知標(biāo)簽的訓(xùn)練集
  • 文件輸入輸出模塊
  • kNN算法模塊
已知標(biāo)簽的訓(xùn)練集

點(diǎn)此下載:用到的數(shù)據(jù)及源代碼
其中,trainingDigits文件夾中存放的是用作訓(xùn)練集的的圖片,其中包含了1934個訓(xùn)練樣本,testDigits文件夾中存放的是用作測試集的圖片,其中包含了946個測試樣本。每個文件的文件名中含有它的標(biāo)簽。

文件輸入輸出模塊

python讀文本文件相當(dāng)簡單,為了迎合后續(xù)的kNN算法,我們不把圖像表示成32X32的矩陣形式,而是將其轉(zhuǎn)化成1X1024的向量,為此我們定義一個img2vector函數(shù):

def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect
kNN算法模塊

根據(jù)上述對kNN算法的描述,kNN算法有如下步驟:

  • 測試數(shù)據(jù)與訓(xùn)練集中的每個數(shù)據(jù)進(jìn)行比較,以這兩個數(shù)據(jù)間的歐氏距離作為測試數(shù)據(jù)和訓(xùn)練數(shù)據(jù)間的相似性度量
  • 將算出的歐式距離列表從小到大排序,取前k名所對應(yīng)的訓(xùn)練集中的數(shù)據(jù)
  • 取出這k個數(shù)據(jù)的標(biāo)簽,對數(shù)目進(jìn)行統(tǒng)計(jì),出現(xiàn)次數(shù)最多的標(biāo)簽作為算法的輸出,即分類的結(jié)果
def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
系統(tǒng)整體代碼
'''
kNN: k Nearest Neighbors

Input:      inX: vector to compare to existing dataset (1xN)
            dataSet: size m data set of known vectors (NxM)
            labels: data set labels (1xM vector)
            k: number of neighbors to use for comparison (should be an odd number)
            
Output:     the most popular class label
'''
from numpy import *
import operator
from os import listdir

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
    
def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')           #load the training set
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    testFileList = listdir('testDigits')        #iterate through the test set
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
    print("\nthe total number of errors is: %d" % errorCount)
    print("\nthe total error rate is: %f" % (errorCount/float(mTest)))

系統(tǒng)測試

測試環(huán)境
  • win10 64位
  • python3.6.2
測試步驟
    1. 打開cmd,進(jìn)入kNN.py所在的文件夾
    1. 輸入python進(jìn)入python shell
    1. 輸入from kNN import *導(dǎo)入kNN模塊中所有函數(shù)
    1. 輸入handwritingClassTest(),回車
測試結(jié)果
測試結(jié)果

從測試結(jié)果來看,1.0571%的錯誤率,準(zhǔn)確度還是蠻高的……

增加訓(xùn)練集的樣本容量能有效提高系統(tǒng)的準(zhǔn)確度,但是同時增加了運(yùn)算量,使計(jì)算耗時增加。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容