回歸與梯度下降源碼及對應(yīng)分析

參考資料:
1.2014斯坦福大學(xué)機(jī)器學(xué)習(xí)視頻前5章視頻,超贊,考完研果然數(shù)學(xué)關(guān)好了很多
2.學(xué)習(xí)了Numpy和matplotlib兩個庫

后面還有用一種叫normal equation的方法,用矩陣求逆進(jìn)行求解

首先上打印結(jié)果,源碼在最后。
首先我們看到參數(shù)除了第一個相差有點大之外,其他兩個都還好。
相差正負(fù)0.5(數(shù)據(jù)范圍是0~100)的準(zhǔn)確率是86%,最后一長串的是兩個結(jié)果的對比,還是可以的。

數(shù)據(jù)實際設(shè)定的的參數(shù)(模型)是:1、2、-3
梯度下降計算出來的參數(shù)(模型)是:[0.13278973]、[2.00740646]、[-2.99257226]
相差小于總體范圍1%的準(zhǔn)確率為:86%
對比預(yù)測結(jié)果與應(yīng)該產(chǎn)生的結(jié)果的對比:
[[-1.35539348e+02 -1.36000000e+02]
 [-2.80146762e+02 -2.80000000e+02]
 [-2.91997353e+01 -2.90000000e+01]
 [-7.64876940e+01 -7.70000000e+01]
 [-1.71280696e+02 -1.71000000e+02]
 [-6.47769289e+01 -6.50000000e+01]
 [-7.91253728e+01 -7.90000000e+01]
 [ 1.56888486e+01  1.60000000e+01]
 [ 1.79814205e+02  1.80000000e+02]
 [-1.57303000e+02 -1.57000000e+02]
 [-1.39310492e+02 -1.39000000e+02]
 [ 7.08188028e+00  7.00000000e+00]
 [ 4.03179937e+01  4.10000000e+01]
 [-3.85482006e+01 -3.80000000e+01]
 [-1.84850568e+02 -1.85000000e+02]
 [-2.22664999e+01 -2.20000000e+01]
 [-4.71551689e+01 -4.70000000e+01]
 [-1.54199203e+02 -1.54000000e+02]
 [-7.45915121e+01 -7.50000000e+01]
 [ 8.49629088e+01  8.50000000e+01]
 [-6.19697308e+01 -6.20000000e+01]
 [-1.34850781e+02 -1.35000000e+02]
 [-3.87547946e+01 -3.90000000e+01]
 [ 6.13378932e+00  6.00000000e+00]
 [ 3.33476834e+01  3.40000000e+01]
 [-2.31243365e+02 -2.31000000e+02]
 [ 9.41259998e+01  9.40000000e+01]
 [-1.98902413e+02 -1.99000000e+02]
 [-2.25630986e+01 -2.20000000e+01]
 [ 1.24629065e+02  1.25000000e+02]
 [-1.58362316e+02 -1.58000000e+02]
 [-8.12069290e+01 -8.10000000e+01]
 [ 7.44951003e-02  0.00000000e+00]
 [ 1.42844022e+02  1.43000000e+02]
 [-2.12880109e+02 -2.13000000e+02]
 [-1.54665321e+02 -1.55000000e+02]
 [-1.44310470e+02 -1.44000000e+02]
 [-2.04961707e+02 -2.05000000e+02]
 [ 1.20577177e+02  1.21000000e+02]
 [-7.52588593e+01 -7.50000000e+01]
 [-3.45704625e+01 -3.40000000e+01]
 [-2.82080019e+02 -2.82000000e+02]
 [-1.83902477e+02 -1.84000000e+02]
 [-2.15295338e+02 -2.15000000e+02]
 [-1.25369871e+02 -1.25000000e+02]
 [-7.96444207e+01 -7.90000000e+01]
 [-1.62080530e+02 -1.62000000e+02]
 [-5.72292761e+01 -5.70000000e+01]
 [-1.96446761e+01 -1.90000000e+01]
 [ 1.22955332e+02  1.23000000e+02]
 [ 1.28210999e+00  1.00000000e+00]
 [-1.51932273e+02 -1.52000000e+02]
 [-1.25873064e+02 -1.26000000e+02]
 [ 6.40149030e+01  6.40000000e+01]
 [ 2.63773731e+01  2.70000000e+01]
 [ 1.06525332e+02  1.07000000e+02]
 [-2.43250729e+02 -2.43000000e+02]
 [-7.73404155e+01 -7.70000000e+01]
 [-5.07992334e+01 -5.10000000e+01]
 [ 8.48146094e+01  8.50000000e+01]
 [-2.35909671e+02 -2.36000000e+02]
 [-1.77006316e+02 -1.77000000e+02]
 [-1.34050989e+02 -1.34000000e+02]
 [-1.15295764e+02 -1.15000000e+02]
 [-8.86804321e+01 -8.90000000e+01]
 [-2.41256069e+01 -2.40000000e+01]
 [-1.04088192e+02 -1.04000000e+02]
 [ 3.57026016e+00  4.00000000e+00]
 [ 1.54925536e+02  1.55000000e+02]
 [-1.16132630e+02 -1.16000000e+02]
 [ 8.19703365e+01  8.20000000e+01]
 [-3.20810831e+01 -3.20000000e+01]
 [-1.04310641e+02 -1.04000000e+02]
 [-1.25110347e+02 -1.25000000e+02]
 [-1.91783803e+02 -1.92000000e+02]
 [ 3.43328492e+01  3.50000000e+01]
 [-1.48139909e+02 -1.48000000e+02]
 [ 1.34777322e+02  1.35000000e+02]
 [ 6.83846093e+01  6.90000000e+01]
 [ 4.40891378e+01  4.40000000e+01]
 [-9.61697903e+01 -9.60000000e+01]
 [ 5.88666248e+01  5.90000000e+01]
 [-6.44220355e+01 -6.40000000e+01]
 [ 9.56885081e+01  9.60000000e+01]
 [-2.82080019e+02 -2.82000000e+02]
 [ 4.37183893e+01  4.40000000e+01]
 [-2.62071631e+01 -2.60000000e+01]
 [-7.36434211e+01 -7.40000000e+01]
 [-1.09477030e+01 -1.10000000e+01]
 [-7.94961213e+01 -7.90000000e+01]
 [ 8.20815611e+01  8.20000000e+01]
 [ 2.02068543e+01  2.10000000e+01]
 [ 3.28951644e+00  3.00000000e+00]
 [ 5.48147371e+01  5.50000000e+01]
 [ 2.74737634e+01  2.80000000e+01]
 [ 1.48977466e+02  1.49000000e+02]
 [-1.38510700e+02 -1.38000000e+02]
 [ 5.00001326e+01  5.00000000e+01]
 [-2.19243529e+00 -2.00000000e+00]
 [-8.08361805e+01 -8.10000000e+01]]

