線性變換器

變換器在一些任務(wù)中效果很好,但是它的時(shí)間復(fù)雜度是二次的,因而有個(gè)局限性,就是輸入的序列很長(zhǎng)的時(shí)候,它的計(jì)算會(huì)非常慢。為了解決這個(gè)局限性,我們將自注意力表示為核特征圖的線性點(diǎn)積,并利用矩陣乘積的結(jié)合性,將復(fù)雜度從O(N^2 )降低至O(N ),其中N是序列長(zhǎng)度。我們證明,這個(gè)公式可以迭代地實(shí)現(xiàn),這樣就可以極大地加速自回歸變換器,我們還揭示了其余遞歸網(wǎng)絡(luò)的關(guān)系。在序列非常長(zhǎng)的自回歸預(yù)測(cè)任務(wù)上,我們的線性變換器與普通的變換器性能相似,速度速度提高了4000倍。

https://fast-transformers.github.io/

https://linear-transformers.com/

https://github.com/idiap/fast-transformers

還有另外兩篇,和這篇有關(guān)系:

Fast?Transformers?with?Clustered?Attention????(2020.9)????https://arxiv.org/abs/2007.04825?(和本文一作二作換個(gè)位置,需要關(guān)注,對(duì)語(yǔ)義分割可能非常適用,開(kāi)源又是放在一起的),這個(gè)是在NIPS2020里的,可關(guān)注兩篇論文的團(tuán)隊(duì)leader的網(wǎng)址:https://fleuret.org/francois/

Efficient Attention: Attention with Linear Complexities? v1:2018,????v9:To appear at WACV 2021?????https://arxiv.org/abs/1812.01243v9

隔壁家的Performer的實(shí)現(xiàn):https://github.com/lucidrains/performer-pytorch

上面這個(gè)人好可怕,實(shí)現(xiàn)了好多:

https://github.com/lucidrains/linear-attention-transformer

https://github.com/lucidrains?tab=repositories

1. 引言

變換器最早由Vaswani等(2017)提出,是用在神經(jīng)網(wǎng)絡(luò)機(jī)器翻譯(Sutskever et al.,2014;Bahdanau et al.,2015)中,并且在自然語(yǔ)言處理(Devlin et al.,2019)、音頻(Sperber et al.,2018)和圖像(Parmar et al.,2019)領(lǐng)域中的各種任務(wù)中展示了令人印象深刻的結(jié)果。除了在充分監(jiān)督的任務(wù)上,當(dāng)使用自回歸(Radford et al.,2018;2019)或掩蓋語(yǔ)言建模目標(biāo)(Devlin et al.,2019;Yang et al.,2019;Song et al.,2019;Liu et al.,2020)的預(yù)訓(xùn)練時(shí),變換器也能有效地將知識(shí)遷移到有限監(jiān)督或無(wú)監(jiān)督的任務(wù)上。

然而,其雖然性能好,但計(jì)算量和內(nèi)存占用大。這一問(wèn)題主要是自注意力的全局感受野造成的,其在處理N個(gè)上下文輸入時(shí),內(nèi)存和時(shí)間復(fù)雜度是二次的O(N^2 )。因此在實(shí)際應(yīng)用中,它的訓(xùn)練很慢,因而會(huì)限制它的上下文長(zhǎng)度,但是這樣又會(huì)破壞時(shí)序的一致性,并有損于長(zhǎng)期依賴關(guān)系的捕獲。Dai等人(2019)(Transformer-XL)通過(guò)關(guān)注先前上下文中的記憶來(lái)解決后一個(gè)問(wèn)題,但這樣計(jì)算效率低。

最近的研究有關(guān)注在不降低速度的情況下增加上下文長(zhǎng)度。Child(2019)(Sparse Transformer)提出注意力矩陣的稀疏因子分解,將自注意力的復(fù)雜度降低至O(N\sqrt{N}  )。Kitaev等人(2020)(Reformer)使用局部敏感的哈希方法進(jìn)一步將復(fù)雜度降低至O(N\log_{}N  ),這樣就可以處理長(zhǎng)序列。盡管這兩個(gè)方法可以在長(zhǎng)序列上快速訓(xùn)練,但是在自回歸任務(wù)的推斷上沒(méi)有帶來(lái)速度的提升。

本文提出線性變換器,可顯著減小內(nèi)存占用,而且其復(fù)雜度和上下文長(zhǎng)度是線性關(guān)系。我們是通過(guò)基于核的自注意力公式,并利用矩陣乘積的結(jié)合性來(lái)計(jì)算自注意力權(quán)重(第3.2節(jié))。我們還展示了,在使用線性變換器的基礎(chǔ)上,復(fù)雜度為線性、內(nèi)存占用為恒定值的因果掩膜(3.3節(jié))。這揭示了我們方法和RNN的關(guān)系,這使得我們可以更快地執(zhí)行巨量級(jí)的自回歸推理(3.4節(jié))。

我們?cè)趫D像生成和語(yǔ)音識(shí)別上的實(shí)驗(yàn)表明,線性變換器可以達(dá)到傳統(tǒng)變換器的性能,而推理速度可以提升3個(gè)數(shù)量級(jí)。

2. 相關(guān)工作

本節(jié)將概述與本文最相關(guān)的一些工作,這些先前工作旨在解決變換器的內(nèi)存占用大、計(jì)算量大的問(wèn)題。此外,我們還從理論上分析變換器的核心部分,即自注意力機(jī)制。最后,我們還提出另一個(gè)工作方向,即試圖緩解softmax在自注意力中帶來(lái)的計(jì)算量大的問(wèn)題。

2.1 快速的變換器

現(xiàn)有的一些工作通過(guò)剪枝(Michel et al.,2019)、權(quán)值因子分解(Lan et al.,2020)、模型量化(Zafrir et al.,2019)、知識(shí)蒸餾來(lái)減小變換器的內(nèi)存占用。Clark等人(2020)提出一種新的預(yù)訓(xùn)練目標(biāo),稱為 replaced token detection,能使用更少的樣本來(lái)訓(xùn)練,并減小了整體的計(jì)算量。Lample(2019)使用product-key注意力來(lái)提高任意層的容量,而帶來(lái)的額外的計(jì)算代價(jià)可忽略不計(jì)。

