[0.2] Tensorflow踩坑記之頭疼的tf.data

今天嘗試總結(jié)一下 tf.data 這個(gè)API的一些用法吧。之所以會(huì)用到這個(gè)API,是因?yàn)樾枰幚淼臄?shù)據(jù)量很大,而且數(shù)據(jù)均是分布式的存儲(chǔ)在多臺(tái)服務(wù)器上,所以沒有辦法采用傳統(tǒng)的喂數(shù)據(jù)方式,而是運(yùn)用了 tf.data 對(duì)數(shù)據(jù)進(jìn)行了相應(yīng)的預(yù)處理,并且最近正趕上總結(jié)需要,嘗試寫一下關(guān)于 tf.data 的一些用法,有錯(cuò)誤的地方一定告訴我哈。

Tensorflow的數(shù)據(jù)讀取

先來看一下Tensorflow的數(shù)據(jù)讀取機(jī)制吧

這一篇文章對(duì)于 tensorflow的數(shù)據(jù)讀取機(jī)制 講解得很不錯(cuò),大噶可以先看一下,有一個(gè)了解。

Dataset API是怎么用的呢

雖然上面的資料關(guān)于 tf.data 講解得都很好,但是我沒有找到一個(gè)很完整滴運(yùn)用 tf.data.TextLineDataset()tf.data.TFRecordDataset() 的例子,所以才想嘗試寫一寫這篇總結(jié)。

MNIST的經(jīng)典例子

本篇博客結(jié)合 mnist 的經(jīng)典例子,針對(duì)不同的源數(shù)據(jù):csv數(shù)據(jù)和tfrecord數(shù)據(jù),分別運(yùn)用 tf.data.TextLineDataset()tf.data.TFRecordDataset() 創(chuàng)建不同的 Dataset 并運(yùn)用四種不同的 Iterator ,分別是 單次,可初始化,可重新初始化,以及可饋送迭代器 的方式實(shí)現(xiàn)對(duì)源數(shù)據(jù)的預(yù)處理工作。

我將相關(guān)的資料放在了瀾子的Github 上,歡迎互粉哇(星星眼)。其中包括了所需的 后綴名為csv和tfrecords的源數(shù)據(jù) (data的文件夾),以及在 jupyter notebook實(shí)現(xiàn)的具體代碼 (tf_dataset_learn.ipynb)。

如果有需要的同學(xué)可以直接
git clone https://github.com/lanhongvp/tensorflow_dataset_learn.git
然后用 jupyter 跑一跑看看輸出,這樣可以有一個(gè)比較直觀的認(rèn)識(shí)。關(guān)于 Git和Github 的使用,大噶可以看我VSCODE_GIT這一篇博客啦。接下來,針對(duì)MNIST例子做一個(gè)簡(jiǎn)單的說明吧。

tf.data.TFRecordDataset() & make_one_shot_iterator()

tf.data.TFRecordDataset() 輸入?yún)?shù)直接是后綴名為tfrecords的文件路徑,正因如此,即可解決數(shù)據(jù)量過大,導(dǎo)致無法單機(jī)訓(xùn)練的問題。本篇博客中,文件路徑即為/Users/honglan/Desktop/train_output.tfrecords,此處是我自己電腦上的路徑,大家可以 根據(jù)自己的需要修改為對(duì)應(yīng)的文件路徑。
make_one_shot_iterator() 即為單次迭代器,是最簡(jiǎn)單的迭代器形式,僅支持對(duì)數(shù)據(jù)集進(jìn)行一次迭代,不需要顯式初始化。
配合 MNIST數(shù)據(jù)集以及tf.data.TFRecordDataset(),實(shí)現(xiàn)代碼如下。

# Validate tf.data.TFRecordDataset() using make_one_shot_iterator()
import tensorflow as tf
import numpy as np

num_epochs = 2
num_class = 10
sess = tf.Session()

