實現(xiàn)常見CNN網(wǎng)絡結(jié)構(gòu)中添加注意力(attention)機制

目前常見網(wǎng)絡結(jié)構(gòu)有許多,例如:

  • LeNet:基于漸變的學習應用于文檔識別
  • AlexNet:具有深卷積神經(jīng)網(wǎng)絡的ImageNet分類
  • VGGNet:用于大規(guī)模圖像識別的非常深的卷積網(wǎng)絡
  • GoogLeNet:卷入更深入
  • Inception-v3:重新思考計算機視覺的初始架構(gòu)
  • ResNet:圖像識別的深度殘差學習
  • Inception-ResNet:Inception-v4,inception-resnet以及剩余連接對學習的影響
  • SqueezeNet:AlexNet級準確度,參數(shù)減少50倍,模型尺寸小于0.5MB
  • MobileNets:用于移動視覺應用的高效卷積神經(jīng)網(wǎng)絡
  • ShuffleNet:移動設備極高效的卷積神經(jīng)網(wǎng)絡

歡迎關注公眾號:七只的Coding日志,更多內(nèi)容鏈接

本次以mini_XCEPTION網(wǎng)絡為例:

def mini_XCEPTION(input_shape, num_classes, l2_regularization=0.01):
    regularization = l2(l2_regularization)

    # base
    img_input = Input(input_shape)
    x = Conv2D(8, (3, 3), strides=(1, 1), kernel_regularizer=regularization,
                                            use_bias=False)(img_input)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(8, (3, 3), strides=(1, 1), kernel_regularizer=regularization,
                                            use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # module 1
    residual = Conv2D(16, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)


    x = SeparableConv2D(16, (3, 3), padding='same',
                        kernel_regularizer=regularization,
                        use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(16, (3, 3), padding='same',
                        kernel_regularizer=regularization,
                        use_bias=False)(x)
    x = BatchNormalization()(x)

    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
    x = layers.add([x, residual])

    # module 2
    residual = Conv2D(32, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    x = SeparableConv2D(32, (3, 3), padding='same',
                        kernel_regularizer=regularization,
                        use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(32, (3, 3), padding='same',
                        kernel_regularizer=regularization,
                        use_bias=False)(x)
    x = BatchNormalization()(x)

    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
    x = layers.add([x, residual])

    # module 3
    residual = Conv2D(64, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    x = SeparableConv2D(64, (3, 3), padding='same',
                        kernel_regularizer=regularization,
                        use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(64, (3, 3), padding='same',
                        kernel_regularizer=regularization,
                        use_bias=False)(x)
    x = BatchNormalization()(x)

    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
    x = layers.add([x, residual])

    # module 4
    residual = Conv2D(128, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    x = SeparableConv2D(128, (3, 3), padding='same',
                        kernel_regularizer=regularization,
                        use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(128, (3, 3), padding='same',
                        kernel_regularizer=regularization,
                        use_bias=False)(x)
    x = BatchNormalization()(x)

    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
    x = layers.add([x, residual])

    x = Conv2D(num_classes, (3, 3),
            #kernel_regularizer=regularization,
            padding='same')(x)
    x = GlobalAveragePooling2D()(x)
    output = Activation('softmax',name='predictions')(x)

    model = Model(img_input, output)
    return model

首先,注意力機制,在ECCV2018的一篇論文中,提出了CBAM(Convolutional Block Attention Module),想看原論文的道友可以點擊這里下載。該篇論文不僅在原有通道注意力(channel attention)機制上進行改進,而且還增加空間注意力(spatial attention)機制。如下圖所示。

CBAM

這篇論文的貢獻點主要有以下三點:

(1) 提出了一個高效的attention模塊—-CBAM,該模塊能夠嵌入到目前的主流CNN網(wǎng)絡結(jié)構(gòu)中。
(2) 通過額外的分離實驗證明了CBAM中attention的有效性。
(3) 在多個平臺上(ImageNet-1K,MS COCO和VOC 2007)上證明了CBAM的性能提升。

  • 通道注意力(channel attention)

channel attention module

該部分的工作與SENet很相似,都是首先將feature map在spatial維度上進行壓縮,得到一個一維矢量以后再進行操作。與SENet不同之處在于,對輸入feature map進行spatial維度壓縮時,作者不單單考慮了average pooling,額外引入max pooling作為補充,通過兩個pooling函數(shù)以后總共可以得到兩個一維矢量。global average pooling對feature map上的每一個像素點都有反饋,而global max pooling在進行梯度反向傳播計算只有feature map中響應最大的地方有梯度的反饋,能作為GAP的一個補充。公式如下:

公式
  • 空間注意力(spatial attention)

spatial attention module

這部分工作是論文跟SENet區(qū)別開來的一個重要貢獻,除了在channel上生成了attention模型,作者表示在spatial層面上也需要網(wǎng)絡能明白feature map中哪些部分應該有更高的響應。首先,還是使用average pooling和max pooling對輸入feature map進行壓縮操作,只不過這里的壓縮變成了通道層面上的壓縮,對輸入特征分別在通道維度上做了mean和max操作。最后得到了兩個二維的feature,將其按通道維度拼接在一起得到一個通道數(shù)為2的feature map,之后使用一個包含單個卷積核的隱藏層對其進行卷積操作,要保證最后得到的feature在spatial維度上與輸入的feature map一致.

公式
  • 兩個機制的連接方式

論文中已經(jīng)有相關實驗證明,作者考慮了三種情況:channel-first、spatial-first和parall的方式,可以看到,channel-first能取得更好的分類結(jié)果。

實驗結(jié)果

實現(xiàn)代碼部分:

from tensorflow.keras import backend as K
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, multiply, Permute, Concatenate, Conv2D, Add, Activation, Lambda

''' 通道注意力機制:
    對輸入feature map進行spatial維度壓縮時,作者不單單考慮了average pooling,
    額外引入max pooling作為補充,通過兩個pooling函數(shù)以后總共可以得到兩個一維矢量。
    global average pooling對feature map上的每一個像素點都有反饋,而global max pooling
    在進行梯度反向傳播計算只有feature map中響應最大的地方有梯度的反饋,能作為GAP的一個補充。
'''
def channel_attention(input_feature, ratio=8):

    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
    channel = input_feature.shape[channel_axis]

    shared_layer_one = Dense(channel // ratio,
                             kernel_initializer='he_normal',
                             activation='relu',
                             use_bias=True,
                             bias_initializer='zeros')

    shared_layer_two = Dense(channel,
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')

    avg_pool = GlobalAveragePooling2D()(input_feature)
    avg_pool = Reshape((1, 1, channel))(avg_pool)
    assert avg_pool.shape[1:] == (1, 1, channel)
    avg_pool = shared_layer_one(avg_pool)
    assert avg_pool.shape[1:] == (1, 1, channel // ratio)
    avg_pool = shared_layer_two(avg_pool)
    assert avg_pool.shape[1:] == (1, 1, channel)

    max_pool = GlobalMaxPooling2D()(input_feature)
    max_pool = Reshape((1, 1, channel))(max_pool)
    assert max_pool.shape[1:] == (1, 1, channel)
    max_pool = shared_layer_one(max_pool)
    assert max_pool.shape[1:] == (1, 1, channel // ratio)
    max_pool = shared_layer_two(max_pool)
    assert max_pool.shape[1:] == (1, 1, channel)

    cbam_feature = Add()([avg_pool, max_pool])
    cbam_feature = Activation('hard_sigmoid')(cbam_feature)

    if K.image_data_format() == "channels_first":
        cbam_feature = Permute((3, 1, 2))(cbam_feature)

    return multiply([input_feature, cbam_feature])

''' 空間注意力機制:
    還是使用average pooling和max pooling對輸入feature map進行壓縮操作,
    只不過這里的壓縮變成了通道層面上的壓縮,對輸入特征分別在通道維度上做了
    mean和max操作。最后得到了兩個二維的feature,將其按通道維度拼接在一起
    得到一個通道數(shù)為2的feature map,之后使用一個包含單個卷積核的隱藏層對
    其進行卷積操作,要保證最后得到的feature在spatial維度上與輸入的feature map一致,
'''
def spatial_attention(input_feature):
    kernel_size = 7

    if K.image_data_format() == "channels_first":
        channel = input_feature.shape[1]
        cbam_feature = Permute((2, 3, 1))(input_feature)
    else:
        channel = input_feature.shape[-1]
        cbam_feature = input_feature

    avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
    assert avg_pool.shape[-1] == 1
    max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
    assert max_pool.shape[-1] == 1
    concat = Concatenate(axis=3)([avg_pool, max_pool])
    assert concat.shape[-1] == 2
    cbam_feature = Conv2D(filters=1,
                          kernel_size=kernel_size,
                          activation='hard_sigmoid',
                          strides=1,
                          padding='same',
                          kernel_initializer='he_normal',
                          use_bias=False)(concat)
    assert cbam_feature.shape[-1] == 1

    if K.image_data_format() == "channels_first":
        cbam_feature = Permute((3, 1, 2))(cbam_feature)

    return multiply([input_feature, cbam_feature])

def cbam_block(cbam_feature, ratio=8):
    """Contains the implementation of Convolutional Block Attention Module(CBAM) block.
    As described in https://arxiv.org/abs/1807.06521.
    """
    # 實驗驗證先通道后空間的方式比先空間后通道或者通道空間并行的方式效果更佳
    cbam_feature = channel_attention(cbam_feature, ratio)
    cbam_feature = spatial_attention(cbam_feature, )
    
    return cbam_feature

在你想要添加的網(wǎng)絡部位添加,例如在第二殘差塊添加。示意圖為論文中所提供:

image.png
   # module 2
    residual = Conv2D(32, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

后添加代碼:

    cbam = cbam_block(residual)

并將

    x = layers.add([x, residual])

改成:

    x = layers.add([x, residual, cbam])

當然不一定必須要add()方法,還可以運用concatenate()等方法。

CBAM與SE Module一樣,可以嵌入了目前大部分主流網(wǎng)絡中,在不顯著增加計算量和參數(shù)量的前提下能提升網(wǎng)絡模型的特征提取能力??傊?,在網(wǎng)絡結(jié)構(gòu)中添加attention也不失是一種好的選擇。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容