模型預(yù)覽
假設(shè)新模型分encoder+decoder兩部分。其中encoder模塊要導(dǎo)入預(yù)訓(xùn)練的參數(shù),并且數(shù)值固定,不參與訓(xùn)練。decoder則是在encoder的基礎(chǔ)上增加的分支,需要通過數(shù)據(jù)訓(xùn)練不斷優(yōu)化參數(shù)。
大體步驟
主要分為四個(gè)步驟:
1. 繪制整體網(wǎng)絡(luò)圖
2. 固定encoder參數(shù)
3. 導(dǎo)入encoder參數(shù)
4. 訓(xùn)練 + 模型保存
代碼
part1:畫圖
#設(shè)置網(wǎng)絡(luò)整體結(jié)構(gòu)....
part2:固定參數(shù)
# 選擇decode部分的參數(shù)
train_var_list = [var for var in tf.trainable_variables() if 'decode' in var.name]
# 優(yōu)化器只優(yōu)化選中的參數(shù)list
with tf.control_dependencies():
optimizer = optimizer.minimize(loss, global_step=global_step, var_list = train_var_list) #自行選擇優(yōu)化器
part3 導(dǎo)入舊參
# 選擇encode部分參數(shù)
no_train_var = [var for var in tf.global_variables() if 'encode' in var.name] #這里的'encode'是在設(shè)置網(wǎng)絡(luò)過程中某個(gè)scope的命名
# saver選擇要導(dǎo)入的參數(shù)
saver = tf.train.Saver(no_train_var)
# 對(duì)整個(gè)網(wǎng)絡(luò)所有參數(shù)做初始化
init = tf.global_variables_initializer()
sess.run(init)
# encode部分參數(shù)覆蓋
saver.restore(sess, weights_path) #這里的weights_path是ckpt文件保存路徑
part4 訓(xùn)練+保存
# 訓(xùn)練......
# 保存模型
# 重新定義saver為選中所有參數(shù),否則最后將只保存no_train_var
saver = tf.train.Saver()
saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
其他
- 對(duì)于該網(wǎng)絡(luò)還有另外一種方法:encode前向傳播保存結(jié)果,將其作為decode網(wǎng)絡(luò)輸入,進(jìn)行訓(xùn)練。
- 模型導(dǎo)入還有其他方法,可參考https://blog.csdn.net/CV_YOU/article/details/80698942。
不同類型的模型(npy, ckpt)導(dǎo)入保存方式有差異。 - 固定參數(shù)還可以在構(gòu)建網(wǎng)絡(luò)的時(shí)候選擇變量的trainable為False,或者設(shè)置變量學(xué)習(xí)率為0.
參數(shù)導(dǎo)入方法2
當(dāng)預(yù)訓(xùn)練模型和新模型的圖不同時(shí),無法用Saver導(dǎo)入?yún)?shù),這時(shí)候要用到tf.assign函數(shù)。
假設(shè)預(yù)訓(xùn)練模型只有encode部分,新模型encode+decode。遍歷模型參數(shù),用預(yù)訓(xùn)練參數(shù)進(jìn)行替換。
代碼
part3 導(dǎo)入舊參
# 導(dǎo)入所有參數(shù)
saver = tf.train.Saver()
# 對(duì)整個(gè)網(wǎng)絡(luò)所有參數(shù)做初始化
init = tf.global_variables_initializer()
sess.run(init)
#讀取預(yù)訓(xùn)練模型
reader = pywrap_tensorflow.NewCheckpointReader(weights_path)
# 逐層遍歷參數(shù)并替換
for vv in tf.trainable_variables():
if 'encode' in vv.name:
weights = reader.get_tensor(weights_key)
_op = tf.assign(vv, weights)
sess.run(_op)