resnet18 50網(wǎng)絡(luò)結(jié)構(gòu)以及pytorch實(shí)現(xiàn)代碼

1 resnet簡(jiǎn)介

??關(guān)于resnet,網(wǎng)上有大量的文章講解其原理和思路,簡(jiǎn)單來說,resnet巧妙地利用了shortcut連接,解決了深度網(wǎng)絡(luò)中模型退化的問題。

2 論文中的結(jié)構(gòu)如下

網(wǎng)絡(luò)結(jié)構(gòu).png

2.1 參考pytorch中的實(shí)現(xiàn),自己畫了一個(gè)網(wǎng)絡(luò)圖,包含了每一層的參數(shù)和輸出

resnet18&resnet50.jpg

PS:經(jīng)評(píng)論區(qū)@字里行間_yan提醒,原始圖片中部分描述有歧義,已更正。一般來說,特征圖的尺寸變化應(yīng)表述為上采樣和下采樣,通道數(shù)的變化才是升維和降維。


2020/4/17更新
本來自己隨便寫寫,感覺看到這篇文章的人挺多,回來填坑

  • 增加了pytorch中的代碼解讀。
  • 修復(fù)了圖中參數(shù)k的標(biāo)識(shí)錯(cuò)誤(1x1卷積 k=1)
  • 新增SVG文件下載地址鏈接(雖然文件簡(jiǎn)單,下載后還請(qǐng)?jiān)诒疚狞c(diǎn)個(gè)贊鼓勵(lì)下): 鏈接: 鏈接: https://pan.baidu.com/s/183ReRQMJXt2yUkhExnezBA 提取碼: kmf2

3 pytorch中的resnet

3.1 代碼文件

完整代碼文件可在pytorch官方文檔查到,地址https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html#resnet18
順便說一句,torchvision里面實(shí)現(xiàn)了大部分經(jīng)典的網(wǎng)絡(luò),分類分割檢測(cè)都有,還包含了常用的一些數(shù)據(jù)庫加載等,對(duì)于剛?cè)腴T的同學(xué)來說會(huì)省很多事。

3.2 代碼閱讀

個(gè)人習(xí)慣從模型調(diào)用開始看,首先看調(diào)用

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model
def resnet50(pretrained=False, progress=True, **kwargs):
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)

??可以看到resnet至少需要兩個(gè)顯示的參數(shù),分別是block和layers。這里的block就是論文里提到的resnet18和resnet50中應(yīng)用的兩種不同結(jié)構(gòu)。layers就是網(wǎng)絡(luò)層數(shù),也就是每個(gè)block的個(gè)數(shù),在前文圖中也有體現(xiàn)。
??然后看網(wǎng)絡(luò)結(jié)構(gòu),代碼略長(zhǎng),為了閱讀體驗(yàn)就直接截取了重要部分以及在代碼中注釋,建議配合完整代碼閱讀。

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        #參數(shù)比調(diào)用多幾個(gè),模型相較于最初發(fā)文章的時(shí)候有過更新
        #block: basicblock或者bottleneck,后續(xù)會(huì)提到
        #layers:每個(gè)block的個(gè)數(shù),如resnet50, layers=[3,4,6,3]
        #num_classes: 數(shù)據(jù)庫類別數(shù)量
        #zero_init_residual:其他論文中提到的一點(diǎn)小trick,殘差參數(shù)為0
        #groups:卷積層分組,應(yīng)該是為了resnext擴(kuò)展
        #width_per_group:同上,此外還可以是wideresnet擴(kuò)展
        #replace_stride_with_dilation:空洞卷積,非原論文內(nèi)容
        #norm_layer:原論文用BN,此處設(shè)為可自定義
       
         # 中間部分代碼省略,只看模型搭建部分
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    #中間部分代碼省略
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
           #當(dāng)需要特征圖需要降維或通道數(shù)不匹配的時(shí)候調(diào)用
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        #每一個(gè)self.layer的第一層需要調(diào)用downsample,所以單獨(dú)寫,跟下面range中的1 相對(duì)應(yīng)
        #block的定義看下文
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        #前向傳播
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

然后是論文中的block

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

#用在resnet18中的結(jié)構(gòu),也就是兩個(gè)3x3卷積
class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']
    #inplanes:輸入通道數(shù)
    #planes:輸出通道數(shù)
    #base_width,dilation,norm_layer不在本文討論范圍
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        #中間部分省略
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        #為后續(xù)相加保存輸入
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            #遇到降尺寸或者升維的時(shí)候要保證能夠相加
            identity = self.downsample(x)

        out += identity#論文中最核心的部分,resnet的簡(jiǎn)潔和優(yōu)美的體現(xiàn)
        out = self.relu(out)

        return out

#bottleneck是應(yīng)用在resnet50及其以上的結(jié)構(gòu),主要是1x1,3x3,1x1
class Bottleneck(nn.Module):
    expansion = 4
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        #中間省略
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
    #同basicblock
    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

?? 我已經(jīng)很熟悉resnet,不知道在哪些地方存在閱讀問題,所以只在代碼中把一些關(guān)鍵的地方注釋了,代碼可以跟我畫的那個(gè)圖結(jié)合起來看,更容易理解。
??如有問題歡迎留言,我看到就會(huì)回復(fù)。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容