精簡CNN模型系列之三:SkipNet

介紹

CNN模型為了追求精度提高層數(shù)已經(jīng)是愈來愈多,可更多的層次帶來的精度邊際提升卻不斷減小?;蛘邔δ承┹斎雸D片而言,真正所需的layers并非那么多,只有一些真正模糊、特征不明顯、即使人看上去也較難分辨的圖片才需要較多的layers處理最終得到能分別其類別的表達特征。

SkipNet主要是以此假設(shè)出發(fā),通過在傳統(tǒng)CNN的每個layer(或module)上設(shè)置判斷其是否需要執(zhí)行的Gate module來決定是否需要真的執(zhí)行此層計算,若判斷為否則直接將activation feature maps傳入到下一層,越過當(dāng)下層的運算不做。無益這樣做可以有效地節(jié)省傳統(tǒng)CNN模型在部署時進行推理工作所需的時間。

就這樣一旦訓(xùn)練好,SkipNet在做圖片推理時可根據(jù)輸入的feature maps不同靈活地決定是否執(zhí)行某一網(wǎng)絡(luò)中的層。下圖可反映SkipNet這一根本特點。

SkipNet根本思想

SkiptNet

對于每一層操作而言,SkipNet可表示為:xi+1 = GiFi(xi)+(1-Gi)xi。其中xi和Fi(xi)分別表示第ith layer的輸入與輸出feature maps;Gi ∈{0,1} 為第ith layer的Gate函數(shù)。

對于此處的Gate函數(shù),作者實驗了兩種不同的表示方法。Paper中SkipNet基于的CNN網(wǎng)絡(luò)為Resnet,其中Gate即可以被獨立地添加在各個Residual block上面作為單獨的個體,有著不同的參數(shù)即Feed-forward Gate;還可以所有的Residual blocks復(fù)用一個Gate module即Recurrent Gate。其不同之處可從下圖中看出。

SkipNet中兩種不同的Gate函數(shù)選擇

Gate module設(shè)計

作者在論文中共嘗試了三種不同的Gate module設(shè)計,它們對計算與accuracy的考量略有不同。

FFGate-I: MaxPool(2x2) -> Conv(3x3, 1) -> Conv(3x3, 2) -> AvgPool -> FC,整體計算量約為Residual block的19%,在論文中主要用于較淺的一些網(wǎng)絡(luò)(層數(shù)小于100);
FFGate-II: Conv(3x3, 2) -> AvgPool -> FC,整體計算量約為Residual block的12.5%,主要用于較深的一些網(wǎng)絡(luò)(層數(shù)大于100);
RNNGate: AvgPool -> Conv(1x1) -> LSTM(10 hidden units) -> FC,整體計算量約為Residual block的0.04%,是論文中首選的Gate函數(shù)。在深層次網(wǎng)絡(luò)中它相對于Feed-forward Gate有較大的性能與分類精度優(yōu)勢,只是在較淺的層次上它精度略低,但計算開銷仍有較大優(yōu)勢。

下圖為以上三種Gate module的概況描述。

三種具體的Gate_module設(shè)計

使用Hybrid RL的Skipping policy學(xué)習(xí)

對于上節(jié)所介紹的Gate函數(shù)可理解為是這么一種決策:Π(xi,i) = P(Gi(xi) = gi),(其中g(shù)i∈{0,1},分別表示執(zhí)行還是略過第ith層執(zhí)行的兩種離散決策)。

這樣對于有N層的CNN來說,我們在forward時需要決定下如此一個輸入為x的決策序列:g = [g1,....,gN] ? Π(F<sub>&theta;</sub>)。在這里F&theta; = [F&theta;1,....,F&theta;N]表示CNN網(wǎng)絡(luò)中N個layers的計算。

而整體的目標(biāo)函數(shù)則可表示如下:

Skip_learning中使用Hybrid_RL時的整體目標(biāo)函數(shù)

