機器學(xué)習(xí)KNN算法之手寫數(shù)字?jǐn)?shù)字識別

算法簡介

手寫數(shù)字識別是KNN算法一個特別經(jīng)典的實例,其數(shù)據(jù)源獲取方式有兩種,一種是來自MNIST數(shù)據(jù)集,另一種是從UCI歐文大學(xué)機器學(xué)習(xí)存儲庫中下載,本文基于后者講解該例。
基本思想就是利用KNN算法推斷出如下圖一個32x32的二進制矩陣代表的數(shù)字是處于0-9之間哪一個數(shù)字。

在這里插入圖片描述

數(shù)據(jù)集包括兩部分,一部分是訓(xùn)練數(shù)據(jù)集,共有1934個數(shù)據(jù);另一部分是測試數(shù)據(jù)集,共有946個數(shù)據(jù)。所有數(shù)據(jù)命名格式都是統(tǒng)一的,例如數(shù)字5的第56個樣本——5_56.txt,這樣做為了方便提取出樣本的真實標(biāo)簽。
在這里插入圖片描述

數(shù)據(jù)的格式也有兩種,一種是像上圖一樣由0、1組成的文本文件;另一種則是手寫數(shù)字圖片,需要對圖片做一些處理,轉(zhuǎn)化成像上圖一樣的格式,下文皆有介紹。

算法步驟

  1. 收集數(shù)據(jù):公開數(shù)據(jù)源
  2. 分析數(shù)據(jù),構(gòu)思如何處理數(shù)據(jù)
  3. 導(dǎo)入訓(xùn)練數(shù)據(jù),轉(zhuǎn)化為結(jié)構(gòu)化的數(shù)據(jù)格式
  4. 計算距離(歐式距離)
  5. 導(dǎo)入測試數(shù)據(jù),計算模型準(zhǔn)確率
  6. 手寫數(shù)字,實際應(yīng)用模型

由于所有數(shù)據(jù)皆由0和1構(gòu)成,所以不需要數(shù)據(jù)標(biāo)準(zhǔn)化和歸一化這一步驟

算法實現(xiàn)

處理數(shù)據(jù)

在計算兩個樣本之間的距離時,每一個屬性是一一對應(yīng)的,所以這里將32x32的數(shù)字矩陣轉(zhuǎn)化成1x1024數(shù)字矩陣,方便計算樣本之間距離。

#處理文本文件
def img_deal(file):
    #創(chuàng)建一個1*1024的一維零矩陣
    the_matrix = np.zeros((1,1024))
    fb = open(file)
    for i in range(32):
        #逐行讀取
        lineStr = fb.readline()
        for j in range(32):
            #將32*32=1024個元素賦值給一維零矩陣
            the_matrix[0,32*i+j] = int(lineStr[j])
    return the_matrix
計算歐式距離

numpy有一個tile方法,可以將一個一維矩陣橫向復(fù)制若干次,縱向復(fù)制若干次,所以將一個測試數(shù)據(jù)經(jīng)過tile方法處理后再減去訓(xùn)練數(shù)據(jù),得到新矩陣后,再將該矩陣中每一條數(shù)據(jù)(橫向)平方加和并開根號后即可得到測試數(shù)據(jù)與每一條訓(xùn)練數(shù)據(jù)之間的距離。

下一步將所有距離升序排列,取到前K個,并在這個范圍里,每個數(shù)字類別的個數(shù),并返回出現(xiàn)次數(shù)較多那個數(shù)字類別的標(biāo)簽。

def classify(test_data,train_data,label,k):
    Size = train_data.shape[0]
    #將測試數(shù)據(jù)每一行復(fù)制Size次減去訓(xùn)練數(shù)據(jù),橫向復(fù)制Size次,縱向復(fù)制1次
    the_matrix = np.tile(test_data,(Size,1)) - train_data
    #將相減得到的結(jié)果平方
    sq_the_matrix = the_matrix ** 2
    #平方加和,axis = 1 代表橫向
    all_the_matrix = sq_the_matrix.sum(axis = 1)
    #結(jié)果開根號得到最終距離
    distance = all_the_matrix ** 0.5
    #將距離由小到大排序,給出結(jié)果為索引
    sort_distance = distance.argsort()
    dis_Dict = {}
    #取到前k個
    for i in range(k):
        #獲取前K個標(biāo)簽
        the_label = label[sort_distance[i]]
        #將標(biāo)簽的key和value傳入字典
        dis_Dict[the_label] = dis_Dict.get(the_label,0)+1
    #將字典按value值的大小排序,由大到小,即在K范圍內(nèi),篩選出現(xiàn)次數(shù)最多幾個標(biāo)簽
    sort_Count = sorted(dis_Dict.items(), key=operator.itemgetter(1), reverse=True)
    #返回出現(xiàn)次數(shù)最多的標(biāo)簽
    return sort_Count[0][0]
測試數(shù)據(jù)集應(yīng)用

首先要對訓(xùn)練數(shù)據(jù)集處理,listdir方法是返回一個文件夾下所有的文件,隨后生成一個行數(shù)為文件個數(shù),列數(shù)為1024的訓(xùn)練數(shù)據(jù)矩陣,并且將訓(xùn)練數(shù)據(jù)集中每條數(shù)據(jù)的真實標(biāo)簽切割提取存入至labels列表中,即計算距離classify函數(shù)中傳入的label。

