TensorFlow 使用預(yù)訓(xùn)練模型 ResNet-50

????????升級版見:TensorFlow 使用 tf.estimator 訓(xùn)練模型(預(yù)訓(xùn)練 ResNet-50)。

????????前面的文章已經(jīng)說明了怎么使用 TensorFlow 來構(gòu)建、訓(xùn)練、保存、導(dǎo)出模型等,現(xiàn)在來說明怎么使用 TensorFlow 調(diào)用預(yù)訓(xùn)練模型來精調(diào)神經(jīng)網(wǎng)絡(luò)。為了簡單起見,以調(diào)用預(yù)訓(xùn)練的 ResNet-50 用于圖像分類為例,使用的模塊仍然是 tf.contrib.slim。

????????TensorFlow 的所有用于圖像分類的預(yù)訓(xùn)練模型的下載地址為 models/research/slim,包含常用的 VGG,Inception,ResNet,MobileNet 以及最新的 NasNet 模型等。要使用這些預(yù)訓(xùn)練模型的關(guān)鍵是將這些預(yù)訓(xùn)練的參數(shù)正確的導(dǎo)入到定義好的神經(jīng)網(wǎng)絡(luò),這可以通過函數(shù) slim.assign_from_checkpoint_fn 來方便的實(shí)現(xiàn)。下面,用代碼來說明。

????????所有代碼見 GitHub/finetune_classification

一、Fine tuning 模型定義

????????前已提及,TensorFlow 所有預(yù)訓(xùn)練模型均在 GitHub 項(xiàng)目 models/research/slim,而其對應(yīng)的神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)則在其子文件夾 nets。我們以調(diào)用 ResNet-50 為例(其它模型類似),首先來定義網(wǎng)絡(luò)結(jié)構(gòu):

import tensorflow as tf

from tensorflow.contrib.slim import nets

slim = tf.contrib.slim


def predict(self, preprocessed_inputs):
    """Predict prediction tensors from inputs tensor.

    Outputs of this function can be passed to loss or postprocess functions.

    Args:
        preprocessed_inputs: A float32 tensor with shape [batch_size,
            height, width, num_channels] representing a batch of images.
            
    Returns:
        prediction_dict: A dictionary holding prediction tensors to be
            passed to the Loss or Postprocess functions.
    """
    net, endpoints = nets.resnet_v1.resnet_v1_50(
        preprocessed_inputs, num_classes=None,
        is_training=self._is_training)
    net = tf.squeeze(net, axis=[1, 2])
    net = slim.fully_connected(net, num_outputs=self.num_classes,
                               activation_fn=None, scope='Predict')
    prediction_dict = {'logits': net}
    return prediction_dict

????????我們假設(shè)要分類的圖像有 self.num_classes 個(gè)類,隨機(jī)選擇一個(gè)批量的圖像,對這些圖像進(jìn)行預(yù)處理后,把它們作為參數(shù)傳入 predict 函數(shù),此時(shí)直接調(diào)用 TensorFlow-slim 封裝好的 nets.resnet_v1.resnet_v1_50 神經(jīng)網(wǎng)絡(luò)得到圖像特征,因?yàn)?ResNet-50 是用于 1000 個(gè)類的分類的,所以需要設(shè)置參數(shù) num_classes=None 禁用它的最后一個(gè)輸出層。我們假設(shè)輸入的圖像批量形狀為 [None, 224, 224, 3],則 resnet_v1_50 函數(shù)返回的形狀為 [None, 1, 1, 2048],為了輸入到全連接層,需要用函數(shù) tf.squeeze 去掉形狀為 1 的第 1,2 個(gè)索引維度。最后,連接再一個(gè)全連接層得到 self.num_classes 個(gè)類的預(yù)測輸出。

????????可以看到,使用 tf.contrib.slim 模塊,調(diào)用 ResNet-50 等神經(jīng)網(wǎng)絡(luò)變得異常簡單。而接下來的關(guān)鍵問題是怎么導(dǎo)入預(yù)訓(xùn)練的參數(shù),進(jìn)而使用我們自己的數(shù)據(jù)來對預(yù)訓(xùn)練模型進(jìn)行精調(diào)。在闡述怎么解決這個(gè)問題之前,先將整個(gè)模型定義的文件 model.py 列出以方便閱讀:

# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 17:21:12 2018

