????????升級版見:TensorFlow 使用 tf.estimator 訓(xùn)練模型(預(yù)訓(xùn)練 ResNet-50)。
????????前面的文章已經(jīng)說明了怎么使用 TensorFlow 來構(gòu)建、訓(xùn)練、保存、導(dǎo)出模型等,現(xiàn)在來說明怎么使用 TensorFlow 調(diào)用預(yù)訓(xùn)練模型來精調(diào)神經(jīng)網(wǎng)絡(luò)。為了簡單起見,以調(diào)用預(yù)訓(xùn)練的 ResNet-50 用于圖像分類為例,使用的模塊仍然是 tf.contrib.slim。
????????TensorFlow 的所有用于圖像分類的預(yù)訓(xùn)練模型的下載地址為 models/research/slim,包含常用的 VGG,Inception,ResNet,MobileNet 以及最新的 NasNet 模型等。要使用這些預(yù)訓(xùn)練模型的關(guān)鍵是將這些預(yù)訓(xùn)練的參數(shù)正確的導(dǎo)入到定義好的神經(jīng)網(wǎng)絡(luò),這可以通過函數(shù) slim.assign_from_checkpoint_fn 來方便的實(shí)現(xiàn)。下面,用代碼來說明。
????????所有代碼見 GitHub/finetune_classification。
一、Fine tuning 模型定義
????????前已提及,TensorFlow 所有預(yù)訓(xùn)練模型均在 GitHub 項(xiàng)目 models/research/slim,而其對應(yīng)的神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)則在其子文件夾 nets。我們以調(diào)用 ResNet-50 為例(其它模型類似),首先來定義網(wǎng)絡(luò)結(jié)構(gòu):
import tensorflow as tf
from tensorflow.contrib.slim import nets
slim = tf.contrib.slim
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
net, endpoints = nets.resnet_v1.resnet_v1_50(
preprocessed_inputs, num_classes=None,
is_training=self._is_training)
net = tf.squeeze(net, axis=[1, 2])
net = slim.fully_connected(net, num_outputs=self.num_classes,
activation_fn=None, scope='Predict')
prediction_dict = {'logits': net}
return prediction_dict
????????我們假設(shè)要分類的圖像有 self.num_classes 個(gè)類,隨機(jī)選擇一個(gè)批量的圖像,對這些圖像進(jìn)行預(yù)處理后,把它們作為參數(shù)傳入 predict 函數(shù),此時(shí)直接調(diào)用 TensorFlow-slim 封裝好的 nets.resnet_v1.resnet_v1_50 神經(jīng)網(wǎng)絡(luò)得到圖像特征,因?yàn)?ResNet-50 是用于 1000 個(gè)類的分類的,所以需要設(shè)置參數(shù) num_classes=None 禁用它的最后一個(gè)輸出層。我們假設(shè)輸入的圖像批量形狀為 [None, 224, 224, 3],則 resnet_v1_50 函數(shù)返回的形狀為 [None, 1, 1, 2048],為了輸入到全連接層,需要用函數(shù) tf.squeeze 去掉形狀為 1 的第 1,2 個(gè)索引維度。最后,連接再一個(gè)全連接層得到 self.num_classes 個(gè)類的預(yù)測輸出。
????????可以看到,使用 tf.contrib.slim 模塊,調(diào)用 ResNet-50 等神經(jīng)網(wǎng)絡(luò)變得異常簡單。而接下來的關(guān)鍵問題是怎么導(dǎo)入預(yù)訓(xùn)練的參數(shù),進(jìn)而使用我們自己的數(shù)據(jù)來對預(yù)訓(xùn)練模型進(jìn)行精調(diào)。在闡述怎么解決這個(gè)問題之前,先將整個(gè)模型定義的文件 model.py 列出以方便閱讀:
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 17:21:12 2018
@author: shirhe-lyh
"""
import tensorflow as tf
from tensorflow.contrib.slim import nets
import preprocessing
slim = tf.contrib.slim
class Model(object):
"""xxx definition."""
def __init__(self, num_classes, is_training,
fixed_resize_side=368,
default_image_size=336):
"""Constructor.
Args:
is_training: A boolean indicating whether the training version of
computation graph should be constructed.
num_classes: Number of classes.
"""
self._num_classes = num_classes
self._is_training = is_training
self._fixed_resize_side = fixed_resize_side
self._default_image_size = default_image_size
@property
def num_classes(self):
return self._num_classes
def preprocess(self, inputs):
"""preprocessing.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
preprocessed_inputs = preprocessing.preprocess_images(
inputs, self._default_image_size, self._default_image_size,
resize_side_min=self._fixed_resize_side,
is_training=self._is_training,
border_expand=True, normalize=False,
preserving_aspect_ratio_resize=False)
preprocessed_inputs = tf.cast(preprocessed_inputs, tf.float32)
return preprocessed_inputs
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
net, endpoints = nets.resnet_v1.resnet_v1_50(
preprocessed_inputs, num_classes=None,
is_training=self._is_training)
net = tf.squeeze(net, axis=[1, 2])
logits = slim.fully_connected(net, num_outputs=self.num_classes,
activation_fn=None, scope='Predict')
prediction_dict = {'logits': logits}
return prediction_dict
def postprocess(self, prediction_dict):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
logits = prediction_dict['logits']
logits = tf.nn.softmax(logits)
classes = tf.argmax(logits, axis=1)
postprocessed_dict = {'logits': logits,
'classes': classes}
return postprocessed_dict
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists_dict: A dict of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
logits = prediction_dict['logits']
slim.losses.sparse_softmax_cross_entropy(
logits=logits,
labels=groundtruth_lists,
scope='Loss')
loss = slim.losses.get_total_loss()
loss_dict = {'loss': loss}
return loss_dict
def accuracy(self, postprocessed_dict, groundtruth_lists):
"""Calculate accuracy.
Args:
postprocessed_dict: A dictionary containing the postprocessed
results
groundtruth_lists: A dict of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
accuracy: The scalar accuracy.
"""
classes = postprocessed_dict['classes']
accuracy = tf.reduce_mean(
tf.cast(tf.equal(classes, groundtruth_lists), dtype=tf.float32))
return accuracy
二、預(yù)訓(xùn)練模型導(dǎo)入
????????要將預(yù)訓(xùn)練模型 ResNet-50 的參數(shù)導(dǎo)入到前面定義好的模型,需要繼續(xù)借助 tf.contrib.slim 模塊,而且方法很簡單,只需要在訓(xùn)練函數(shù) slim.learning.train 中指定初始化參數(shù)來源函數(shù) init_fn 即可,而這可以通過函數(shù)
slim.assign_from_checkpoint_fn(model_path, var_list,
ignore_missing_vars=False,
reshape_variables=False)
很方便的實(shí)現(xiàn)。其中,第一個(gè)參數(shù) model_path 指定預(yù)訓(xùn)練模型 xxx.ckpt 文件的路徑,第二個(gè)參數(shù) var_list 指定需要導(dǎo)入對應(yīng)預(yù)訓(xùn)練參數(shù)的所有變量,通過函數(shù)
slim.get_variables_to_restore(include=None,
exclude=None)
可以快速指定,如果需要排除一些變量,也就是如果想讓某些變量隨機(jī)初始化而不是直接使用預(yù)訓(xùn)練模型來初始化,則直接在參數(shù) exclude 中指定即可。第三個(gè)參數(shù) ignore_missing_vars 非常重要,一定要將其設(shè)置為 True,也就是說,一定要忽略那些在定義的模型結(jié)構(gòu)中可能存在的而在預(yù)訓(xùn)練模型中沒有的變量,因?yàn)槿绻约憾x的模型結(jié)構(gòu)中存在一個(gè)參數(shù),而這些參數(shù)在預(yù)訓(xùn)練模型文件 xxx.ckpt 中沒有,那么如果不忽略的話,就會(huì)導(dǎo)入失?。ㄟ@樣的變量很多,比如卷積層的偏置項(xiàng) bias,一般預(yù)訓(xùn)練模型中沒有,所以需要忽略,即使用默認(rèn)的零初始化)。最后一個(gè)參數(shù) reshape_variabels 指定對某些變量進(jìn)行變形,這個(gè)一般用不到,使用默認(rèn)的 False 即可。
????????有了以上的基礎(chǔ),而且你還閱讀過上一篇文章 TensorFlow-slim 訓(xùn)練 CNN 分類模型(續(xù)) 的話,那么整個(gè)使用預(yù)訓(xùn)練模型的訓(xùn)練文件 train.py 就很容易寫出了,如下(重點(diǎn)在函數(shù) get_init_fn):
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 17:21:35 2018
@author: shirhe-lyh
"""
"""Train a CNN classification model via pretrained ResNet-50 model.
Example Usage:
---------------
python3 train.py \
--checkpoint_path: Path to pretrained ResNet-50 model.
--record_path: Path to training tfrecord file.
--logdir: Path to log directory.
"""
import os
import tensorflow as tf
import model
import preprocessing
slim = tf.contrib.slim
flags = tf.app.flags
flags.DEFINE_string('record_path',
'/data2/raycloud/jingxiong_datasets/AIChanllenger/' +
'AgriculturalDisease_trainingset/train.record',
'Path to training tfrecord file.')
flags.DEFINE_string('checkpoint_path',
'/home/jingxiong/python_project/model_zoo/' +
'resnet_v1_50.ckpt',
'Path to pretrained ResNet-50 model.')
flags.DEFINE_string('logdir', './training', 'Path to log directory.')
flags.DEFINE_float('learning_rate', 0.0001, 'Initial learning rate.')
flags.DEFINE_float(
'learning_rate_decay_factor', 0.1, 'Learning rate decay factor.')
flags.DEFINE_float(
'num_epochs_per_decay', 3.0,
'Number of epochs after which learning rate decays. Note: this flag counts '
'epochs per clone but aggregates per sync replicas. So 1.0 means that '
'each clone will go over full epoch individually, but replicas will go '
'once across all replicas.')
flags.DEFINE_integer('num_samples', 32739, 'Number of samples.')
flags.DEFINE_integer('num_steps', 10000, 'Number of steps.')
flags.DEFINE_integer('batch_size', 48, 'Batch size')
FLAGS = flags.FLAGS
def get_record_dataset(record_path,
reader=None,
num_samples=50000,
num_classes=7):
"""Get a tensorflow record file.
Args:
"""
if not reader:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/class/label':
tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
dtype=tf.int64))}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(image_key='image/encoded',
format_key='image/format'),
'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
items_to_descriptions = {
'image': 'An image with shape image_shape.',
'label': 'A single integer.'}
return slim.dataset.Dataset(
data_sources=record_path,
reader=reader,
decoder=decoder,
num_samples=num_samples,
num_classes=num_classes,
items_to_descriptions=items_to_descriptions,
labels_to_names=labels_to_names)
def configure_learning_rate(num_samples_per_epoch, global_step):
"""Configures the learning rate.
Modified from:
https://github.com/tensorflow/models/blob/master/research/slim/
train_image_classifier.py
Args:
num_samples_per_epoch: he number of samples in each epoch of training.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
"""
decay_steps = int(num_samples_per_epoch * FLAGS.num_epochs_per_decay /
FLAGS.batch_size)
return tf.train.exponential_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
def get_init_fn(checkpoint_exclude_scopes=None):
"""Returns a function run by che chief worker to warm-start the training.
Modified from:
https://github.com/tensorflow/models/blob/master/research/slim/
train_image_classifier.py
Note that the init_fn is only run when initializing the model during the
very first global step.
Args:
checkpoint_exclude_scopes: Comma-separated list of scopes of variables
to exclude when restoring from a checkpoint.
Returns:
An init function run by the supervisor.
"""
if FLAGS.checkpoint_path is None:
return None
# Warn the user if a checkpoint exists in the train_dir. Then we'll be
# ignoring the checkpoint anyway.
if tf.train.latest_checkpoint(FLAGS.logdir):
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists ' +
'in %s' % FLAGS.logdir)
return None
exclusions = []
if checkpoint_exclude_scopes:
exclusions = [scope.strip() for scope in
checkpoint_exclude_scopes.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
if not excluded:
variables_to_restore.append(var)
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.logging.info('Fine-tuning from %s' % checkpoint_path)
return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=True)
def get_trainable_variables(checkpoint_exclude_scopes=None):
"""Return the trainable variables.
Args:
checkpoint_exclude_scopes: Comma-separated list of scopes of variables
to exclude when restoring from a checkpoint.
Returns:
The trainable variables.
"""
exclusions = []
if checkpoint_exclude_scopes:
exclusions = [scope.strip() for scope in
checkpoint_exclude_scopes.split(',')]
variables_to_train = []
for var in tf.trainable_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
if not excluded:
variables_to_train.append(var)
return variables_to_train
def main(_):
# Specify which gpu to be used
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
num_samples = FLAGS.num_samples
dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples,
num_classes=61)
data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
image, label = data_provider.get(['image', 'label'])
# Border expand and resize
image = preprocessing.border_expand(image, resize=True, output_height=368,
output_width=368)
inputs, labels = tf.train.batch([image, label],
batch_size=FLAGS.batch_size,
#capacity=5*FLAGS.batch_size,
allow_smaller_final_batch=True)
cls_model = model.Model(is_training=True, num_classes=61)
preprocessed_inputs = cls_model.preprocess(inputs)
prediction_dict = cls_model.predict(preprocessed_inputs)
loss_dict = cls_model.loss(prediction_dict, labels)
loss = loss_dict['loss']
postprocessed_dict = cls_model.postprocess(prediction_dict)
acc = cls_model.accuracy(postprocessed_dict, labels)
tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', acc)
global_step = slim.create_global_step()
learning_rate = configure_learning_rate(num_samples, global_step)
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
momentum=0.9)
# optimizer = tf.train.AdamOptimizer(learning_rate=0.00001)
vars_to_train = get_trainable_variables()
train_op = slim.learning.create_train_op(loss, optimizer,
summarize_gradients=True,
variables_to_train=vars_to_train)
tf.summary.scalar('learning_rate', learning_rate)
init_fn = get_init_fn()
slim.learning.train(train_op=train_op, logdir=FLAGS.logdir,
init_fn=init_fn, number_of_steps=FLAGS.num_steps,
save_summaries_secs=20,
save_interval_secs=600)
if __name__ == '__main__':
tf.app.run()
????????函數(shù) get_init_fn 從指定路徑下讀取預(yù)訓(xùn)練模型。如果沒有指定預(yù)訓(xùn)練模型路徑(FLAGS.checkpoint_path),則返回 None(表示隨機(jī)初始化參數(shù))。如果在訓(xùn)練路徑下(FLAGS.logdir)已經(jīng)保存過訓(xùn)練后的模型,也返回 None(即忽略預(yù)訓(xùn)練模型參數(shù),而使用最后訓(xùn)練保存下來的模型初始化參數(shù))。如果你只想導(dǎo)入部分層的預(yù)訓(xùn)練參數(shù),而忽略其它層的預(yù)訓(xùn)練參數(shù),則可以設(shè)置 checkpoint_exclude_scopes 這個(gè)參數(shù),用來指定你要排除掉(即不需要導(dǎo)入預(yù)訓(xùn)練參數(shù))的那些層的名字,比如你要禁用第一卷積層,以及第一個(gè) block1,只需要指定:
checkpoint_exclude_scopes = 'resnet_v1_50/conv1,resnet_v1_50/block1'
init_fn = get_init_fn(checkpoint_exclude_scopes)
函數(shù) get_trainable_variables 的作用是獲取需要訓(xùn)練的變量,它默認(rèn)返回所有可訓(xùn)練的變量。當(dāng)你需要凍結(jié)一些層,讓這些層的參數(shù)不更新時(shí),通過參數(shù) checkpoint_exclude_scopes 指定,比如我想讓 ResNet-50 的 block1 和 block2/unit_1 凍結(jié)時(shí),通過:
scopes_to_freeze = 'resnet_v1_50/block1,resnet_v1_50/block2/unit_1'
vars_to_train = get_trainable_variables(scopes_to_freeze )
調(diào)用即可。
三、數(shù)據(jù)集以及訓(xùn)練
????????本文 GitHub/finetune_classification 上的代碼默認(rèn)使用 AI Challenger 全球AI挑戰(zhàn)賽/農(nóng)作物病害檢測 數(shù)據(jù)集。下載好數(shù)據(jù)集之后,執(zhí)行如下指令:
$ python3 generate_tfrecord.py \
--images_dir Path/to/AgriculturalDisease_trainingset/images \
--annotation_path Path/to/AgriculturalDisease_train_annotations.json \
--output_path Path/to/train.record
將訓(xùn)練集圖像寫入到 train.record 文件中。之后,執(zhí)行:
$ python3 train.py \
--record_path Path/to/train.record \
--checkpoint_path Path/to/pretrained_ResNet-50_model/resnet_v1_50.ckpt
開始訓(xùn)練。訓(xùn)練開始之后,會(huì)在當(dāng)前 train.py 路徑下生成一個(gè)文件夾 training 用來保存訓(xùn)練模型。需要額外說明的是,訓(xùn)練過程不會(huì)在終端輸出準(zhǔn)確率、損失等數(shù)據(jù),需要在終端執(zhí)行:
$ tensorboard --logdir Path/to/training
之后,打開返回的 http 鏈接在瀏覽器查看準(zhǔn)確率、損失等訓(xùn)練曲線(訓(xùn)練過程中,訓(xùn)練結(jié)束后都可查看)。訓(xùn)練正常啟動(dòng)后,每 10 分鐘會(huì)保存一次模型到 training 文件夾(諸如 model.ckpt-xxx 之類的文件),你可以選擇使用其中的 model.ckpt-xxx 模型來直接進(jìn)行預(yù)測,也可以選擇將 model.ckpt-xxx 轉(zhuǎn)化為 .pb 文件之后再進(jìn)行預(yù)測,如果選擇轉(zhuǎn)化,執(zhí)行:
$ python3 export_inference_graph.py \
--trained_checkpoint_prefix Path/to/model.ckpt-xxx \
--output_directory Path/to/exported_pb_file_directory
之后,在指定的輸出路徑下(Path/to/exported_pb_file_directory)會(huì)生成一個(gè)文件夾,該文件內(nèi)的 frozen_inference_graph.pb 即是轉(zhuǎn)化成的固化模型文件(固化指的是所有參數(shù)都轉(zhuǎn)化成了常數(shù))。之后就可以使用 evaluate.py 或者 predict.py 進(jìn)行驗(yàn)證或預(yù)測了。
????????如果你使用其它數(shù)據(jù)集,整個(gè)訓(xùn)練過程和上面的步驟一樣,只需要根據(jù)具體的標(biāo)注文件來修改文件data_provider.py 中函數(shù) provide,該函數(shù)返回一個(gè)字典,其中 key 代表訓(xùn)練數(shù)據(jù)集中圖像的路徑,value 代表圖像對應(yīng)的類標(biāo)號;其它參數(shù),比如訓(xùn)練圖像個(gè)數(shù),類別數(shù)目,學(xué)習(xí)率等,在 train.py 中修改。
預(yù)告:下一篇文章將要介紹如何用 TensorFlow 來訓(xùn)練多任務(wù)多標(biāo)簽?zāi)P?,敬請期待?/p>