注意力機(jī)制解決的問題:傳統(tǒng)序列處理模型如RNN和LSTM,捕捉長距離依賴關(guān)系的難題。注意力機(jī)制允許模型在序列的不同位置之間建立直接聯(lián)系,有效捕捉遠(yuǎn)距離依賴關(guān)系。
為了減少推理過程中KV Cache占用的顯存,GQA和MQA通過head之間共享KV實(shí)現(xiàn),這是一種犧牲性能對存儲(chǔ)空間妥協(xié)的方案,而MLA通過對KV對做低秩聯(lián)合壓縮來減少推理中的KV緩存,目標(biāo)是減少kv cache存儲(chǔ)量的同時(shí),保存模型的效果。
具體做法是,對于每個(gè)token,先通過一個(gè)低秩矩陣將KV聯(lián)合壓縮到一個(gè)低維向量中,然后通過兩個(gè)升維矩陣
,
解壓縮回高維,后續(xù)進(jìn)行普通的多頭注意力計(jì)算,這樣每次只需要存這個(gè)低維向量。
這樣做有個(gè)問題,就是壓縮和解壓操作使計(jì)算量增加了,而實(shí)際計(jì)算中,通過“矩陣吸收”操作,也就是矩陣運(yùn)算過程中的結(jié)合律使多個(gè)矩陣合并,從而減少計(jì)算量。
具體計(jì)算過程如下(對做相同壓縮操作,于是也有了
和
):
如上所示,計(jì)算過程中,由于矩陣乘法結(jié)合律,,
合并成一個(gè)矩陣
,同理,
,
合并成
。
對比普通MHA計(jì)算公式:
可知,兩種注意力機(jī)制計(jì)算量相同,沒有引入額外計(jì)算量,而緩存從兩個(gè)高維,
變成了一個(gè)低維
接下來是位置編碼RoPE的處理,MHA中,RoPE可以通過對,
向量乘以一個(gè)位置相關(guān)的變換矩陣
(
為當(dāng)前token所處的位置)。然而,在MLA中,如果做相同的處理將會(huì)如下所示:
由于不是一個(gè)固定的矩陣,無法實(shí)現(xiàn)矩陣吸收來減少計(jì)算量。對于這個(gè)問題,deepseek的做法是將參與注意力計(jì)算的
,
分成兩部分,一部分進(jìn)行矩陣吸收操作,不帶位置信息,一部分進(jìn)行位置信息計(jì)算。
對于,基于潛在向量
通過矩陣
變換為低維向量后進(jìn)行RoPE變換得到
;對于
,直接將輸入
也通過一個(gè)矩陣
變換后做RoPE變換得到
,其中
按照MQA的處理方式,各個(gè)head之間共享,既減少了顯存調(diào)用又保證了位置編碼的全局一致。然后將
,
拼接到前面計(jì)算得到的
,
向量后面,得到最終用于計(jì)算注意力的
,
。
這樣計(jì)算點(diǎn)積時(shí)如下,其中,
表示token,
表示head:
這樣不包含位置編碼的部分就可以進(jìn)行矩陣吸收的處理,每個(gè)head緩存一個(gè)
;后一項(xiàng)按MQA的方式計(jì)算,所有head只需緩存一個(gè)共享的
。
學(xué)習(xí)參考資料:zhihu-冷面爸