使用這些方法可以減小內(nèi)存占用、減小計(jì)算量、使訓(xùn)練得更快、推理得更快,但從根本上講這些方法仍然是二次的,因而難以應(yīng)用到長(zhǎng)序列上。相比之下,我們證明我們的方法在理論上(3.2節(jié))和實(shí)證上(4.1節(jié))都降低了變換器的內(nèi)存占用和時(shí)間復(fù)雜度。

另一個(gè)研究方向是旨在增加變換器中自注意力的上下文。上下文是指序列中用來(lái)計(jì)算自注意力的最遠(yuǎn)部分。Dai(2019)提出TransformerXL,通過(guò)學(xué)習(xí)固定長(zhǎng)度以外的上下文依賴關(guān)系而不破壞時(shí)序一致性,達(dá)到語(yǔ)言建模當(dāng)時(shí)的最佳性能。然而,在內(nèi)存中保留先前的上下文會(huì)帶來(lái)額外的計(jì)算代價(jià)。相比之下,Sukhbaatar(2019)通過(guò)學(xué)習(xí)每個(gè)注意力頭的最佳注意力跨距,同時(shí)對(duì)歷史內(nèi)存和計(jì)算時(shí)間進(jìn)行控制,顯著增加了上下文長(zhǎng)度。請(qǐng)注意,這兩種方法的漸進(jìn)時(shí)間復(fù)雜度和原初的變換器是相同的。相比之下,我們減小了漸進(jìn)時(shí)間復(fù)雜度,這使得我們可以使用更遠(yuǎn)程的上下文。

和我們的工作更相關(guān)的是Child(2019)和Kitaev(2020)的工作,前者(SparseTransformer)提出注意力矩陣的稀疏因子分解,在長(zhǎng)序列的生成任務(wù)上,將整體復(fù)雜度從二次降低到O(N\sqrt{N}  )。更晚近的,Kitaev(2020)提出Reformer,使用局部敏感散列(LSH),這樣可以計(jì)算更少的點(diǎn)積,從而將復(fù)雜度降低至O(N\log_{}N  )。請(qǐng)注意,為了能夠使用LSH,Reformer增加了約束,也就是keys和queries要是相同的,因此,該方法不適用于解碼任務(wù),在解碼任務(wù)中keys和queries是不同的。相比之下,我們的方法對(duì)keys和queries沒(méi)有任何約束,而且其時(shí)間復(fù)雜度和序列長(zhǎng)度是線性關(guān)系。此外,我們的方法在自回歸任務(wù)上的推理速度是提高3個(gè)數(shù)量級(jí),在驗(yàn)證困惑度上取得相當(dāng)?shù)男阅堋?/p>

2.2 理解自注意力

很少有人從理論分析上更好地理解自注意力機(jī)制。Tscai(2019)等人對(duì)變換器中注意力機(jī)制提出一個(gè)基于核的公式,該公式將注意力視為對(duì)輸入做平滑的核函數(shù),核得分為輸入之間的相似程度。該公式是有助于更好地理解注意力機(jī)制和位置編碼模塊。相比之下,我們使用基于核的公式來(lái)加速注意力的計(jì)算過(guò)程并降低它的時(shí)間復(fù)雜度。另外,我們還發(fā)現(xiàn),在keys和queries上使用相似度為正的核函數(shù),線性注意力通常會(huì)收斂。

最近,Cordonnier等人(2020)從理論上和實(shí)驗(yàn)上證明了,多頭注意力的頭夠多的話,可以表達(dá)任何卷積層。這里,我們證明了,用在自回歸任務(wù)上的自注意力層可以看做是一個(gè)遞歸網(wǎng)絡(luò),而且還可以顯著加快在自回歸任務(wù)上變換器的推理速度。

2.3 將softmax線性化

多年來(lái),softmax一直是類別很多的多分類任務(wù)的計(jì)算瓶頸(Goodman,2001;Morin&Bengio,2005;Mnih & Hinton,2009)。最近的研究(Blanc&Rendle,2017;Rawat et al.,2019)通過(guò)用特征圖的線性點(diǎn)積來(lái)近似softmax,以及采樣,來(lái)加快訓(xùn)練速度。受這些工作啟發(fā),我們將變換器中的注意力中的softmax線性化。在本文工作進(jìn)展的同時(shí),Shen(2020)探索了線性化注意力在圖像中目標(biāo)檢測(cè)任務(wù)中的應(yīng)用。相比之下,我們不僅將注意力計(jì)算線性化,而且建立了一個(gè)在訓(xùn)練階段和推理階段,復(fù)雜度為線性、內(nèi)存占用為恒定值的自回歸變換器模型。此外,我們還證明,通過(guò)核的映射(the lens of kernels),每個(gè)變換器可看做是個(gè)遞歸網(wǎng)絡(luò)。

3. 線性變換器

本節(jié)將我們提出的線性變換器公式化。我們展示了,通過(guò)將注意力從傳統(tǒng)的softmax注意力轉(zhuǎn)換為基于特征圖點(diǎn)積的注意力,可以減小時(shí)間復(fù)雜度和內(nèi)存占用,以及可以構(gòu)建在耗時(shí)與序列長(zhǎng)度呈線性關(guān)系的序列生成的因果模型,類似于遞歸網(wǎng)絡(luò)。

首先,在3.1節(jié)中,我們會(huì)介紹Vaswani等(2017)提出的原初版的變換器的公式。然后,在3.2節(jié)和3.3節(jié)我們提出線性變換器。最后,在3.4節(jié),我們將變換器改寫為遞歸網(wǎng)絡(luò)的形式。

3.1 原初版的變換器

x\in R^{N\times F}表示N個(gè)維度為F的特征向量。變換器是從R^{N\times F}\rightarrow R^{N\times F}映射的函數(shù),函數(shù)的計(jì)算為L個(gè)變換器層的組成:

