【文章來(lái)自】
github-SwiGLU激活函數(shù)
SwiGLU激活函數(shù)
Gated Linear Units由兩個(gè)線性映射的元素乘組成,其中一個(gè)過(guò)一道sigmoid函數(shù)。用其他激活函數(shù)替代sigmoid,就形成了GLU的變種,SwiGLU就是其中一個(gè)。SwiGLU替換Transformer FFN中的ReLU,可提升Transformer的訓(xùn)練效果。
GLU
Gated Linear Units(GLUs)是一種在神經(jīng)網(wǎng)絡(luò)中使用的激活機(jī)制,旨在提高模型的表達(dá)能力和計(jì)算效率。GLUs是由Yann Dauphin等人于2016年在論文《Language Modeling with Gated Convolutional Networks》中提出的,最初用于提升語(yǔ)言模型中的卷積神經(jīng)網(wǎng)絡(luò)(CNN)的性能。
GLUs利用了門控機(jī)制(gating mechanism),類似于LSTM中的門控單元,通過(guò)引入額外的門控結(jié)構(gòu)來(lái)控制信息的傳遞。具體的公式如下:
其中:
-
是輸入張量。
-
和
是權(quán)重矩陣。
-
和
是偏置向量。
-
是sigmoid激活函數(shù)。
-
表示元素乘。
公式的前半部分是一個(gè)線性變換,后半部分是一個(gè)門控信號(hào),利用sigmoid函數(shù)將其范圍壓縮到0,1之間。最終的輸出是這兩個(gè)部分的逐元素乘積。
優(yōu)點(diǎn):
- 提高模型表達(dá)能力:GLUs通過(guò)引入門控機(jī)制,允許網(wǎng)絡(luò)選擇性地傳遞信息,增強(qiáng)了模型的非線性表達(dá)能力。
- 緩解梯度消失問(wèn)題:由于使用了sigmoid函數(shù),GLUs可以有效緩解梯度消失問(wèn)題,使得訓(xùn)練更加穩(wěn)定。
- 計(jì)算效率高:與一些復(fù)雜的激活函數(shù)相比,GLUs的計(jì)算開(kāi)銷較低,適合大規(guī)模神經(jīng)網(wǎng)絡(luò)的訓(xùn)練。
實(shí)現(xiàn):
import torch
import torch.nn as nn
class GLU(nn.Module):
def __init__(self, input_dim, output_dim):
super(GLU, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.gate = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x) * torch.sigmoid(self.gate(x))
# 示例使用
input_dim = 10
output_dim = 20
glu = GLU(input_dim, output_dim)
x = torch.randn(5, input_dim)
output = glu(x)
print(output)
Swish
Swish激活函數(shù)是一種由谷歌研究團(tuán)隊(duì)提出的激活函數(shù),具有以下形式:
Swish 激活函數(shù)的特點(diǎn)是它的輸出不僅僅依賴于輸入 本身,還受
的影響,使其在一定程度上保留了輸入的信息,同時(shí)引入了非線性變換。這種特性使得 Swish 函數(shù)在深度神經(jīng)網(wǎng)絡(luò)中表現(xiàn)出色,尤其是在處理梯度傳播和優(yōu)化問(wèn)題時(shí)。
Swish 激活函數(shù)的優(yōu)點(diǎn):
- 平滑性:Swish 是一個(gè)平滑的激活函數(shù),沒(méi)有 ReLU 激活函數(shù)的硬拐點(diǎn),這有助于梯度的穩(wěn)定傳播。
- 自門控機(jī)制:Swish 函數(shù)通過(guò)輸入本身進(jìn)行自門控,允許小負(fù)值通過(guò),這對(duì)于梯度的流動(dòng)有利,尤其是深層網(wǎng)絡(luò)。
- 實(shí)驗(yàn)表現(xiàn):在一些實(shí)驗(yàn)中,Swish 激活函數(shù)在圖像分類和機(jī)器翻譯等任務(wù)上顯示出優(yōu)于 ReLU 和其他傳統(tǒng)激活函數(shù)的性能。
[圖片上傳失敗...(image-ef2212-1744340658672)]
Transformer中的SwiGLU
在Transformer中, attention層之后,還要過(guò)一層position-wise feed-forward networks。
其中,是每個(gè)位置的向量。這個(gè)公式的含義是
經(jīng)過(guò)一個(gè)線性映射,再經(jīng)過(guò)ReLU,最后再經(jīng)過(guò)另一個(gè)線性映射。
SwiGLU替換ReLU之后,F(xiàn)FN則變成:
其中, 意思是
。公式中省略偏置項(xiàng)。
代碼中的公式應(yīng)該為:
# 原始FFN
import torch
import torch.nn as nn
class PositionWiseFeedForward(nn.Module):
def __init__(self, d_model, d_ffn):
super(PositionWiseFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ffn)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(d_ffn, d_model)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
#SwiGLU FFN
class SwiGLUFeedForward(nn.Module):
def __init__(self, d_model, d_ffn):
super(SwiGLUFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ffn)
self.linear2 = nn.Linear(d_model, d_ffn)
self.linear3 = nn.Linear(d_ffn, d_model)
def forward(self, x):
swish = x * torch.sigmoid(self.linear1(x))
v = self.linear2(x)
x = swish * v
return self.linear3(x)
參考資料