@author: shirhe-lyh
"""

import tensorflow as tf

from tensorflow.contrib.slim import nets

import preprocessing

slim = tf.contrib.slim
    
        
class Model(object):
    """xxx definition."""
    
    def __init__(self, num_classes, is_training,
                 fixed_resize_side=368,
                 default_image_size=336):
        """Constructor.
        
        Args:
            is_training: A boolean indicating whether the training version of
                computation graph should be constructed.
            num_classes: Number of classes.
        """
        self._num_classes = num_classes
        self._is_training = is_training
        self._fixed_resize_side = fixed_resize_side
        self._default_image_size = default_image_size
        
    @property
    def num_classes(self):
        return self._num_classes
        
    def preprocess(self, inputs):
        """preprocessing.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        preprocessed_inputs = preprocessing.preprocess_images(
            inputs, self._default_image_size, self._default_image_size, 
            resize_side_min=self._fixed_resize_side,
            is_training=self._is_training,
            border_expand=True, normalize=False,
            preserving_aspect_ratio_resize=False)
        preprocessed_inputs = tf.cast(preprocessed_inputs, tf.float32)
        return preprocessed_inputs
    
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
            net, endpoints = nets.resnet_v1.resnet_v1_50(
                preprocessed_inputs, num_classes=None,
                is_training=self._is_training)
        net = tf.squeeze(net, axis=[1, 2])
        logits = slim.fully_connected(net, num_outputs=self.num_classes,
                                      activation_fn=None, scope='Predict')
        prediction_dict = {'logits': logits}
        return prediction_dict
    
    def postprocess(self, prediction_dict):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            **params: Additional keyword arguments for specific implementations
                of specified models.
                
        Returns:
            A dictionary containing the postprocessed results.
        """
        logits = prediction_dict['logits']
        logits = tf.nn.softmax(logits)
        classes = tf.argmax(logits, axis=1)
        postprocessed_dict = {'logits': logits,
                              'classes': classes}
        return postprocessed_dict
    
    def loss(self, prediction_dict, groundtruth_lists):
        """Compute scalar loss tensors with respect to provided groundtruth.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            groundtruth_lists_dict: A dict of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            A dictionary mapping strings (loss names) to scalar tensors
                representing loss values.
        """
        logits = prediction_dict['logits']
        slim.losses.sparse_softmax_cross_entropy(
            logits=logits, 
            labels=groundtruth_lists,
            scope='Loss')
        loss = slim.losses.get_total_loss()
        loss_dict = {'loss': loss}
        return loss_dict
        
    def accuracy(self, postprocessed_dict, groundtruth_lists):
        """Calculate accuracy.
        
        Args:
            postprocessed_dict: A dictionary containing the postprocessed 
                results
            groundtruth_lists: A dict of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            accuracy: The scalar accuracy.
        """
        classes = postprocessed_dict['classes']
        accuracy = tf.reduce_mean(
            tf.cast(tf.equal(classes, groundtruth_lists), dtype=tf.float32))
        return accuracy

二、預(yù)訓(xùn)練模型導(dǎo)入

????????要將預(yù)訓(xùn)練模型 ResNet-50 的參數(shù)導(dǎo)入到前面定義好的模型,需要繼續(xù)借助 tf.contrib.slim 模塊,而且方法很簡單,只需要在訓(xùn)練函數(shù) slim.learning.train 中指定初始化參數(shù)來源函數(shù) init_fn 即可,而這可以通過函數(shù)

slim.assign_from_checkpoint_fn(model_path, var_list,
                               ignore_missing_vars=False,
                               reshape_variables=False)

很方便的實(shí)現(xiàn)。其中,第一個(gè)參數(shù) model_path 指定預(yù)訓(xùn)練模型 xxx.ckpt 文件的路徑,第二個(gè)參數(shù) var_list 指定需要導(dǎo)入對應(yīng)預(yù)訓(xùn)練參數(shù)的所有變量,通過函數(shù)

slim.get_variables_to_restore(include=None,
                              exclude=None)

可以快速指定,如果需要排除一些變量,也就是如果想讓某些變量隨機(jī)初始化而不是直接使用預(yù)訓(xùn)練模型來初始化,則直接在參數(shù) exclude 中指定即可。第三個(gè)參數(shù) ignore_missing_vars 非常重要,一定要將其設(shè)置為 True,也就是說,一定要忽略那些在定義的模型結(jié)構(gòu)中可能存在的而在預(yù)訓(xùn)練模型中沒有的變量,因?yàn)槿绻约憾x的模型結(jié)構(gòu)中存在一個(gè)參數(shù),而這些參數(shù)在預(yù)訓(xùn)練模型文件 xxx.ckpt 中沒有,那么如果不忽略的話,就會(huì)導(dǎo)入失?。ㄟ@樣的變量很多,比如卷積層的偏置項(xiàng) bias,一般預(yù)訓(xùn)練模型中沒有,所以需要忽略,即使用默認(rèn)的零初始化)。最后一個(gè)參數(shù) reshape_variabels 指定對某些變量進(jìn)行變形,這個(gè)一般用不到,使用默認(rèn)的 False 即可。

????????有了以上的基礎(chǔ),而且你還閱讀過上一篇文章 TensorFlow-slim 訓(xùn)練 CNN 分類模型(續(xù)) 的話,那么整個(gè)使用預(yù)訓(xùn)練模型的訓(xùn)練文件 train.py 就很容易寫出了,如下(重點(diǎn)在函數(shù) get_init_fn):

# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 17:21:35 2018

@author: shirhe-lyh
"""