# Use `tf.parse_single_example()` to extract data from a `tf.Example`
# protocol buffer, and perform any additional per-record preprocessing.
def parser(record):
    keys_to_features = {
        "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
        "pixels": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    # Parse the string into an array of pixels corresponding to the image
    images = tf.decode_raw(parsed["image_raw"],tf.uint8)
    images = tf.reshape(images,[28,28,1])
    labels = tf.cast(parsed['label'], tf.int32)
    labels = tf.one_hot(labels,num_class)
    pixels = tf.cast(parsed['pixels'], tf.int32)
    print("IMAGES",images)
    print("LABELS",labels)
    
    return {"image_raw": images}, labels


filenames = ["/Users/honglan/Desktop/train_output.tfrecords"] 
# replace the filenames with your own path
dataset = tf.data.TFRecordDataset(filenames)
print("DATASET",dataset)

# Use `Dataset.map()` to build a pair of a feature dictionary and a label
# tensor for each example.
dataset = dataset.map(parser)
print("DATASET_1",dataset)
dataset = dataset.shuffle(buffer_size=10000)
print("DATASET_2",dataset)
dataset = dataset.batch(32)
print("DATASET_3",dataset)
dataset = dataset.repeat(num_epochs)
print("DATASET_4",dataset)
iterator = dataset.make_one_shot_iterator()

# `features` is a dictionary in which each value is a batch of values for
# that feature; `labels` is a batch of labels.
features, labels = iterator.get_next()

print("FEATURES",features)
print("LABELS",labels)
print("SESS_RUN_LABELS \n",sess.run(labels))

tf.data.TFRecordDataset() & Initializable iterator

make_initializable_iterator() 為可初始化迭代器,運(yùn)用此迭代器首先需要先運(yùn)行顯式 iterator.initializer 操作,然后才能使用。并且,可運(yùn)用 可初始化迭代器實(shí)現(xiàn)訓(xùn)練集和驗(yàn)證集的切換。
配合 MNIST數(shù)據(jù)集 實(shí)現(xiàn)代碼如下。

# Validate tf.data.TFRecordDataset() using make_initializable_iterator()
# In order to switch between train and validation data
num_epochs = 2
num_class = 10

def parser(record):
    keys_to_features = {
        "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
        "pixels": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)
    
    # Parse the string into an array of pixels corresponding to the image
    images = tf.decode_raw(parsed["image_raw"],tf.uint8)
    images = tf.reshape(images,[28,28,1])
    labels = tf.cast(parsed['label'], tf.int32)
    labels = tf.one_hot(labels,10)
    pixels = tf.cast(parsed['pixels'], tf.int32)
    print("IMAGES",images)
    print("LABELS",labels)
    
    return {"image_raw": images}, labels


filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parser) # Parse the record into tensors
# print("DATASET",dataset)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
print("DATASET",dataset)
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
print("ITERATOR",iterator)
print("FEATURES",features)
print("LABELS",labels)


# Initialize `iterator` with training data.
training_filenames = ["/Users/honglan/Desktop/train_output.tfrecords"] 
# replace the filenames with your own path
sess.run(iterator.initializer,feed_dict={filenames: training_filenames})
print("TRAIN\n",sess.run(labels))
# print(sess.run(features))

# Initialize `iterator` with validation data.
validation_filenames = ["/Users/honglan/Desktop/val_output.tfrecords"] 
# replace the filenames with your own path
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
print("VAL\n",sess.run(labels))

tf.data.TextLineDataset() & Reinitializable iterator

tf.data.TextLineDataset(),輸入?yún)?shù)可以是后綴名為csv或者是txt的源數(shù)據(jù)的文件路徑。
此處用的迭代器是 Reinitializable iterator,即為可重新初始化迭代器。官方定義如下。配合 MNIST數(shù)據(jù)集 實(shí)現(xiàn)代碼見第二部分。

可重新初始化迭代器可以通過多個(gè)不同的 Dataset 對(duì)象進(jìn)行初始化。例如,您可能有一個(gè)訓(xùn)練輸入管道,它會(huì)對(duì)輸入圖片進(jìn)行隨機(jī)擾動(dòng)來改善泛化;還有一個(gè)驗(yàn)證輸入管道,它會(huì)評(píng)估對(duì)未修改數(shù)據(jù)的預(yù)測(cè)。這些管道通常會(huì)使用不同的 Dataset 對(duì)象,這些對(duì)象具有相同的結(jié)構(gòu)(即每個(gè)組件具有相同類型和兼容形狀)。

# validate tf.data.TextLineDataset() using Reinitializable iterator
# In order to switch between train and validation data

def decode_line(line):
    # Decode the line to tensor
    record_defaults = [[1.0] for col in range(785)]
    items = tf.decode_csv(line, record_defaults)
    features = items[1:785]
    label = items[0]

    features = tf.cast(features, tf.float32)
    features = tf.reshape(features,[28,28,1])
    label = tf.cast(label, tf.int64)
    label = tf.one_hot(label,num_class)
    return features,label