其中Ri = (1-gi)Ci表示的是每個Gate module所節(jié)省的計算,亦為它的激勵函數(shù)。因為paper中用的是Resnet,故假定所有的Ci相同,設(shè)為1。然后α 則為CNN分類準(zhǔn)確率與計算節(jié)省之間的平衡系數(shù)。可以看出這里的目標(biāo)函數(shù)設(shè)計同時考慮了模型分類精度與計算效率并力圖在其中尋找平衡。

下式為具體計算時的梯度計算公式??梢钥闯鏊饕蓛刹糠纸M成,第一部分表示的是學(xué)習(xí)分類精度的supervised loss,第二部分則是要接合RL最終學(xué)習(xí)出來的反映計算節(jié)省的Skip learning policy。

Skip_learning中使用Hybrid_RL時的梯度計算

下圖為使用Hybrid RL的具體算法概述。

Hybrid_RL_learning算法

實驗結(jié)果

下圖為SkipNet在各大數(shù)據(jù)集上得到的分類精度結(jié)果。

在各大數(shù)據(jù)集上SkipNet得到的分類精度

下表中反映了不同SkipNet配置與訓(xùn)練方法在達到與原生ResNet相似精度的情況下?lián)Q來的計算節(jié)省。

不同SkipNet配置在達到相似精度情況下得到的計算節(jié)省

代碼分析

如下為FFGate-I的設(shè)計實現(xiàn),其它Gate module的寫法并無太多不同。

