tensorflow固定部分參數(shù)進(jìn)行訓(xùn)練

模型預(yù)覽

假設(shè)新模型分encoder+decoder兩部分。其中encoder模塊要導(dǎo)入預(yù)訓(xùn)練的參數(shù),并且數(shù)值固定,不參與訓(xùn)練。decoder則是在encoder的基礎(chǔ)上增加的分支,需要通過數(shù)據(jù)訓(xùn)練不斷優(yōu)化參數(shù)。

大體步驟

主要分為四個(gè)步驟:
1. 繪制整體網(wǎng)絡(luò)圖
2. 固定encoder參數(shù)
3. 導(dǎo)入encoder參數(shù)
4. 訓(xùn)練 + 模型保存

代碼

part1:畫圖

#設(shè)置網(wǎng)絡(luò)整體結(jié)構(gòu)....

part2:固定參數(shù)

# 選擇decode部分的參數(shù)
train_var_list = [var for var in tf.trainable_variables() if 'decode' in var.name] 
# 優(yōu)化器只優(yōu)化選中的參數(shù)list
with tf.control_dependencies():
      optimizer = optimizer.minimize(loss, global_step=global_step, var_list = train_var_list) #自行選擇優(yōu)化器

part3 導(dǎo)入舊參

# 選擇encode部分參數(shù)
no_train_var = [var for var in tf.global_variables() if 'encode' in var.name]  #這里的'encode'是在設(shè)置網(wǎng)絡(luò)過程中某個(gè)scope的命名
# saver選擇要導(dǎo)入的參數(shù)
saver = tf.train.Saver(no_train_var)
# 對(duì)整個(gè)網(wǎng)絡(luò)所有參數(shù)做初始化
init = tf.global_variables_initializer()
sess.run(init)
# encode部分參數(shù)覆蓋
saver.restore(sess, weights_path) #這里的weights_path是ckpt文件保存路徑

part4 訓(xùn)練+保存

# 訓(xùn)練......
# 保存模型
# 重新定義saver為選中所有參數(shù),否則最后將只保存no_train_var
saver = tf.train.Saver()
saver.save(sess=sess, save_path=model_save_path, global_step=epoch) 

其他

  1. 對(duì)于該網(wǎng)絡(luò)還有另外一種方法:encode前向傳播保存結(jié)果,將其作為decode網(wǎng)絡(luò)輸入,進(jìn)行訓(xùn)練。
  2. 模型導(dǎo)入還有其他方法,可參考https://blog.csdn.net/CV_YOU/article/details/80698942。
    不同類型的模型(npy, ckpt)導(dǎo)入保存方式有差異。
  3. 固定參數(shù)還可以在構(gòu)建網(wǎng)絡(luò)的時(shí)候選擇變量的trainable為False,或者設(shè)置變量學(xué)習(xí)率為0.

參數(shù)導(dǎo)入方法2

當(dāng)預(yù)訓(xùn)練模型和新模型的圖不同時(shí),無法用Saver導(dǎo)入?yún)?shù),這時(shí)候要用到tf.assign函數(shù)。
假設(shè)預(yù)訓(xùn)練模型只有encode部分,新模型encode+decode。遍歷模型參數(shù),用預(yù)訓(xùn)練參數(shù)進(jìn)行替換。

代碼

part3 導(dǎo)入舊參

# 導(dǎo)入所有參數(shù)
saver = tf.train.Saver()
# 對(duì)整個(gè)網(wǎng)絡(luò)所有參數(shù)做初始化
init = tf.global_variables_initializer()
sess.run(init)
#讀取預(yù)訓(xùn)練模型
reader = pywrap_tensorflow.NewCheckpointReader(weights_path)
# 逐層遍歷參數(shù)并替換
for vv in tf.trainable_variables():
    if 'encode' in vv.name:
        weights = reader.get_tensor(weights_key)
        _op = tf.assign(vv, weights)
        sess.run(_op)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容