T_{l} (x)=f_{l} (A_{l} (x)+x)? ??(1)

其中f_{l} (\cdot )單獨(dú)地變換每個(gè)特征,通常都是簡(jiǎn)單地由兩層前饋神經(jīng)網(wǎng)絡(luò)構(gòu)成。A_{l} (\cdot )是自注意力函數(shù),是變換器中唯一執(zhí)行序列間交叉作用的部分。

對(duì)于每個(gè)位置,自注意力函數(shù)計(jì)算所有其他位置的特征表示的加權(quán)平均值,權(quán)重與表示之間的相似性分?jǐn)?shù)成比例。公式上,輸入序列x被3個(gè)矩陣權(quán)重投影:W_{Q} \in R^{F\times D} 、W_{K} \in R^{F\times D} 、W_{V} \in R^{F\times D} ,分別對(duì)應(yīng)Q、K、V。所有位置的最終輸出,?A_{l} (x )=V^{’},計(jì)算為:

Q=xW_{Q} ,K=xW_{K} ,V=xW_{V} ,

A_{l} (x )=V^{’}=softmax(\frac{QK ^ T  }{\sqrt{D} } )V?(2)

注意,在上述公式中,softmax函數(shù)按行用于QK ^ T。通常地,Q、KV被稱為queries、keys和values。

公式(2)實(shí)現(xiàn)的是自注意力的一種特例,被稱為softmax注意力,其中相似度分?jǐn)?shù)是一個(gè)query和一個(gè)key的點(diǎn)積的指數(shù)。

令一個(gè)矩陣帶一個(gè)下標(biāo)i表示該矩陣的第i行的向量,我們可以寫一個(gè)廣義的注意力函數(shù),適用于任意的相似度計(jì)算方式:

V_{i}^{’}  = \frac{\sum\nolimits_{j=1}^N  sim(Q_{i},  K _{j})V_{j} }{\sum\nolimits_{j=1}^N  sim(Q_{i},  K _{j})} ?(3)

如果我們令相似度函數(shù)為sim(q,k)=exp(\frac{q^T k}{ \sqrt{D} } ),公式(3)和公式(2)就是等效的。

3.2 線性化的注意力

公式(2)中的注意力的定義是通用的,可以用來(lái)定義其它的注意力的實(shí)習(xí),例如多項(xiàng)式注意力或者RBF核注意力(Tsai,2019)。要使公式(3)來(lái)定義一個(gè)注意力函數(shù),唯一的約束是它的計(jì)算是非負(fù)的。這包括所有的核函數(shù)k(x,y):R^{2\times F}\rightarrow R_{+}

給定一個(gè)特征表示的核函數(shù)\phi (x),我們可以將公式(2)重寫為:

V_{i}^{’}  = \frac{\sum\nolimits_{j=1}^N  \phi( Q_{i}) ^T \phi( K_{j})  V_{j} }{\sum\nolimits_{j=1}^N   \phi( Q_{i}) ^T \phi( K_{j} )} ?(4)

然后利用矩陣相乘的結(jié)合性,將上述公式進(jìn)一步簡(jiǎn)化為:

V_{i}^{’}  = \frac{\phi( Q_{i}) ^T\sum\nolimits_{j=1}^N   \phi( K_{j})  V_{j} ^T }{\phi( Q_{i}) ^T\sum\nolimits_{j=1}^N   \phi( K_{j} )} ? ??(5)

我們將分子寫成向量的形式,這個(gè)公式會(huì)更容易理解:

(\phi (Q)\phi (K)^T)V= \phi (Q)(\phi (K)^T V)? ??(6)

注意,特征映射\phi (\cdot )是按行應(yīng)用于矩陣QK的。

從公式(2)可以看出,softmax注意力的計(jì)算量是與O(N^2 )成正比的,其中N是序列長(zhǎng)度。內(nèi)存占用也是一樣的,因?yàn)楸仨毚鎯?chǔ)完整的注意力矩陣來(lái)計(jì)算關(guān)于查詢、鍵和值的梯度。相比之下,我們從公式(5)中提出的線性變換器的時(shí)間復(fù)雜度和內(nèi)存占用為O(N),因?yàn)槲覀兛梢杂?jì)算\sum\nolimits_{j=1}^N   \phi( K_{j})  V_{j} ^T \sum\nolimits_{j=1}^N   \phi( K_{j} ),然后在每個(gè)query上重復(fù)使用已計(jì)算結(jié)果。

3.2.1 特征映射和計(jì)算代價(jià)

對(duì)于softmax注意力,乘加的總的計(jì)算損失為O(N^2 max(D,M) ),其中D是查詢和鍵的維度,M是值的維度。相反,對(duì)于我們提出的線性注意力,我們首先計(jì)算維度C的特征圖,然后計(jì)算新的值的乘加運(yùn)算的復(fù)雜度為O(NCM)。

前面的分析沒(méi)有考慮核和特征函數(shù)的選擇。注意,對(duì)應(yīng)于指數(shù)核的特征函數(shù)是無(wú)窮維的,這使得精確的softmax注意力的線性化是不可行的。另一方面,例如,多項(xiàng)式核有一個(gè)精確的有限維度的特征映射,并且被證明與指數(shù)核或RBF核(Tsai et al.,2019)同樣有效。一個(gè)線性化的度為2的多項(xiàng)式變換器的計(jì)算復(fù)雜度為O(ND^2 M)。這樣的話,當(dāng)N>D^2 時(shí),計(jì)算復(fù)雜度就非常的好。請(qǐng)注意,這個(gè)假設(shè)在實(shí)踐中是正確的,因?yàn)槲覀兿M軌蛱幚砩先f(wàn)個(gè)元素的序列。

在我們的實(shí)驗(yàn)中,當(dāng)處理較小的序列時(shí),我們使用一個(gè)特征映射,其使得相似度函數(shù)的計(jì)算結(jié)果總為正:

\phi (x)=elu(x)+1? ??(7)