"""Train a CNN classification model via pretrained ResNet-50 model.

Example Usage:
---------------
python3 train.py \
    --checkpoint_path: Path to pretrained ResNet-50 model.
    --record_path: Path to training tfrecord file.
    --logdir: Path to log directory.
"""

import os
import tensorflow as tf

import model
import preprocessing

slim = tf.contrib.slim
flags = tf.app.flags

flags.DEFINE_string('record_path', 
                    '/data2/raycloud/jingxiong_datasets/AIChanllenger/' +
                    'AgriculturalDisease_trainingset/train.record',
                    'Path to training tfrecord file.')
flags.DEFINE_string('checkpoint_path', 
                    '/home/jingxiong/python_project/model_zoo/' +
                    'resnet_v1_50.ckpt', 
                    'Path to pretrained ResNet-50 model.')
flags.DEFINE_string('logdir', './training', 'Path to log directory.')
flags.DEFINE_float('learning_rate', 0.0001, 'Initial learning rate.')
flags.DEFINE_float(
    'learning_rate_decay_factor', 0.1, 'Learning rate decay factor.')
flags.DEFINE_float(
    'num_epochs_per_decay', 3.0,
    'Number of epochs after which learning rate decays. Note: this flag counts '
    'epochs per clone but aggregates per sync replicas. So 1.0 means that '
    'each clone will go over full epoch individually, but replicas will go '
    'once across all replicas.')
flags.DEFINE_integer('num_samples', 32739, 'Number of samples.')
flags.DEFINE_integer('num_steps', 10000, 'Number of steps.')
flags.DEFINE_integer('batch_size', 48, 'Batch size')

FLAGS = flags.FLAGS


def get_record_dataset(record_path,
                       reader=None, 
                       num_samples=50000, 
                       num_classes=7):
    """Get a tensorflow record file.
    
    Args:
        
    """
    if not reader:
        reader = tf.TFRecordReader
        
    keys_to_features = {
        'image/encoded': 
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': 
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/class/label': 
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                               dtype=tf.int64))}
        
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(image_key='image/encoded',
                                              format_key='image/format'),
        'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)
    
    labels_to_names = None
    items_to_descriptions = {
        'image': 'An image with shape image_shape.',
        'label': 'A single integer.'}
    return slim.dataset.Dataset(
        data_sources=record_path,
        reader=reader,
        decoder=decoder,
        num_samples=num_samples,
        num_classes=num_classes,
        items_to_descriptions=items_to_descriptions,
        labels_to_names=labels_to_names)
    
    
def configure_learning_rate(num_samples_per_epoch, global_step):
    """Configures the learning rate.
    
    Modified from:
        https://github.com/tensorflow/models/blob/master/research/slim/
        train_image_classifier.py
    
    Args:
        num_samples_per_epoch: he number of samples in each epoch of training.
        global_step: The global_step tensor.
        
    Returns:
        A `Tensor` representing the learning rate.
    """
    decay_steps = int(num_samples_per_epoch * FLAGS.num_epochs_per_decay /
                      FLAGS.batch_size)
    return tf.train.exponential_decay(FLAGS.learning_rate,
                                      global_step,
                                      decay_steps,
                                      FLAGS.learning_rate_decay_factor,
                                      staircase=True,
                                      name='exponential_decay_learning_rate')
    
    
