一、鏈接
1、tensorflow 保存訓(xùn)練模型ckpt 查看ckpt文件中的變量名和對應(yīng)值
https://www.cnblogs.com/adong7639/p/7764769.html
2、Tensorflow修改已訓(xùn)練模型變量名字的方法
https://zhuanlan.zhihu.com/p/36982683
3、如何修改在TensorFlow框架下訓(xùn)練保存的模型參數(shù)名稱
https://blog.csdn.net/jiongnima/article/details/86632517
二、原理介紹:
? ? ?? 在解釋代碼之前,先介紹一下要用到的兩個重要的接口:
1. tf.contrib.framework.list_variables。將已保存參數(shù)的(名稱,形狀)以列表的形式返回。在更新的TensorFlow版本中,該接口已經(jīng)被整合到了tf.train.list_variables里面。
2. tf.contrib.framework.load_variable??梢詡魅朊Q,返回讀取的已保存參數(shù)的值。在更新的TensorFlow版本中,該接口已經(jīng)被整合到了tf.train.load_variable里面。
? ?? 在修改保存的參數(shù)名稱時,做法分為以下6步:
1. 使用list_variables函數(shù)逐個讀出已保存的參數(shù)名稱
2. 使用load_variable函數(shù)逐個讀取已保存的參數(shù)值
3. 逐個修改參數(shù)名稱
4. 使用已修改的參數(shù)名稱,結(jié)合tf.Variable函數(shù)逐個重建參數(shù)
5. 將已重建的參數(shù)逐個加入新參數(shù)列表
6. 使用tf.train.Saver().save將新參數(shù)列表寫入硬盤
? ? ?? 下面放出筆者的代碼,在代碼中,筆者給DeepLab V2預(yù)訓(xùn)練的模型參數(shù)全加上了前綴“deeplab_v2”。在這里筆者使用的還是許久之前的DeepLab預(yù)訓(xùn)練模型,參數(shù)保存還是一個ckpt文件(deeplab_resnet.ckpt)。代碼如下:
import tensorflow as tf
import argparse
import os
parser = argparse.ArgumentParser(description='')
parser.add_argument("--checkpoint_path", default='../deeplab_resnet/deeplab_resnet.ckpt', help="restore ckpt") #原參數(shù)路徑
parser.add_argument("--new_checkpoint_path", default='../deeplab_resnet_altered/', help="path_for_new ckpt") #新參數(shù)保存路徑
parser.add_argument("--add_prefix", default='deeplab_v2/', help="prefix for addition") #新參數(shù)名稱中加入的前綴名
args = parser.parse_args()
def main():
? ? ? ? if not os.path.exists(args.new_checkpoint_path):
? ? ? ? ? ? ? ? os.makedirs(args.new_checkpoint_path)
? ? ?? with tf.Session() as sess:
? ? ? ? ? ? ?? new_var_list=[] #新建一個空列表存儲更新后的Variable變量
? ? ? ? ? ? ? for var_name, _ in tf.contrib.framework.list_variables(args.checkpoint_path): #得到checkpoint文件中所有的參數(shù)(名字,形狀)元組
? ? ? ? ? ? ? ? ? ? ?? var = tf.contrib.framework.load_variable(args.checkpoint_path, var_name) #得到上述參數(shù)的值
? ? ? ? ? ? ? ? ? ? ?? new_name = var_name
? ? ? ? ? ? ? ? ? ? ?? new_name = args.add_prefix + new_name #在這里加入了名稱前綴,大家可以自由地作修改
? ? ? ? ? ? ? ? ? ? ? ? #除了修改參數(shù)名稱,還可以修改參數(shù)值(var)
? ? ? ? ? ? ? ? ? ? ? ? print('Renaming %s to %s.' % (var_name, new_name))
? ? ? ? ? ? ? ? ? ? ? ? renamed_var = tf.Variable(var, name=new_name) #使用加入前綴的新名稱重新構(gòu)造了參數(shù)
? ? ? ? ? ? ? ? ? ? ?? new_var_list.append(renamed_var) #把賦予新名稱的參數(shù)加入空列表
? ? ? ? ? ? ? print('starting to write new checkpoint !')
? ? ? ? ? ? ?? saver = tf.train.Saver(var_list=new_var_list) #構(gòu)造一個保存器
? ? ? ? ? ? ?? sess.run(tf.global_variables_initializer()) #初始化一下參數(shù)(這一步必做)
? ? ? ? ? ? ?? model_name = 'deeplab_resnet_altered' #構(gòu)造一個保存的模型名稱
? ? ? ? ? ? ?? checkpoint_path = os.path.join(args.new_checkpoint_path, model_name) #構(gòu)造一下保存路徑
? ? ? ? ? ? ?? saver.save(sess, checkpoint_path) #直接進行保存
? ? ? ? ? ? ?? print("done !")
if __name__ == '__main__':
? ? main()
在終端下面運行一下代碼:

可以看到參數(shù)名稱都被重置了,加上了前綴“deeplab_v2”:

在代碼中設(shè)定的保存文件夾下,能夠查看已保存的新參數(shù)名稱的模型參數(shù):

? ? ?? 由于后來的TensorFlow框架在保存模型時已經(jīng)放棄了保存單個ckpt文件的做法,因此都是得到4個文件,如上所示。然后我們就可以在代碼中愉快地使用新參數(shù)名稱的模型進行初始化啦~
loader = tf.train.Saver(var_list=restore_vars) #設(shè)置一下要初始化哪些參數(shù)
checkpoint = tf.train.latest_checkpoint(args.checkpoint_path) #保存的新參數(shù)名的模型路徑
loader.restore(sess, ckpt_path) #初始化模型參數(shù)
? ? ?? 到這里,本篇博文就接近尾聲了。本篇博文主要講述了如何修改TensorFlow框架下訓(xùn)練的參數(shù)名稱,核心還是找出參數(shù)名->更改參數(shù)名->重建參數(shù)->保存。