mmdetection組件構(gòu)成與注冊表分析

mmdetection 使用模塊化設(shè)計(jì),將一般的目標(biāo)檢測算法分成了幾個(gè)不同的模塊,在使用時(shí)只需要在配置文件中聲明各個(gè)模塊使用的組件名稱和參數(shù),就可以像搭建積木一樣搭建一個(gè)完整的目標(biāo)檢測模型;

基本組件

mmdetection的組件大多數(shù)以類的形式定義:

  • BACKBONES 對應(yīng)目標(biāo)檢測模型的主干網(wǎng)絡(luò),用以對圖片進(jìn)行特征抽取.如常用的Resnet,ResNeXt,HRNet等.
  • NECKS 對主干網(wǎng)絡(luò)產(chǎn)生的特征圖做一些特定的處理,最常見的就是fpn多尺度抽取信息.現(xiàn)有(FPN,BFP,HRFPN等)
  • Heads 目標(biāo)檢測的頭部,包含了目標(biāo)檢測的主要算法邏輯,包括bbox的產(chǎn)生,回歸target的計(jì)算,loss的計(jì)算等
  • LOSS 損失函數(shù)的定義
  • DETECTOR 前面所介紹的組件搭建而成的一個(gè)整體,通過加載detector來運(yùn)行整體算法
  • PIPELINES 數(shù)據(jù)增強(qiáng)管道類.定義了數(shù)據(jù)預(yù)處理和后處理部分

mmdetection中提供了類似注冊表的實(shí)現(xiàn)方式,對各個(gè)組件進(jìn)行注冊和使用:
首先我們來看Registry類的定義:
mmdet/utils/registry.py

class Registry(object):
    #初始化name是什么組件,組件里面是一個(gè)dict,保存name跟它的具體類
    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def __repr__(self):
        format_str = self.__class__.__name__ + '(name={}, items={})'.format(
            self._name, list(self._module_dict.keys()))
        return format_str

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        return self._module_dict.get(key, None)

    #把組件類與類名注冊到注冊表中,方便從config文件構(gòu)建類
    def _register_module(self, module_class):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    def register_module(self, cls):
        self._register_module(cls)
        return cls

我們看到Registry類其實(shí)底層保存一個(gè)dict,用于保存組件名字跟具體的類.方便從注冊表中找到相應(yīng)的類進(jìn)行初始化.接著,定義了全局注冊表:
mmdet/models/registry.py

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

我們來看,注冊表如何使用:
如果我們自定義了一個(gè)resnet的backbone類,我們將這樣使用Registry類的register_module裝飾函數(shù),將resnet注冊到BACKBONES注冊表中;

@BACKBONES.register_module
class ResNet(nn.Module):

那么我們該如何從config中構(gòu)建起一個(gè)類呢:

#mmdet/models/builder.py
def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

def build_backbone(cfg):
    return build(cfg, BACKBONES)

##mmdet/utils/registry.py
def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    #type即注冊表中類名字,代表了要從注冊表中根據(jù)type的name來獲得類
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')
    if mmcv.is_str(obj_type):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    ##進(jìn)行類的實(shí)例化,并傳入config中的參數(shù)
    return obj_cls(**args)

build_from_cfg函數(shù)的作用是,根據(jù)config文件中的type與傳入的注冊表來獲取需要實(shí)例化的具體類,然后再將config中的參數(shù)傳入類初始化函數(shù)中,得到一個(gè)實(shí)例化的組件類.

resnet為例,整體流程如下所示:

  • (1)resnet類編寫完成后,用@BACKBONES.register_module裝飾器將自身注冊到BACKBONES注冊表中.

  • (2)在config中定義backbone,并指明了具體參數(shù)

#config/faster_rcnn_r50_fpn_1x.py
backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        style='pytorch')
  • (3)通過build_from_cfg()函數(shù),傳入的分別是backbone這個(gè)dict和BACKBONES注冊表類
  • (4)通過'type'為ResNet找到resnet的類,并初始化參數(shù)depth,num_stages,out_indices.

mmdetection這樣通過注冊表的方式實(shí)現(xiàn)了數(shù)據(jù)與實(shí)現(xiàn)的分離;能更好地對組件進(jìn)行抽象.

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容