LLM面面觀之MoE

1. 背景

根據(jù)本qiang~最新的趨勢(shì)觀察,基于MoE架構(gòu)的開源大模型越來越多,比如馬斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等,因此想探究一下MoE里面的部分細(xì)節(jié)。

此文是本qiang~針對(duì)大語言模型的MoE的整理,包括原理、流程及部分源碼。

2. MoE原理

MoE的流行源于”歐洲的OpenAI”

Mistral AI發(fā)布的論文及模型《Mixtral of Experts》,評(píng)測(cè)集上的效果吊打眾多開源模型,如Llama 2 70B和GPT3.5。

《Mixtral of Experts》基礎(chǔ)模型使用的是Mistral

AI自研的Mistral 7B,該模型的特點(diǎn)包括:滑窗注意力(Sliding Window Aattention), 滾動(dòng)緩沖區(qū)緩存(Rolling

Buffer Cache)以及預(yù)填充-分塊(Pre-fill and Chunking),具體細(xì)節(jié)可以查閱文末的論文地址。

本文以《Mixtral of

Experts》為引子,探究MoE的相關(guān)細(xì)節(jié),MoE的原理如下圖所示:


圖2.1 MoE的原理

(1) Transformers架構(gòu)中的每一層中的FFN網(wǎng)絡(luò)均替換為了8個(gè)FFN(專家),且由一個(gè)網(wǎng)關(guān)路由(gate

router)進(jìn)行控制

(2) 針對(duì)每一個(gè)token,每一層的網(wǎng)關(guān)路由僅選擇其中的2個(gè)FFN(專家)來處理當(dāng)前狀態(tài)并進(jìn)行加權(quán)輸出

(3) 結(jié)果就是,每一個(gè)token訪問了47B參數(shù),但是在推理階段僅僅使用了13B的激活參數(shù)(即,只使用2個(gè)專家,凍結(jié)其他6個(gè)專家)。

(4) 與Dropout機(jī)制對(duì)比,Dropout讓部分神經(jīng)元失活,而MoE是讓部分專家失活。

3. 源碼

本qiang~研讀并嘗試執(zhí)行了Mistral官網(wǎng)的github推理代碼,該代碼框架非常適合新手,無他,只因其幾乎只是在torch上層做的封裝,很少引擎其他第三方庫,不像transformers,功能強(qiáng)大,但不適合新手研讀代碼…

為了普適性,下面的代碼截取了transformers框架中的代碼。

首先看下通用Transformers中FFN中的代碼模塊,代碼位置在transformers.models.mistral.modeling_mistral,主要流程是:

(1) 先經(jīng)過gate_proj和up_proj的2個(gè)[hidden_size,

intermediate_size]的線性轉(zhuǎn)換

(2) 使用激活函數(shù)對(duì)gate_proj進(jìn)行激活

(3) 二者的內(nèi)積再經(jīng)過down_proj線性轉(zhuǎn)換。


class MistralMLP(nn.Module):

??? def __init__(self,? config):

??????? super().__init__()

??????? self.config = config

??????? self.hidden_size =? config.hidden_size


? self.intermediate_size = config.intermediate_size

??????? self.gate_proj =? nn.Linear(self.hidden_size, self.intermediate_size, bias=False)

??????? self.up_proj =? nn.Linear(self.hidden_size, self.intermediate_size, bias=False)

??????? self.down_proj =? nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

??? ????self.act_fn = ACT2FN[config.hidden_act]


??? def forward(self, x):

??????? return? self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


再來看下MoE中的專家模塊,代碼位置在transformers.models.mixtral.modeling_mixtral,主要流程是:

(1) 首先經(jīng)過網(wǎng)關(guān)路由self.gate

(2) 然后選擇其中2個(gè)專家,并歸一化

(3) 之后遍歷每個(gè)專家網(wǎng)絡(luò),并按照expert_mask進(jìn)行篩選

(4) 如果expert_mask有值,則選擇指定部分的隱藏層進(jìn)行FFN操作,且輸出結(jié)果進(jìn)行加權(quán)

(5) 最后原地增加先前初始化的最終結(jié)果變量final_hidden_states


class MixtralSparseMoeBlock(nn.Module):


??? def __init__(self,? config):

??????? super().__init__()

??????? self.hidden_dim =? config.hidden_size

??????? self.ffn_dim =? config.intermediate_size

??????? self.num_experts =? config.num_local_experts

??????? self.top_k =? config.num_experts_per_tok


??????? # gating

??????? self.gate =? nn.Linear(self.hidden_dim, self.num_experts, bias=False)


??????? self.experts =? nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in? range(self.num_experts)])


??? def forward(self,? hidden_states: torch.Tensor) -> torch.Tensor:

??????? """? """

??????? batch_size,? sequence_length, hidden_dim = hidden_states.shape

??????? hidden_states =? hidden_states.view(-1, hidden_dim)

??????? # router_logits:? (batch * sequence_length, n_experts)

