AutoGraph是TF提供的一個非常具有前景的工具, 它能夠?qū)⒁徊糠謕ython語法的代碼轉(zhuǎn)譯成高效的圖表示代碼. 由于從TF 2.0開始, TF將會默認使用動態(tài)圖(eager execution), 因此利用AutoGraph, 在理想情況下, 能讓我們實現(xiàn)用動態(tài)圖寫(方便, 靈活), 用靜態(tài)圖跑(高效, 穩(wěn)定).
但是! 在使用的過程中, 如無意外肯定是會有意外的, 這篇文章就是指出一些AutoGraph和tf.function的奇怪的行為, 讓你更愉快地使用它們.
本文假設(shè)讀者具有一定的Python和TensorFlow的使用經(jīng)驗.
會話執(zhí)行
對tf1.X有經(jīng)驗的讀者應(yīng)該不會對讓我們又愛又恨的計算圖(tf.Graph)和執(zhí)行會話(tf.Session)感到陌生, 一個常規(guī)的流程如下:
- 初始化一個計算圖并且將該計算圖設(shè)置為當前scope下的默認計算圖
- 用TF API設(shè)計計算圖(比如:
y=tf.matmul(a, x) + b) - 提前界定好參數(shù)共享并劃分相應(yīng)的參數(shù)scope
- 創(chuàng)建并配置好
tf.Session - 將計算圖傳給
tf.Session - 初始化參數(shù)
- 用
tf.Session.run來執(zhí)行計算圖的節(jié)點, 被執(zhí)行的節(jié)點會反向追蹤所有依賴的需要執(zhí)行的節(jié)點并執(zhí)行計算.
以下是上述過程的一個代碼例子:
g = tf.Graph() #初始化計算圖
with g.as_default(): # 設(shè)置為默認計算圖
a = tf.constant([[10,10],[11.,1.]])
x = tf.constant([[1.,0.],[0.,1.]])
b = tf.Variable(12.)
y = tf.matmul(a, x) + b # 描述計算圖
init_op = tf.global_variables_initializer() # 待執(zhí)行節(jié)點
with tf.Session() as sess: # 配置會話
sess.run(init_op) # 執(zhí)行節(jié)點
print(sess.run(y)) # 輸出結(jié)果
在TF 2.0中, 由于默認為動態(tài)圖, 計算會直接被執(zhí)行, 也就是說, 我們不需要
- 定義計算圖
- 會話執(zhí)行
- 參數(shù)初始化
- 用scope定義參數(shù)分享
- 用
tf.control_dependencies來聲明節(jié)點的非直接依賴
我們可以像寫普通python代碼(or pytorch)一樣, 寫了就執(zhí)行:
a = tf.constant([[10,10],[11.,1.]])
x = tf.constant([[1.,0.],[0.,1.]])
b = tf.Variable(12.)
y = tf.matmul(a, x) + b
print(y.numpy())
一般來說, eager代碼會比執(zhí)行相同操作的靜態(tài)圖代碼的效率低, 因為很多計算圖優(yōu)化的方法只能用在數(shù)據(jù)流圖上.
如果想在TF 2.0上構(gòu)建傳統(tǒng)的計算圖, 我們就需要用到tf.function.
函數(shù), 而非會話
TF 2.0的其中一個重要改變就是去除tf.Session(此處應(yīng)有掌聲). 這個改變會迫使用戶用更好的方式來組織代碼: 不用再用讓人糾結(jié)的tf.Session來執(zhí)行代碼, 就是一個個python函數(shù), 加上一個簡單的裝飾器.
在TF 2.0里面, 如果需要構(gòu)建計算圖, 我們只需要給python函數(shù)加上@tf.function的裝飾器.
上文提到靜態(tài)圖的執(zhí)行效率更高, 但是加速并不是一定的. 一般來說, 計算圖越復(fù)雜, 加速效果越明顯. 對于復(fù)雜的計算圖, 比如訓(xùn)練深度學(xué)習(xí)模型, 獲得的加速是巨大的. (譯者注: 個人感覺還是要結(jié)合實際來看, 如果某一部分的計算既有復(fù)雜的計算圖, 而計算圖的復(fù)雜性又帶來了額外的內(nèi)存消耗
或者計算量, 那么加速會比較明顯, 但是很多時候, 比如一般的CNN模型, 主要計算量并不在于圖的復(fù)雜性, 而在于卷積、矩陣乘法等操作, 加速并不會很明顯. 此處想法有待驗證)
這個自動將python代碼轉(zhuǎn)成圖表示代碼的工具就叫做AutoGraph.
在TF 2.0中, 如果一個函數(shù)被@tf.function裝飾了, 那么AutoGraph將會被自動調(diào)用, 從而將python函數(shù)轉(zhuǎn)換成可執(zhí)行的圖表示.
tf.function: 究竟發(fā)生了什么?
在第一次調(diào)用被@tf.function裝飾的函數(shù)時, 下列事情將會發(fā)生:
- 該函數(shù)被執(zhí)行并跟蹤。和Tensorflow 1.x類似, Eager會在這個函數(shù)中被禁用,因此每個
tf.API只會定義一個生成tf.Tensor輸出的節(jié)點 - AutoGraph用于檢測可以轉(zhuǎn)換為等效圖表示的Python操作(
while→tf.while,for→tf.while,if→tf.cond,assert→tf.assert...) - 為了保留執(zhí)行順序,在每個語句之后自動添加
tf.control_dependencies,以便在執(zhí)行第i+1行時確保第i行已經(jīng)被執(zhí)行. 至此計算圖已經(jīng)確定 - 根據(jù)函數(shù)名稱和輸入?yún)?shù),創(chuàng)建唯一ID并將其與定義好的計算圖相關(guān)聯(lián)。計算圖被緩存到一個映射表中:
map [id] = graph - 如果ID配對上了,之后的函數(shù)調(diào)用都會直接使用該計算圖
下一節(jié)將會具體闡述如何將TF 1.X代碼塊分別改寫到eager和計算圖版本.
改寫到eager execution
要使用tf.function, 第一步需要先將TF 1.X的設(shè)計計算圖的代碼放進python函數(shù)里面.
def f():
a = tf.constant([[10,10],[11.,1.]])
x = tf.constant([[1.,0.],[0.,1.]])
b = tf.Variable(12.)
y = tf.matmul(a, x) + b
return y
應(yīng)為TF 2.0默認是eager的, 我們可以直接執(zhí)行該函數(shù)(不需要tf.Session):
print(f().numpy())
我們就會得到輸出:
[[22. 22.]
[23. 13.]]
從eager到tf.function
我們可以直接用@tf.function來裝飾函數(shù)f, 我們在原來f的基礎(chǔ)上加上宇宙第一的debug大法: print來更好地看看究竟發(fā)生了什么.
@tf.function
def f():
a = tf.constant([[10,10],[11.,1.]])
x = tf.constant([[1.,0.],[0.,1.]])
b = tf.Variable(12.)
y = tf.matmul(a, x) + b
print("PRINT: ", y)
tf.print("TF-PRINT: ", y)
return y
f()
所以發(fā)生了什么呢?
-
@tf.function將函數(shù)f包進了tensorflow.python.eager.def_function.Function這個對象, 函數(shù)f被賦予到了這個對象的.python_function屬性. - 當
f()被執(zhí)行的時候, 計算圖會同時被構(gòu)建, 但是計算不會執(zhí)行, 因此我們會得到以下結(jié)果,tf.的操作不會被執(zhí)行:
PRINT: Tensor("add:0", shape=(2, 2), dtype=float32)
- 最終, 你會看到代碼會執(zhí)行失敗:
ValueError: tf.function-decorated function tried to create variables on non-first call.
在 RFC: Functions, not Session里面有個非常明確的指示:
State (like
tf.Variableobjects) are only created the first time the function f is called. 狀態(tài)(比如tf.Variable) 只會在函數(shù)被第一次調(diào)用時創(chuàng)建.
但是 Alexandre Passos指出, 在函數(shù)轉(zhuǎn)換成圖表示時, 我們沒有辦法確定tf.function調(diào)用了多少次函數(shù), 因此我們在第一次調(diào)用函數(shù)f時, 在圖構(gòu)建的過程中, 可能會被執(zhí)行了多次, 這就導(dǎo)致了上述錯誤.
造成這個錯誤的根源在于同樣的命令在動態(tài)圖和靜態(tài)圖中的不一致性. 在動態(tài)圖中, tf.Variable時一個普通的python變量, 超出了其作用域范圍就會被銷毀. 而在靜態(tài)圖中, tf.Variable則是計算圖中一個持續(xù)存在的節(jié)點, 不受python的作用域的影響. 因此, 這是使用tf.function的第一個教訓(xùn):
將一個在動態(tài)圖中可行的函數(shù)轉(zhuǎn)換成靜態(tài)圖需要用靜態(tài)圖的方式思考該函數(shù)是否可行
那么我們可以怎樣去規(guī)避這個錯誤呢?
- 將
tf.Variable作為函數(shù)的參數(shù)傳入 - 將父作用域繼承
tf.Variable - 將
tf.Variable作為類屬性來調(diào)用
用改變變量作用域來處理
這里指方法2和方法3. 顯然的, 我們推薦使用方法3:
class F():
def __init__(self):
self._b = None
@tf.function
def __call__(self):
a = tf.constant([[10, 10], [11., 1.]])
x = tf.constant([[1., 0.], [0., 1.]])
if self._b is None:
self._b = tf.Variable(12.)
y = tf.matmul(a, x) + self._b
print("PRINT: ", y)
tf.print("TF-PRINT: ", y)
return y
f = F()
f()
將狀態(tài)作為傳入?yún)?shù)來處理
我們之后會看到, 我們并不能隨意地用tf.function來轉(zhuǎn)化eager的代碼并達到加速的目的, 我們需要想象一下轉(zhuǎn)化是怎么完成的, 在轉(zhuǎn)python的代碼到圖操作的時候究竟發(fā)生了什么, 這些轉(zhuǎn)化包含了什么黑魔法. 這里的例子比較簡單, 我們會在接下來的文章中更深入的探討.
@tf.function
def f(b):
a = tf.constant([[10,10],[11.,1.]])
x = tf.constant([[1.,0.],[0.,1.]])
y = tf.matmul(a, x) + b
print("PRINT: ", y)
tf.print("TF-PRINT: ", y)
return y
b = tf.Variable(12.)
f(b)
上述函數(shù)會得到我們想要的結(jié)果, 另外, 作為參數(shù)被傳入的變量能夠在函數(shù)中直接更新, 而更新后的值會在函數(shù)外也適用. 下面的代碼會打印出1,2,3
a = tf.Variable(0)
@tf.function
def g(x):
x.assign_add(1)
return x
print(g(a))
print(g(a))
print(g(a))
總結(jié)
- 我們可以用
@tf.function裝飾器來將python代碼轉(zhuǎn)成圖表示代碼 - 我們不能在被裝飾函數(shù)中初始化
tf.Variable - 可以用變量作用域繼承(對象屬性)或者參數(shù)傳入的方法使用在函數(shù)外初始化的變量
在之后的部分我們會更加深入地探討輸入?yún)?shù)類型對效率的影響, 以及python操作的轉(zhuǎn)換細節(jié).
聲明: 本文翻譯自Paolo Galeone的博客, 已取得作者的同意, 如需轉(zhuǎn)載本文請聯(lián)系本人
Disclaimer: This is a translation of the article Analyzing tf.function to discover AutoGraph strengths and subtleties by Paolo Galeone.