Tensorflow-導(dǎo)入數(shù)據(jù)-翻譯整理

tf.data可以創(chuàng)建復(fù)雜的數(shù)據(jù)輸入流水線,實現(xiàn)圖像、文字的導(dǎo)入。
tf.data產(chǎn)生兩個新的抽象類:

  • tf.data. Dataset,數(shù)據(jù)集,多個元素組成的隊列,每個元素包含一個或多個張量。有兩種方式創(chuàng)建Dataset對象:
    • 從張量創(chuàng)建,比如tf.data.Dataset.from_tensor_slices((dict(train_x), train_y))。
    • 從另外的數(shù)據(jù)集創(chuàng)建,比如dataset.shuffle(1000).repeat().batch(batch_size)
  • tf.data.Iterator,從數(shù)據(jù)集中提取元素,Iterator.get_next()產(chǎn)生下一個元素。Iterator.initializer方法可以使用不同的數(shù)據(jù)集不同的參數(shù)進(jìn)行初始化。

基本原理

構(gòu)建數(shù)據(jù)輸入流:

  • 定義數(shù)據(jù)源,tf.data.Dataset.from_tensors(tensors)tf.data.Dataset.from_tensor_slices(tensors),或者使用tf.data.TFRecordDataset讀取本地的TFRecord格式文件。
  • 變換數(shù)據(jù)集,例如使用Dataset.map(map_func)的方法對每個張量元素執(zhí)行操作。
  • 提取數(shù)據(jù),tf.data.Iterator,使用Iterator.initializer重新初始化數(shù)據(jù),使用Iterator.get_next()獲取下一個tensor張量。
  1. 數(shù)據(jù)集結(jié)構(gòu)

Dataset必須包含多個同類元素elements,每個元素包含一個或多個tensor稱為組件components,每個組件的構(gòu)成:

  • tf.Dtype表示數(shù)據(jù)類
  • tf. TensorShape表示數(shù)據(jù)形狀
  • Dataset.output_types,Dataset.output_shapes用來查看以上兩個屬性
    示例代碼:
import tensorflow as tf

tensor=tf.random_uniform([4, 10])
with tf.Session() as session:
    print(session.run([tensor]))

dataset1 = tf.data.Dataset.from_tensor_slices(tensor)
print(dataset1.output_types)  # ==> "tf.float32"
print(dataset1.output_shapes)  # ==> "(10,)"

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random_uniform([4]),
    tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types)  # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes)  # ==> "((), (100,))"

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes)  # ==> "(10, ((), (100,)))"

tf.random_uniform([4,10])產(chǎn)生的是一個4行10列的0~1隨機(jī)數(shù)字,類似

