TensorFlow API解釋

1、tf.shape(varA)和varA.get_shape()
這兩個API都是返回varA的size大小,但是tf.shape(varA)中varA的數(shù)據(jù)類型可以是Tensor、list、Array;而varA.get_shape()中varA的數(shù)據(jù)類型只能是Tensor,返回的是一個元組tuple。
例如:
import tensorflow as tf
import numpy as np
varX=tf.constant([[1,2,3],[4,5,6],[7,8,9]])
varY=[[1,2,3],[4,5,6],[7,8,9]]
varZ=np.arange(9).reshape([1,3,3])
sess=tf.Session()
varX_shape=tf.shape(varX)
varY_shape=tf.shape(varY)
varZ_shape=tf.shape(varZ)
print sess.run(varX_shape)
print sess.run(varY_shape)
print sess.run(varZ_shape)
輸出結(jié)果是:
[3 3]
[3 3]
[1 3 3]
而:
varX_shape=varX.get_shape()
返回varX_shape是元組TensorShape([Dimension(3), Dimension(3)]),所以不能用sess.run(varX_shape)輸出結(jié)果,因為既不是Tensor,也不是string。但是可以用.as_list()得到size,即
varX_shape=varX.get_shape().as_list()返回varX_shape=[3, 3]
varY_shape=varY.get_shape()則會報錯AttributeError: 'list' object has no attribute 'get_shape',因為varY不是Tensor,是list
varZ_shape=varZ.get_shape()也會報錯AttributeError: 'numpy.ndarray' object has no attribute 'get_shape',因為varZ也不是Tensor,是Array。
所以,如果一個Tensor的靜態(tài)shape未定義,則可用上述API來獲得其動態(tài)shape。
PS: 看一個varX的維度數(shù),可以看varX定義或者輸出的最左邊(或最右邊)有幾個[(或])。

2、tf.Tensor.set_shape()和tf.reshape()
tf.Tensor.set_shape的函數(shù)調(diào)用為tf.Tensor.set_shape(shape),參數(shù)shape為要調(diào)整為的形狀(注意shape里最多只能有一個維度的值可以為-1,表示該維度自動計算得到),作用是更新一個Tensor對象的靜態(tài)shape。如果該Tensor對象的靜態(tài)shape信息不能夠直接推導(dǎo)出來時,tf.Tensor.set_shape()設(shè)置該Tensor額外的shape信息,要注意的是該方法不改變Tensor的動態(tài)shape信息。
tf.reshape()的函數(shù)調(diào)用為tf.reshape(tensorVar, shape, name=None),參數(shù)tensorVar為被調(diào)整維度的張量,shape為要調(diào)整為的形狀,作用為返回一個shape形狀的新Tensor。
示例:
varA = [1, 2, 3, 4, 5, 6, 7, 8, 9]
tensorVarA = tf.constant(varA)
tensorVarB = tf.reshape(tensorVarA, [1, 3, 3])
with tf.Session() as sess:
print(sess.run(tensorVarB))
輸出為:
[[[1 2 3]
[4 5 6]
[7 8 9]]]

3、tf.ConfigProto函數(shù)
該API用在創(chuàng)建session的時候,用于對session進(jìn)行參數(shù)配置。調(diào)用過程為:
with tf.Session(config = tf.ConfigProto(), ...)
tf.ConfigProto()由以下兩個參數(shù):
log_device_placement = True/Flase,指示是否打印參數(shù)在設(shè)備上的分配信息
allow_soft_placement = True/Flase,指示如果指定的設(shè)備不存在是,是否允許TensorFlow自動分配其他設(shè)備。為了避免出現(xiàn)指定設(shè)備不存在而出錯的情況,可以在創(chuàng)建session的時候?qū)llow_soft_placement設(shè)置為True,這樣TensorFlow會自動選擇一個存在并且支持的設(shè)備來運行operation
gpu_options = ...,設(shè)置每個GPU應(yīng)該使用的顯存容量,由tf.GPUOptions(per_process_gpu_memory_fraction = 0.9)設(shè)置
tf.ConfigProto().gpu_options.allow_growth = True/False,指示使用的GPU容量,是否按需增加
所以,一個完整的tf.ConfigProto使用如下:
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
config = tf.ConfigProto(log_device_placement = True, allow_soft_placement = True, gpu_options = gpu_options)
config.gpu_options.allow_growth = True
With tf.Session(config = config, ...) as sess:
...
PS: CUDA_VISIBLE_DEVICE=...,指示使用哪塊GPU運行,如:
CUDA_VISIBLE_DEVICE=0 python a.py,在0號GPU上運行a.py
CUDA_VISIBLE_DEVICE=0,1 python b.py,在0、1號GPU上運行b.py