labels = []
#listdir方法是返回一個文件夾中包含的文件
    train_data = listdir('trainingDigits')
    #獲取該文件夾中文件的個數(shù)
    m_train=len(train_data)
    #生成一個列數(shù)為train_matrix,行為1024的零矩陣
    train_matrix = np.zeros((m_train,1024))
    for i in range(m_train):
        file_name_str = train_data[i]
        file_str = file_name_str.split('.')[0]
        #切割出訓(xùn)練集中每個數(shù)據(jù)的真實標(biāo)簽
        file_num = int(file_str.split('_')[0])
        labels.append(file_num)
        #將所有訓(xùn)練數(shù)據(jù)集中的數(shù)據(jù)都傳入到train_matrix中
        train_matrix[i,:] = img_deal('trainingDigits/%s'%file_name_str)

然后對測試訓(xùn)練數(shù)據(jù)集做與上述一樣的處理,并將測試數(shù)據(jù)矩陣TestClassify、訓(xùn)練數(shù)據(jù)矩陣train_matrix、訓(xùn)練數(shù)據(jù)真實標(biāo)簽labelsK共4個參數(shù)傳入計算距離classify函數(shù)中,最后計算出模型準(zhǔn)確率并輸出預(yù)測錯誤的數(shù)據(jù)。

error = []
 test_matrix = listdir('testDigits')
    correct = 0.0
    m_test = len(test_matrix)
    for i in range(m_test):
        file_name_str = test_matrix[i]
        file_str = file_name_str.split('.')[0]
        #測試數(shù)據(jù)集每個數(shù)據(jù)的真實結(jié)果
        file_num = int(file_str.split('_')[0])
        TestClassify = img_deal('testDigits/%s'%file_name_str)
        classify_result = classify(TestClassify,train_matrix,labels,3)
        print('預(yù)測結(jié)果:%s\t真實結(jié)果:%s'%(classify_result,file_num))
        if classify_result == file_num:
            correct += 1.0
        else:
            error.append((file_name_str,classify_result))
    print("正確率:{:.2f}%".format(correct / float(m_test) * 100))
    print(error)
    print(len(error))

代碼運行部分截圖如下

在這里插入圖片描述

當(dāng)K = 3時,準(zhǔn)確率達到了98.94%,對于這個模型而言,準(zhǔn)確率是十分可觀的,但運行效率卻比較低,接近30秒的運行時間。因為每個測試數(shù)據(jù)都要與近2000個訓(xùn)練數(shù)據(jù)進行距離計算,而每次計算又包含1024個維度浮點運算,高次數(shù)多維度的計算是導(dǎo)致模型運行效率低的主要原因。

K值

下圖是K值與模型準(zhǔn)確率的關(guān)系變化圖,K = 3時,模型準(zhǔn)確率達到峰值,隨著K增大,準(zhǔn)確率越來越小,所以這份數(shù)據(jù)的噪聲還是比較小的。


在這里插入圖片描述
手寫數(shù)字測試

建模完成了,模型的準(zhǔn)確率也不錯,為何自己手寫的數(shù)字測試一下呢?所以偶就手動寫了幾個數(shù)字


在這里插入圖片描述

正常拍出的圖片是RGB彩色圖片,并且像素也各不相同,所以需要對圖片做兩項處理:轉(zhuǎn)化成黑白圖片、將像素轉(zhuǎn)化為32x32,這樣才符合我們上文算法的要求;對于像素點,數(shù)值一般位于0-255,255代表白、0代表黑,但因為手寫數(shù)字像素點顏色并不規(guī)范,所以我們設(shè)置一個閾值用以判斷黑白之分。
圖片轉(zhuǎn)文本代碼如下:

def pic_txt():
    for i in range(0,10):
        img = Image.open('.\handwritten\%s.png'%i)
        #將圖片像素更改為32X32
        img = img.resize((32,32))
        #將彩色圖片變?yōu)楹诎讏D片
        img = img.convert('L')
        #保存
        path = '.\handwritten\%s_new.jpg'%i
        img.save(path)
    for i in range(0,10):
        fb = open('.\hand_written\%s_handwritten.txt'%i,'w')
        new_img = Image.open('.\handwritten\%s_new.jpg'%i)
        #讀取圖片的寬和高
        width,height = new_img.size
        for i in range(height):
            for j in range(width):
                # 獲取像素點
                color = new_img.getpixel((j, i))
                #像素點較高的為圖片中的白色
                if color>170:
                    fb.write('0')
                else:
                    fb.write('1')
            fb.write('\n')
        fb.close()

整體代碼運行截圖如下:

在這里插入圖片描述

正確率為70%,畢竟測試數(shù)據(jù)很少,10個數(shù)字中4、7、8三個數(shù)字預(yù)測錯誤,還算可觀;由于光線問題,有幾個數(shù)字左下角會有一些黑影,也會對測試結(jié)果產(chǎn)生一定的影響,若避免類似情況,并且多增加一些測試數(shù)據(jù),正確率定會得到提升的。

公眾號【奶糖貓】后臺回復(fù)“手寫數(shù)字”即可獲取源碼和數(shù)據(jù)供參考,感謝閱讀。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容