因為相對transformer做一些改動看看效果,所以接下來這幾天先來看看deformable DETR的代碼實現。
先來看models的內容:

models文件
文件position_encoding.py
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from util.misc import NestedTensor
頭文件中只有一個NestedTensor需要解釋,其本身其實是一種對于多個tensor的集合的封裝,讓該集合的tensor同時變換。定義如下:
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
對位置進行sine公式構造編碼的類PositionEmbeddingSine
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats #每一個點的編碼長度
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32) # 列累加和,唯一的位置表示數字
x_embed = not_mask.cumsum(2, dtype=torch.float32) # 行累加和
if self.normalize:
eps = 1e-6
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) # 特征維度上的索引值
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
這部分來自于<Attention is all your need>中的公式:

position embedding
最終獲得每個位置獨一無二的長度為
num_pos_feats的位置編碼。
除此之外,還有可以學習的位置編碼
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos
這個里面對于我來說可能nn.Embedding這層還是第一次見,這里其實表示產生個長度為
的可學習的編碼向量,這個50應該是實際情況設置的,比如featmap的尺寸大小。然后這一層forward時是選擇對應索引的編碼向量。比如
x_emb=self.col_embed(i)就是選擇i對應的索引所指向的編碼向量。
兩個不同編碼方式同意的接口:
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding
backbone.py文件里主要有一個torchvision.model._utils.IntermediateLayerGetter, 這個函數主要是從模型中設置輸出的層,代碼如下:
class IntermediateLayerGetter(nn.ModuleDict):
"""
#Module封裝器,用以返回model的中間若干層輸出
Module wrapper that returns intermediate layers from a model
# 有一個強假設,即模塊在使用時與模型中注冊順序相同
It has a strong assumption that the modules have been registered
into the model in the same order as they are used.
# 這表示forward中的module不能使用兩次
This means that one should **not** reuse the same nn.Module
twice in the forward if you want this to work.
# 另外,只能query model中直接注冊的層,而不能使用間接的層,比如resnet中層可能為layer1.1.conv1等,此時只能query layer1, 而不能使用layer1.1
Additionally, it is only able to query submodules that are directly
assigned to the model. So if `model` is passed, `model.feature1` can
be returned, but not `model.feature1.layer2`.
Arguments:
model (nn.Module): model on which we will extract the features
# 模型:用于被選擇層的model,比如resnet
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
# 字典,設定被選擇的層,name是model的層名, new_name是重新命名的層
def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.named_children():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out
舉個例子:
m = torchvision.models.resnet50(pretrained=True)
new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': '1', 'layer2': '2'})
out = new_m(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
可以發(fā)現重新命名了,且內容從resnet50的layer1和layer2取出。
backbone的創(chuàng)建函數:
def build_backbone(args):
position_embedding = build_position_encoding(args) # 位置編碼,每個位置對應一個向量, NxdxHxW
train_backbone = args.lr_backbone > 0 # 是否固定訓練好的參數
return_interm_layers = args.masks or (args.num_feature_levels > 1) # 返回的中間層, True的話返回2,3,4層
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding) # 返回每個中間層的輸出,和對應的pos編碼
return model
這兩個文件主要是定義了backbone的內容,提供圖像CNN之后的特征以及對應的位置編碼,下面我們關注一下transformer的內容。