可以看到迭代次數(shù)到一定的時候就基本保持不變了。


image.png

代碼大致分析:
main函數(shù)里面分為產(chǎn)生隨機(jī)數(shù)據(jù)、訓(xùn)練模型、測試模型三部分,重點是訓(xùn)練模型,而訓(xùn)練參數(shù)又是下圖左上角的公式迭代計算出來的,這個式子重點又是式子的右側(cè),也就是偏微分的計算,這里實際上也不是很難,就是計算的時候向量化了而已,具體看后面的代碼。


image.png

總結(jié):
準(zhǔn)確率收到很多數(shù)據(jù)的影響,包括樣本數(shù)量(numOfSamples)、參數(shù)調(diào)整的步伐大?。╯tep)——步伐太大通常會導(dǎo)致cost變得非常大、迭代的次數(shù)(i),迭代起點(全部設(shè)置為0)。
注意——代碼中每次都會用隨機(jī)產(chǎn)生的數(shù)據(jù),所以要看看某個變量是否影響準(zhǔn)確率的時候要改為數(shù)據(jù)每次都一樣的。

需要改進(jìn):
1.我這里的特征的動態(tài)范圍都是差不多的,視頻中講到如果特征范圍相差太大,需要進(jìn)行歸一化。

全部代碼:

from numpy import *
import matplotlib.pyplot as plt


theta_pre_setting = [1,2,-3]

# 梯度下降訓(xùn)練多元線性回歸模型
# 輸入?yún)?shù),Data是個矩陣,array類型,m*n,m個樣本,n個特征
# 返回值線性系數(shù)
def train_linear_regression_model(input,output):
    numOfSamples = input.shape[0]
    numOfFeatures = input.shape[1]
    theta = array([0,0,0]).reshape((-1,1))

    # 偏微分
    partial_diff = zeros(numOfFeatures)

    # 每次調(diào)整的步伐
    step = 0.0001

    # 記錄調(diào)整過程代價的變化
    cost_list = []

    i = 0
    while True:
        # 更新theta
        partial_diff = dot(input.T,predict_result(input,theta) - output)/numOfSamples
        theta = theta - step * partial_diff

        # 計算新的代價
        cost = sum(power((predict_result(input,theta) - output),2))
        # print(theta)
        # print(cost)

        cost_list.append(cost)

        # if cost < 1:
        #     break

        i += 1
        if i > 10000:
            break

    plt.plot(cost_list)
    plt.show()
    return theta


def predict_result(input,theta):
    return dot(input,theta)


def generate_data(numOfSamples):
    area = random.randint(0, 100, size=[numOfSamples,])
    age = random.randint(0, 100, size=[numOfSamples,])
    price = calc_price(area, age)

    # 這里reshape((-1,1))用來轉(zhuǎn)置一維向量
    input = concatenate([area.reshape((-1, 1)), age.reshape((-1, 1))], axis=1)

    # 為了計算方便在最前面增加多一個特征,恒為1
    temp = ones(numOfSamples)
    input = concatenate([temp.reshape((-1, 1)), input], axis=1)
    output = price.reshape((-1, 1))

    return input,output


# 測試
def test_data(theta):
    input,output = generate_data(100)
    ret = predict_result(input,theta)

    print('相差小于總體范圍1%的準(zhǔn)確率為:{a}%'.format(a=len(where(abs(output - ret)<=0.5)[0])))

    # 為了對比預(yù)測數(shù)據(jù)與應(yīng)該正確產(chǎn)生的數(shù)據(jù)
    print('對比預(yù)測結(jié)果與應(yīng)該產(chǎn)生的結(jié)果的對比:')
    ret = concatenate([ret.reshape((-1, 1)), output.reshape((-1, 1))], axis=1)
    print(ret)


def calc_price(area,age,theta = theta_pre_setting):
    return theta[0]+theta[1]*area+theta[2]*age


def main():
    print('數(shù)據(jù)實際設(shè)定的的參數(shù)(模型)是:{theta0}、{theta1}、{theta2}'.format(theta0=theta_pre_setting[0],theta1=theta_pre_setting[1],theta2=theta_pre_setting[2]))

# 產(chǎn)生模擬數(shù)據(jù)
    numOfSamples = 500
    input,output = generate_data(numOfSamples)

# 訓(xùn)練模型,得到系數(shù)theta
    theta = train_linear_regression_model(input,output)
    print('梯度下降計算出來的參數(shù)(模型)是:{theta0}、{theta1}、{theta2}'.format(theta0=theta[0],theta1=theta[1],theta2=theta[2]))
    test_data(theta)


if __name__ == '__main__':
    main()

  • normal equation
# 令導(dǎo)數(shù)為0的解決方法
def normal_equation(input,output):
    #近似求逆
    return np.linalg.lstsq(input,output)
    # return np.dot(np.linalg.inv(input),output)

def main():
    input,output = generate_data(100)
    theta = normal_equation(input,output)
    print(theta)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容