深度學(xué)習(xí)模塊8-Axial_attention模塊

論文《AXIAL ATTENTION IN MULTIDIMENSIONAL TRANSFORMERS》

1、作用

Axial Attention 提出了一種用于圖像和其他作為高維張量組織的數(shù)據(jù)的自注意力基的自回歸模型。傳統(tǒng)的自回歸模型要么因高維數(shù)據(jù)而導(dǎo)致計(jì)算資源需求過(guò)大,要么為了減少資源需求而在分布表達(dá)性或?qū)崿F(xiàn)簡(jiǎn)便性方面做出妥協(xié)。Axial Transformers 設(shè)計(jì)旨在在保持?jǐn)?shù)據(jù)上聯(lián)合分布的完整表達(dá)性和易于使用標(biāo)準(zhǔn)深度學(xué)習(xí)框架實(shí)現(xiàn)的同時(shí),要求合理的內(nèi)存和計(jì)算資源,并在標(biāo)準(zhǔn)生成建?;鶞?zhǔn)上實(shí)現(xiàn)最先進(jìn)的結(jié)果。

2、機(jī)制

1、軸向注意力

與對(duì)張量元素的序列應(yīng)用標(biāo)準(zhǔn)自注意力不同,Axial Transformer 沿著張量的單個(gè)軸應(yīng)用注意力,稱為“軸向注意力”,而不是展平張量。這種操作在計(jì)算和內(nèi)存使用上比標(biāo)準(zhǔn)自注意力節(jié)省顯著,因?yàn)樗匀坏嘏c張量的多個(gè)維度對(duì)齊。

2、半并行結(jié)構(gòu)

Axial Transformer 的層結(jié)構(gòu)允許在解碼時(shí)并行計(jì)算絕大多數(shù)上下文,而無(wú)需引入任何獨(dú)立性假設(shè),這對(duì)于即使是非常大的Axial Transformer也是廣泛適用的。

3、獨(dú)特優(yōu)勢(shì)

1、計(jì)算效率

Axial Transformer 通過(guò)軸向注意力操作在資源使用上實(shí)現(xiàn)了顯著節(jié)省,對(duì)于具有 N = N1/d × · · · × N1/d 形狀的 d 維張量,相比標(biāo)準(zhǔn)自注意力,軸向注意力在資源上節(jié)省了 O(N(d?1)/d) 因子。

2、完全表達(dá)性

盡管Axial Transformer沿單個(gè)軸應(yīng)用注意力,但其結(jié)構(gòu)設(shè)計(jì)確保了模型可以表達(dá)數(shù)據(jù)的全局依賴性,不丟失對(duì)前一個(gè)像素的依賴性。

3、簡(jiǎn)單易實(shí)現(xiàn)

Axial Transformer 不需要為GPU或TPU編寫(xiě)特定的子程序,它可以使用深度學(xué)習(xí)框架中廣泛可用的高效操作(主要是密集的MatMul操作)簡(jiǎn)單實(shí)現(xiàn)。

4、代碼

import torch
from torch import nn
from operator import itemgetter
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states

# 定義一個(gè)模塊包裝器,確保通過(guò)保存和恢復(fù)隨機(jī)數(shù)生成器(RNG)狀態(tài)的確定性行為。
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net # 要包裝的網(wǎng)絡(luò)
        self.cpu_state = None   # CPU RNG狀態(tài)
        self.cuda_in_fwd = None # 前向傳遞中是否使用了CUDA
        self.gpu_devices = None  # 使用的GPU設(shè)備
        self.gpu_states = None # GPU RNG狀態(tài)
    
    # 記錄當(dāng)前的隨機(jī)狀態(tài)
    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)
    # 前向傳遞
    def forward(self, *args, record_rng=False, set_rng=False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)


# 可逆塊模塊,實(shí)現(xiàn)可逆網(wǎng)絡(luò)中的一個(gè)塊
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f) # 包裝f函數(shù),確保確定性
        self.g = Deterministic(g)  # 包裝g函數(shù),確保確定性
# 前向傳遞,實(shí)現(xiàn)可逆計(jì)算
    def forward(self, x, f_args={}, g_args={}):
        x1, x2 = torch.chunk(x, 2, dim=1) # 將輸入分為兩部分
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)  # 計(jì)算y1
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args) # 計(jì)算y2

        return torch.cat([y1, y2], dim=1) # 返回合并后的結(jié)果