def get_init_fn(checkpoint_exclude_scopes=None):
    """Returns a function run by che chief worker to warm-start the training.
    
    Modified from:
        https://github.com/tensorflow/models/blob/master/research/slim/
        train_image_classifier.py
    
    Note that the init_fn is only run when initializing the model during the 
    very first global step.
    
    Args:
        checkpoint_exclude_scopes: Comma-separated list of scopes of variables
            to exclude when restoring from a checkpoint.
    
    Returns:
        An init function run by the supervisor.
    """
    if FLAGS.checkpoint_path is None:
        return None
    
    # Warn the user if a checkpoint exists in the train_dir. Then we'll be
    # ignoring the checkpoint anyway.
    if tf.train.latest_checkpoint(FLAGS.logdir):
        tf.logging.info(
            'Ignoring --checkpoint_path because a checkpoint already exists ' +
            'in %s' % FLAGS.logdir)
        return None
    
    exclusions = []
    if checkpoint_exclude_scopes:
        exclusions = [scope.strip() for scope in 
                     checkpoint_exclude_scopes.split(',')]
    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
        if not excluded:
            variables_to_restore.append(var)
    
    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
        checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Fine-tuning from %s' % checkpoint_path)
    
    return slim.assign_from_checkpoint_fn(
        checkpoint_path,
        variables_to_restore,
        ignore_missing_vars=True)


def get_trainable_variables(checkpoint_exclude_scopes=None):
    """Return the trainable variables.
    
    Args:
        checkpoint_exclude_scopes: Comma-separated list of scopes of variables
            to exclude when restoring from a checkpoint.
    
    Returns:
        The trainable variables.
    """
    exclusions = []
    if checkpoint_exclude_scopes:
        exclusions = [scope.strip() for scope in 
                     checkpoint_exclude_scopes.split(',')]
    variables_to_train = []
    for var in tf.trainable_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
        if not excluded:
            variables_to_train.append(var)
    return variables_to_train


def main(_):
    # Specify which gpu to be used
    os.environ["CUDA_VISIBLE_DEVICES"] = '1'
    
    num_samples = FLAGS.num_samples
    dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples, 
                                 num_classes=61)
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    image, label = data_provider.get(['image', 'label'])
    
    # Border expand and resize
    image = preprocessing.border_expand(image, resize=True, output_height=368,
                                        output_width=368)
        
    inputs, labels = tf.train.batch([image, label],
                                    batch_size=FLAGS.batch_size,
                                    #capacity=5*FLAGS.batch_size,
                                    allow_smaller_final_batch=True)
    
    cls_model = model.Model(is_training=True, num_classes=61)
    preprocessed_inputs = cls_model.preprocess(inputs)
    prediction_dict = cls_model.predict(preprocessed_inputs)
    loss_dict = cls_model.loss(prediction_dict, labels)
    loss = loss_dict['loss']
    postprocessed_dict = cls_model.postprocess(prediction_dict)
    acc = cls_model.accuracy(postprocessed_dict, labels)
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('accuracy', acc)

    global_step = slim.create_global_step()
    learning_rate = configure_learning_rate(num_samples, global_step)
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, 
                                           momentum=0.9)
#    optimizer = tf.train.AdamOptimizer(learning_rate=0.00001)
    vars_to_train = get_trainable_variables()
    train_op = slim.learning.create_train_op(loss, optimizer,
                                             summarize_gradients=True,
                                             variables_to_train=vars_to_train)
    tf.summary.scalar('learning_rate', learning_rate)
    
    init_fn = get_init_fn()
    
    slim.learning.train(train_op=train_op, logdir=FLAGS.logdir, 
                        init_fn=init_fn, number_of_steps=FLAGS.num_steps,
                        save_summaries_secs=20,
                        save_interval_secs=600)
    
if __name__ == '__main__':
    tf.app.run()

????????函數(shù) get_init_fn 從指定路徑下讀取預(yù)訓(xùn)練模型。如果沒有指定預(yù)訓(xùn)練模型路徑(FLAGS.checkpoint_path),則返回 None(表示隨機(jī)初始化參數(shù))。如果在訓(xùn)練路徑下(FLAGS.logdir)已經(jīng)保存過訓(xùn)練后的模型,也返回 None(即忽略預(yù)訓(xùn)練模型參數(shù),而使用最后訓(xùn)練保存下來的模型初始化參數(shù))。如果你只想導(dǎo)入部分層的預(yù)訓(xùn)練參數(shù),而忽略其它層的預(yù)訓(xùn)練參數(shù),則可以設(shè)置 checkpoint_exclude_scopes 這個(gè)參數(shù),用來指定你要排除掉(即不需要導(dǎo)入預(yù)訓(xùn)練參數(shù))的那些層的名字,比如你要禁用第一卷積層,以及第一個(gè) block1,只需要指定:

