淺析DeepSeek多頭潛在注意力機制(MLA)

淺析DeepSeek多頭潛在注意力機制(MLA)

背景:DeepSeek在無損模型效果的同時大幅降低了大模型的訓(xùn)練以及推理成本,引起業(yè)界廣范關(guān)注。所涉及的優(yōu)化包括不限于:使用低精度計算,知識蒸餾和稀疏計算,軟硬件協(xié)同優(yōu)化,模型架構(gòu)優(yōu)化【Mutil-head Latent Attention 以及DeepSeekMoE】,分布式訓(xùn)練和優(yōu)化通信策略,自我學(xué)習(xí)和高效利用GPU資源等。本文將重點圍繞從DeepSeek-V2開始的注意力優(yōu)化,潛在注意力機制Mutil-head Latent Attention進行淺析。MLA是通過顯著減少大模型推理過程中的kv-緩存來實現(xiàn)推理加速的。所以本文將從(1)大模型推理過程中為什么需要KV緩存(2)常用的幾種優(yōu)化推理過程中KV緩存的方法 (3)MLA是什么以及為什么能減少KV緩存的數(shù)學(xué)原理。 三部分展開。

為什么需要KV緩存:

大模型的常用架構(gòu):

常見大模型的基本架構(gòu)是基于transformer的,僅僅使用transformer的decoder部分,而且將transformer的decoder中的Mutil-HEAD-Self-Attention去掉,僅保留MASK-Mutil-Head-Self-Attention。如下圖所示(左傳統(tǒng)tranformer-右常用大模型架構(gòu))。


所以整個訓(xùn)練以及推理過程中,Mask-Mutil-Head-SelfAttention作為Token學(xué)習(xí)上下文依賴的部分是比較重要的。

大模型的整體推理過程:

? 當(dāng)我們輸入一段問題(提示詞)給大模型后,大模型是按下述方式完成回答。

? 1. Prefill:提示詞的embedding prefill&&推理出第一個詞。

? ? ? (1) 提示詞lookup table 得到每一個詞的embedding && 計算Position Embedding。

? ? ? (2)計算提示詞的每一個詞的Q,K,V 矩陣(并將KV緩存起來)。 并得到self-attention矩陣 softmax(Q*transpose(K)),如果提示詞的長度是N,attention矩陣是N*N的。多個頭的話 就是HEAD_NUM *N*N. 【實際上大模型會pdding成輸入的max_len比如2048緯度確保長度統(tǒng)一,】HEAD_NUM*MAX_LEN*MAX_LEN.

(3)然后根據(jù)Attention矩陣的最后一行以及每一個詞的V得到 最后一個詞的MHA階段的輸出,將最后一行該輸出放入FFN,就會得到1個輸出,這個詞就是回答階段的第一個詞。


2. Decoding.第一個詞

(1)這個詞looking up table? && 計算 PE。

(2)得到該詞的Q,K,V矩陣。計算該詞的attention矩陣需要所有輸入詞的K,V矩陣(并加入緩存)。來得到一個1*(N+1)的attention權(quán)重。根據(jù)這個權(quán)重和所有的V【之前提示詞的以及這個詞的】得到這個詞的MHA的輸出。如果我們不緩存所有的KV,那么相當(dāng)推理每一個詞的時候都將之前的KV按照 1里面的(1)(2)步驟重復(fù)計算。

(3)將MHA的輸出放入FFN,得到1個輸出這個詞就是大模型回答的第二個答案。

3.Decoding 直到結(jié)束。

將第二個token 按照步驟2的操作 執(zhí)行一遍會得到第三個詞。直到出現(xiàn)end_token或者到達MAX_LEN終止。


KVCache-必要性演示:

下面以一個簡化版本的LLM推理過程來演示KV-cache必要性。第一個詞,有沒有kv-cache計算量是一樣的。對于當(dāng)前query token詞來說attention矩陣就是1*1.

第二個詞,如果沒有kv cache 需要重新計算key Token1以及Value Token1?!疽簿褪怯衚v cache的紫色部分】


第三個詞,如果沒有kv cache,需要重新計算key Token1以及2, Value Token1&&2.


重復(fù)計算的次數(shù)隨著Token位置的延后 越累積越多。所以KVcache 是必要。實際上將MHA 部分的復(fù)雜度從n*2降低到n。

訓(xùn)練過程中為什么不需要KV-Cache:

(1)訓(xùn)練過程中是batch 訓(xùn)練,不像推理那樣一個一個token都需要依賴上文信息,每一個batch 實際上是一個完整句子,拆開的多個樣本,也就是這個batch的最后一個樣本是見到了這個句子全部token的。所有token的注意力矩陣一下就能全部算出來。舉一個例子。上圖的圖三,maxlen就是3個詞,那么一次計算就得到了這個batch的3*3的mask-attention矩陣。KVcache沒有意義。

常用的幾種優(yōu)化推理過程中KV緩存的方法 :

通過上面的分析可以看出KV-cache就是一個空間換時間的過程,但是GPU相對CPU的單核緩存較小問題是不允許cache過大的。所以有很多研究圍繞著縮小kvcache 展開。


普通MHA kv_cache 存儲數(shù):

按上文所述。來看一下普通的MHA需要存儲多少kv_cache.

按GPT中的一些數(shù)字大小。

句子的最大長度 L??2048

一個詞embeding的size: DIM? ?12288

注意力的頭數(shù):Head_Num? ?96

