TensorFlow HOWTO 1.2 LASSO、嶺和 Elastic Net

1.2 LASSO、嶺和 Elastic Net

當(dāng)參數(shù)變多的時候,就要考慮使用正則化進(jìn)行限制,防止過擬合。

操作步驟

導(dǎo)入所需的包。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as ds
import sklearn.model_selection as ms

導(dǎo)入數(shù)據(jù),并進(jìn)行預(yù)處理。我們使用波士頓數(shù)據(jù)集所有數(shù)據(jù)的全部特征。

boston = ds.load_boston()

x_ = boston.data
y_ = np.expand_dims(boston.target, 1)

x_train, x_test, y_train, y_test = \
    ms.train_test_split(x_, y_, train_size=0.7, test_size=0.3)
    
mu_train = x_train.mean(0)
sigma_train = x_train.std(0)
x_train = (x_train - mu_train) / sigma_train
x_test = (x_test - mu_train) / sigma_train

定義超參數(shù)。

n_input = 13
n_epoch = 2000
lr = 0.05
lam = 0.1
l1_ratio = 0.5
變量 含義
n_input 樣本特征數(shù)
n_epoch 迭代數(shù)
lr 學(xué)習(xí)率
lam 正則化系數(shù)
l1_ratio L1 正則化比例。如果它是 1,模型為 LASSO 回歸;如果它是 0,模型為嶺回歸;如果在 01 之間,模型為 Elastic Net。

搭建模型。

變量 含義
x 輸入
y 真實標(biāo)簽
w 權(quán)重
b 偏置
z 輸出,也就是標(biāo)簽預(yù)測值
x = tf.placeholder(tf.float64, [None, n_input])
y = tf.placeholder(tf.float64, [None, 1])
w = tf.Variable(np.random.rand(n_input, 1))
b = tf.Variable(np.random.rand(1, 1))
z = x @ w + b

定義損失、優(yōu)化操作、和 R 方度量指標(biāo)。

我們在 MSE 基礎(chǔ)上加上兩個正則項:

\begin{matrix} L_1 = \lambda_1 \|w\|_1 \\ L_2 = \lambda_2 \|w\|^2 \\ L = L_{MSE} + L_1 + L_2 \end{matrix}

變量 含義
mse_loss MSE 損失
l1_loss L1 損失
l2_loss L2 損失
loss 總損失
op 優(yōu)化操作
y_mean y的均值
r_sqr R 方值
mse_loss = tf.reduce_mean((z - y) ** 2)
l1_loss = lam * l1_ratio * tf.reduce_sum(tf.abs(w))
l2_loss = lam * (1 - l1_ratio) * tf.reduce_sum(w ** 2)
loss = mse_loss + l1_loss + l2_loss
op = tf.train.AdamOptimizer(lr).minimize(loss)

y_mean = tf.reduce_mean(y)
r_sqr = 1 - tf.reduce_sum((y - z) ** 2) / tf.reduce_sum((y - y_mean) ** 2)

使用訓(xùn)練集訓(xùn)練模型。

losses = []
r_sqrs = []

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(n_epoch):
        _, loss_ = sess.run([op, loss], feed_dict={x: x_train, y: y_train})
        losses.append(loss_)

使用測試集計算 R 方。

        r_sqr_ = sess.run(r_sqr, feed_dict={x: x_test, y: y_test})
        r_sqrs.append(r_sqr_)

每一百步打印損失和度量值。

        if e % 100 == 0:
            print(f'epoch: {e}, loss: {loss_}, r_sqr: {r_sqr_}')

輸出:

epoch: 0, loss: 601.4143942455931, r_sqr: -5.632461200109857
epoch: 100, loss: 337.83817233312953, r_sqr: -2.8921127959091235
epoch: 200, loss: 205.95485710264686, r_sqr: -1.3905038082279204
epoch: 300, loss: 122.56157140781264, r_sqr: -0.4299323503419834
epoch: 400, loss: 73.34245865955972, r_sqr: 0.13473129501015224
epoch: 500, loss: 46.62652385307641, r_sqr: 0.4391669119513518
epoch: 600, loss: 33.418871666746185, r_sqr: 0.5880392599137905
epoch: 700, loss: 27.51559958401544, r_sqr: 0.6533498987634062
epoch: 800, loss: 25.14275351335227, r_sqr: 0.6787325098436232
epoch: 900, loss: 24.28818622078879, r_sqr: 0.6872955402664112
epoch: 1000, loss: 24.01321943982539, r_sqr: 0.689688496343003
epoch: 1100, loss: 23.93439017638524, r_sqr: 0.6901611522536858
epoch: 1200, loss: 23.914316369424643, r_sqr: 0.690163604062231
epoch: 1300, loss: 23.909792588385457, r_sqr: 0.6901031472929803
epoch: 1400, loss: 23.908894366923214, r_sqr: 0.6900616479035429
epoch: 1500, loss: 23.90873804289015, r_sqr: 0.6900411329923608
epoch: 1600, loss: 23.90871433783755, r_sqr: 0.6900324529674866
epoch: 1700, loss: 23.908711226897406, r_sqr: 0.690029151344134
epoch: 1800, loss: 23.908710876248833, r_sqr: 0.6900280037335323
epoch: 1900, loss: 23.908710842591514, r_sqr: 0.6900276378081478

繪制訓(xùn)練集上的損失。

plt.figure()
plt.plot(losses)
plt.title('Loss on Training Set')
plt.xlabel('#epoch')
plt.ylabel('MSE')
plt.show()

https://github.com/wizardforcel/how2tf/raw/master/img/1-2-1.png

繪制測試集上的 R 方。

plt.figure()
plt.plot(r_sqrs)
plt.title('$R^2$ on Testing Set')
plt.xlabel('#epoch')
plt.ylabel('$R^2$')
plt.show()

https://github.com/wizardforcel/how2tf/raw/master/img/1-2-2.png

擴(kuò)展閱讀

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容