NN_5種參數(shù)優(yōu)化器

參數(shù)定義.jpg

一階動(dòng)量定義梯度下降的方向,二階動(dòng)量定義下降的步長(zhǎng)。

1、SGD(不含動(dòng)量常用的梯度下降法)

m = g(梯度), v = 1(常量)

參數(shù)更新
w1.assign_sub(lr * grad[0])
b1.assign_sub(lr * grad[1])
2、SGDM(在SGD的基礎(chǔ)上增加了一階動(dòng)量)

m = βm + (1-β)g(梯度方向上的指數(shù)滑動(dòng)平均值),v = 1(常量)

w和b初始一階動(dòng)量均為0
m_w, m_b = 0, 0
β超參數(shù),經(jīng)驗(yàn)值是0.9
beta = 0.9
一階動(dòng)量計(jì)算公式
m_w = beta * m_w + (1 - beta) * grads[0]
m_b = beta * m_b + (1 - beta) * grads[1]
參數(shù)更新
w1.assign_sub(lr * m_w)
b1.assign_sub(lr * m_b)
3、Adagrad (在SGD的基礎(chǔ)上增加了二階動(dòng)量)

m = g(梯度),v = Σg2(梯度平方的累計(jì)和)

設(shè)二階動(dòng)量初始值為0
v_w, v_b = 0, 0
計(jì)算二階動(dòng)量梯度平方的累計(jì)和
v_w += tf.square(grads[0])
v_b += tf.square(grads[1])
參數(shù)更新
w1.assign_sub(lr * grads[0] / tf.sqrt(v_w))
b1.assign_sub(lr * grads[1] / tf.sqrt(v_b))
4、RMSProp(在SGD的基礎(chǔ)上增加了二階動(dòng)量)

m = g(梯度),v = βv +(1-β)g2(各時(shí)刻梯度方向的指數(shù)滑動(dòng)平均)

設(shè)二階動(dòng)量初始值為0
v_w, v_b = 0, 0
β超參數(shù),經(jīng)驗(yàn)值是0.9
beta = 0.9
計(jì)算指數(shù)滑動(dòng)平均
v_w = beta * v_w + (1 - beta) * tf.square(grads[0])
v_b = beta * v_b + (1 - beta) * tf.square(grads[1])
參數(shù)更新
w1.assign_sub(lr * grads[0] / tf.sqrt(v_w))
b1.assign_sub(lr * grads[1] / tf.sqrt(v_b))
5、Adam(同時(shí)結(jié)合了SGDM的一階動(dòng)量和RMSProp二階動(dòng)量,并增加了兩個(gè)修正項(xiàng),把修正后的一階動(dòng)量和二階動(dòng)量帶入?yún)?shù)更新公式)

m = βm + (1-β)g,v = βv +(1-β)g2
m(修正項(xiàng))=m/1-βt,v(修正項(xiàng))=m/1-βt

初始化參數(shù)
m_w, m_b = 0, 0
v_w, v_b = 0, 0
beta1, beta2 = 0.9, 0.999
delta_w, delta_b = 0, 0
更新的總batch數(shù)
global_step = 0
計(jì)算一階動(dòng)量
m_w = beta1 * m_w + (1 - beta1) * grads[0]
m_b = beta1 * m_b + (1 - beta1) * grads[1]
計(jì)算二階動(dòng)量
v_w = beta2 * v_w + (1 - beta2) * tf.square(grads[0])
v_b = beta2 * v_b + (1 - beta2) * tf.square(grads[1])
計(jì)算修正項(xiàng)
m_w_correction = m_w / (1 - tf.pow(beta1, int(global_step)))
m_b_correction = m_b / (1 - tf.pow(beta1, int(global_step)))
v_w_correction = v_w / (1 - tf.pow(beta2, int(global_step)))
v_b_correction = v_b / (1 - tf.pow(beta2, int(global_step)))
將修正項(xiàng)代入公式,參數(shù)更新
w1.assign_sub(lr * m_w_correction / tf.sqrt(v_w_correction))
b1.assign_sub(lr * m_b_correction / tf.sqrt(v_b_correction))
結(jié)果對(duì)比
loss及acc曲線.jpg

訓(xùn)練耗時(shí).png
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容