TensorFlow學習8:制作數(shù)據(jù)集

將所有圖片生成一個二進制數(shù)據(jù)集文件的過程

示例代碼

#可以將圖片和標簽制作成二進制文件,讀取二進制文件進行數(shù)據(jù)讀取,會提高內存利用率。
#訓練數(shù)據(jù)的特征用鍵值對的形式表示
def write_tfRecord(tfRecordName,image_path,label_path):
    #創(chuàng)建寫入
    writer=tf.python_io.TFRecordWriter(tfRecordName)
    num_pic=0
    f=open(label_path,'r')
    contents=f.readlines()
    f.close()
    #遍歷每張圖和標簽
    for content in contents:
        value=content.split()
        img_path=image_path+value[0]
        img=Image.open(img_path)
        img_raw=img.tobytes()
        labels=[0]*10
        lables[int(value[1])]=1
        example=tf.train.Example(features=tf.train.features(feature={
            'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            'label':tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
            }))
        writer.write(example.SerializeToString())
        num_pic+=1
        #序列化
        print("the number of picture:",num_pic)
    writer.close()


def generate_tfRecord():
    isExists=os.path.exists(data_path)
    if not isExists:
        os.makedirs(data_path)
        print("Created")
    else:
        print("Already Exists")

    write_tfRecord(tfRecord_train,image_train_path,label_train_path)
    write_tfRecord(tfRecord_test,image_test_path,label_test_path)


#解析文件
def read_tfRecord(tfRecord_path):
    #生成一個先入先出的隊列
    filename_queue=tf.train.string_input_producer([tfRecord_path])
    reader=tf.TFRecordReader()
    _,serialized_example=reader.read(filename_queue)
    features=tf.parse_single_example(serialized_example,features={
        'label':tf.FixedLenFeature([10],tf.int64),
        'img_raw':tf.FixedLenFeature([],tf.string)
        })
    img=tf.decode_raw(features['img_raw'],tf.uint8)
    img.set_shape([784])
    img=tf.cast(img,tf.float32)*(1./255)
    label=tf.cast(features['label'],tf.float32)

    return img,lable

def get_tfrecord(num,isTrain=True):
    if isTrain:
        tfRecord_path=tfRecord_path
    else:
        tfRecord_path=tfRecord_test
    img,label=read_tfRecord(tfRecord_path)

    img_batch,label_batch=tf.train.shuffle_batch([img,label],batch_size=num,num_threads=2,capacity=1000,min_after_dequeue=700)

    return img_batch,label_batch

def main():
    generate_tfRecord()

if __name__=='__main__':
    main()




參考:人工智能實踐:Tensorflow筆記

?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容