def create_dataset(filename, batch_size=32, is_shuffle=False, n_repeats=0):
    """create dataset for train and validation dataset"""
    dataset = tf.data.TextLineDataset(filename).skip(1)
    if n_repeats > 0:
        dataset = dataset.repeat(n_repeats)         # for train
    # dataset = dataset.map(decode_line).map(normalize)  
    dataset = dataset.map(decode_line)    
    # decode and normalize
    if is_shuffle:
        dataset = dataset.shuffle(10000)            # shuffle
    dataset = dataset.batch(batch_size)
    return dataset


training_filenames = ["/Users/honglan/Desktop/train.csv"] 
# replace the filenames with your own path
validation_filenames = ["/Users/honglan/Desktop/val.csv"] 
# replace the filenames with your own path

# Create different datasets
training_dataset = create_dataset(training_filenames, batch_size=32, \
                                  is_shuffle=True, n_repeats=num_epochs) # train_filename
validation_dataset = create_dataset(validation_filenames, batch_size=32, \
                                  is_shuffle=True, n_repeats=num_epochs) # val_filename

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                           training_dataset.output_shapes)
features, labels = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Using reinitializable iterator to alternate between training and validation.
sess.run(training_init_op)
print("TRAIN\n",sess.run(labels))
# print(sess.run(features))

# Reinitialize `iterator` with validation data.
sess.run(validation_init_op)
print("VAL\n",sess.run(labels))

tf.data.TextLineDataset() & Feedable iterator.

數(shù)據(jù)集讀取方式同上一部分一樣,運(yùn)用tf.data.TextLineDataset()此處運(yùn)用的迭代器是 可饋送迭代器,其可以與 tf.placeholder 一起使用,通過熟悉的 feed_dict 機(jī)制選擇每次調(diào)用 tf.Session.run 時(shí)所使用的 Iterator。并使用 tf.data.Iterator.from_string_handle定義一個(gè)可讓在兩個(gè)數(shù)據(jù)集之間切換的可饋送迭代器,結(jié)合 MNIST數(shù)據(jù)集 的代碼如下

# validate tf.data.TextLineDataset() using two different iterator
# In order to switch between train and validation data

def decode_line(line):
    # Decode the line to tensor
    record_defaults = [[1.0] for col in range(785)]
    items = tf.decode_csv(line, record_defaults)
    features = items[1:785]
    label = items[0]

    features = tf.cast(features, tf.float32)
    features = tf.reshape(features,[28,28])
    label = tf.cast(label, tf.int64)
    label = tf.one_hot(label,num_class)
    return features,label


def create_dataset(filename, batch_size=32, is_shuffle=False, n_repeats=0):
    """create dataset for train and validation dataset"""
    dataset = tf.data.TextLineDataset(filename).skip(1)
    if n_repeats > 0:
        dataset = dataset.repeat(n_repeats)         # for train
    # dataset = dataset.map(decode_line).map(normalize)  
    dataset = dataset.map(decode_line)    
    # decode and normalize
    if is_shuffle:
        dataset = dataset.shuffle(10000)            # shuffle
    dataset = dataset.batch(batch_size)
    return dataset


training_filenames = ["/Users/honglan/Desktop/train.csv"] 
# replace the filenames with your own path
validation_filenames = ["/Users/honglan/Desktop/val.csv"] 
# replace the filenames with your own path

# Create different datasets
training_dataset = create_dataset(training_filenames, batch_size=32, \
                                  is_shuffle=True, n_repeats=num_epochs) # train_filename
validation_dataset = create_dataset(validation_filenames, batch_size=32, \
                                  is_shuffle=True, n_repeats=num_epochs) # val_filename

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
features, labels = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Using different handle to alternate between training and validation.
print("TRAIN\n",sess.run(labels, feed_dict={handle: training_handle}))
# print(sess.run(features))

# Initialize `iterator` with validation data.
sess.run(validation_iterator.initializer)
print("VAL\n",sess.run(labels, feed_dict={handle: validation_handle}))

小結(jié)

  • 運(yùn)用tfrecords處理數(shù)據(jù)的速度明顯加快
  • 可以根據(jù)自身需要選擇不同的iterator方式對(duì)源數(shù)據(jù)進(jìn)行預(yù)處理
  • 單機(jī)訓(xùn)練時(shí)也可以采用 tf.data中API的相應(yīng)處理方式
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容