[array([[0.8139155 , 0.00317001, 0.1536988, ..., 0.625741  ],
       [0.00984228, 0.88505733, 0.44980478, ..., 0.30504322],
       [0.10747015, 0.639518  , 0.6030766 , ..., 0.5297921 ],
       [0.48373353, 0.7960038 , 0.666453,..., 0.7486484 ]],

使用字典標(biāo)記字段名,dataset的dtype和shape可以是個字典:

import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(
   {"a": tf.random_uniform([4]),
    "b": tf.random_uniform([4, 100], maxval=10, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

對數(shù)據(jù)集進(jìn)行處理,Dataset.map(), Dataset.flat_map(), and Dataset.filter(),它們可以對數(shù)據(jù)集的每個元素進(jìn)行處理,元素的結(jié)構(gòu)決定函數(shù)的參數(shù)。

  1. 創(chuàng)建迭代器

數(shù)據(jù)集自帶創(chuàng)建各種迭代器的方法,迭代器分類:

  • one-shot,
  • initializable,
  • reinitializable, and
  • feedable.

one-shot不需要初始化,尤其適合estimator使用。

import tensorflow as tf

dataset = tf.data.Dataset.range(100) #生成0~99共100個元素的數(shù)據(jù)集
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(100):
      v = sess.run(next_element)
      print(v)

#輸出0,1,2,3,4...,99

可初始化的initializable類型的迭代器需要初始化之后才能使用,但可以使用參數(shù)feed_dict,sess.run(iterator.initializer, feed_dict={placeholder: 10})

import tensorflow as tf
sess= tf.Session()

#創(chuàng)建一個站位符,和后面的feed_dict聯(lián)合用來喂數(shù)據(jù)
max_v = tf.placeholder(tf.int64, shape=[]) 

#創(chuàng)建一個數(shù)據(jù)集,從max_v到max_v*2
dataset = tf.data.Dataset.range(max_v, max_v*2)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

sess.run(iterator.initializer, feed_dict={max_v: 10}) #max_v必須和占位符一致
for i in range(3):
  value = sess.run(next_element)
  print('max_v=10',value)

sess.run(iterator.initializer, feed_dict={max_v: 100}) #max_v必須和占位符一致
for i in range(3):
  value = sess.run(next_element)
  print('max_v=100',value)

上面的代碼輸出

max_v=10 10
max_v=10 11
max_v=10 12
max_v=100 100
max_v=100 101
max_v=100 102

可重復(fù)初始化的reinitializable類型迭代器,可以使用多個不同的數(shù)據(jù)集進(jìn)行初始化。比如創(chuàng)建一個訓(xùn)練輸入管道故意把數(shù)據(jù)擾亂,再創(chuàng)建一個評價管道使用正常的數(shù)據(jù)集。不同數(shù)據(jù)集必須有完全相同的結(jié)構(gòu)。

iterator =tf.data.Iterator.from_structure(dtypes,shapes)
next_element = iterator.get_next()
training_init_op =iterator.make_initializer(dataset)

下面是示例代碼,會輸出2遍10個train和5個validation。僅供演示,并沒有什么實際作用。

import tensorflow as tf
sess=tf.Session()

training_dataset = tf.data.Dataset.range(10).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(5)

iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                           training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

for _ in range(2): #兩個周期
  sess.run(training_init_op)
  for _ in range(10):
    v = sess.run(next_element)
    print('train',v)

  sess.run(validation_init_op)
  for _ in range(5):
    v = sess.run(next_element)
    print('validation',v)

可喂食的的迭代器feedable,不直接由dataset生成,但它可以聯(lián)合tf.placeholder一起為tf.Session.run每次呼叫切換不同的迭代器(由dataset生成),類似feed_dict的機(jī)制。

handle = tf.placeholder(tf.string, shape=[]) 
iterator = tf.data.Iterator.from_string_handle(handle, dtypes, shapes)
next_element = iterator.get_next()
...
training_iterator = training_dataset.make_one_shot_iterator()
training_handle = sess.run(training_iterator.string_handle())
...
sess.run(next_element, feed_dict={handle: training_handle})

示例代碼,會輸出3次10個train和5個validation。

import tensorflow as tf
sess=tf.Session()

#定義相同結(jié)構(gòu)的數(shù)據(jù)集
training_dataset = tf.data.Dataset.range(0,10).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(100,200)

#feedable迭代器由placeholder及其結(jié)構(gòu)設(shè)定
handle = tf.placeholder(tf.string, shape=[]) #這個handle名和下面的必須一致
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

#不同類型的迭代器
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

#Iterator.string_handle()返回一個張量,可以用來喂食placeholder
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

for _ in range(3):
  for _ in range(10):
    v=sess.run(next_element, feed_dict={handle: training_handle})
    print('train',v)

  sess.run(validation_iterator.initializer)
  for _ in range(5):
    v=sess.run(next_element, feed_dict={handle: validation_handle})
    print('validation',v)
  1. 提取迭代器的數(shù)據(jù)

Iterator.get_next()方法返回一個或多個張量對象,它們關(guān)聯(lián)到迭代器的下一個元素,當(dāng)這些張量被計算的時候,才會獲取下一個元素的數(shù)據(jù)。Iterator.get_next()方法并不會立即運(yùn)算,而是必須把返回的對象放到表達(dá)式里面,并且把表達(dá)式結(jié)果傳遞到tf.Session.run()中,才會被計算。
到達(dá)最后一個元素的時候再執(zhí)行Iterator.get_next()會出錯,tf.errors.OutOfRangeError,如果需要再使用,必須重新初始化。
示例代碼

import tensorflow as tf
sess=tf.Session()

dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next() #獲得張量對象
result = tf.add(next_element, next_element) #傳入表達(dá)式

#將表達(dá)式結(jié)果傳入run執(zhí)行
sess.run(iterator.initializer)
print(sess.run(result))  # ==> "0"
print(sess.run(result))  # ==> "2"
print(sess.run(result))  # ==> "4"
print(sess.run(result))  # ==> "6"
print(sess.run(result))  # ==> "8"

try:
  sess.run(result)
except tf.errors.OutOfRangeError:
  print("End of dataset")  # ==> "End of dataset"

如果Dataset包含嵌套結(jié)構(gòu),Iterator.get_next()將返回同樣結(jié)構(gòu)的張量。示例代碼:

import tensorflow as tf
sess=tf.Session()

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

iterator = dataset3.make_initializable_iterator()

sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()
ziped= iterator.get_next()
print('ziped',ziped);
print('next1',next1);
print('next2',next2);
print('next3',next3);

以上代碼將輸出,看到iterator.get_next()方法得到結(jié)果的結(jié)構(gòu)和dataset設(shè)置的相同:

ziped (<tf.Tensor 'IteratorGetNext_1:0' shape=(10,) dtype=float32>, (<tf.Tensor 'IteratorGetNext_1:1' shape=() dtype=float32>, <tf.Tensor 'IteratorGetNext_1:2' shape=(100,) dtype=float32>))
next1 Tensor("IteratorGetNext:0", shape=(10,), dtype=float32)
next2 Tensor("IteratorGetNext:1", shape=(), dtype=float32)
next3 Tensor("IteratorGetNext:2", shape=(100,), dtype=float32)

讀取輸入數(shù)據(jù)

  1. 使用Numpy數(shù)組

如果數(shù)據(jù)就在內(nèi)存里,最簡單的創(chuàng)建dataset的方法就是,用Dataset.from_tensor_slices()把它們轉(zhuǎn)為張量。但是這將把數(shù)據(jù)完全嵌入到計算圖graph中,消耗大量內(nèi)存(最多2G)。

#僅供示意,請勿執(zhí)行
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]
assert features.shape[0] == labels.shape[0] #確保數(shù)據(jù)形狀一致
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

利用placeholder占位符和feed_dict可以優(yōu)化內(nèi)存占用:

# 僅供示意,請勿執(zhí)行
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]
assert features.shape[0] == labels.shape[0]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
...對dataset進(jìn)行其他操作...
dataset = ...
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={
                          features_placeholder: features,
                          labels_placeholder: labels})
  1. 使用TFRecord數(shù)據(jù)

TFRecord是簡單的面向記錄record-oriented的二進(jìn)制數(shù)據(jù)格式;tf.data.TFRecordDataset讓我們可以把單個或多個TFRecord文件作為輸入管道的一部分。

#示意代碼。一次性讀取兩個文件
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

結(jié)合tf.placeholder和feed_dict使用,分別喂食訓(xùn)練數(shù)據(jù)和驗證數(shù)據(jù):

#示意代碼,請勿運(yùn)行
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)  # 將記錄數(shù)據(jù)解析為張量.
dataset = dataset.repeat()  # 重復(fù)輸入.
dataset = dataset.batch(32) #合并成批次
iterator = dataset.make_initializable_iterator() 

training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
  1. 使用文本數(shù)據(jù)

使用tf.data.TextLineDataset可以從單個或多個文本文件中逐行的讀取數(shù)據(jù)。

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)

