35、AFT模塊
論文《An Attention Free Transformer》
1、作用
注意力自由變換器(AFT)旨在通過去除傳統(tǒng)Transformer中的點積自注意力機制,提供一種更高效的變換器模型。它特別適用于需要高計算效率和較低內(nèi)存消耗的應(yīng)用場景,如移動設(shè)備和邊緣計算。
2、機制
AFT通過直接對輸入特征進行變換來實現(xiàn)序列間的關(guān)聯(lián),不再需要復雜的自注意力計算。它使用一種簡單的基于位置的加權(quán)策略,通過這種方式,每個輸出元素是輸入元素的加權(quán)和,權(quán)重由元素的相對位置決定。這種方法極大地降低了模型的復雜性和運行時內(nèi)存需求。
3、獨特優(yōu)勢
1、高效性:AFT由于避免了昂貴的自注意力計算,因此在執(zhí)行速度和計算效率上有明顯優(yōu)勢。
2、簡化模型結(jié)構(gòu):通過消除自注意力機制,AFT簡化了模型結(jié)構(gòu),使得模型更加輕量化,易于實現(xiàn)和部署。
3、適應(yīng)性強:AFT的結(jié)構(gòu)使其更容易適應(yīng)于不同的任務(wù)和數(shù)據(jù)集,具有良好的泛化能力。
4、資源占用低:對于資源受限的環(huán)境,如移動設(shè)備和邊緣計算設(shè)備,AFT提供了一種實用的解決方案,能夠在保持較高性能的同時,降低資源消耗。
4、代碼
import numpy as np
import torch
from torch import nn
from torch.nn import init
class AFT_FULL(nn.Module):
# 初始化AFT_FULL模塊
def __init__(self, d_model, n=49, simple=False):
super(AFT_FULL, self).__init__()
# 定義QKV三個線性變換層
self.fc_q = nn.Linear(d_model, d_model)
self.fc_k = nn.Linear(d_model, d_model)
self.fc_v = nn.Linear(d_model, d_model)
# 根據(jù)simple參數(shù)決定位置偏置的初始化方式
if (simple):
self.position_biases = torch.zeros((n, n)) # 簡單模式下為零矩陣
else:
self.position_biases = nn.Parameter(torch.ones((n, n))) # 非簡單模式下為可學習的參數(shù)
self.d_model = d_model
self.n = n # 輸入序列的長度
self.sigmoid = nn.Sigmoid() # 使用Sigmoid函數(shù)
self.init_weights() # 初始化模型權(quán)重
def init_weights(self):
# 對模塊中的參數(shù)進行初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, input):
bs, n, dim = input.shape # 輸入的批大小、序列長度和特征維度
# 通過QKV變換生成查詢、鍵和值
q = self.fc_q(input) # bs,n,dim
k = self.fc_k(input).view(1, bs, n, dim) # 1,bs,n,dim,為了后續(xù)運算方便
v = self.fc_v(input).view(1, bs, n, dim) # 1,bs,n,dim
# 使用位置偏置和鍵值對進行加權(quán)求和
numerator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)) * v, dim=2) # n,bs,dim
denominator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)), dim=2) # n,bs,dim
# 計算加權(quán)求和的結(jié)果,并通過sigmoid函數(shù)調(diào)制查詢向量
out = (numerator / denominator) # n,bs,dim
out = self.sigmoid(q) * (out.permute(1, 0, 2)) # bs,n,dim,最后將結(jié)果重新排列
return out
# 示例使用
if __name__ == '__main__':
block = AFT_FULL(d_model=512, n=64).cuda()
input = torch.rand(64, 64, 512).cuda()
output = block(input)
print(output.shape) # 打印輸出形狀