Tensorflow針對不定尺寸的圖片讀寫tfrecord文件總結(jié)

介紹

最近在讀取tfrecord時,遇到了關(guān)于tensorf shape的問題。

我們需要知道,大多數(shù)情況下圖片進(jìn)行encode編碼保存在tfrecord時 是一個一維張量,shape為(1,)。 而在輸入神經(jīng)網(wǎng)絡(luò)之前,我們必須要將這個圖片張量reshape成一個合乎網(wǎng)絡(luò)結(jié)構(gòu)需求的三維張量。
在針對這樣的需求時,我們會發(fā)現(xiàn),大部分同學(xué)會選擇在生成tfrecord前就定義好網(wǎng)絡(luò)的輸入shape,例如[224,224,3], 然后將所有的圖片先reshape成這個大小,接著存儲在tfrecord中。
這種方式的優(yōu)點在于提前完成的reshape,避免了后續(xù)很多的shape uncompatible 的問題,以及后續(xù)訓(xùn)練中不用再對圖片進(jìn)行reshape,加快了訓(xùn)練速度。
缺點在于,限制了網(wǎng)絡(luò)輸入尺寸的定義。每修改一次神經(jīng)網(wǎng)絡(luò)的輸入shape。

當(dāng)我們需要從存儲著不定尺寸圖片的tfrecord讀取數(shù)據(jù)時, 我們是無法直接將圖片reshape成指定的網(wǎng)絡(luò)結(jié)構(gòu)輸入尺寸的。例如圖片大小 [667,1085,3]。顯然,我們無法直接將其reshape成 [224,224,3]的。那么我們該如何處理呢?

按照思路,我們應(yīng)該先將圖片的一維tensor 轉(zhuǎn)換成三維tensor, 然后再利用 tf.image庫中不同的reshape 操作,將三維圖片tensor轉(zhuǎn)換為需要的 tensor大小。

按照這種思路,在這里,我總結(jié)了兩種讀寫tfrecord的方式,并對這兩種方式的不同點,尤其是容易導(dǎo)致bug的地方進(jìn)行了整理。

第一種: 利用slim.dataset.Dataset讀寫tfrecord文件,這種方式常見于利用slim庫進(jìn)行目標(biāo)檢測等網(wǎng)絡(luò)的實現(xiàn)過程中。
第二種:tf.parse_single_example 是更為常見的一種方式

利用slim.dataset.Dataset讀寫tfrecord文件

利用這個這個接口讀寫tfrecord非常的方便。它的神奇之處在于,
它不需要圖片寬高的信息,只需要其二進(jìn)制string tensor。 這個接口會自動返回一個三維圖片tensor。 在此基礎(chǔ)上,我們可以很方便的對其進(jìn)行reshape,然后輸入神經(jīng)網(wǎng)絡(luò)。
具體步驟如下:
在生成tfrecord文件時,我們需要先定義 tf_example的寫入格式,然后在將圖片文件依據(jù)這個寫入格式,生成tfrecord文件

  • 定義 tf_example的寫入特征
def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def create_tf_example(image_path, label, resize_size=None):
    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)

    # 對于可能存在RGBA 4通道的圖片進(jìn)行處理
    image,is_process = process_image_channels(image)

    # 如有必要,那么就在生成tfrecord時即進(jìn)行resize
    width, height = image.size
    if resize_size is not None:
        if width > height:
            width = int(width * resize_size / height)
            height = resize_size
        else:
            width = resize_size
            height = int(height * resize_size / width)
        image = image.resize((width, height), Image.ANTIALIAS)
    # update encode_jpg
    if is_process or resize_size is not None:
        bytes_io = io.BytesIO()
        image.save(bytes_io, format='JPEG')
        encoded_jpg = bytes_io.getvalue()

    tf_example = tf.train.Example(
        features=tf.train.Features(feature={
            'image/encoded': bytes_feature(encoded_jpg),
            'image/format': bytes_feature('jpg'.encode()),
            'image/class/label': int64_feature(label),
            'image/height': int64_feature(height),
            'image/width': int64_feature(width)}))
    return tf_example
  • 生成完整的tfrecord文件
    在定義完對應(yīng)的tf_example 方式后,我們可以遍歷圖片文件,生成完整的tfrecord文件了。