其中?elu(\cdot )表示指數(shù)線性單應(yīng)激活函數(shù)(Clevert et al.,2015)。相比relu,我們更喜歡elu,這樣可避免當(dāng)x為負(fù)值時(shí)梯度為0。這種特征映射方式使得注意力函數(shù)的乘加運(yùn)算的復(fù)雜度為O(NDM)。在本文的實(shí)驗(yàn)部分,我們展示了公式(7)的性能與完整的變換器的性能相當(dāng),卻顯著降低了計(jì)算量和內(nèi)存占用。

3.3 因果屏蔽

通過(guò)屏蔽注意力計(jì)算,變換器結(jié)構(gòu)可用于高效地訓(xùn)練自回歸模型,也就是第i個(gè)位置的計(jì)算只能受它之前的位置的影響,也即位置j,j\leq i,而不能受它之后的位置的影響。因果屏蔽將公式(3)改寫為:

V_{i}^{’}  = \frac{\sum\nolimits_{j=1}^i  sim(Q_{i},  K _{j})V_{j} }{\sum\nolimits_{j=1}^i  sim(Q_{i},  K _{j})} ? ??(8)

按照第3.2節(jié)的推導(dǎo),我們將帶屏蔽機(jī)制的注意力線性化為:

V_{i}^{’}  = \frac{\phi( Q_{i}) ^T\sum\nolimits_{j=1}^i   \phi( K_{j})  V_{j} ^T }{\phi( Q_{i}) ^T\sum\nolimits_{j=1}^i   \phi( K_{j} )} ? ??(9)

S_{i} Z_{i} 為:

S_{i} =\sum\nolimits_{j=1}^i   \phi( K_{j})  V_{j} ^T ? ??(10)

Z_{i} =\sum\nolimits_{j=1}^i   \phi( K_{j})  ? ??(11)

我們可以將公式(9)進(jìn)一步簡(jiǎn)化為:

V_{i}^{’}  = \frac{\phi( Q_{i}) ^T S_{i}  }{\phi( Q_{i}) ^T Z_{i} } ? ??(12)

請(qǐng)注意,S_{i} Z_{i} 可以從S_{i-1} Z_{i-1} 計(jì)算得到,計(jì)算耗時(shí)是恒定值,這使得帶因果屏蔽機(jī)制的線性變換器的計(jì)算復(fù)雜度和序列長(zhǎng)度成線性關(guān)系。

3.3.1 梯度計(jì)算

任何深度學(xué)習(xí)框架下,公式(12)的樸素實(shí)現(xiàn)都需要存儲(chǔ)所有中間值S_{i} 以計(jì)算梯度。這導(dǎo)致內(nèi)存占用以max(D,M)成倍增長(zhǎng),從而阻礙因果模型在長(zhǎng)序列或深網(wǎng)絡(luò)上的使用。

為了解決這個(gè)問(wèn)題,我們將公式(9)中的分子項(xiàng)的梯度導(dǎo)出為累加和。這樣我們就可以在前向傳播和反向傳播中,以線性的時(shí)間復(fù)雜度恒定的內(nèi)存占用,來(lái)計(jì)算因果注意力。詳細(xì)的推導(dǎo)過(guò)程在補(bǔ)充材料中。

給定分子項(xiàng)\tilde{V_{i} } ,以及標(biāo)量損失函數(shù)對(duì)該分子項(xiàng)的梯度\nabla_{\tilde{V_{i} } } L,此處省略一段話以及三個(gè)公式,回頭再看,為因果模型而設(shè)計(jì)的,不一定用得上。

這樣,公式(9,13-15)中的累加和項(xiàng)的計(jì)算的時(shí)間復(fù)雜度與序列長(zhǎng)度呈線性關(guān)系,內(nèi)存占用為恒定值。這樣,給定一個(gè)維度為C的特征圖,算法的計(jì)算復(fù)雜度為O(NCM),內(nèi)存占用為O(Nmax(C,M))。算法1是分子項(xiàng)的前向傳播和反向傳播的偽代碼實(shí)現(xiàn)。

算法1

3.3.2 訓(xùn)練和推理

當(dāng)訓(xùn)練自回歸變換器模型時(shí),完整的序列真值是可得的。這使得,對(duì)于公式(1)中的f_{l} (\cdot ),以及注意力的計(jì)算,在層的層面上的并行化是可行的。因此,變換器比遞歸網(wǎng)絡(luò)訓(xùn)練得更快。另一方面,在推理階段,時(shí)間步i的輸出是時(shí)間步i+1的輸入,這使得自回歸模型無(wú)法并行化。此外,變換器的每個(gè)時(shí)間步的計(jì)算耗時(shí)不是不變的,而是和當(dāng)前序列長(zhǎng)度的平方成正比,因?yàn)楸仨氂?jì)算所有以前的時(shí)間步的注意力。

我們提出的線性變換器結(jié)合了兩者的優(yōu)點(diǎn)。訓(xùn)練時(shí),計(jì)算可以并行化,充分利用GPU或其它加速器的優(yōu)勢(shì)。推理時(shí),每個(gè)時(shí)間步的預(yù)測(cè)的計(jì)算耗時(shí)和內(nèi)存占用都是恒定的。這意味著我們可以簡(jiǎn)單地將矩陣\phi( K_{j})  V_{j} ^T 存儲(chǔ)為內(nèi)部狀態(tài),并像遞歸網(wǎng)絡(luò)一樣在每個(gè)時(shí)間步對(duì)其進(jìn)行更新。這樣我們的線性變換器要比其它的變換器模型要快千倍。

3.4 變換器就是遞歸網(wǎng)絡(luò)

已有文獻(xiàn)中,變換器和遞歸網(wǎng)絡(luò)被認(rèn)為是從根本上不同的。然而,從3.3節(jié)的因果屏蔽公式和上一節(jié)的討論中可以明顯看出,任何具有因果屏蔽的變換器層都可以表示成這樣一個(gè)模型:給一個(gè)輸入,修改內(nèi)部狀態(tài),預(yù)測(cè)一個(gè)輸出,這就是遞歸網(wǎng)絡(luò)RNN。請(qǐng)注意,和Universal Transformers(Dehghani,2018)不同,我們說(shuō)的遞歸是指計(jì)算時(shí)間,而不是網(wǎng)絡(luò)層的深度。

