TensorFlow自學(xué)第3篇——異步文件

沒有經(jīng)過實(shí)踐檢驗(yàn)的理論,不管它多么漂亮,都會失去分量,不會為人所承認(rèn);沒有以有分量的理論作基礎(chǔ)的實(shí)踐一定會遭到失?。ㄩT捷列夫)。

看視頻——記筆記——翻書查漏補(bǔ)缺,這種方式到目前看來,是經(jīng)過實(shí)踐檢驗(yàn)的,正確的學(xué)習(xí)方式。獨(dú)自學(xué)習(xí)不易,且學(xué)且行。

線程與隊(duì)列

在使用TensorFlow進(jìn)行異步計(jì)算時,隊(duì)列是一種強(qiáng)大的機(jī)制。
一個簡單的例子。先創(chuàng)建一個“先入先出”的隊(duì)列(FIFOQueue),并將其內(nèi)部所有元素初始化為零。然后,構(gòu)建一個TensorFlow圖,它從隊(duì)列前端取走一個元素,加上1之后,放回隊(duì)列的后端。慢慢地,隊(duì)列的元素的值就會增加。

TensorFlow提供了兩個類來幫助多線程的實(shí)現(xiàn):tf.Coordinator和 tf.QueueRunner。Coordinator類可以用來同時停止多個工作線程并且向那個在等待所有工作線程終止的程序報(bào)告異常,QueueRunner類用來協(xié)調(diào)多個工作線程同時將多個張量推入同一個隊(duì)列中。

tf.QueueRunner

QueueRunner類會創(chuàng)建一組線程, 這些線程可以重復(fù)的執(zhí)行Enquene操作, 他們使用同一個Coordinator來處理線程同步終止。此外,一個QueueRunner會運(yùn)行一個closer thread,當(dāng)Coordinator收到異常報(bào)告時,這個closer thread會自動關(guān)閉隊(duì)列。

您可以使用一個queue runner,來實(shí)現(xiàn)上述結(jié)構(gòu)。 首先建立一個TensorFlow圖表,這個圖表使用隊(duì)列來輸入樣本。增加處理樣本并將樣本推入隊(duì)列中的操作。增加training操作來移除隊(duì)列中的樣本。

tf.Coordinator

Coordinator類用來幫助多個線程協(xié)同工作,多個線程同步終止。 其主要方法有:
should_stop():如果線程應(yīng)該停止則返回True。
request_stop(): 請求該線程停止。
join():等待被指定的線程終止。
首先創(chuàng)建一個Coordinator對象,然后建立一些使用Coordinator對象的線程。這些線程通常一直循環(huán)運(yùn)行,一直到should_stop()返回True時停止。 任何線程都可以決定計(jì)算什么時候應(yīng)該停止。它只需要調(diào)用request_stop(),同時其他線程的should_stop()將會返回True,然后都停下來。

舉個栗子

"""
CPU負(fù)責(zé)TensorFlow的計(jì)算,IO負(fù)責(zé)讀取文件
由于速度上的差異,通常做法是:主線程進(jìn)行模型訓(xùn)練,子線程讀取數(shù)據(jù),二者通過隊(duì)列進(jìn)行數(shù)據(jù)傳輸
相當(dāng)于主線程從隊(duì)列讀數(shù)據(jù),子進(jìn)程往隊(duì)列放數(shù)據(jù)
"""

import tensorflow as tf
import os
# 忽略不必要的警告信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


# 一、模擬同步,先處理數(shù)據(jù),然后取數(shù)據(jù)訓(xùn)練

# 1、首先定義隊(duì)列
Q1 = tf.FIFOQueue(3, tf.float32)
# 放入一些數(shù)據(jù)
# 此處不能直接傳入[0.1,0.2,0.3],因?yàn)閷τ趖f來說,接收的一切值都是tensor張量
# 但是此處需要傳入的是列表,所以改為[[0.1, 0.2, 0.3], ]
enq_many1 = Q1.enqueue_many([[0.1, 0.2, 0.3], ])
# 2、定義一些處理數(shù)據(jù)的邏輯,取數(shù)據(jù),+1,入隊(duì)列
out_q1 = Q1.dequeue()
data1 = out_q1 + 1
en_q1 = Q1.enqueue(data1)

