1 resnet簡(jiǎn)介
??關(guān)于resnet,網(wǎng)上有大量的文章講解其原理和思路,簡(jiǎn)單來說,resnet巧妙地利用了shortcut連接,解決了深度網(wǎng)絡(luò)中模型退化的問題。
2 論文中的結(jié)構(gòu)如下

2.1 參考pytorch中的實(shí)現(xiàn),自己畫了一個(gè)網(wǎng)絡(luò)圖,包含了每一層的參數(shù)和輸出

PS:經(jīng)評(píng)論區(qū)@字里行間_yan提醒,原始圖片中部分描述有歧義,已更正。一般來說,特征圖的尺寸變化應(yīng)表述為上采樣和下采樣,通道數(shù)的變化才是升維和降維。
2020/4/17更新
本來自己隨便寫寫,感覺看到這篇文章的人挺多,回來填坑
- 增加了pytorch中的代碼解讀。
- 修復(fù)了圖中參數(shù)k的標(biāo)識(shí)錯(cuò)誤(1x1卷積 k=1)
- 新增SVG文件下載地址鏈接(雖然文件簡(jiǎn)單,下載后還請(qǐng)?jiān)诒疚狞c(diǎn)個(gè)贊鼓勵(lì)下): 鏈接: 鏈接: https://pan.baidu.com/s/183ReRQMJXt2yUkhExnezBA 提取碼: kmf2
3 pytorch中的resnet
3.1 代碼文件
完整代碼文件可在pytorch官方文檔查到,地址https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html#resnet18
順便說一句,torchvision里面實(shí)現(xiàn)了大部分經(jīng)典的網(wǎng)絡(luò),分類分割檢測(cè)都有,還包含了常用的一些數(shù)據(jù)庫加載等,對(duì)于剛?cè)腴T的同學(xué)來說會(huì)省很多事。
3.2 代碼閱讀
個(gè)人習(xí)慣從模型調(diào)用開始看,首先看調(diào)用
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet50(pretrained=False, progress=True, **kwargs):
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
??可以看到resnet至少需要兩個(gè)顯示的參數(shù),分別是block和layers。這里的block就是論文里提到的resnet18和resnet50中應(yīng)用的兩種不同結(jié)構(gòu)。layers就是網(wǎng)絡(luò)層數(shù),也就是每個(gè)block的個(gè)數(shù),在前文圖中也有體現(xiàn)。
??然后看網(wǎng)絡(luò)結(jié)構(gòu),代碼略長(zhǎng),為了閱讀體驗(yàn)就直接截取了重要部分以及在代碼中注釋,建議配合完整代碼閱讀。
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
#參數(shù)比調(diào)用多幾個(gè),模型相較于最初發(fā)文章的時(shí)候有過更新
#block: basicblock或者bottleneck,后續(xù)會(huì)提到
#layers:每個(gè)block的個(gè)數(shù),如resnet50, layers=[3,4,6,3]
#num_classes: 數(shù)據(jù)庫類別數(shù)量
#zero_init_residual:其他論文中提到的一點(diǎn)小trick,殘差參數(shù)為0
#groups:卷積層分組,應(yīng)該是為了resnext擴(kuò)展
#width_per_group:同上,此外還可以是wideresnet擴(kuò)展
#replace_stride_with_dilation:空洞卷積,非原論文內(nèi)容
#norm_layer:原論文用BN,此處設(shè)為可自定義
# 中間部分代碼省略,只看模型搭建部分
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
#中間部分代碼省略
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
#當(dāng)需要特征圖需要降維或通道數(shù)不匹配的時(shí)候調(diào)用
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
#每一個(gè)self.layer的第一層需要調(diào)用downsample,所以單獨(dú)寫,跟下面range中的1 相對(duì)應(yīng)
#block的定義看下文
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
#前向傳播
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
然后是論文中的block
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
#用在resnet18中的結(jié)構(gòu),也就是兩個(gè)3x3卷積
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
#inplanes:輸入通道數(shù)
#planes:輸出通道數(shù)
#base_width,dilation,norm_layer不在本文討論范圍
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
#中間部分省略
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
#為后續(xù)相加保存輸入
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
#遇到降尺寸或者升維的時(shí)候要保證能夠相加
identity = self.downsample(x)
out += identity#論文中最核心的部分,resnet的簡(jiǎn)潔和優(yōu)美的體現(xiàn)
out = self.relu(out)
return out
#bottleneck是應(yīng)用在resnet50及其以上的結(jié)構(gòu),主要是1x1,3x3,1x1
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
#中間省略
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
#同basicblock
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
?? 我已經(jīng)很熟悉resnet,不知道在哪些地方存在閱讀問題,所以只在代碼中把一些關(guān)鍵的地方注釋了,代碼可以跟我畫的那個(gè)圖結(jié)合起來看,更容易理解。
??如有問題歡迎留言,我看到就會(huì)回復(fù)。