# Feedforward-Gate (FFGate-I)
class FeedforwardGateI(nn.Module):
    """ Use Max Pooling First and then apply to multiple 2 conv layers.
    The first conv has stride = 1 and second has stride = 2"""
    def __init__(self, pool_size=5, channel=10):
        super(FeedforwardGateI, self).__init__()
        self.pool_size = pool_size
        self.channel = channel

        self.maxpool = nn.MaxPool2d(2)
        self.conv1 = conv3x3(channel, channel)
        self.bn1 = nn.BatchNorm2d(channel)
        self.relu1 = nn.ReLU(inplace=True)

        # adding another conv layer
        self.conv2 = conv3x3(channel, channel, stride=2)
        self.bn2 = nn.BatchNorm2d(channel)
        self.relu2 = nn.ReLU(inplace=True)

        pool_size = math.floor(pool_size/2)  # for max pooling
        pool_size = math.floor(pool_size/2 + 0.5)  # for conv stride = 2

        self.avg_layer = nn.AvgPool2d(pool_size)
        self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2,
                                      kernel_size=1, stride=1)
        self.prob_layer = nn.Softmax()
        self.logprob = nn.LogSoftmax()

    def forward(self, x):
        x = self.maxpool(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.avg_layer(x)
        x = self.linear_layer(x).squeeze()
        softmax = self.prob_layer(x)
        logprob = self.logprob(x)

        # discretize output in forward pass.
        # use softmax gradients in backward pass
        x = (softmax[:, 1] > 0.5).float().detach() - \
            softmax[:, 1].detach() + softmax[:, 1]

        x = x.view(x.size(0), 1, 1, 1)
        return x, logprob

下面這個class里面則具體實現(xiàn)了如何將Gate module與某一CNN網(wǎng)絡(luò)結(jié)合起來從而實現(xiàn)相關(guān)的SkipNet。

class ResNetFeedForwardRL(nn.Module):
    """Adding gating module on every basic block"""

    def __init__(self, block, layers, num_classes=10,
                 gate_type='ffgate1', **kwargs):
        self.inplanes = 16
        super(ResNetFeedForwardRL, self).__init__()

        self.num_layers = layers
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        self.gate_instances = []
        self.gate_type = gate_type
        self._make_group(block, 16, layers[0], group_id=1,
                         gate_type=gate_type, pool_size=32)
        self._make_group(block, 32, layers[1], group_id=2,
                         gate_type=gate_type, pool_size=16)
        self._make_group(block, 64, layers[2], group_id=3,
                         gate_type=gate_type, pool_size=8)

        # remove the last gate instance, (not optimized)
        del self.gate_instances[-1]

        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        self.softmax = nn.Softmax()
        self.saved_actions = []
        self.rewards = []

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(0) * m.weight.size(1)
                m.weight.data.normal_(0, math.sqrt(2. / n))

    def _make_group(self, block, planes, layers, group_id=1,
                    gate_type='fisher', pool_size=16):
        """ Create the whole group"""
        for i in range(layers):
            if group_id > 1 and i == 0:
                stride = 2
            else:
                stride = 1

            meta = self._make_layer_v2(block, planes, stride=stride,
                                       gate_type=gate_type,
                                       pool_size=pool_size)

            setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0])
            setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1])
            setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2])

            # add into gate instance collection
            self.gate_instances.append(meta[2])

    def _make_layer_v2(self, block, planes, stride=1,
                       gate_type='fisher', pool_size=16):
        """ create one block and optional a gate module """
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),

            )
        layer = block(self.inplanes, planes, stride, downsample)
        self.inplanes = planes * block.expansion

        if gate_type == 'ffgate1':
            gate_layer = RLFeedforwardGateI(pool_size=pool_size,
                                            channel=planes*block.expansion)
        elif gate_type == 'ffgate2':
            gate_layer = RLFeedforwardGateII(pool_size=pool_size,
                                             channel=planes*block.expansion)
        else:
            gate_layer = None

        if downsample:
            return downsample, layer, gate_layer
        else:
            return None, layer, gate_layer

    def repackage_vars(self):
        self.saved_actions = repackage_hidden(self.saved_actions)

    def forward(self, x, reinforce=False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        masks = []
        gprobs = []
        # must pass through the first layer in first group
        x = getattr(self, 'group1_layer0')(x)
        # gate takes the output of the current layer
        mask, gprob = getattr(self, 'group1_gate0')(x)
        gprobs.append(gprob)
        masks.append(mask.squeeze())
        prev = x  # input of next layer

        for g in range(3):
            for i in range(0 + int(g == 0), self.num_layers[g]):
                if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None:
                    prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev)
                x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x)
                # new mask is taking the current output
                prev = x = mask.expand_as(x) * x \
                           + (1 - mask).expand_as(prev) * prev
                mask, gprob = getattr(self, 'group{}_gate{}'.format(g+1, i))(x)
                gprobs.append(gprob)
                masks.append(mask.squeeze())

        del masks[-1]

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        # collect all actions
        for inst in self.gate_instances:
            self.saved_actions.append(inst.saved_action)

        if reinforce:  # for pure RL
            softmax = self.softmax(x)
            action = softmax.multinomial()
            self.saved_actions.append(action)

        return x, masks, gprobs

參考文獻

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 介紹 SqueezeNet同這個系列要介紹的其它任一CNN模型一樣不只關(guān)心模型分類精度,同樣也重視其計算速度與模型...
    manofmountain閱讀 3,932評論 0 4
  • 最近發(fā)現(xiàn)自己的一個缺點,很多原理雖然從理論上或著數(shù)學(xué)上理解了,但是難以用一種簡潔的偏于溝通的方式表達出來。所以合上...
    給力桃閱讀 1,810評論 0 0
  • 文章作者:Tyan博客:noahsnail.com | CSDN | 簡書 聲明:作者翻譯論文僅為學(xué)習(xí),如有侵權(quán)請...
    SnailTyan閱讀 9,568評論 0 16
  • 介紹 終于可以說一下Resnet分類網(wǎng)絡(luò)了,它差不多是當(dāng)前應(yīng)用最為廣泛的CNN特征提取網(wǎng)絡(luò)。它的提出始于2015年...
    manofmountain閱讀 295,908評論 3 79
  • 這是一道再平常不過的門 被歲月侵蝕的藍幾處剝落 任夕陽下固執(zhí)的風(fēng) 拂過不動的環(huán) 空自搖落幾度黃昏 門后的兩院青草 ...
    張秉初閱讀 340評論 0 6

友情鏈接更多精彩內(nèi)容