每個頭的注意力維數(shù)? ?Head_DIM? ? ?128. [128*96=12288]

解碼器層數(shù)? 96

所以一層解碼器需要的存儲kv-cache是 2*Head_Num*Head_DIM*L.

因為需要存儲k&&V,所以是2,不同頭的參數(shù)不同,所以需要HEAD_NUM, HEAD_DIM是每個K,V的向量維度。L是句子的最大長度。所以解碼階段最大的存儲就是2*Head_Num*Head_DIM*L。

所以一層解碼器需要的存儲kv-cache是 2*Head_Num*Head_DIM*L.

因為需要存儲k&&V,所以是2,不同頭的參數(shù)不同,所以需要HEAD_NUM, HEAD_DIM是每個K,V的向量維度。L是句子的最大長度。所以解碼階段最大的存儲就是2*Head_Num*Head_DIM*L。

常用的幾種優(yōu)化KV的方法:

什么是MLA:

MLA旨在進一步縮小KV緩存的大小,同時在性能上超越之前提到的注意力機制(包括MHA)。它通過將KV緩存壓縮到低維潛在空間,成功將緩存大小減小了90%+,

MLA不會像傳統(tǒng)方式那樣在每個頭計算和存儲每個令牌的鍵和值,而是使用下投影矩陣DownLinear把它們壓縮成潛在向量C。想達到一個C解決所有頭的KV_Cache問題,在推理時候,再通過一個UPLinear 升維變相達到kv_cache的目的。這就是MLA的核心思想 - 在保持模型能力的同時,通過降維來減少內(nèi)存占用


MLA的數(shù)學(xué)原理:

數(shù)學(xué)原理上:首先創(chuàng)造一個維度為dc的latent向量,這個向量維度遠小于Head_Num*Head_DIM。對于第S個head,我們建立參數(shù)矩陣,注意這里的參數(shù)矩陣,和傳統(tǒng)的MHA比,多了一個Wc ,(傳統(tǒng)的只有Wq,Wk,Wv)同時Wq的維度與傳統(tǒng)的MHA一樣是d*dk 緯度d也就是前面我們常說的DIM(一個詞的embedding size),dk 就是我們前面說的每一個頭的維度Head_DIM. Wk,Wv 緯度已經(jīng)變了,從d*dk 變成了dc *dk . 現(xiàn)在的Wc*Wv 才是老的Wv。Wc*Wk 才是老的Wk


首先,將輸入的詞xi投影到一個低維空間得到ci。xi的維度是1*d。


然后利用ci和其他參數(shù)矩陣得到輸入詞xi在第s個頭的q,k,v矩陣:

這里的ki = xi *老的Wk= xi *Wc*Wk = Ci*Wk ,同樣vi也是一樣的推導(dǎo)過程。


對計算句子中位置t的詞qt的attention矩陣。 只需要計算qt*transpose(ki), i<=t即可,這里利用乘法交換律


可以看出對qt 這個詞來說,在第S個頭上,他的attention矩陣的對i這個詞的值只需要自己的輸入xt*(Wq*transpose(Wk))*tranpose(ci),那么這里我們只需要存儲若干個ci就行了,他就是每一個詞對應(yīng)的潛在向量。 對不同的頭,ci都是一樣的。而我們知道老的MHA中,每次qt的得到也是需要xt*Wq,這里只不過是變成了xt*Wq*transpose(Wk),可以說運算量幾乎沒增加。但是所需要存儲的cache數(shù)大大減少。僅僅需要。

L*dc,直接和MQA一個量級,但是效果確和MHA一樣。DeepSeek-v2中 dc是512維度,其存儲和group=2的MGA是一樣的。

MLA與ROPE的融合:

什么是ROPE:rope是這樣一個矩陣,與q,k相乘后,實現(xiàn)位置編碼的功效。m就是位置【0,1,2,3。。。】詞是第幾個位置就是幾,theta_i=10000^{-2i/ d}, d就是128, i就是0 到d/2-1.


但當(dāng)加入RoPE后,這個合并就無法實現(xiàn)了。因為RoPE是一個與位置相關(guān)的 dk×dk 分塊對角矩陣 Rm,它滿足 Rm*tranpose(Rn)=Rm?n。加入RoPE后的注意力計算變?yōu)?


多了Rt-i這個是與位置i相關(guān)的使得MLA的kv存儲失效了。

DeepSeek采用了一個混合方案 - 在每個Attention Head的Q、K中新增 dr 【實戰(zhàn)中dr=dk/2=64】個維度用于RoPE,其中K的新增維度在所有Head間共享:(同時v3開始 q也被壓縮到低維度



帶ROPE的QKV如下圖:


對第s個頭的token_i計算attention 矩陣的第i個位置 qi與ki的轉(zhuǎn)置 相乘。因為q,k這里都是1行多列的向量【Qpart1,Qpart2】*tranpose(【Kpart1,Kpart2】)= Qpart1*tranpose(Kpart1) + Qpart2*tranpose(Kpart2)

實際上

這部分在上面已經(jīng)推導(dǎo)過,只緩存ci即可。而Qpart2*tranpose(kPart2)WkrRI 緩存起來,這樣R部分就沒有任何和位置計算相關(guān)的地方在計算attention的過程中。

整個的MLA的流程如下圖所示,對于單一token,只需要存儲latent向量C,維度512.以及WkrRi 也就是圖中的KtR 緯度64,(【512+64】/128 = 4.5)。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容