大模型算法面試筆記——多頭潛在注意力(MLA)

注意力機(jī)制解決的問題:傳統(tǒng)序列處理模型如RNN和LSTM,捕捉長距離依賴關(guān)系的難題。注意力機(jī)制允許模型在序列的不同位置之間建立直接聯(lián)系,有效捕捉遠(yuǎn)距離依賴關(guān)系。

為了減少推理過程中KV Cache占用的顯存,GQA和MQA通過head之間共享KV實(shí)現(xiàn),這是一種犧牲性能對存儲(chǔ)空間妥協(xié)的方案,而MLA通過對KV對做低秩聯(lián)合壓縮來減少推理中的KV緩存,目標(biāo)是減少kv cache存儲(chǔ)量的同時(shí),保存模型的效果。

具體做法是,對于每個(gè)token,先通過一個(gè)低秩矩陣將KV聯(lián)合壓縮到一個(gè)低維向量c^{KV}中,然后通過兩個(gè)升維矩陣W^{UK},W^{UV}解壓縮回高維,后續(xù)進(jìn)行普通的多頭注意力計(jì)算,這樣每次只需要存這個(gè)低維向量。

這樣做有個(gè)問題,就是壓縮和解壓操作使計(jì)算量增加了,而實(shí)際計(jì)算中,通過“矩陣吸收”操作,也就是矩陣運(yùn)算過程中的結(jié)合律使多個(gè)矩陣合并,從而減少計(jì)算量。
具體計(jì)算過程如下(對q做相同壓縮操作,于是也有了c^QW^{UQ}):

\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrtu0z1t8os})vW^O
=\operatorname{softmax}(\frac{c^QW^{UQ}(c^{KV}W^{UK})^T}{\sqrtu0z1t8os})c^{KV}W^{UV}W^O
=\operatorname{softmax}(\frac{c^Q(W^{UQ}(W^{UK})^T)(c^{KV})^T}{\sqrtu0z1t8os})c^{KV}(W^{UV}W^O)
=\operatorname{softmax}(\frac{c^QW^{UQUK}(c^{KV})^T}{\sqrtu0z1t8os})c^{KV}W^{UVO}

如上所示,計(jì)算過程中,由于矩陣乘法結(jié)合律,W^{UQ}W^{UK}合并成一個(gè)矩陣W^{UQUK},同理,W^{UV},W^O合并成W^{UVO}。

對比普通MHA計(jì)算公式:

\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrtu0z1t8os})vW^O
=\operatorname{softmax}(\frac{h_tW^QK^T}{\sqrtu0z1t8os})vW^O

可知,兩種注意力機(jī)制計(jì)算量相同,沒有引入額外計(jì)算量,而緩存從兩個(gè)高維K,V變成了一個(gè)低維c^{KV}

接下來是位置編碼RoPE的處理,MHA中,RoPE可以通過對q,k向量乘以一個(gè)位置相關(guān)的變換矩陣R_i(i為當(dāng)前token所處的位置)。然而,在MLA中,如果做相同的處理將會(huì)如下所示:

q_iR_i(k_jR_j)^T=c^QW^{UQ}R_i(c^{KV}_jW^{UK}R_j)^T
=c^QW^{UQ}R_iR^T_j(W^{UK})^T(c_j^{KV})^T

由于R_i不是一個(gè)固定的矩陣,無法實(shí)現(xiàn)矩陣吸收來減少計(jì)算量。對于這個(gè)問題,deepseek的做法是將參與注意力計(jì)算的qv分成兩部分,一部分進(jìn)行矩陣吸收操作,不帶位置信息,一部分進(jìn)行位置信息計(jì)算。

對于q,基于潛在向量c^Q通過矩陣W^{QR}變換為低維向量后進(jìn)行RoPE變換得到q^R;對于k,直接將輸入h_t也通過一個(gè)矩陣W^{KR}變換后做RoPE變換得到k^R,其中k^R按照MQA的處理方式,各個(gè)head之間共享,既減少了顯存調(diào)用又保證了位置編碼的全局一致。然后將q_R,k^R拼接到前面計(jì)算得到的q,k向量后面,得到最終用于計(jì)算注意力的q=[q^C;q^R],k=[k^C;k^R]
這樣計(jì)算點(diǎn)積時(shí)如下,其中t,j表示token,i表示head:

q_{t,i}k_{j,i}^T=[q_{t,i}^C;q_{t,i}^R]\times[k_{j,i}^C;k_t^R]=q^C_{t,i}k_{j,i}^C+q_{t,i}^Rk_t^R

這樣不包含位置編碼的部分q^C_{t,i}k_{j,i}^C就可以進(jìn)行矩陣吸收的處理,每個(gè)head緩存一個(gè)c_{t,i}^{KV};后一項(xiàng)按MQA的方式計(jì)算,所有head只需緩存一個(gè)共享的k^R_t。

學(xué)習(xí)參考資料:zhihu-冷面爸

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容