在下面的公式中,我們將公式(1)中的變換器層寫成遞歸網(wǎng)絡(luò)的形式。得到的RNN有兩個(gè)隱藏狀態(tài),稱為注意力記憶s和歸一化記憶z。我們用下標(biāo)表示遞歸的時(shí)間步。

s_{0} =0? ??(16)

z_{0} =0? ??(17)

s_{i} = s_{i-1} +\phi (x_{i} W_{K} ) (x_{i} W_{V} )^T ? ??(18)

z_{i} = z_{i-1} +\phi (x_{i} W_{K} ) ? ??(19)

y_{i} = f_{l} (\frac{\phi (x_{i}  W _{Q} )^T s_{i} }{\phi (x_{i}  W _{Q} )^T z_{i} }+x_{i}  )? ??(20)

在上面這些公式中,x_{i} 表示某個(gè)變換器層的第i個(gè)輸入,y_{i} 表示該層的第i個(gè)輸出。請(qǐng)注意,我們的公式對(duì)特征函數(shù)沒(méi)有任何的限定,它可以表示任何變換器模型,理論上甚至適用于softmax注意力。這些公式是為更好地理解變換器和常用的遞歸網(wǎng)絡(luò)(Hochreiter&Schmidhuber,1997)之間的關(guān)系,以及存儲(chǔ)和檢索信息的過(guò)程,邁出的第一步。

4. 實(shí)驗(yàn)

本節(jié)從實(shí)驗(yàn)分析我們提的線性變換器的性能。首先,在4.1節(jié),我們從計(jì)算量、內(nèi)存占用,以及在合成數(shù)據(jù)上的收斂性等方面評(píng)估線性注意力。在4.2節(jié)和4.3節(jié),分別在圖像生成和語(yǔ)音識(shí)別兩個(gè)實(shí)際應(yīng)用中評(píng)估我們的模型,以進(jìn)一步展示我們提出的線性變換器的性能。結(jié)果表明,我們的模型的性能與最先進(jìn)的變換器模型相比是由競(jìng)爭(zhēng)力的,而GPU內(nèi)存占用和計(jì)算量上是大大減少了。

我們實(shí)驗(yàn)中用于對(duì)比的基線有兩個(gè),一個(gè)是帶softmax注意力的完整變換器,另一個(gè)是Reformer(Kitaev et al., 2020),后者是最先進(jìn)的快速變換器。對(duì)于Reformer,我們使用已經(jīng)開(kāi)源的PyTorch復(fù)現(xiàn)代碼。對(duì)于完整的變換器,我們使用默認(rèn)的PyTorch實(shí)現(xiàn)。請(qǐng)注意,對(duì)于Reformer,我們不使用可逆層,但這不影響,因?yàn)槲覀冎粶y(cè)量自注意力層的內(nèi)存占用。在所有實(shí)驗(yàn)中,我們使用softmax表示標(biāo)準(zhǔn)的變換器(Vaswani et al., 2017),用linear表示我們提出的線性變換器,用lsh-X表示Reformer(Kitaev et al., 2020),其中X表示散列輪數(shù)。

在訓(xùn)練線性變換器中,我們使用公式(7)的特征映射方式。我們的PyTorch(Paszke et al., 2019)代碼、文檔、例子在https://linear-transformers.com/? 公式(13-15)的使梯度計(jì)算的內(nèi)存占用為恒定值的實(shí)現(xiàn)是大概200行的CUDA代碼。

4.1 合成任務(wù)

4.1.1 收斂性分析

為了檢驗(yàn)線性變換器的收斂特性,我們?cè)谝粋€(gè)帶因果屏蔽的人工復(fù)制任務(wù)上訓(xùn)練模型。也就是說(shuō),變換器必須復(fù)制一系列符號(hào),類似于Kitaev等人的序列復(fù)制任務(wù)(2020)。我們使用一個(gè)最長(zhǎng)為128的序列,包含10個(gè)不同的符號(hào),符號(hào)由專門的分隔符分開(kāi)。對(duì)于三種變換器,我們都構(gòu)建4層的網(wǎng)絡(luò),8個(gè)注意力頭,批量大小都是64,都是要RAdam優(yōu)化器(Liu et al., 2019),初始學(xué)習(xí)率都為10-3,更新3000次后學(xué)習(xí)率減小至10-4。圖2所示是損失相對(duì)于梯度步數(shù)的曲線。我們觀察到,線性變換器收斂得更平滑,并且比Reformer收斂到一個(gè)更小的損失值,這是因?yàn)榫€性變換器沒(méi)有引入散列導(dǎo)致的噪聲。線性變換器收斂到與完整變換器相同的損失值。

圖2 在序列復(fù)制任務(wù)上的收斂性分析。線性變換器收斂地穩(wěn)定,并最終性能與softmax變換器一樣。

4.1.2 內(nèi)存占用和計(jì)算量

本節(jié)比較變換器的計(jì)算量和內(nèi)存占用。我們計(jì)算在合成數(shù)據(jù)上不同的序列長(zhǎng)度N=\left\{ 2^9 , 2^{10} ... 2^{16} \right\} ,并衡量GPU占用的峰值,以及變換器的每個(gè)變換消耗的時(shí)間。我們將批量大小設(shè)置為與序列長(zhǎng)度成反比,并報(bào)告批量中每個(gè)樣本的時(shí)間消耗和內(nèi)存占用。