# 反向傳遞,用于梯度計(jì)算
    def backward_pass(self, y, dy, f_args={}, g_args={}):
        y1, y2 = torch.chunk(y, 2, dim=1)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=1)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=1)
            dx = torch.cat([dx1, dx2], dim=1)

        return x, dx

# 不可逆塊模塊,對(duì)比可逆塊的實(shí)現(xiàn)
class IrreversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = f# 直接使用f函數(shù)
        self.g = g# 直接使用g函數(shù)

    def forward(self, x, f_args, g_args):
        x1, x2 = torch.chunk(x, 2, dim=1)
        y1 = x1 + self.f(x2, **f_args)
        y2 = x2 + self.g(y1, **g_args)
        return torch.cat([y1, y2], dim=1)

# 可逆函數(shù)實(shí)現(xiàn),用于在可逆網(wǎng)絡(luò)中應(yīng)用自定義的可逆操作
class _ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, x, blocks, kwargs):
        ctx.kwargs = kwargs
        for block in blocks:
            x = block(x, **kwargs)
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x

    @staticmethod
    def backward(ctx, dy):
        y = ctx.y
        kwargs = ctx.kwargs
        for block in ctx.blocks[::-1]:
            y, dy = block.backward_pass(y, dy, **kwargs)
        return dy, None, None


class ReversibleSequence(nn.Module): #逆塊串聯(lián)起來(lái),構(gòu)成一個(gè)可逆的網(wǎng)絡(luò)結(jié)構(gòu)。
    def __init__(self, blocks, ):
        super().__init__()
        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])# 將傳入的函數(shù)對(duì)構(gòu)建為可逆塊,并加入模塊列表

    def forward(self, x, arg_route=(True, True), **kwargs):
        f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)# 將傳入的函數(shù)對(duì)構(gòu)建為可逆塊,并加入模塊列表
        block_kwargs = {'f_args': f_args, 'g_args': g_args}
        x = torch.cat((x, x), dim=1)  # 將輸入復(fù)制一份并合并,為可逆計(jì)算做準(zhǔn)備
        x = _ReversibleFunction.apply(x, self.blocks, block_kwargs)# 通過(guò)_ReversibleFunction執(zhí)行可逆序列的前向計(jì)算
        return torch.stack(x.chunk(2, dim=1)).mean(dim=0)# 將結(jié)果拆分并取均值,完成前向傳遞



# 檢查值是否非None
def exists(val):
    return val is not None

# 從數(shù)組中按索引映射元素
def map_el_ind(arr, ind):
    return list(map(itemgetter(ind), arr))

# 對(duì)數(shù)組進(jìn)行排序并返回原始索引
def sort_and_return_indices(arr):
    indices = [ind for ind in range(len(arr))]# 創(chuàng)建索引列表
    arr = zip(arr, indices)  # 將數(shù)組的元素與它們的索引配對(duì)
    arr = sorted(arr) # 對(duì)配對(duì)進(jìn)行排序
    return map_el_ind(arr, 0), map_el_ind(arr, 1) # 返回排序后的數(shù)組和對(duì)應(yīng)的原始索引



# 計(jì)算維度排列
def calculate_permutations(num_dimensions, emb_dim):
    total_dimensions = num_dimensions + 2
    emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions)
    axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]

    permutations = []

    for axial_dim in axial_dims:
        last_two_dims = [axial_dim, emb_dim]
        dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
        permutation = [*dims_rest, *last_two_dims]
        permutations.append(permutation)

    return permutations



# 通道層歸一化
class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) / (std + self.eps) * self.g + self.b

# 前置歸一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

# 順序執(zhí)行模塊
class Sequential(nn.Module):
    def __init__(self, blocks):
        super().__init__()
        self.blocks = blocks

    def forward(self, x):
        for f, g in self.blocks:
            x = x + f(x)
            x = x + g(x)
        return x

