深度學習模塊36-AFT模塊

35、AFT模塊

論文《An Attention Free Transformer》

1、作用

注意力自由變換器(AFT)旨在通過去除傳統(tǒng)Transformer中的點積自注意力機制,提供一種更高效的變換器模型。它特別適用于需要高計算效率和較低內(nèi)存消耗的應(yīng)用場景,如移動設(shè)備和邊緣計算。

2、機制

AFT通過直接對輸入特征進行變換來實現(xiàn)序列間的關(guān)聯(lián),不再需要復雜的自注意力計算。它使用一種簡單的基于位置的加權(quán)策略,通過這種方式,每個輸出元素是輸入元素的加權(quán)和,權(quán)重由元素的相對位置決定。這種方法極大地降低了模型的復雜性和運行時內(nèi)存需求。

3、獨特優(yōu)勢

1、高效性:AFT由于避免了昂貴的自注意力計算,因此在執(zhí)行速度和計算效率上有明顯優(yōu)勢。

2、簡化模型結(jié)構(gòu):通過消除自注意力機制,AFT簡化了模型結(jié)構(gòu),使得模型更加輕量化,易于實現(xiàn)和部署。

3、適應(yīng)性強:AFT的結(jié)構(gòu)使其更容易適應(yīng)于不同的任務(wù)和數(shù)據(jù)集,具有良好的泛化能力。

4、資源占用低:對于資源受限的環(huán)境,如移動設(shè)備和邊緣計算設(shè)備,AFT提供了一種實用的解決方案,能夠在保持較高性能的同時,降低資源消耗。

4、代碼

import numpy as np
import torch
from torch import nn
from torch.nn import init

class AFT_FULL(nn.Module):
    # 初始化AFT_FULL模塊
    def __init__(self, d_model, n=49, simple=False):
        super(AFT_FULL, self).__init__()
        # 定義QKV三個線性變換層
        self.fc_q = nn.Linear(d_model, d_model)
        self.fc_k = nn.Linear(d_model, d_model)
        self.fc_v = nn.Linear(d_model, d_model)
        # 根據(jù)simple參數(shù)決定位置偏置的初始化方式
        if (simple):
            self.position_biases = torch.zeros((n, n))  # 簡單模式下為零矩陣
        else:
            self.position_biases = nn.Parameter(torch.ones((n, n)))  # 非簡單模式下為可學習的參數(shù)
        self.d_model = d_model
        self.n = n  # 輸入序列的長度
        self.sigmoid = nn.Sigmoid()  # 使用Sigmoid函數(shù)

        self.init_weights()  # 初始化模型權(quán)重

    def init_weights(self):
        # 對模塊中的參數(shù)進行初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, input):
        bs, n, dim = input.shape  # 輸入的批大小、序列長度和特征維度

        # 通過QKV變換生成查詢、鍵和值
        q = self.fc_q(input)  # bs,n,dim
        k = self.fc_k(input).view(1, bs, n, dim)  # 1,bs,n,dim,為了后續(xù)運算方便
        v = self.fc_v(input).view(1, bs, n, dim)  # 1,bs,n,dim

        # 使用位置偏置和鍵值對進行加權(quán)求和
        numerator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)) * v, dim=2)  # n,bs,dim
        denominator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)), dim=2)  # n,bs,dim

        # 計算加權(quán)求和的結(jié)果,并通過sigmoid函數(shù)調(diào)制查詢向量
        out = (numerator / denominator)  # n,bs,dim
        out = self.sigmoid(q) * (out.permute(1, 0, 2))  # bs,n,dim,最后將結(jié)果重新排列

        return out

# 示例使用
if __name__ == '__main__':
    block = AFT_FULL(d_model=512, n=64).cuda()
    input = torch.rand(64, 64, 512).cuda()
    output = block(input)
    print(output.shape) # 打印輸出形狀

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容