TensorFlow中的Learning rate decay介紹

在模型訓(xùn)練DL模型時(shí),隨著模型的epoch迭代,往往會(huì)推薦逐漸減小learning rate,在一些實(shí)驗(yàn)中也證明確實(shí)對(duì)訓(xùn)練的收斂有正向效果。對(duì)于learning rate的改變,有定制衰減規(guī)則直接控制的,也有通過算法自動(dòng)尋優(yōu)的。這里主要介紹下TF自帶的兩種衰減方法:指數(shù)衰減和多項(xiàng)式衰減。

指數(shù)衰減(tf.train.exponential_decay)

方法原型:

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None){#exponential_decay}

參數(shù):

learning_rate:初始值

global_step:全局step數(shù)(每個(gè)step對(duì)應(yīng)一次batch)

decay_steps:learning rate更新的step周期,即每隔多少step更新一次learning rate的值

decay_rate:指數(shù)衰減參數(shù)(對(duì)應(yīng)α^t中的α)

staircase:是否階梯性更新learning rate,也就是global_step/decay_steps的結(jié)果是float型還是向下取整

計(jì)算公式:

decayed_learning_rate=learning_rate*decay_rate^(global_step/decay_steps)


多項(xiàng)式衰減(tf.train.polynomial_decay)

方法原型:

tf.train.polynomial_decay(learning_rate, global_step, decay_steps, end_learning_rate=0.0001, power=1.0, cycle=False, name=None){#polynomial_decay}

參數(shù):

learning_rate:初始值

global_step:全局step數(shù)(每個(gè)step對(duì)應(yīng)一次batch)

decay_steps:learning rate更新的step周期,即每隔多少step更新一次learning rate的值

end_learning_rate:衰減最終值

power:多項(xiàng)式衰減系數(shù)(對(duì)應(yīng)(1-t)^α的α)

cycle:step超出decay_steps之后是否繼續(xù)循環(huán)t

計(jì)算公式:

當(dāng)cycle=False時(shí)

global_step=min(global_step, decay_steps)

decayed_learning_rate=

(learning_rate-end_learning_rate)*(1-global_step/decay_steps)^(power)+end_learning_rate

當(dāng)cycle=True時(shí)

decay_steps=decay_steps*ceil(global_step/decay_steps)

decayed_learning_rate=

(learning_rate-end_learning_rate)*(1-global_step/decay_steps)^(power)+end_learning_rate

注:ceil是向上取整


更新lr的一般代碼:

def _configure_learning_rate(num_samples_per_epoch, global_step):

"""Configures the learning rate.

Args:

num_samples_per_epoch: The number of samples in each epoch of training.

global_step: The global_step tensor.

Returns:

A `Tensor` representing the learning rate.

Raises:

ValueError: if

"""

decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *

FLAGS.num_epochs_per_decay)

if FLAGS.sync_replicas:

decay_steps /= FLAGS.replicas_to_aggregate

if FLAGS.learning_rate_decay_type == 'exponential':

return tf.train.exponential_decay(FLAGS.learning_rate,

global_step,

decay_steps,

FLAGS.learning_rate_decay_factor,

staircase=True,

name='exponential_decay_learning_rate')

elif FLAGS.learning_rate_decay_type == 'fixed':

return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')

elif FLAGS.learning_rate_decay_type == 'polynomial':

return tf.train.polynomial_decay(FLAGS.learning_rate,

global_step,

decay_steps,

FLAGS.end_learning_rate,

power=1.0,

cycle=False,

name='polynomial_decay_learning_rate')

else:

raise ValueError('learning_rate_decay_type [%s] was not recognized',

FLAGS.learning_rate_decay_type)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 如果你覺得這篇文章對(duì)你有幫助可隨意轉(zhuǎn)載,但請(qǐng)注明出處和作者。 我們?cè)谧鰣?bào)告,PPT,或者完成一些設(shè)計(jì)的時(shí)候總會(huì)插入...
    liliboy閱讀 722評(píng)論 0 1
  • iOS中有幾種線程鎖:@synchronized、NSLock以及NSRecursiveLock(遞歸鎖)。本文用...
    rapunzelyeah閱讀 212評(píng)論 0 2
  • 本以為忘了就忘了,放下了就放下了,可為什么當(dāng)在某個(gè)特定的環(huán)境中,再加上那個(gè)時(shí)候的音樂,此時(shí)若是想起了,竟也會(huì)牽一發(fā)...
    mj小鴿子閱讀 258評(píng)論 0 0

友情鏈接更多精彩內(nèi)容