閱讀筆記-Transformer Quality in Linear Time

這篇文章的目的是為了解決transformer 處理長序列任務(wù)遇到的計算復(fù)雜度較高的問題。為了解決這個問題,許多工作聚焦于探索更有效的注意力機制,比如linear attention,但這類方法往往存在著以下三個缺陷:

  • inferior quality. linear attention 相對于vanilla attention 往往會帶來明顯的指標掉點。
  • overhead in practice. efficient attention往往是進行了復(fù)雜的layout 變換,這種操作計算復(fù)雜度體現(xiàn)不出來,但在實際應(yīng)用中往往帶來較大的時間開銷。
  • inefficient auto-regressive training. 這一塊主要是針對時序上使用的efficient attention,比如RNN-style 的序列狀態(tài)更新,因為時序的存在導(dǎo)致并行度降低,訓(xùn)練上速度較慢。

所以本文不是直接去擬合MHSA,而是設(shè)計了一種新的quad attention結(jié)構(gòu)GAU,這種結(jié)構(gòu)對于attention不是特別敏感,從而能夠進一步使用線性時間去近似擬合。最終這種線性擬合的結(jié)構(gòu)被命名為 FLASH (Fast Linear Attention With a Single Head.)
下面來分別看一下GAU和FLASH的設(shè)計。

1. GAU (Gated Attention Unit)

GAU 是從GLU(Gated Linear Unit) 發(fā)展而來的,GLU的定義如下
U= \phi_u(XW_u), ~~~ V = \phi_v(XW_v) \\ O = (U\cdot V)W_o
這個式子可以理解成對relu的輸出變量加了一個平方的非線性變換,當然這里只是類比,因為\phi 是放在外面的,如果是 \phi(XW_u\cdot XW_v)W_o 就更明顯了,只是relu放在里面和外面還是不完全等價的,相當于更復(fù)雜的非線性變換。GLU的解釋是v對u做了一次gate。
GAU將上式中的V替換成了quad attention的輸出,見下圖

image.png

這里需要注意的幾點:

  • GAU中的attention部分 使用的relu^2, 而不是softmax, 這個操作和我們上面解釋GLU差不多;relu^2 作者驗證在自然語言上同樣有效,但CV上還不確定。
  • GAU中的QKV都是single head的, Q、K的生成方式不同,采用類似于BN中可學(xué)習(xí)參數(shù)的那種形式的scale_offset 算子
  • GAU的參數(shù)量相對于MHSA要小近似一半,因此同樣的參數(shù)量和速度下,可以多堆疊一倍的GAU。
    下面兩張表分別對比了結(jié)構(gòu)中一些操作的影響,gating這個文章中沒細說,加上gating操作,參數(shù)量反而降低了,推測應(yīng)該是一種過濾的操作?如果是GLU中的gate操作,參數(shù)量應(yīng)該增大吧。


    image.png
    image.png

2. Fast Linear Attention with GAU

對GAU進行線性逼近,其出發(fā)點基于兩點觀測:

  • GAU的門限機制允許一個更弱的,比如單頭,無softmax (見Table1和Table2的對比)注意力機制實現(xiàn)相似的性能;
  • 相同的計算代價下,MHSA+MLP 差不多等于2個GAU,而注意力逼近一般需要更多的layer去捕獲更全面的依賴,所以GAU是一個好的選擇。

目前存在的線性復(fù)雜度的attention可以大致劃分為兩類: Partial Attention和Linear Attention。
Partial Attention,包括劃分window,local+sparse, axial, hash, clustering等,這個方法雖然表現(xiàn)不如full attention,但理論上確實能夠?qū)ong sequence有較好的速度表現(xiàn)。但問題關(guān)鍵在于實際應(yīng)用中往往需要gather,scatter,slice和cat等layout操作,這類操作并行度較差,對硬件不友好,導(dǎo)致實際場景中速度要慢很多。
Linear Attention,主要是去掉Softmax操作,從而能夠利用矩陣乘法結(jié)合律先計算K^TV, 將計算復(fù)雜度降下來,同時對于NLP中的時序任務(wù),隨著序列逐步增長K^TV 是個累加的過程,每次只需要計算當前時刻的K_i^TV_i加上歷史累積的M_t=K^T_{:t}V_{:t}。相比于quad attention每次都需要計算全部QK^T 計算量顯然小很多。但是,對于長序列而言,雖然每次計算量小很多,但是這種序列計算就帶來RNN類型存在的先天時序性,沒法并行操作,只有執(zhí)行完t-1步才能執(zhí)行第t步,所以計算復(fù)雜度小但是計算時間反而有可能更長。對于移動芯片來說,存儲器較小會使這種情況加劇。

3.1 Mixed Chunk Attention

本文提出的線性逼近結(jié)構(gòu)。首先將序列劃分為G個不重疊的chunk,每個chunk生成對應(yīng)的U_g, V_g, Z_g, 由Z_g通過per-dim的scaling和offset 生成 Q_g^{quad}, K_g^{quad}, Q_g^{lin}, K_g^{lin}, 每個chunk會同時參與local attention 和 global attention, local attention的話
\hat{V}_q^{quad} = relu^2(Q_g^{quad}K_g^{quad^T} +b)V_g
而global attention則劃分為non-causal和 causal,即是否時序上算, causal 會帶來訓(xùn)練時間的大幅增加。
Non-Causal: \hat{V}_g^{lin} = Q_g^{lin} \Big (\sum_{h=1}^GK_h^{lin^T}V_h \Big) \\ Casual: \hat{V}_g^{lin} = Q_g^{lin} \Big (\sum_{h=1}^{g-1}K_h^{lin^T}V_h \Big)
那么最終在當前幀的輸出是把local 和 global的attention 疊加在一起放到GAU中,
O_g = \Big [ U_g \odot (\hat{V}_g^{quad} + \hat{V}_g^{lin})\Big ]
偽代碼如下:

image.png

這種操作其實是quad attention和linear attention的這種方案。

一些討論:

  • chunk的劃分能加快auto-regressive training 的過程
  • overlapping local attention能夠改善質(zhì)量,但是會存在memory re-formatting operations導(dǎo)致實際運行速度差很多。另外作者認為optimal partial attention 是任務(wù)相關(guān)的,但non-overlapping 是通用的。
  • 和combiner相比,combiner也劃分chunk,但每個chunk內(nèi)使用的quad local attention,這樣FLASH在chunk內(nèi)就能夠允許更長的chunk。

3. 實驗

實驗在NLP上做的,這里只看下結(jié)果。


image.png

時間對比:


image.png

4. 結(jié)論

這篇文章先是提出一種GAU,然后GAU對attention的依賴較小,進而可以把GAU中的attention替換成linear attention。這個倒是可以在CV上嘗試,GAU可以多堆一倍。另一點mixed-chunk 在CV上感覺用處不大,倒是可以在track任務(wù)上使用,track的memory bank上更新query。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容