def generate_tfrecord(annotation_dict, output_path, resize_size=None):
    num_valid_tf_example = 0
    writer = tf.python_io.TFRecordWriter(output_path)
    for image_path, label in annotation_dict.items():
        if not tf.gfile.GFile(image_path):
            print('%s does not exist.' % image_path)
            continue
        tf_example = create_tf_example(image_path, label, resize_size)
        if tf_example:
            writer.write(tf_example.SerializeToString())
            num_valid_tf_example += 1

            if num_valid_tf_example % 100 == 0:
                print('Create %d TF_Example.' % num_valid_tf_example)
    writer.close()
    print('Total create TF_Example: %d' % num_valid_tf_example)

對應(yīng)著,在讀取tfrecord時,slim提供了 slim.dataset.Dataset 的API接口,非常方便對讀入的tfrecord數(shù)據(jù)進(jìn)行操作。

def get_record_dataset(record_path,
                       reader=None, 
                       num_samples=50000, 
                       num_classes=32):
    """Get a tensorflow record file.
    
    Args:
        
    """
    if not reader:
        reader = tf.TFRecordReader
        
    keys_to_features = {
        'image/encoded': 
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': 
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/class/label': 
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                               dtype=tf.int64))}
        
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(image_key='image/encoded',
                                              format_key='image/format'),
        'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)
    
    labels_to_names = None
    items_to_descriptions = {
        'image': 'An image with shape image_shape.',
        'label': 'A single integer.'}
    return slim.dataset.Dataset(
        data_sources=record_path,
        reader=reader,
        decoder=decoder,
        num_samples=num_samples,
        num_classes=num_classes,
        items_to_descriptions=items_to_descriptions,
        labels_to_names=labels_to_names)

在返回了slim.dataset.Dataset這個slim支持的data封裝后, 我們可直接對返回的圖片數(shù)據(jù)進(jìn)行reshape,保證這個圖片張量的shape與網(wǎng)絡(luò)結(jié)構(gòu)的輸入層shape一致。

   dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples, 
                                 num_classes=FLAGS.num_classes)
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    image, label = data_provider.get(['image', 'label'])
    
    # 輸出當(dāng)前tensor的靜態(tài)shape 和動態(tài)shape,與另一種讀取方式進(jìn)行對比
    print("----------tf.shape(image): ",tf.shape(image))
    print("----------image.get_shape(): ",image.get_shape())
    image = _fixed_sides_resize(image, output_height=368, output_width=368)
        
    inputs, labels = tf.train.batch([image, label],
                                    batch_size=FLAGS.batch_size,
                                    #capacity=5*FLAGS.batch_size,
                                    allow_smaller_final_batch=True)

其中,對三維圖片張量進(jìn)行reshape的代碼如下

def _fixed_sides_resize(image, output_height, output_width):
    """Resize images by fixed sides.
    
    Args:
        image: A 3-D image `Tensor`.
        output_height: The height of the image after preprocessing.
        output_width: The width of the image after preprocessing.

    Returns:
        resized_image: A 3-D tensor containing the resized image.
    """
    output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
    output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)

    image = tf.expand_dims(image, 0)
    resized_image = tf.image.resize_nearest_neighbor(
        image, [output_height, output_width], align_corners=False)
    resized_image = tf.squeeze(resized_image)
    resized_image.set_shape([None, None, 3])
    return resized_image

完成了這幾步之后,我們就可以利用image 和 label 進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練了。

利用tf.parse_single_example 讀寫tfrecord文件

這種方式我們需要自己手動將一維的圖片tensor,先還原成三維圖片tensor。 因為每一張圖片的shape不相同。那么我們需要將圖片的shape也存入tfrecord文件中。當(dāng)我們從tfrecord文件中讀取時,我們先利用tf.reshape將一維圖片張量還原成三維圖片張量,再reshape規(guī)定的網(wǎng)絡(luò)輸入尺寸。

  • 照例,此處的重點在于tf_example的構(gòu)建。在這一部分,我將圖片的shape作為一個feature,也存入了tfrecord里面。 那么,在對張量的還原時,我們可以利用這個三維的shape tensor,