with tf.Session() as sess1:
    # 初始化隊(duì)列
    sess1.run(enq_many1)
    # 處理數(shù)據(jù)
    for i in range(100):
        sess1.run(en_q1) # TensorFlow中,計(jì)算有依賴性
    # 訓(xùn)練數(shù)據(jù)
    for i in range(Q1.size().eval()):
        print(sess1.run(Q1.dequeue()))

# --------------------------------分割線---------------------------------

# 二、模擬異步,子線程存入樣本,主線程讀取樣本

# 1、定義一個隊(duì)列,1000
Q2 = tf.FIFOQueue(1000, tf.float32)
# 2、定義要做的事情,循環(huán),+1,放隊(duì)列
var2 = tf.Variable(0.0)
# 實(shí)現(xiàn)自增op
data2 = tf.assign_add(var2, tf.constant(1.0))
en_q2 = Q2.enqueue(data2)
# 3、定義隊(duì)列管理器op,指定多少個子線程,子線程該干什么事
qr2 = tf.train.QueueRunner(Q2, enqueue_ops=[en_q2] * 2)
# 初始化變量op
init_op2 = tf.global_variables_initializer()

with tf.Session() as sess2:
    # 初始化變量
    sess2.run(init_op2)
    # 開啟線程協(xié)調(diào)器
    coord = tf.train.Coordinator()
    # 開啟子線程
    threads = qr2.create_threads(sess2, coord=coord, start=True)
    # 主線程讀取數(shù)據(jù),等待訓(xùn)練
    for i in range(300):
        print(sess2.run(Q2.dequeue()))

    # 回收線程
    coord.request_stop()
    coord.join(threads)


# --------------------------------分割線---------------------------------


# 三、文件讀取

# 1、構(gòu)造一個文件隊(duì)列
# 2、構(gòu)造文件閱讀器,讀取隊(duì)列一個樣本的內(nèi)容,解碼
# 3、批處理
# 4、主線程取樣本數(shù)據(jù)訓(xùn)練
# TensorFlow默認(rèn)一次讀取一個樣本,即對CSV文件只讀一行,二進(jìn)制位文件只讀一個樣本的字節(jié)數(shù),圖片文件讀取一張

def csvread(filelist):
    """
    讀取CSV文件
    :param filelist: 文件路徑+文件名的列表
    :return: 讀取的內(nèi)容
    """
    # 1、構(gòu)造文件隊(duì)列
    file_queue = tf.train.string_input_producer(filelist)
    # 2、構(gòu)造CSV閱讀器,讀取隊(duì)列,默認(rèn)以行讀取
    reader = tf.TextLineReader()
    key, value = reader.read(file_queue)
    print(value)
    # 3、對每行內(nèi)容解碼
    records = [["None"],["None"]]
    # record_defaults指定每一個樣本的每一列的類型,指定默認(rèn)值[["None"],[4.0]]
    example, label = tf.decode_csv(value, record_defaults=records)
    print(example, label)
    # 讀取多個數(shù)據(jù),批處理
    example_batch, label_batch = tf.train.batch([example, label], batch_size=9, num_threads=1, capacity=9)
    return example_batch, label_batch


if __name__=="__main__":

    # 1、找到文件,放入列表
    filename = os.listdir("./csvdata/")
    print(filename)
    filelist = [os.path.join("./csvdata/", file) for file in filename]
    print(filelist)

    example_batch, label_batch = csvread(filelist)

    with tf.Session() as sess:
        # 定義線程協(xié)調(diào)器
        coord = tf.train.Coordinator()
        # 開啟讀取文件的線程
        threads = tf.train.start_queue_runners(sess, coord=coord)
        # 打印讀取內(nèi)容
        print(sess.run([example_batch, label_batch]))

        # 回收子線程
        coord.request_stop()
        coord.join(threads)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容