每種變換器模型都被評(píng)估到適合GPU內(nèi)存的最大序列長(zhǎng)度。我們使用的是NVidia GTX 1080 Ti,內(nèi)存為11GB。這使得softmax的最大序列長(zhǎng)度為4096個(gè)元素,lsh-4和lsh-8的最大序列長(zhǎng)度為16384個(gè)。正如預(yù)期的那樣,softmax的時(shí)間復(fù)雜度和序列長(zhǎng)度的二次方成正比。我們的方法比其它的方法都更快,且只需要更少的內(nèi)存,如圖1所示。我們注意到,線性注意力機(jī)制和Reformer的時(shí)間復(fù)雜度、內(nèi)存占用都與序列長(zhǎng)度成線性的正比關(guān)系。注意,盡管Reformer的漸進(jìn)復(fù)雜度為O(N\log_{}N  ),\log_{}N 是足夠的小所以不影響計(jì)算耗時(shí)。

圖1 前向-反向傳播的耗時(shí)與內(nèi)存占用。線性變換器與Reformer相對(duì)于序列長(zhǎng)度是線性的,softmax是二次的

4.2 圖像生成

變換器在一些條件的或者是非條件的自回歸生成任務(wù)上展現(xiàn)出很好的性能(Radford et al.,2019;Child et al.,2019)。但是由于任務(wù)本身是按順序的,而且內(nèi)存占用和序列長(zhǎng)度成正比,從變換器中采樣是很慢的。本節(jié),我們訓(xùn)練帶因果屏蔽的變換器,以逐像素地生成圖像。我們?nèi)〉玫男阅?bits per dimension)與softmax注意力不相上下,而速度快千倍以上,而且每張圖像從第1個(gè)像素到最后個(gè)像素的內(nèi)存占用是不變的。在我們的補(bǔ)充材料中,有訓(xùn)練過(guò)程、生成圖像的質(zhì)量、生成單張圖像的耗時(shí)等方面的比較。此外,我們還和faster softmax transformer進(jìn)行比較,其不同于PyTorch實(shí)現(xiàn),是在推理階段讀取keys和values的緩存。

4.2.1 MNIST

首先,我們?cè)趶V泛使用的MNIST數(shù)據(jù)集(LeCun et al., 2010)用自回歸變換器訓(xùn)練圖像生成模型。本實(shí)驗(yàn)的網(wǎng)絡(luò)結(jié)構(gòu)是8個(gè)注意力層,每層8個(gè)注意力頭。我們將embedding的長(zhǎng)度設(shè)置為256,也就是每個(gè)注意力頭是32維。我們的前饋的大小是embedding大小的4倍。我們使用Salimans等人(2017)提出的10種logistics的混合來(lái)建模我們的輸出。我們使用RAdam優(yōu)化器,學(xué)習(xí)率設(shè)置為10-4,所有模型訓(xùn)練250個(gè)epoch。對(duì)于Reformer的散列,我們使用1到4輪的散列。另外,按照Kitaev(2020,即Reformer作者)所提的,我們使用64個(gè)buckets以及大約32個(gè)元素為一個(gè)chunk。具體地是,我們將長(zhǎng)度為783的輸入序列劃分為27個(gè)chunks,每個(gè)chunk有29個(gè)元素。由于序列長(zhǎng)度實(shí)際上很小,即只有784個(gè)像素,為了避免因設(shè)置不同的批量大小而導(dǎo)致對(duì)比不公平,我們將3種變換器的批量大小都設(shè)為10。

表1是實(shí)驗(yàn)結(jié)果。我們觀察到,就最終的性能上(就困惑度而言),線性變換器與softmax變換器性能幾乎相同,而速度快300倍以上。這是因?yàn)槲覀兊哪P蛢?nèi)存占用低,它能夠用一個(gè)GPU同時(shí)生成10000個(gè)MNIST圖像。特別的,無(wú)論當(dāng)前序列長(zhǎng)度多少,內(nèi)存占用是恒定的,因?yàn)樵谙袼刂g唯一需要存儲(chǔ)的東西是公式(18)和公式(19)中的s_{i} z_{i} 。相比之下,softmax變換器和Reformer的內(nèi)存占用都會(huì)隨著當(dāng)前序列長(zhǎng)度的增長(zhǎng)而增加。

表1 在MINST圖像上的自回歸圖像生成任務(wù)。線性變換器的性能與softmax變換器幾乎相同,但是要快300倍多

圖3所示是我們的模型的圖像補(bǔ)全以及無(wú)條件圖像生成的結(jié)果的樣例。我們觀察到,我們的線性變換器生成的樣本是很逼真的,邊界清晰,沒(méi)有噪聲。在圖像補(bǔ)全的結(jié)果中,我們還觀察到,我們的線性變換器的結(jié)果有與原圖像相同的筆劃樣式和寬度,也就是能夠有效地捕獲遠(yuǎn)程關(guān)系。請(qǐng)注意,所有變換器產(chǎn)生的困惑度都是相當(dāng)?shù)?,我們沒(méi)有觀察到不同變換器生成結(jié)果的質(zhì)量上的差異。

圖3 無(wú)條件生成樣本和圖像補(bǔ)全結(jié)果。(a)是被遮擋的樣本,(b)是補(bǔ)全結(jié)果,(c)是原始圖像,我們的模型與softmax變換器性能相當(dāng),但快300倍多,每秒生成142張圖像。

4.2.2 CIFAR-10

當(dāng)序列越長(zhǎng)時(shí),我們的線性變換器就越有優(yōu)勢(shì)。為了展示這一點(diǎn),我們訓(xùn)練了16層的變換器來(lái)生成CIFAR-10圖像(Krizhevsky,2019)。每層的設(shè)置與前一個(gè)實(shí)驗(yàn)相同。對(duì)于Reformer,我們?cè)僖淮问褂?4 buckets and 83 chunks of 37 elements,這和論文中建議的32接近。由于序列長(zhǎng)度幾乎是前一個(gè)實(shí)驗(yàn)的4倍,在我們能使用的最大GPU上(即NVidia P40,24GB的內(nèi)存)上,批量大小設(shè)置為1。線性變換器和Reformer的批量大小都設(shè)置為4。所有模型都訓(xùn)練7天。表2中是每維比特?cái)?shù)和圖像生成吞吐量。請(qǐng)注意,盡管這個(gè)實(shí)驗(yàn)不是重點(diǎn)關(guān)注最終的困惑度,但是很明顯,序列越長(zhǎng),更快的變換器在每GPU小時(shí)上的效率的優(yōu)勢(shì)就越明顯,比更慢的變換器模型的得分也更高。

