介紹
最近在讀取tfrecord時,遇到了關(guān)于tensorf shape的問題。
我們需要知道,大多數(shù)情況下圖片進(jìn)行encode編碼保存在tfrecord時 是一個一維張量,shape為(1,)。 而在輸入神經(jīng)網(wǎng)絡(luò)之前,我們必須要將這個圖片張量reshape成一個合乎網(wǎng)絡(luò)結(jié)構(gòu)需求的三維張量。
在針對這樣的需求時,我們會發(fā)現(xiàn),大部分同學(xué)會選擇在生成tfrecord前就定義好網(wǎng)絡(luò)的輸入shape,例如[224,224,3], 然后將所有的圖片先reshape成這個大小,接著存儲在tfrecord中。
這種方式的優(yōu)點在于提前完成的reshape,避免了后續(xù)很多的shape uncompatible 的問題,以及后續(xù)訓(xùn)練中不用再對圖片進(jìn)行reshape,加快了訓(xùn)練速度。
缺點在于,限制了網(wǎng)絡(luò)輸入尺寸的定義。每修改一次神經(jīng)網(wǎng)絡(luò)的輸入shape。
當(dāng)我們需要從存儲著不定尺寸圖片的tfrecord讀取數(shù)據(jù)時, 我們是無法直接將圖片reshape成指定的網(wǎng)絡(luò)結(jié)構(gòu)輸入尺寸的。例如圖片大小 [667,1085,3]。顯然,我們無法直接將其reshape成 [224,224,3]的。那么我們該如何處理呢?
按照思路,我們應(yīng)該先將圖片的一維tensor 轉(zhuǎn)換成三維tensor, 然后再利用 tf.image庫中不同的reshape 操作,將三維圖片tensor轉(zhuǎn)換為需要的 tensor大小。
按照這種思路,在這里,我總結(jié)了兩種讀寫tfrecord的方式,并對這兩種方式的不同點,尤其是容易導(dǎo)致bug的地方進(jìn)行了整理。
第一種: 利用slim.dataset.Dataset讀寫tfrecord文件,這種方式常見于利用slim庫進(jìn)行目標(biāo)檢測等網(wǎng)絡(luò)的實現(xiàn)過程中。
第二種:tf.parse_single_example 是更為常見的一種方式
利用slim.dataset.Dataset讀寫tfrecord文件
利用這個這個接口讀寫tfrecord非常的方便。它的神奇之處在于,
它不需要圖片寬高的信息,只需要其二進(jìn)制string tensor。 這個接口會自動返回一個三維圖片tensor。 在此基礎(chǔ)上,我們可以很方便的對其進(jìn)行reshape,然后輸入神經(jīng)網(wǎng)絡(luò)。
具體步驟如下:
在生成tfrecord文件時,我們需要先定義 tf_example的寫入格式,然后在將圖片文件依據(jù)這個寫入格式,生成tfrecord文件
- 定義 tf_example的寫入特征
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def create_tf_example(image_path, label, resize_size=None):
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
# 對于可能存在RGBA 4通道的圖片進(jìn)行處理
image,is_process = process_image_channels(image)
# 如有必要,那么就在生成tfrecord時即進(jìn)行resize
width, height = image.size
if resize_size is not None:
if width > height:
width = int(width * resize_size / height)
height = resize_size
else:
width = resize_size
height = int(height * resize_size / width)
image = image.resize((width, height), Image.ANTIALIAS)
# update encode_jpg
if is_process or resize_size is not None:
bytes_io = io.BytesIO()
image.save(bytes_io, format='JPEG')
encoded_jpg = bytes_io.getvalue()
tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature('jpg'.encode()),
'image/class/label': int64_feature(label),
'image/height': int64_feature(height),
'image/width': int64_feature(width)}))
return tf_example
- 生成完整的tfrecord文件
在定義完對應(yīng)的tf_example 方式后,我們可以遍歷圖片文件,生成完整的tfrecord文件了。
def generate_tfrecord(annotation_dict, output_path, resize_size=None):
num_valid_tf_example = 0
writer = tf.python_io.TFRecordWriter(output_path)
for image_path, label in annotation_dict.items():
if not tf.gfile.GFile(image_path):
print('%s does not exist.' % image_path)
continue
tf_example = create_tf_example(image_path, label, resize_size)
if tf_example:
writer.write(tf_example.SerializeToString())
num_valid_tf_example += 1
if num_valid_tf_example % 100 == 0:
print('Create %d TF_Example.' % num_valid_tf_example)
writer.close()
print('Total create TF_Example: %d' % num_valid_tf_example)
對應(yīng)著,在讀取tfrecord時,slim提供了 slim.dataset.Dataset 的API接口,非常方便對讀入的tfrecord數(shù)據(jù)進(jìn)行操作。
def get_record_dataset(record_path,
reader=None,
num_samples=50000,
num_classes=32):
"""Get a tensorflow record file.
Args:
"""
if not reader:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/class/label':
tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
dtype=tf.int64))}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(image_key='image/encoded',
format_key='image/format'),
'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
items_to_descriptions = {
'image': 'An image with shape image_shape.',
'label': 'A single integer.'}
return slim.dataset.Dataset(
data_sources=record_path,
reader=reader,
decoder=decoder,
num_samples=num_samples,
num_classes=num_classes,
items_to_descriptions=items_to_descriptions,
labels_to_names=labels_to_names)
在返回了slim.dataset.Dataset這個slim支持的data封裝后, 我們可直接對返回的圖片數(shù)據(jù)進(jìn)行reshape,保證這個圖片張量的shape與網(wǎng)絡(luò)結(jié)構(gòu)的輸入層shape一致。
dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples,
num_classes=FLAGS.num_classes)
data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
image, label = data_provider.get(['image', 'label'])
# 輸出當(dāng)前tensor的靜態(tài)shape 和動態(tài)shape,與另一種讀取方式進(jìn)行對比
print("----------tf.shape(image): ",tf.shape(image))
print("----------image.get_shape(): ",image.get_shape())
image = _fixed_sides_resize(image, output_height=368, output_width=368)
inputs, labels = tf.train.batch([image, label],
batch_size=FLAGS.batch_size,
#capacity=5*FLAGS.batch_size,
allow_smaller_final_batch=True)
其中,對三維圖片張量進(jìn)行reshape的代碼如下
def _fixed_sides_resize(image, output_height, output_width):
"""Resize images by fixed sides.
Args:
image: A 3-D image `Tensor`.
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
Returns:
resized_image: A 3-D tensor containing the resized image.
"""
output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)
image = tf.expand_dims(image, 0)
resized_image = tf.image.resize_nearest_neighbor(
image, [output_height, output_width], align_corners=False)
resized_image = tf.squeeze(resized_image)
resized_image.set_shape([None, None, 3])
return resized_image
完成了這幾步之后,我們就可以利用image 和 label 進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練了。
利用tf.parse_single_example 讀寫tfrecord文件
這種方式我們需要自己手動將一維的圖片tensor,先還原成三維圖片tensor。 因為每一張圖片的shape不相同。那么我們需要將圖片的shape也存入tfrecord文件中。當(dāng)我們從tfrecord文件中讀取時,我們先利用tf.reshape將一維圖片張量還原成三維圖片張量,再reshape規(guī)定的網(wǎng)絡(luò)輸入尺寸。
- 照例,此處的重點在于tf_example的構(gòu)建。在這一部分,我將圖片的shape作為一個feature,也存入了tfrecord里面。 那么,在對張量的還原時,我們可以利用這個三維的shape tensor,
def create_tf_example(image_path, label, resize_size=None):
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
# 對于RGBA 4通道的圖片進(jìn)行處理
image,is_process = process_image_channels(image)
# Resize
width, height = image.size
if resize_size is not None:
if width > height:
width = int(width * resize_size / height)
height = resize_size
else:
width = resize_size
height = int(height * resize_size / width)
image = image.resize((width, height), Image.ANTIALIAS)
img_array = np.asarray(image)
shape = img_array.shape
byte_image = image.tobytes()
tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image': bytes_feature(byte_image),
'label': int64_feature(label),
'img_shape': int64_list_feature(shape)}))
return tf_example
在完成這個后,我們?nèi)耘f可以使用上述提及的generate_tfrecord 函數(shù)來生成對應(yīng)的tfrecord
那么,對應(yīng)這種方式生成的tfrecord文件,我們該如何讀取呢?
在這里,我給出對應(yīng)的parse_example函數(shù)就足以了。
def parse(serialized):
# Define a dict with the data-names and types we expect to
# find in the TFRecords file.
# It is a bit awkward that this needs to be specified again,
# because it could have been written in the header of the
# TFRecords file instead.
features = {
'image':
tf.FixedLenFeature((), tf.string, default_value=''),
'label':
tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
dtype=tf.int64)),
'img_shape':
tf.FixedLenFeature(shape=(3,), dtype=tf.int64)}
# Parse the serialized data so we get a dict with our data.
parsed_example = tf.parse_single_example(
serialized=serialized, features=features)
# Get the image as raw bytes.
image_raw = parsed_example['image']
# Decode the raw bytes so it becomes a tensor with type.
image = tf.decode_raw(image_raw, tf.uint8)
# The type is now uint8 but we need it to be float.
image = tf.cast(image, tf.float32)
shape = parsed_example['img_shape']
image = tf.reshape(image,shape=shape)
if not (shape[0] == shape[1] == default_img_size):
image = _fixed_sides_resize(image,default_img_size,default_img_size)
image.set_shape([default_img_size,default_img_size,3])
label = parsed_example['label']
# The image and label are now correct TensorFlow types.
return image, label
在這里,讀寫tfrecord的重要流程就已經(jīng)展現(xiàn)好了。
對比
這兩種方式有一個比較重要的區(qū)別,那就是制作tfrecord時存儲的圖片信息不同。
使用slim api時 我們制作tfrecord 時,相關(guān)代碼為
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
當(dāng)我們使用第二種方式時,制作tfrecord時存儲的圖片信息的相關(guān)代碼如下所示。
image = Image.open(img_dir)
byte_image = image.tobytes()
第一種方式保存的圖片信息,其字節(jié)數(shù)不等于圖片的height, width, channel的乘積。 所以不能用 第二種的方式去讀取這種方式存儲的tfrecord。 會出現(xiàn) reshape時 維度不對的錯誤。 當(dāng)然,使用slim.dataset.Dataset 則不需要考慮這個問題了。 網(wǎng)絡(luò)上使用slim.dataset.Dataset 來加載tfrecord的方式,都是使用第一種方式存儲的tfrecord數(shù)據(jù)。
第二種方式,其存儲的圖片字節(jié)大小等于圖片的height, width, channel的乘積。所以它可以直接用tf.reshape直接將原圖矩陣還原回來,然后再進(jìn)行下一步的reshape操作。
總結(jié)
之所以寫這篇文章,是因為網(wǎng)絡(luò)上針對不定尺寸圖片tfrecord讀取的解決方案不是很完善。
例如 https://stackoverflow.com/questions/40258943/using-height-width-information-stored-in-a-tfrecords-file-to-set-shape-of-a-ten
將height, width,channel 分別存入tfrecord,然后按照提問者描述這樣是不成功的。
再例如https://stackoverflow.com/questions/35028173/how-to-read-images-with-different-size-in-a-tfrecord-file 提供的解決方案
image_rows = tf.cast(features['rows'], tf.int32)
image_cols = tf.cast(features['cols'], tf.int32)
image_data = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.reshape(image_data, tf.pack([image_rows, image_cols, 3]))
這種方式在tf.reshape階段會報錯,因為我們無法將 兩個tensor和一個int數(shù)值組合起來。最完善的方式是直接將shape作為一個整體存入tfrecord中,最終讀取出來就是一個張量了。