《機器學習》-感知機源碼解析

import numpy as np
import time
def loadData(fileName):
    '''
    加載Mnist數(shù)據(jù)集
    :param fileName:要加載的數(shù)據(jù)集路徑
    :return: list形式的數(shù)據(jù)集及標記
    '''
    print('start to read data')
    # 存放數(shù)據(jù)及標記的list
    dataArr = []; labelArr = []
    # 打開文件
    fr = open(fileName, 'r')
    # 將文件按行讀取
    for line in fr.readlines():
        # 對每一行數(shù)據(jù)按切割福','進行切割,返回字段列表
        curLine = line.strip().split(',')
        # Mnsit有0-9是個標記,由于是二分類任務,所以將>=5的作為1,<5為-1
        if int(curLine[0]) >= 5:
            labelArr.append(1)
        else:
            labelArr.append(-1)
        #存放標記
        #[int(num) for num in curLine[1:]] -> 遍歷每一行中除了以第一哥元素(標記)外將所有元素轉(zhuǎn)換成int類型
        #[int(num)/255 for num in curLine[1:]] -> 將所有數(shù)據(jù)除255歸一化(非必須步驟,可以不歸一化)
        dataArr.append([int(num)/255 for num in curLine[1:]])
    #返回data和label
    return dataArr, labelArr
def perceptron(dataArr, labelArr, iter=50):
    '''
    感知器訓練過程
    :param dataArr:訓練集的數(shù)據(jù) (list)
    :param labelArr: 訓練集的標簽(list)
    :param iter: 迭代次數(shù),默認50
    :return: 訓練好的w和b
    '''
    print('start to trans')
    #將數(shù)據(jù)轉(zhuǎn)換成矩陣形式(在機器學習中因為通常都是向量的運算,轉(zhuǎn)換稱矩陣形式方便運算)
    #轉(zhuǎn)換后的數(shù)據(jù)中每一個樣本的向量都是橫向的
    dataMat = np.mat(dataArr)
    #將標簽轉(zhuǎn)換成矩陣,之后轉(zhuǎn)置(.T為轉(zhuǎn)置)。
    #轉(zhuǎn)置是因為在運算中需要單獨取label中的某一個元素,如果是1xN的矩陣的話,無法用label[i]的方式讀取
    #對于只有1xN的label可以不轉(zhuǎn)換成矩陣,直接label[i]即可,這里轉(zhuǎn)換是為了格式上的統(tǒng)一
    labelMat = np.mat(labelArr).T
    #獲取數(shù)據(jù)矩陣的大小,為m*n
    m, n = np.shape(dataMat)
    #創(chuàng)建初始權(quán)重w,初始值全為0。
    #np.shape(dataMat)的返回值為m,n -> np.shape(dataMat)[1])的值即為n,與
    #樣本長度保持一致
    w = np.zeros((1, np.shape(dataMat)[1]))
    #初始化偏置b為0
    b = 0
    #初始化步長,也就是梯度下降過程中的n,控制梯度下降速率
    h = 0.0001
    #進行iter次迭代計算
    for k in range(iter):
        #對于每一個樣本進行梯度下降
        #李航書中在2.3.1開頭部分使用的梯度下降,是全部樣本都算一遍以后,統(tǒng)一
        #進行一次梯度下降
        #在2.3.1的后半部分可以看到(例如公式2.6 2.7),求和符號沒有了,此時用
        #的是隨機梯度下降,即計算一個樣本就針對該樣本進行一次梯度下降。
        #兩者的差異各有千秋,但較為常用的是隨機梯度下降。
        for i in range(m):
            #獲取當前樣本的向量
            xi = dataMat[i]
            #獲取當前樣本所對應的標簽
            yi = labelMat[i]
            #判斷是否是誤分類樣本
            #誤分類樣本特診為: -yi(w*xi+b)>=0,詳細可參考書中2.2.2小節(jié)
            #在書的公式中寫的是>0,實際上如果=0,說明改點在超平面上,也是不正確的
            if -1 * yi * (w * xi.T + b) >= 0:
                #對于誤分類樣本,進行梯度下降,更新w和b
                w = w + h *  yi * xi
                b = b + h * yi
        #打印訓練進度
        print('Round %d:%d training' % (k, iter))
    #返回訓練完的w、b
    return w, b
def test(dataArr, labelArr, w, b):
    '''
    測試準確率
    :param dataArr:測試集
    :param labelArr: 測試集標簽
    :param w: 訓練獲得的權(quán)重w
    :param b: 訓練獲得的偏置b
    :return: 正確率
    '''
    print('start to test')
    #將數(shù)據(jù)集轉(zhuǎn)換為矩陣形式方便運算
    dataMat = np.mat(dataArr)
    #將label轉(zhuǎn)換為矩陣并轉(zhuǎn)置,詳細信息參考上文perceptron中
    #對于這部分的解說
    labelMat = np.mat(labelArr).T
    #獲取測試數(shù)據(jù)集矩陣的大小
    m, n = np.shape(dataMat)
    #錯誤樣本數(shù)計數(shù)
    errorCnt = 0
    #遍歷所有測試樣本
    for i in range(m):
        #獲得單個樣本向量
        xi = dataMat[i]
        #獲得該樣本標記
        yi = labelMat[i]
        #獲得運算結(jié)果
        result = -1 * yi * (w * xi.T + b)
        #如果-yi(w*xi+b)>=0,說明該樣本被誤分類,錯誤樣本數(shù)加一
        if result >= 0: errorCnt += 1
    #正確率 = 1 - (樣本分類錯誤數(shù) / 樣本總數(shù))
    accruRate = 1 - (errorCnt / m)
    #返回正確率
    return accruRate
if __name__ == '__main__':
    #獲取當前時間
    #在文末同樣獲取當前時間,兩時間差即為程序運行時間
    start = time.time()
    #獲取訓練集及標簽
    train_path= r"F:\機器學習入門\統(tǒng)計學習\Statistical-Learning-Method_Code-master\Mnist\mnist_train.csv"


    test_path = r"F:\機器學習入門\統(tǒng)計學習\Statistical-Learning-Method_Code-master\Mnist\mnist_test\mnist_test.csv"
    trainData, trainLabel = loadData(train_path)
    #獲取測試集及標簽
    testData, testLabel = loadData(test_path)
    #訓練獲得權(quán)重
    w, b = perceptron(trainData, trainLabel, iter = 30)
    #進行測試,獲得正確率
    accruRate = test(testData, testLabel, w, b)
    #獲取當前時間,作為結(jié)束時間
    end = time.time()
    #顯示正確率
    print('accuracy rate is:', accruRate)
    #顯示用時時長
    print('time span:', end - start)

原理解析的非常清楚
[https://www.pkudodo.com/2018/11/18/1-4/]

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容