def create_tf_example(image_path, label, resize_size=None):
    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    # 對于RGBA 4通道的圖片進(jìn)行處理
    image,is_process = process_image_channels(image)

    # Resize
    width, height = image.size
    if resize_size is not None:
        if width > height:
            width = int(width * resize_size / height)
            height = resize_size
        else:
            width = resize_size
            height = int(height * resize_size / width)
        image = image.resize((width, height), Image.ANTIALIAS)
    
    img_array = np.asarray(image)
    shape = img_array.shape
    byte_image = image.tobytes()
    
    tf_example = tf.train.Example(
        features=tf.train.Features(feature={
            'image': bytes_feature(byte_image),
            'label': int64_feature(label),
            'img_shape': int64_list_feature(shape)}))
    return tf_example
  • 在完成這個后,我們?nèi)耘f可以使用上述提及的generate_tfrecord 函數(shù)來生成對應(yīng)的tfrecord

  • 那么,對應(yīng)這種方式生成的tfrecord文件,我們該如何讀取呢?
    在這里,我給出對應(yīng)的parse_example函數(shù)就足以了。

def parse(serialized):
    # Define a dict with the data-names and types we expect to
    # find in the TFRecords file.
    # It is a bit awkward that this needs to be specified again,
    # because it could have been written in the header of the
    # TFRecords file instead.

    features = {
        'image':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'label':
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
                                                                     dtype=tf.int64)),
        'img_shape': 
            tf.FixedLenFeature(shape=(3,), dtype=tf.int64)}

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.parse_single_example(
        serialized=serialized, features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.decode_raw(image_raw, tf.uint8)
    # The type is now uint8 but we need it to be float.
    image = tf.cast(image, tf.float32)
    
    shape = parsed_example['img_shape']
    
    image = tf.reshape(image,shape=shape)
    
    if not (shape[0] == shape[1] == default_img_size):
        image = _fixed_sides_resize(image,default_img_size,default_img_size)
    
    image.set_shape([default_img_size,default_img_size,3])
    label = parsed_example['label']
    # The image and label are now correct TensorFlow types.
    return image, label

在這里,讀寫tfrecord的重要流程就已經(jīng)展現(xiàn)好了。

對比

這兩種方式有一個比較重要的區(qū)別,那就是制作tfrecord時存儲的圖片信息不同。
使用slim api時 我們制作tfrecord 時,相關(guān)代碼為

    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()

當(dāng)我們使用第二種方式時,制作tfrecord時存儲的圖片信息的相關(guān)代碼如下所示。

image = Image.open(img_dir)
byte_image = image.tobytes()

第一種方式保存的圖片信息,其字節(jié)數(shù)不等于圖片的height, width, channel的乘積。 所以不能用 第二種的方式去讀取這種方式存儲的tfrecord。 會出現(xiàn) reshape時 維度不對的錯誤。 當(dāng)然,使用slim.dataset.Dataset 則不需要考慮這個問題了。 網(wǎng)絡(luò)上使用slim.dataset.Dataset 來加載tfrecord的方式,都是使用第一種方式存儲的tfrecord數(shù)據(jù)。

第二種方式,其存儲的圖片字節(jié)大小等于圖片的height, width, channel的乘積。所以它可以直接用tf.reshape直接將原圖矩陣還原回來,然后再進(jìn)行下一步的reshape操作。

總結(jié)

之所以寫這篇文章,是因為網(wǎng)絡(luò)上針對不定尺寸圖片tfrecord讀取的解決方案不是很完善。
例如 https://stackoverflow.com/questions/40258943/using-height-width-information-stored-in-a-tfrecords-file-to-set-shape-of-a-ten
將height, width,channel 分別存入tfrecord,然后按照提問者描述這樣是不成功的。
再例如https://stackoverflow.com/questions/35028173/how-to-read-images-with-different-size-in-a-tfrecord-file 提供的解決方案

image_rows = tf.cast(features['rows'], tf.int32)
image_cols = tf.cast(features['cols'], tf.int32)
image_data = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.reshape(image_data, tf.pack([image_rows, image_cols, 3]))

這種方式在tf.reshape階段會報錯,因為我們無法將 兩個tensor和一個int數(shù)值組合起來。最完善的方式是直接將shape作為一個整體存入tfrecord中,最終讀取出來就是一個張量了。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容