表2 在CIFAR-10上生成圖像,單塊GPU,用7天訓(xùn)練自回歸變換器模型。我們的線性變換器要比softmax訓(xùn)練3倍多的epochs,因而困惑度更好。我們的模型在生成圖像上比基線快4000多倍

由于對(duì)于Reformer和softmax注意力來(lái)說(shuō),生成單個(gè)像素的內(nèi)存占用和耗時(shí),與像素的個(gè)數(shù)成二次的正比關(guān)系,因此線性變換器的的吞吐量增加更為明顯。尤其是,當(dāng)softmax變換器生成1張圖像的時(shí)候,我們的線性變換器可以生成4460張圖像。我們觀察到,我們模型生成的圖像是具有空間一致性的,并且生成結(jié)果逼真,不會(huì)顯著阻礙圖像分類識(shí)別。例如,圖4b中,所有圖像都成功地完成了狗的鼻子(第一排)或卡車的擋風(fēng)玻璃(最后一行)。

圖4 在CIFAR-10上生成圖像的無(wú)條件生成樣本和圖像補(bǔ)全結(jié)果,(a)是被遮擋的原圖,(b)是補(bǔ)全結(jié)果,(c)是原圖。當(dāng)序列越長(zhǎng)時(shí),線性變換器在速度上相比softmax變換器的優(yōu)勢(shì)就越能體現(xiàn)出來(lái),我們的模型要快4000多倍,每秒生成1785張圖像

4.3 自動(dòng)語(yǔ)音識(shí)別(ASR)

為了證明我們的方法也可以用于非自回歸任務(wù),我們?cè)u(píng)估了線性變換器在端到端ASR任務(wù)上的性能,使用CTC損失(Connectionist Temporal Classification,2006)。我們以非自回歸方式預(yù)測(cè)每個(gè)輸入幀的音素分布。我們使用我們使用80小時(shí)的《華爾街日?qǐng)?bào)》數(shù)據(jù)集(Paul&Baker,1992)和40維的不帶temporal differences的mel-scale filterbanks作為特征。這個(gè)數(shù)據(jù)集的平均序列長(zhǎng)度是800幀,最大長(zhǎng)度是2400幀。本實(shí)驗(yàn)我們還對(duì)比了雙向LSTM(Hochreiter&Schmidhuber,1997),其有3個(gè)隱藏層,隱藏層size是320。我們使用Adam優(yōu)化器,學(xué)習(xí)率設(shè)置為10-3,當(dāng)驗(yàn)證集錯(cuò)誤率不下降時(shí)學(xué)習(xí)率就減小。變換器模型有9層,每層6個(gè)注意力頭,embedding的維度和圖像的實(shí)驗(yàn)相同。使用RAdam優(yōu)化器,初始學(xué)習(xí)率是10-4,當(dāng)驗(yàn)證集錯(cuò)誤率不下降時(shí)學(xué)習(xí)率減半。

所有模型都是在每個(gè)訓(xùn)練epoch上評(píng)估音素錯(cuò)誤率(PER)。我們觀察到,線性變換器在性能和速度方面都遠(yuǎn)遠(yuǎn)優(yōu)于遞歸網(wǎng)絡(luò)基線和Reformer,如表3所示。請(qǐng)注意,與所有基線相比,softmax變換器實(shí)現(xiàn)了更低的錯(cuò)誤率,但速度很慢。尤其是,線性變換器在每個(gè)epoch上要比它快3倍。我們?cè)谘a(bǔ)充材料中提供了訓(xùn)練過(guò)程的曲線圖。

表3 在WSJ數(shù)據(jù)集上的語(yǔ)音識(shí)別。評(píng)估指標(biāo)是音素錯(cuò)誤率(PER)和每個(gè)epoch的訓(xùn)練耗時(shí)。我們的模型比LSTM和Reformer性能更好,且訓(xùn)練得更快,推理得更快。

5. 結(jié)論

我們提出線性變換器,大大減小原初變換器的內(nèi)存占用和計(jì)算量。具體地是,利用矩陣相乘的結(jié)合性,我們使自注意力計(jì)算的時(shí)間復(fù)雜度、內(nèi)存占用與序列長(zhǎng)度成線性關(guān)系。我們證明我們的模型可以用于因果屏蔽任務(wù),并仍然保持線性漸近的復(fù)雜度。最后,我們將變換器模型表示為遞歸網(wǎng)絡(luò)的形式,這樣我們就可以對(duì)自回歸任務(wù)更快地推理。

這一特性為將來(lái)研究RNN和Transformer中的信息存儲(chǔ)和檢索開(kāi)辟了許多方向。另一個(gè)有待探索的研究方向是選取線性注意力的特征映射方式。例如,用隨機(jī)Fourier特征逼近RBF核可以使我們使用帶有softmax注意力的預(yù)訓(xùn)練模型。

補(bǔ)充材料:

A. 梯度推導(dǎo)

B. 訓(xùn)練過(guò)程曲線

在圖5中,我們展示了實(shí)驗(yàn)中所有變換器模型的訓(xùn)練過(guò)程曲線。對(duì)于MNIST實(shí)驗(yàn)(圖5a),所有方法都是訓(xùn)練250個(gè)epoch。因?yàn)樾蛄虚L(zhǎng)度是很小的,所以不同方法的訓(xùn)練耗時(shí)不會(huì)有明顯變化。我們還觀察到,我們的方法和softmax注意力是收斂到相同的水平,顯然優(yōu)于兩個(gè)Reformer模型。

圖5 線性變換器總是收斂地比Reformer更快,在自回歸任務(wù)上與softmax相當(dāng)。在MNIST上都是訓(xùn)練250epoch,在CIFAR-10上都是訓(xùn)練7天,在語(yǔ)音識(shí)別上都是訓(xùn)練到收斂為止。實(shí)驗(yàn)細(xì)節(jié)見(jiàn)4.2.1節(jié)、4.2.2節(jié)、4.3節(jié)