當(dāng)文本文件第一行包含標(biāo)題欄信息的時候,我們使用Dataset.skip()和Dataset.filter()進(jìn)行處理。為了對每個文件進(jìn)行處理,可以使用Dataset.flat_map()為每個文件創(chuàng)建嵌套的數(shù)據(jù)集:

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)

dataset = dataset.flat_map(
    lambda filename: (
        tf.data.TextLineDataset(filename)
        .skip(1) #跳過第一行
        .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#")))) #忽略#注釋行

使用Dataset.map()處理數(shù)據(jù)

Dataset.map(f)處理數(shù)據(jù)中的每一個元素element,其中函數(shù)f(element) 需要返回一個新的元素element。

解析tf.Example緩沖格式信息

輸入管道從TFRecord格式文件提取tf.train.Example,每個tf.train.Example包含了單個或多個特征,然后輸入管道把這些特征轉(zhuǎn)為張量。

#將一個標(biāo)量字符串example_proto轉(zhuǎn)為1個標(biāo)量字符串+1個標(biāo)量整數(shù),表示某圖片的名稱和標(biāo)簽
def _parse_function(example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
解碼圖片數(shù)據(jù)并重新設(shè)置大小
def _parse_function(filename, label):
  image_string = tf.read_file(filename) #讀取圖片
  image_decoded = tf.image.decode_image(image_string) #解碼圖片
  image_resized = tf.image.resize_images(image_decoded, [28, 28]) #統(tǒng)一大小
  return image_resized, label #返回特征和標(biāo)簽

filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]`對應(yīng) `filenames[i]`圖片的標(biāo)簽.
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) #參數(shù)結(jié)構(gòu)與parse函數(shù)一致
dataset = dataset.map(_parse_function)
使用tf.py_func()應(yīng)用任意Python邏輯

盡必要時候,在Dataset.map(f)中使用tf.py_func().

import cv2
# 使用自定義的OpenCV函數(shù)代替 `tf.read_file()`讀取圖片
def _read_py_function(filename, label):
  image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
  return image_decoded, label

# 使用tf方法把圖片調(diào)整到統(tǒng)一尺寸
def _resize_function(image_decoded, label):
  image_decoded.set_shape([None, None, None])
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
    lambda filename, label: tuple(tf.py_func( #返回2元素的元組
        _read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)

將數(shù)據(jù)集元素分批次Batching

簡單分批處理

最簡單的分批方法是使用Dataset.batch()把n個連續(xù)元素組成一個新元素,要求每個舊元素結(jié)構(gòu)必須相同。

以下示例代碼把4個數(shù)字分批處理。

import tensorflow as tf
sess=tf.Session()

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
batched_dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])
print(sess.run(next_element))  # ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
print(sess.run(next_element))  # ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])
帶有填充的分批

對于長度不同的張量可以使用Dataset.padded_batch()填充分批.

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element))  # ==> [[4, 4, 4, 4, 0, 0, 0],
                               #      [5, 5, 5, 5, 5, 0, 0],
                               #      [6, 6, 6, 6, 6, 6, 0],
                               #      [7, 7, 7, 7, 7, 7, 7]]

tf.fill([x,y],n)方法形成x行y列都是n的列表,比如tf.fill([2,3],9):

[9 9 9]
[9 9 9]

tf.cast(x,dtype)是對張量x進(jìn)行數(shù)據(jù)格式轉(zhuǎn)化,比如從tf.float32轉(zhuǎn)為tf.int32,在這里轉(zhuǎn)為整數(shù)才能使用。
以上代碼中dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))得到的是

[]
[1]
[2 2]
[3 3 3]
[4 4 4 4]
[5 5 5 5 5]
[6 6 6 6 6 6]
[7 7 7 7 7 7 7]
[8 8 8 8 8 8 8 8]

然后dataset.padded_batch(4, padded_shapes=[None])把每4個元素分為一個批次,并填充0處理,得到

[[0 0 0]
 [1 0 0]
 [2 2 0]
 [3 3 3]]
[[4 4 4 4 0 0 0]
 [5 5 5 5 5 0 0]
 [6 6 6 6 6 6 0]
 [7 7 7 7 7 7 7]]
[[ 8  8  8  8  8  8  8  8  0  0  0]
 [ 9  9  9  9  9  9  9  9  9  0  0]
 [10 10 10 10 10 10 10 10 10 10  0]
 [11 11 11 11 11 11 11 11 11 11 11]]

訓(xùn)練流程

處理多個周期epochs

Dataset有兩種方法創(chuàng)建周期

  1. 使用Dataset.repeat(n)
    重復(fù)n次,如果為空則無限重復(fù)下去:
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10) #重復(fù)10次
dataset = dataset.batch(32)
  1. 無限循環(huán)并捕獲異常
    為了避免無限重復(fù)到達(dá)結(jié)尾時候產(chǎn)生的異常,可以做如下處理
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# 100個周期
for _ in range(100):
  sess.run(iterator.initializer)
  while True:
    try:
      sess.run(next_element)
    except tf.errors.OutOfRangeError:
      break
  #在這里執(zhí)行其他的代碼
隨機(jī)順序調(diào)整

Dataset.shuffle(n)將元素填充到n個緩沖,然后再隨機(jī)提取處理,示意代碼

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
使用高級API接口

tf.train.MonitoredTrainingSession簡化了分布式運(yùn)算設(shè)置,當(dāng)運(yùn)算完成時候通過tf.errors.OutOfRangeError得知,推薦結(jié)合Dataset.make_one_shot_iterator()迭代器使用:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()

next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)

training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
  while not sess.should_stop():
    sess.run(training_op)

結(jié)合Estimator一起使用的示例代碼:

def dataset_input_fn():
  filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
  dataset = tf.data.TFRecordDataset(filenames)

  def parser(record):
    keys_to_features = {
        "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
        "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    image = tf.image.decode_jpeg(parsed["image_data"])
    image = tf.reshape(image, [299, 299, 1])
    label = tf.cast(parsed["label"], tf.int32)

    return {"image_data": image, "date_time": parsed["date_time"]}, label

  dataset = dataset.map(parser)
  dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.batch(32)
  dataset = dataset.repeat(num_epochs)
  iterator = dataset.make_one_shot_iterator()

  features, labels = iterator.get_next()
  return features, labels

小結(jié)

  • 導(dǎo)入數(shù)據(jù)集的基本機(jī)制
    • 數(shù)據(jù)集Dataset
      • 同類元素elements
        • 張量組件components
          • dtype
          • shape
    • 迭代器iterator
      • one-shot
      • initializable
      • reinitializable
      • feedable
    • 從迭代器獲取元素get_next()
  • 讀取數(shù)據(jù)
    • Numpy arrays
    • TFRecord
    • Text
  • 使用Dataset.map處理數(shù)據(jù)
    • tf.Example協(xié)議緩沖
    • 解碼圖像
    • 自定義tf.py_func()
  • 數(shù)據(jù)元素分批
    • Dataset.batch()
    • Dataset. padded_batch()
  • 訓(xùn)練流程
    • 多個周期epochs,Dataset.repeat()
    • 隨機(jī)洗牌Dataset.shuffle()
    • 使用高級接口tf.train.MonitoredTrainingSession

探索人工智能的新邊界

如果您發(fā)現(xiàn)文章錯誤,請不吝留言指正;
如果您覺得有用,請點喜歡;
如果您覺得很有用,感謝轉(zhuǎn)發(fā)~


END

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容