這篇文章的目的是為了解決transformer 處理長序列任務(wù)遇到的計算復(fù)雜度較高的問題。為了解決這個問題,許多工作聚焦于探索更有效的注意力機制,比如linear attention,但這類方法往往存在著以下三個缺陷:
- inferior quality. linear attention 相對于vanilla attention 往往會帶來明顯的指標掉點。
- overhead in practice. efficient attention往往是進行了復(fù)雜的layout 變換,這種操作計算復(fù)雜度體現(xiàn)不出來,但在實際應(yīng)用中往往帶來較大的時間開銷。
- inefficient auto-regressive training. 這一塊主要是針對時序上使用的efficient attention,比如RNN-style 的序列狀態(tài)更新,因為時序的存在導(dǎo)致并行度降低,訓(xùn)練上速度較慢。
所以本文不是直接去擬合MHSA,而是設(shè)計了一種新的quad attention結(jié)構(gòu)GAU,這種結(jié)構(gòu)對于attention不是特別敏感,從而能夠進一步使用線性時間去近似擬合。最終這種線性擬合的結(jié)構(gòu)被命名為 FLASH (Fast Linear Attention With a Single Head.)
下面來分別看一下GAU和FLASH的設(shè)計。
1. GAU (Gated Attention Unit)
GAU 是從GLU(Gated Linear Unit) 發(fā)展而來的,GLU的定義如下
這個式子可以理解成對relu的輸出變量加了一個平方的非線性變換,當然這里只是類比,因為 是放在外面的,如果是
就更明顯了,只是relu放在里面和外面還是不完全等價的,相當于更復(fù)雜的非線性變換。GLU的解釋是v對u做了一次gate。
GAU將上式中的V替換成了quad attention的輸出,見下圖

這里需要注意的幾點:
- GAU中的attention部分 使用的
, 而不是softmax, 這個操作和我們上面解釋GLU差不多;
作者驗證在自然語言上同樣有效,但CV上還不確定。
- GAU中的QKV都是single head的, Q、K的生成方式不同,采用類似于BN中可學(xué)習(xí)參數(shù)的那種形式的scale_offset 算子
-
GAU的參數(shù)量相對于MHSA要小近似一半,因此同樣的參數(shù)量和速度下,可以多堆疊一倍的GAU。
下面兩張表分別對比了結(jié)構(gòu)中一些操作的影響,gating這個文章中沒細說,加上gating操作,參數(shù)量反而降低了,推測應(yīng)該是一種過濾的操作?如果是GLU中的gate操作,參數(shù)量應(yīng)該增大吧。
image.pngimage.png
2. Fast Linear Attention with GAU
對GAU進行線性逼近,其出發(fā)點基于兩點觀測:
- GAU的門限機制允許一個更弱的,比如單頭,無softmax (見Table1和Table2的對比)注意力機制實現(xiàn)相似的性能;
- 相同的計算代價下,MHSA+MLP 差不多等于2個GAU,而注意力逼近一般需要更多的layer去捕獲更全面的依賴,所以GAU是一個好的選擇。
目前存在的線性復(fù)雜度的attention可以大致劃分為兩類: Partial Attention和Linear Attention。
Partial Attention,包括劃分window,local+sparse, axial, hash, clustering等,這個方法雖然表現(xiàn)不如full attention,但理論上確實能夠?qū)ong sequence有較好的速度表現(xiàn)。但問題關(guān)鍵在于實際應(yīng)用中往往需要gather,scatter,slice和cat等layout操作,這類操作并行度較差,對硬件不友好,導(dǎo)致實際場景中速度要慢很多。
Linear Attention,主要是去掉Softmax操作,從而能夠利用矩陣乘法結(jié)合律先計算K^TV, 將計算復(fù)雜度降下來,同時對于NLP中的時序任務(wù),隨著序列逐步增長 是個累加的過程,每次只需要計算當前時刻的
加上歷史累積的
。相比于quad attention每次都需要計算全部
計算量顯然小很多。但是,對于長序列而言,雖然每次計算量小很多,但是這種序列計算就帶來RNN類型存在的先天時序性,沒法并行操作,只有執(zhí)行完t-1步才能執(zhí)行第t步,所以計算復(fù)雜度小但是計算時間反而有可能更長。對于移動芯片來說,存儲器較小會使這種情況加劇。
3.1 Mixed Chunk Attention
本文提出的線性逼近結(jié)構(gòu)。首先將序列劃分為G個不重疊的chunk,每個chunk生成對應(yīng)的, 由
通過per-dim的scaling和offset 生成
, 每個chunk會同時參與local attention 和 global attention, local attention的話
而global attention則劃分為non-causal和 causal,即是否時序上算, causal 會帶來訓(xùn)練時間的大幅增加。
那么最終在當前幀的輸出是把local 和 global的attention 疊加在一起放到GAU中,
偽代碼如下:

這種操作其實是quad attention和linear attention的這種方案。
一些討論:
- chunk的劃分能加快auto-regressive training 的過程
- overlapping local attention能夠改善質(zhì)量,但是會存在memory re-formatting operations導(dǎo)致實際運行速度差很多。另外作者認為optimal partial attention 是任務(wù)相關(guān)的,但non-overlapping 是通用的。
- 和combiner相比,combiner也劃分chunk,但每個chunk內(nèi)使用的quad local attention,這樣FLASH在chunk內(nèi)就能夠允許更長的chunk。
3. 實驗
實驗在NLP上做的,這里只看下結(jié)果。

時間對比:

4. 結(jié)論
這篇文章先是提出一種GAU,然后GAU對attention的依賴較小,進而可以把GAU中的attention替換成linear attention。這個倒是可以在CV上嘗試,GAU可以多堆一倍。另一點mixed-chunk 在CV上感覺用處不大,倒是可以在track任務(wù)上使用,track的memory bank上更新query。