在CIFAR-10上,見(jiàn)圖5b,我們?cè)O(shè)置所有模型的訓(xùn)練耗時(shí)都為7天。我們觀察到,線性變換器和lsh-1比另外兩個(gè)模型能訓(xùn)練更多的epoch,并且取得更好的性能。如果序列更長(zhǎng)的話,這一差距有望繼續(xù)增大。

最后,在語(yǔ)音識(shí)別任務(wù)上,見(jiàn)圖5c,softmax收斂時(shí)的性能明顯優(yōu)于Reformer和線性變換器。注意,線性變換器訓(xùn)練每個(gè)epoch的速度要快3倍,這意味著它比softmax訓(xùn)練了大約4倍多的epoch。盡管在這項(xiàng)任務(wù)中,softmax變換器是表現(xiàn)最好的,但我們觀察到,和Reformer相比,線性變換器在收斂性和最終性能上都更好。

C. 圖像生成吞吐量討論

C.1. Stateful softmax 注意力

4.2節(jié),我們報(bào)告了圖像生成的吞吐量,并與softmax transformer和lsh進(jìn)行了比較。在本節(jié)中,我們創(chuàng)建另一個(gè)基線,表示為stateful softmax,它是在自回歸任務(wù)上將softmax變換器改成遞歸網(wǎng)絡(luò)的計(jì)算方式,也就是說(shuō),所有的keys和values都會(huì)被保存下來(lái),然后當(dāng)預(yù)測(cè)序列的下一個(gè)元素時(shí),將它們?cè)賯鬟f給模型。這個(gè)遞歸模型的狀態(tài)(state)是一組鍵和值,其規(guī)模與序列長(zhǎng)度成正比。這與我們提出的模型有質(zhì)的不同,我們提出的模型的狀態(tài)的維度是固定不變的,并且是給定第i-1個(gè)狀態(tài)來(lái)計(jì)算第i個(gè)狀態(tài),無(wú)論i是多少,計(jì)算量都是不變的。

表4是結(jié)果。我們注意到,stateful-softmax比原初的變換器要快得多,然而它的復(fù)雜度仍然是序列長(zhǎng)度的二次方,在CIFAR-10上,我們提的方法要比stateful-softmax快50倍以上。另外我們要說(shuō)明的是,構(gòu)建stateful-Reformer并不容易,因?yàn)槊看翁峁┬碌妮斎霑r(shí)都需要執(zhí)行排序(sorting)和分塊( chunking )操作。

表4

C.2. 將批量大小都設(shè)為1

在前面的部分中,我們?cè)u(píng)估了用于自回歸圖像生成任務(wù)的所有變換器變量的吞吐量。然而,要考慮的另一個(gè)重要因素是延遲,即生成單個(gè)圖像所需的總時(shí)間。為此,我們使用批處理大小為1并測(cè)量所有方法生成單個(gè)圖像所需的時(shí)間。除了在GPU上運(yùn)行推理,我們還評(píng)估了CPU上所需的時(shí)間。結(jié)果見(jiàn)表5。

表5

我們觀察到,所有的方法都沒(méi)有充分利用GPU,并且比表4的圖像生成吞吐量小得多。我們所提的線性變換器是最快的,尤其是,它在CIFAR-10上比softmax變換器要幾乎快6.6倍(筆者注,CIFAR-10的圖像尺寸是32x32)。請(qǐng)注意,我們的線性自回歸變換器是唯一一個(gè),在各種設(shè)置下,在CPU上比在GPU上更快的。這是因?yàn)閷⒆⒁饬C(jī)制以RNN的形式的計(jì)算成本很低,以至于主要的計(jì)算瓶頸是序列上不可避免的外循環(huán)。

D 圖像生成的定性分析

在本節(jié)中,我們將為我們的圖像生成實(shí)驗(yàn)提供定性的結(jié)果。由于所有模型的困惑度大致相同,如預(yù)期的那樣,質(zhì)量的差異并不顯著。然而,一個(gè)相當(dāng)有趣的觀察結(jié)果是,Reformer模型在其無(wú)條件生成樣本中提供的變化要少得多。此外,我們觀察到,圖像完成是一項(xiàng)比無(wú)條件生成更容易的任務(wù),因?yàn)樗心P投急憩F(xiàn)得更好。

圖6 變換器生成的無(wú)條件樣本。見(jiàn)4.2.1節(jié)
圖7 MNIST的digit補(bǔ)全。見(jiàn)4.2.1節(jié)
圖8 CIFAR-10上生成的無(wú)條件樣本,見(jiàn)4.2.2節(jié)
圖9 CIFAR-10上的圖像補(bǔ)全結(jié)果,見(jiàn)4.2.2節(jié)

源碼里這個(gè)地方會(huì)比較暈:linear_attention.py

答案在這里:????https://fast-transformers.github.io/attention/

別人讀這篇論文
https://zhuanlan.zhihu.com/p/157490738

https://mp.weixin.qq.com/s?__biz=MzIwMTc4ODE0Mw==&mid=2247506101&idx=1&sn=ba682426ace59910a6837a4cef2bf9cf&chksm=96ea0735a19d8e23615c7d58a73870d48eb149cb4cb474377ceea9d745efb6dd5aee9c6f126c#rd

https://kexue.fm/archives/7325


Pytorch torch.norm, torch.cosine_similarity 對(duì)向量或者張量計(jì)算Cosine相似度, 歐式距離

https://blog.csdn.net/dongfangxiaozi_/article/details/93882664

能否直接用torch.cosine_similarity替換你設(shè)計(jì)的那一套?不能。最終的目的是設(shè)計(jì)phi(x),使phi(x)phi(y)來(lái)近似f(x,y)。使用torch.cosine_similarity相當(dāng)于還是沒(méi)把Q和K拆開(kāi)。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容