checkpoint_exclude_scopes = 'resnet_v1_50/conv1,resnet_v1_50/block1'
init_fn = get_init_fn(checkpoint_exclude_scopes)

函數(shù) get_trainable_variables 的作用是獲取需要訓(xùn)練的變量,它默認(rèn)返回所有可訓(xùn)練的變量。當(dāng)你需要凍結(jié)一些層,讓這些層的參數(shù)不更新時(shí),通過參數(shù) checkpoint_exclude_scopes 指定,比如我想讓 ResNet-50 的 block1block2/unit_1 凍結(jié)時(shí),通過:

scopes_to_freeze = 'resnet_v1_50/block1,resnet_v1_50/block2/unit_1'
vars_to_train = get_trainable_variables(scopes_to_freeze )

調(diào)用即可。

三、數(shù)據(jù)集以及訓(xùn)練

????????本文 GitHub/finetune_classification 上的代碼默認(rèn)使用 AI Challenger 全球AI挑戰(zhàn)賽/農(nóng)作物病害檢測 數(shù)據(jù)集。下載好數(shù)據(jù)集之后,執(zhí)行如下指令:

$ python3 generate_tfrecord.py \
    --images_dir Path/to/AgriculturalDisease_trainingset/images \
    --annotation_path Path/to/AgriculturalDisease_train_annotations.json \
    --output_path Path/to/train.record

將訓(xùn)練集圖像寫入到 train.record 文件中。之后,執(zhí)行:

$ python3 train.py \
    --record_path Path/to/train.record \
    --checkpoint_path Path/to/pretrained_ResNet-50_model/resnet_v1_50.ckpt

開始訓(xùn)練。訓(xùn)練開始之后,會(huì)在當(dāng)前 train.py 路徑下生成一個(gè)文件夾 training 用來保存訓(xùn)練模型。需要額外說明的是,訓(xùn)練過程不會(huì)在終端輸出準(zhǔn)確率、損失等數(shù)據(jù),需要在終端執(zhí)行:

$ tensorboard --logdir Path/to/training

之后,打開返回的 http 鏈接在瀏覽器查看準(zhǔn)確率、損失等訓(xùn)練曲線(訓(xùn)練過程中,訓(xùn)練結(jié)束后都可查看)。訓(xùn)練正常啟動(dòng)后,每 10 分鐘會(huì)保存一次模型到 training 文件夾(諸如 model.ckpt-xxx 之類的文件),你可以選擇使用其中的 model.ckpt-xxx 模型來直接進(jìn)行預(yù)測,也可以選擇將 model.ckpt-xxx 轉(zhuǎn)化為 .pb 文件之后再進(jìn)行預(yù)測,如果選擇轉(zhuǎn)化,執(zhí)行:

$ python3 export_inference_graph.py \
    --trained_checkpoint_prefix Path/to/model.ckpt-xxx \
    --output_directory Path/to/exported_pb_file_directory

之后,在指定的輸出路徑下(Path/to/exported_pb_file_directory)會(huì)生成一個(gè)文件夾,該文件內(nèi)的 frozen_inference_graph.pb 即是轉(zhuǎn)化成的固化模型文件(固化指的是所有參數(shù)都轉(zhuǎn)化成了常數(shù))。之后就可以使用 evaluate.py 或者 predict.py 進(jìn)行驗(yàn)證或預(yù)測了。

????????如果你使用其它數(shù)據(jù)集,整個(gè)訓(xùn)練過程和上面的步驟一樣,只需要根據(jù)具體的標(biāo)注文件來修改文件data_provider.py 中函數(shù) provide,該函數(shù)返回一個(gè)字典,其中 key 代表訓(xùn)練數(shù)據(jù)集中圖像的路徑,value 代表圖像對應(yīng)的類標(biāo)號;其它參數(shù),比如訓(xùn)練圖像個(gè)數(shù),類別數(shù)目,學(xué)習(xí)率等,在 train.py 中修改。

預(yù)告:下一篇文章將要介紹如何用 TensorFlow 來訓(xùn)練多任務(wù)多標(biāo)簽?zāi)P?,敬請期待?/p>

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容