# 維度置換
class PermuteToFrom(nn.Module):
    def __init__(self, permutation, fn):
        super().__init__()
        self.fn = fn
        _, inv_permutation = sort_and_return_indices(permutation)
        self.permutation = permutation
        self.inv_permutation = inv_permutation

    def forward(self, x, **kwargs):
        axial = x.permute(*self.permutation).contiguous()

        shape = axial.shape
        *_, t, d = shape

      
        axial = axial.reshape(-1, t, d)

        
        axial = self.fn(axial, **kwargs)

       
        axial = axial.reshape(*shape)
        axial = axial.permute(*self.inv_permutation).contiguous()
        return axial



#軸向位置嵌入
class AxialPositionalEmbedding(nn.Module):
    def __init__(self, dim, shape, emb_dim_index=1):
        super().__init__()
        parameters = []
        total_dimensions = len(shape) + 2
        ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]

        self.num_axials = len(shape)

        for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
            shape = [1] * total_dimensions
            shape[emb_dim_index] = dim
            shape[axial_dim_index] = axial_dim
            parameter = nn.Parameter(torch.randn(*shape))
            setattr(self, f'param_{i}', parameter)

    def forward(self, x):
        for i in range(self.num_axials):
            x = x + getattr(self, f'param_{i}')
        return x


#自注意力模塊
class SelfAttention(nn.Module):
    def __init__(self, dim, heads, dim_heads=None):
        super().__init__()
        self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
        dim_hidden = self.dim_heads * heads

        self.heads = heads
        self.to_q = nn.Linear(dim, dim_hidden, bias=False)
        self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias=False)
        self.to_out = nn.Linear(dim_hidden, dim)

    def forward(self, x, kv=None):
        kv = x if kv is None else kv
        q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))

        b, t, d, h, e = *q.shape, self.heads, self.dim_heads

        merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
        q, k, v = map(merge_heads, (q, k, v))

        dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)
        dots = dots.softmax(dim=-1)
        out = torch.einsum('bij,bje->bie', dots, v)

        out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
        out = self.to_out(out)
        return out


#軸向注意力模塊
class AxialAttention(nn.Module):
    def __init__(self, dim, num_dimensions=2, heads=8, dim_heads=None, dim_index=-1, sum_axial_out=True):
        assert (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'
        super().__init__()
        self.dim = dim# 特征維度
        self.total_dimensions = num_dimensions + 2# 總維度數(shù)
        self.dim_index = dim_index if dim_index > 0 else (dim_index + self.total_dimensions)

        attentions = []
        for permutation in calculate_permutations(num_dimensions, dim_index):
            attentions.append(PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads)))

        self.axial_attentions = nn.ModuleList(attentions)
        self.sum_axial_out = sum_axial_out

    def forward(self, x):
        assert len(x.shape) == self.total_dimensions, 'input tensor does not have the correct number of dimensions'
        assert x.shape[self.dim_index] == self.dim, 'input tensor does not have the correct input dimension'

        if self.sum_axial_out:
            return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))

        out = x
        for axial_attn in self.axial_attentions:
            out = axial_attn(out)
        return out




class AxialImageTransformer(nn.Module):
    def __init__(self, dim, depth, heads=8, dim_heads=None, dim_index=1, reversible=True, axial_pos_emb_shape=None):
        super().__init__()
        permutations = calculate_permutations(2, dim_index)

        get_ff = lambda: nn.Sequential(
            ChanLayerNorm(dim),
            nn.Conv2d(dim, dim * 4, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(dim * 4, dim, 3, padding=1)
        )

        self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if exists(
            axial_pos_emb_shape) else nn.Identity()

        layers = nn.ModuleList([])
        for _ in range(depth):
            attn_functions = nn.ModuleList(
                [PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads))) for permutation in
                 permutations])
            conv_functions = nn.ModuleList([get_ff(), get_ff()])
            layers.append(attn_functions)
            layers.append(conv_functions)

        execute_type = ReversibleSequence if reversible else Sequential
        self.layers = execute_type(layers)

    def forward(self, x):
        x = self.pos_emb(x)
        return self.layers(x)



if __name__ == '__main__':
    block = AxialImageTransformer(
        dim=64,
        depth=12,
        reversible=True
    ).cuda()
    input = torch.rand(1, 64, 64, 64).cuda()
    output = block(input)
    print(output.shape)

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容