4、tf.where函數(shù)
tf.where的函數(shù)調(diào)用為tf.where(condition, x = None, y = None, name = None),根據(jù)condition判定返回。即condition是True,選擇x;condition是False,選擇y。

5、tf.nn.dropout函數(shù)
tf.nn.dropout的函數(shù)調(diào)用為tf.nn.dropout(tensorVar, keep_prob, noise_shape = None, seed = None, name = None),使tensor以一定的概率保留。tensorVar為輸入tensor,keep_prob為神經(jīng)元保留的概率。
典型用法如下:
tensorVarY =tf.matmul(tf.nn.dropout(tensorVarX, keep_prob = 0.5), W)
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob = 0.5)為RNN中dropout的用法
PS: 一般dropout在訓(xùn)練的時候用,test的時候就不需要dropout了。

6、Saver
Saver的函數(shù)調(diào)用為tf.train.Saver(),實現(xiàn)保存訓(xùn)練的結(jié)果,即保存模型的參數(shù),以便下一次的迭代訓(xùn)練或用于預(yù)測。Saver類提供了以下三點功能:
1)提供了向checkpoint文件保存和從checkpoint文件中恢復(fù)變量的相關(guān)方法。checkpoint文件是一個二進(jìn)制文件,它把變量名映射到對應(yīng)的tensor值 。
2)只要提供一個計數(shù)器,當(dāng)計數(shù)器觸發(fā)時,Saver類可以自動的生成checkpoint文件。這讓我們可以在訓(xùn)練過程中保存多個中間結(jié)果。例如,我們可以保存每一步訓(xùn)練的結(jié)果。
3)為了避免填滿整個磁盤,Saver可以自動的管理checkpoint文件。例如,我們可以指定保存最近的N個checkpoint文件。checkpoint文件指明最新的模型,和模型的存儲位置。restore時,也是查看checkpoint文件獲取最新的模型。
典型用法如下:

import os
cwd = os.cwd()
saver = tf.train.Saver()
saver.save(sess, cwd + 'model.ckpt', global_step)

預(yù)測或載入模型:

ckpt = tf.train.get_checkpoint_state(cwd)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print(sess.run(W))

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 1. tf函數(shù) tensorflow 封裝的工具類函數(shù) | 操作組 | 操作 ||:-------------| ...
    南墻已破閱讀 5,631評論 0 5
  • 簡單線性回歸 import tensorflow as tf import numpy # 創(chuàng)造數(shù)據(jù) x_dat...
    CAICAI0閱讀 3,668評論 0 49
  • 若是問我的擇偶標(biāo)準(zhǔn),我希望他既hold得住白襯衫牛仔衫,又能駕馭得了運動衫。想了下身邊的男生朋友,能符合這一點的好...
    晨光微熙2017閱讀 211評論 0 0
  • 簡介 OkHttp3IdlingResource是Jake Wharton大神為okhttp寫的Espresso的...
    sylviaMo閱讀 1,178評論 0 0
  • 眾口難調(diào) 廚師口頭禪 炒菜既不能鹽重又不能鹽少 適度為妙 而這個度往往難把握 高明的廚藝 也會嘆氣 因為眾人口味不...
    旖旎i閱讀 169評論 4 8

友情鏈接更多精彩內(nèi)容