2025-04-11-SwiGLU激活函數(shù)

【文章來(lái)自】
github-SwiGLU激活函數(shù)

SwiGLU激活函數(shù)

Gated Linear Units由兩個(gè)線性映射的元素乘組成,其中一個(gè)過(guò)一道sigmoid函數(shù)。用其他激活函數(shù)替代sigmoid,就形成了GLU的變種,SwiGLU就是其中一個(gè)。SwiGLU替換Transformer FFN中的ReLU,可提升Transformer的訓(xùn)練效果。

GLU

Gated Linear Units(GLUs)是一種在神經(jīng)網(wǎng)絡(luò)中使用的激活機(jī)制,旨在提高模型的表達(dá)能力和計(jì)算效率。GLUs是由Yann Dauphin等人于2016年在論文《Language Modeling with Gated Convolutional Networks》中提出的,最初用于提升語(yǔ)言模型中的卷積神經(jīng)網(wǎng)絡(luò)(CNN)的性能。

GLUs利用了門控機(jī)制(gating mechanism),類似于LSTM中的門控單元,通過(guò)引入額外的門控結(jié)構(gòu)來(lái)控制信息的傳遞。具體的公式如下:

y=(X W+b) \otimes \sigma(X V+c)

其中:

  • X 是輸入張量。
  • WV 是權(quán)重矩陣。
  • bc 是偏置向量。
  • \sigma 是sigmoid激活函數(shù)。
  • \otimes 表示元素乘。

公式的前半部分是一個(gè)線性變換,后半部分是一個(gè)門控信號(hào),利用sigmoid函數(shù)將其范圍壓縮到0,1之間。最終的輸出是這兩個(gè)部分的逐元素乘積。

優(yōu)點(diǎn):

  1. 提高模型表達(dá)能力:GLUs通過(guò)引入門控機(jī)制,允許網(wǎng)絡(luò)選擇性地傳遞信息,增強(qiáng)了模型的非線性表達(dá)能力。
  2. 緩解梯度消失問(wèn)題:由于使用了sigmoid函數(shù),GLUs可以有效緩解梯度消失問(wèn)題,使得訓(xùn)練更加穩(wěn)定。
  3. 計(jì)算效率高:與一些復(fù)雜的激活函數(shù)相比,GLUs的計(jì)算開(kāi)銷較低,適合大規(guī)模神經(jīng)網(wǎng)絡(luò)的訓(xùn)練。

實(shí)現(xiàn):

import torch
import torch.nn as nn

class GLU(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GLU, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.gate = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        return self.linear(x) * torch.sigmoid(self.gate(x))

# 示例使用
input_dim = 10
output_dim = 20
glu = GLU(input_dim, output_dim)
x = torch.randn(5, input_dim)
output = glu(x)
print(output)

Swish

Swish激活函數(shù)是一種由谷歌研究團(tuán)隊(duì)提出的激活函數(shù),具有以下形式:

\operatorname{Swish}_{\beta}(x)=x \cdot \operatorname{sigmoid}(\beta x)

Swish 激活函數(shù)的特點(diǎn)是它的輸出不僅僅依賴于輸入 x 本身,還受\sigma(x) 的影響,使其在一定程度上保留了輸入的信息,同時(shí)引入了非線性變換。這種特性使得 Swish 函數(shù)在深度神經(jīng)網(wǎng)絡(luò)中表現(xiàn)出色,尤其是在處理梯度傳播和優(yōu)化問(wèn)題時(shí)。

Swish 激活函數(shù)的優(yōu)點(diǎn):

  • 平滑性:Swish 是一個(gè)平滑的激活函數(shù),沒(méi)有 ReLU 激活函數(shù)的硬拐點(diǎn),這有助于梯度的穩(wěn)定傳播。
  • 自門控機(jī)制:Swish 函數(shù)通過(guò)輸入本身進(jìn)行自門控,允許小負(fù)值通過(guò),這對(duì)于梯度的流動(dòng)有利,尤其是深層網(wǎng)絡(luò)。
  • 實(shí)驗(yàn)表現(xiàn):在一些實(shí)驗(yàn)中,Swish 激活函數(shù)在圖像分類和機(jī)器翻譯等任務(wù)上顯示出優(yōu)于 ReLU 和其他傳統(tǒng)激活函數(shù)的性能。

[圖片上傳失敗...(image-ef2212-1744340658672)]

Transformer中的SwiGLU

在Transformer中, attention層之后,還要過(guò)一層position-wise feed-forward networks。
\operatorname{FFN}(x)=\max \left(0, x W_1+b_1\right) W_2+b_2
其中,x是每個(gè)位置的向量。這個(gè)公式的含義是x經(jīng)過(guò)一個(gè)線性映射,再經(jīng)過(guò)ReLU,最后再經(jīng)過(guò)另一個(gè)線性映射。

SwiGLU替換ReLU之后,F(xiàn)FN則變成:
\operatorname{FFN}_{\text {SwiGLU }}\left(x\right)=\left(\operatorname{Swish}_1(x W) \otimes x V\right) W_2

其中, \operatorname{Swish}_1 意思是 \beta=1。公式中省略偏置項(xiàng)。

代碼中的公式應(yīng)該為:

{\text {SwiGLU }}\left(x\right)= \operatorname{Swish}_1(x W) \otimes x V

{\text {SwiGLU }}\left(x\right)= x \cdot \operatorname{sigmoid}(\beta x W) \otimes x V

# 原始FFN
import torch
import torch.nn as nn

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ffn):
        super(PositionWiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(d_ffn, d_model)
    
    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

#SwiGLU FFN
class SwiGLUFeedForward(nn.Module):
    def __init__(self, d_model, d_ffn):
        super(SwiGLUFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.linear2 = nn.Linear(d_model, d_ffn)
        self.linear3 = nn.Linear(d_ffn, d_model)
    
    def forward(self, x):
        swish = x * torch.sigmoid(self.linear1(x))
        v = self.linear2(x)
        x = swish * v
        return self.linear3(x)

參考資料

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容