??????? router_logits =? self.gate(hidden_states)


??????? routing_weights =? F.softmax(router_logits, dim=1, dtype=torch.float)

??????? routing_weights,? selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)

??????? routing_weights /=? routing_weights.sum(dim=-1, keepdim=True)

??????? # we cast back to? the input dtype

??????? routing_weights =? routing_weights.to(hidden_states.dtype)


??????? final_hidden_states? = torch.zeros(

??????????? (batch_size *? sequence_length, hidden_dim), dtype=hidden_states.dtype,? device=hidden_states.device

??????? )


??????? # One hot encode the? selected experts to create an expert mask

??????? # this will be used? to easily index which expert is going to be sollicitated

??????? expert_mask =? torch.nn.functional.one_hot(selected_experts,? num_classes=self.num_experts).permute(2, 1, 0)


??????? # Loop over all? available experts in the model and perform the computation on each expert

??????? for expert_idx in? range(self.num_experts):

??????????? expert_layer =? self.experts[expert_idx]

??????????? idx, top_x =? torch.where(expert_mask[expert_idx])


??????????? if? top_x.shape[0] == 0:

??????????????? continue


??????????? # in torch it is? faster to index using lists than torch tensors

??????????? top_x_list =? top_x.tolist()

??????????? idx_list =? idx.tolist()


??????????? # Index the? correct hidden states and compute the expert hidden state for

??????????? # the current? expert. We need to make sure to multiply the output hidden

??????????? # states by? `routing_weights` on the corresponding tokens (top-1 and top-2)

??????????? current_state =? hidden_states[None, top_x_list].reshape(-1, hidden_dim)


? current_hidden_states = expert_layer(current_state) *? routing_weights[top_x_list, idx_list, None]


??????????? # However? `index_add_` only support torch tensors for indexing so we'll use

??????????? # the `top_x`? tensor here.


? final_hidden_states.index_add_(0, top_x,? current_hidden_states.to(hidden_states.dtype))

??????? final_hidden_states? = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

???? ???return final_hidden_states, router_logits


其中MixtralBlockSparseTop2MLP代碼如下,可以看到和傳統(tǒng)MistralMLP內(nèi)容完全一致。


class MixtralBlockSparseTop2MLP(nn.Module):

??? def __init__(self,? config: MixtralConfig):

??????? super().__init__()

??????? self.ffn_dim =? config.intermediate_size

??????? self.hidden_dim =? config.hidden_size


??????? self.w1 =? nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

??????? self.w2 =? nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)

??????? self.w3 = nn.Linear(self.hidden_dim,? self.ffn_dim, bias=False)


??????? self.act_fn =? ACT2FN[config.hidden_act]


??? def forward(self,? hidden_states):


? current_hidden_states = self.act_fn(self.w1(hidden_states)) *? self.w3(hidden_states)

??????? current_hidden_states? = self.w2(current_hidden_states)

??????? return? current_hidden_states


4. MoE微調(diào)

由于MoE只是將每一層的FFN改變?yōu)榱嗣恳粚拥膅ate網(wǎng)關(guān)路由+8個(gè)FFN專家,且gate網(wǎng)關(guān)路由和8個(gè)專家內(nèi)部均為線性運(yùn)算,所以可以無縫地結(jié)合LoRA、QLoRA進(jìn)行指令微調(diào)。

可以參考開源項(xiàng)目:https://github.com/yangjianxin1/Firefly

5. 答疑解惑

(1) 問:MoE 8*7B的模型是56B參數(shù)?

答:MoE 8*7B的參數(shù)量是47B,而不是56B,原因是每一層除了8個(gè)專家網(wǎng)絡(luò)外,其他層均是復(fù)用的。

(2) 問:MoE的基礎(chǔ)模型是Mistral7B?

答:不是,MoE的模型架構(gòu)與Mistral

7B相同,但其中的FFN替換為了8個(gè)FFN,且MoE是基于多語言數(shù)據(jù)集預(yù)訓(xùn)練而來的。

(3) MoE的稀疏性(sparse)體現(xiàn)在哪里?

答:在訓(xùn)練和推理時(shí),同時(shí)只有兩個(gè)專家網(wǎng)絡(luò)會(huì)被激活,進(jìn)行前向計(jì)算,其它專家網(wǎng)絡(luò)處于失活狀態(tài)。

6. 總結(jié)

一句話足矣~

本文主要針對(duì)大語言模型的MoE,包括原理及部分源碼。

此外,建議大家可以針對(duì)源碼進(jìn)行運(yùn)行,關(guān)于源碼,歡迎大家一塊交流。

7. 參考

(1) Mistral 7B:https://arxiv.org/pdf/2310.06825v1.pdf

(2) MoE:https://arxiv.org/pdf/2401.04088v1.pdf

(3) MoE開源指令微調(diào)框架Firefly:https://github.com/yangjianxin1/Firefly

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容