算法簡介
手寫數(shù)字識別是KNN算法一個特別經(jīng)典的實例,其數(shù)據(jù)源獲取方式有兩種,一種是來自MNIST數(shù)據(jù)集,另一種是從UCI歐文大學(xué)機器學(xué)習(xí)存儲庫中下載,本文基于后者講解該例。
基本思想就是利用KNN算法推斷出如下圖一個32x32的二進制矩陣代表的數(shù)字是處于0-9之間哪一個數(shù)字。
數(shù)據(jù)集包括兩部分,一部分是訓(xùn)練數(shù)據(jù)集,共有1934個數(shù)據(jù);另一部分是測試數(shù)據(jù)集,共有946個數(shù)據(jù)。所有數(shù)據(jù)命名格式都是統(tǒng)一的,例如數(shù)字5的第56個樣本——5_56.txt,這樣做為了方便提取出樣本的真實標(biāo)簽。

數(shù)據(jù)的格式也有兩種,一種是像上圖一樣由0、1組成的文本文件;另一種則是手寫數(shù)字圖片,需要對圖片做一些處理,轉(zhuǎn)化成像上圖一樣的格式,下文皆有介紹。
算法步驟
- 收集數(shù)據(jù):公開數(shù)據(jù)源
- 分析數(shù)據(jù),構(gòu)思如何處理數(shù)據(jù)
- 導(dǎo)入訓(xùn)練數(shù)據(jù),轉(zhuǎn)化為結(jié)構(gòu)化的數(shù)據(jù)格式
- 計算距離(歐式距離)
- 導(dǎo)入測試數(shù)據(jù),計算模型準(zhǔn)確率
- 手寫數(shù)字,實際應(yīng)用模型
由于所有數(shù)據(jù)皆由0和1構(gòu)成,所以不需要數(shù)據(jù)標(biāo)準(zhǔn)化和歸一化這一步驟
算法實現(xiàn)
處理數(shù)據(jù)
在計算兩個樣本之間的距離時,每一個屬性是一一對應(yīng)的,所以這里將32x32的數(shù)字矩陣轉(zhuǎn)化成1x1024數(shù)字矩陣,方便計算樣本之間距離。
#處理文本文件
def img_deal(file):
#創(chuàng)建一個1*1024的一維零矩陣
the_matrix = np.zeros((1,1024))
fb = open(file)
for i in range(32):
#逐行讀取
lineStr = fb.readline()
for j in range(32):
#將32*32=1024個元素賦值給一維零矩陣
the_matrix[0,32*i+j] = int(lineStr[j])
return the_matrix
計算歐式距離
numpy有一個tile方法,可以將一個一維矩陣橫向復(fù)制若干次,縱向復(fù)制若干次,所以將一個測試數(shù)據(jù)經(jīng)過tile方法處理后再減去訓(xùn)練數(shù)據(jù),得到新矩陣后,再將該矩陣中每一條數(shù)據(jù)(橫向)平方加和并開根號后即可得到測試數(shù)據(jù)與每一條訓(xùn)練數(shù)據(jù)之間的距離。
下一步將所有距離升序排列,取到前K個,并在這個范圍里,每個數(shù)字類別的個數(shù),并返回出現(xiàn)次數(shù)較多那個數(shù)字類別的標(biāo)簽。
def classify(test_data,train_data,label,k):
Size = train_data.shape[0]
#將測試數(shù)據(jù)每一行復(fù)制Size次減去訓(xùn)練數(shù)據(jù),橫向復(fù)制Size次,縱向復(fù)制1次
the_matrix = np.tile(test_data,(Size,1)) - train_data
#將相減得到的結(jié)果平方
sq_the_matrix = the_matrix ** 2
#平方加和,axis = 1 代表橫向
all_the_matrix = sq_the_matrix.sum(axis = 1)
#結(jié)果開根號得到最終距離
distance = all_the_matrix ** 0.5
#將距離由小到大排序,給出結(jié)果為索引
sort_distance = distance.argsort()
dis_Dict = {}
#取到前k個
for i in range(k):
#獲取前K個標(biāo)簽
the_label = label[sort_distance[i]]
#將標(biāo)簽的key和value傳入字典
dis_Dict[the_label] = dis_Dict.get(the_label,0)+1
#將字典按value值的大小排序,由大到小,即在K范圍內(nèi),篩選出現(xiàn)次數(shù)最多幾個標(biāo)簽
sort_Count = sorted(dis_Dict.items(), key=operator.itemgetter(1), reverse=True)
#返回出現(xiàn)次數(shù)最多的標(biāo)簽
return sort_Count[0][0]
測試數(shù)據(jù)集應(yīng)用
首先要對訓(xùn)練數(shù)據(jù)集處理,listdir方法是返回一個文件夾下所有的文件,隨后生成一個行數(shù)為文件個數(shù),列數(shù)為1024的訓(xùn)練數(shù)據(jù)矩陣,并且將訓(xùn)練數(shù)據(jù)集中每條數(shù)據(jù)的真實標(biāo)簽切割提取存入至labels列表中,即計算距離classify函數(shù)中傳入的label。
labels = []
#listdir方法是返回一個文件夾中包含的文件
train_data = listdir('trainingDigits')
#獲取該文件夾中文件的個數(shù)
m_train=len(train_data)
#生成一個列數(shù)為train_matrix,行為1024的零矩陣
train_matrix = np.zeros((m_train,1024))
for i in range(m_train):
file_name_str = train_data[i]
file_str = file_name_str.split('.')[0]
#切割出訓(xùn)練集中每個數(shù)據(jù)的真實標(biāo)簽
file_num = int(file_str.split('_')[0])
labels.append(file_num)
#將所有訓(xùn)練數(shù)據(jù)集中的數(shù)據(jù)都傳入到train_matrix中
train_matrix[i,:] = img_deal('trainingDigits/%s'%file_name_str)
然后對測試訓(xùn)練數(shù)據(jù)集做與上述一樣的處理,并將測試數(shù)據(jù)矩陣TestClassify、訓(xùn)練數(shù)據(jù)矩陣train_matrix、訓(xùn)練數(shù)據(jù)真實標(biāo)簽labels、K共4個參數(shù)傳入計算距離classify函數(shù)中,最后計算出模型準(zhǔn)確率并輸出預(yù)測錯誤的數(shù)據(jù)。
error = []
test_matrix = listdir('testDigits')
correct = 0.0
m_test = len(test_matrix)
for i in range(m_test):
file_name_str = test_matrix[i]
file_str = file_name_str.split('.')[0]
#測試數(shù)據(jù)集每個數(shù)據(jù)的真實結(jié)果
file_num = int(file_str.split('_')[0])
TestClassify = img_deal('testDigits/%s'%file_name_str)
classify_result = classify(TestClassify,train_matrix,labels,3)
print('預(yù)測結(jié)果:%s\t真實結(jié)果:%s'%(classify_result,file_num))
if classify_result == file_num:
correct += 1.0
else:
error.append((file_name_str,classify_result))
print("正確率:{:.2f}%".format(correct / float(m_test) * 100))
print(error)
print(len(error))
代碼運行部分截圖如下
當(dāng)K = 3時,準(zhǔn)確率達到了98.94%,對于這個模型而言,準(zhǔn)確率是十分可觀的,但運行效率卻比較低,接近30秒的運行時間。因為每個測試數(shù)據(jù)都要與近2000個訓(xùn)練數(shù)據(jù)進行距離計算,而每次計算又包含1024個維度浮點運算,高次數(shù)多維度的計算是導(dǎo)致模型運行效率低的主要原因。
K值
下圖是K值與模型準(zhǔn)確率的關(guān)系變化圖,K = 3時,模型準(zhǔn)確率達到峰值,隨著K增大,準(zhǔn)確率越來越小,所以這份數(shù)據(jù)的噪聲還是比較小的。
手寫數(shù)字測試
建模完成了,模型的準(zhǔn)確率也不錯,為何自己手寫的數(shù)字測試一下呢?所以偶就手動寫了幾個數(shù)字
正常拍出的圖片是RGB彩色圖片,并且像素也各不相同,所以需要對圖片做兩項處理:轉(zhuǎn)化成黑白圖片、將像素轉(zhuǎn)化為32x32,這樣才符合我們上文算法的要求;對于像素點,數(shù)值一般位于0-255,255代表白、0代表黑,但因為手寫數(shù)字像素點顏色并不規(guī)范,所以我們設(shè)置一個閾值用以判斷黑白之分。
圖片轉(zhuǎn)文本代碼如下:
def pic_txt():
for i in range(0,10):
img = Image.open('.\handwritten\%s.png'%i)
#將圖片像素更改為32X32
img = img.resize((32,32))
#將彩色圖片變?yōu)楹诎讏D片
img = img.convert('L')
#保存
path = '.\handwritten\%s_new.jpg'%i
img.save(path)
for i in range(0,10):
fb = open('.\hand_written\%s_handwritten.txt'%i,'w')
new_img = Image.open('.\handwritten\%s_new.jpg'%i)
#讀取圖片的寬和高
width,height = new_img.size
for i in range(height):
for j in range(width):
# 獲取像素點
color = new_img.getpixel((j, i))
#像素點較高的為圖片中的白色
if color>170:
fb.write('0')
else:
fb.write('1')
fb.write('\n')
fb.close()
整體代碼運行截圖如下:
正確率為70%,畢竟測試數(shù)據(jù)很少,10個數(shù)字中4、7、8三個數(shù)字預(yù)測錯誤,還算可觀;由于光線問題,有幾個數(shù)字左下角會有一些黑影,也會對測試結(jié)果產(chǎn)生一定的影響,若避免類似情況,并且多增加一些測試數(shù)據(jù),正確率定會得到提升的。
公眾號【奶糖貓】后臺回復(fù)“手寫數(shù)字”即可獲取源碼和數(shù)據(jù)供參考,感謝閱讀。