Pytorch 實(shí)現(xiàn) MobileNet V3 模型,并從 TensorFlow 轉(zhuǎn)化預(yù)訓(xùn)練參數(shù)

????????隨著移動(dòng)終端的普及,以及在其上運(yùn)行深度學(xué)習(xí)模型的需求,神經(jīng)網(wǎng)絡(luò)小型化越來(lái)越得到重視和關(guān)注,已經(jīng)成為研究的熱門之一。作為小型化模型的經(jīng)典代表,MobileNet 系列模型已經(jīng)先后迭代了 3 代,在保持模型參數(shù)量和運(yùn)算量都極其小的情況下,其性能越來(lái)越優(yōu)異。本文我們將實(shí)現(xiàn)最新一代的 MobileNet V3,為了能不花費(fèi)時(shí)間在 ImageNet 數(shù)據(jù)集上訓(xùn)練而直接使用,我們將從 TensorFlow 官方實(shí)現(xiàn)的 MobileNet V3 上轉(zhuǎn)化預(yù)訓(xùn)練參數(shù)。

????????本文將重點(diǎn)關(guān)注以下兩個(gè)方面:

  • 詳細(xì)解讀 MobileNet V3 的網(wǎng)絡(luò)結(jié)構(gòu);
  • 詳細(xì)講述從 TensorFlow 轉(zhuǎn)化預(yù)訓(xùn)練參數(shù)的方法;

????????本文所有代碼見 GitHub: mobilenet_v3

一、MobileNet V3 模型

二、模型實(shí)現(xiàn)

三、預(yù)訓(xùn)練參數(shù)轉(zhuǎn)化

????????完全采用手動(dòng)指定的方式進(jìn)行,即對(duì)于 Pytorch 模型的每一參數(shù),從對(duì)應(yīng)的 TensorFlow 預(yù)訓(xùn)練參數(shù)里取出,然后賦值給它即可。為了保證轉(zhuǎn)化的準(zhǔn)確性,我們的目標(biāo)是:

  • 原 TensorFlow 預(yù)訓(xùn)練模型和轉(zhuǎn)化后的 Pytorch 模型的預(yù)測(cè)結(jié)果要絕對(duì)一致

以下,我們?cè)敿?xì)的來(lái)描述怎么從 TensorFlow 轉(zhuǎn)化預(yù)訓(xùn)練參數(shù)。

1.查看 TensorFlow 預(yù)訓(xùn)練模型參數(shù)名

????????首先到 此頁(yè) 下載 MobileNet V3 模型的 TensorFlow 預(yù)訓(xùn)練模型,下載后請(qǐng)解壓。我們以 large dm=1 (float) 預(yù)訓(xùn)練模型為例來(lái)說(shuō)明。首先,使用如下代碼:

import json
import tensorflow as tf

if __name__ == '__main__':
    checkpoint_path = 'xxx/v3-large_224_1.0_float/ema/model-540000'
    output_path = './mobilenet_v3_large.json'

    reader = tf.train.NewCheckpointReader(checkpoint_path)
    weights = {var: 1 for (var, _) in
               reader.get_variable_to_shape_map().items()}
    
    with open(output_path, 'w') as writer:
        json.dump(weights, writer)

預(yù)訓(xùn)練模型中的所有參數(shù)名都寫到一個(gè) json 文件里,為了不把高維的數(shù)據(jù)寫進(jìn)去,我們都將確切的值改成了 1。但直接寫進(jìn)去的內(nèi)容很亂,可以借助 json 串格式化的工具(比如,在線格式化,或者 Google Chrome 瀏覽器插件 FeHelper)將 mobilenet_v3_large.json 文件里的內(nèi)容格式化,這樣你看到的形式就大概如下了:

格式化之后的 mobilenet_v3_large.json 內(nèi)容

接著,結(jié)合 TensorFlow 官方開源的 MobileNet V3 Large 模型的網(wǎng)絡(luò)定義

MobileNet V3 large 模型 TensorFlow 網(wǎng)絡(luò)定義

就基本可以知道整個(gè)模型參數(shù)命名的具體名字和順序了:

MobilenetV3/Conv/
MobilenetV3/expanded_conv/
MobilenetV3/expanded_conv_1/
...
MobilenetV3/expanded_conv_14/
MobilenetV3/Conv_1/
MobilenetV3/Conv_2/
MobilenetV3/Logits/Conv2d_1c_1x1

以上是 large 模型的總共 19 個(gè)大的命名空間(scope),每個(gè) / 之后會(huì)接小的命名空間。對(duì)于普通的卷積層,比如 Conv, Conv_1, Conv_2, Logits/Conv2d_1c_1x1 你要關(guān)注兩個(gè)東西:

  • 是否有偏置參數(shù):biases
  • 是否有批標(biāo)準(zhǔn)化:BatchNorm

這既可以幫助你修正你定義的 Pytorch 模型,也可以在轉(zhuǎn)化賦值的時(shí)候防止被遺忘。類似的思想可以直接移植到復(fù)雜的模塊 mbv3_op 對(duì)應(yīng)的命名空間,expanded_conv, expanded_conv_1, ...。舉個(gè)簡(jiǎn)單的例子,看 large 模型的第一卷積層:MobilenetV3/Conv/,因?yàn)樵搶邮褂昧伺鷺?biāo)準(zhǔn)化(batch normalization),因此是沒(méi)有偏置參數(shù)的,那么就只有如下的 5 個(gè)參數(shù):

MobilenetV3/Conv/weights,
MobilenetV3/BatchNorm/beta,
MobilenetV3/Conv/BatchNorm/gamma
MobilenetV3/Conv/BatchNorm/moving_mean
MobilenetV3/Conv/BatchNorm/moving_variance

其中后 4 個(gè)參數(shù)對(duì)應(yīng)于批標(biāo)準(zhǔn)化的公式:
\gamma \frac{x - \mu}{\sigma} + \beta. \\
再看 large 模型的最后一個(gè)卷積層(分類層):MobilenetV3/Logits/Conv2d_1c_1x1,因?yàn)樵搶記](méi)有使用批標(biāo)準(zhǔn)化的正規(guī)化函數(shù),因此帶有偏置項(xiàng),就只有兩個(gè)參數(shù):

MobilenetV3/Logits/Conv2d_1c_1x1/weights
MobilenetV3/Logits/Conv2d_1c_1x1/biases

至于其他復(fù)雜模塊,分割開單獨(dú)考慮中間命名空間: project, expand, depthwise, squeeze_excite 之后,其實(shí)就是簡(jiǎn)單的卷積層了,因此也很容易處理。

2.查看 Pytorch 模型結(jié)構(gòu)

????????這一步更容易,直接實(shí)例化定義的 Pytorch 模型,然后打印出來(lái)(這里,模型的所有的層都定義在了屬性 _layers 里,見 mobilenet_v3.MobileNet 類):

import mobilenet_v3

large = mobilenet_v3.large()
print(large._layers[:10])
print(large._layers[10:])

因?yàn)槟P徒Y(jié)構(gòu)很長(zhǎng),所以打印的時(shí)候分成了前后兩部分。保存在 txt 文件里如下:

MobileNet V3 large 模型網(wǎng)絡(luò)結(jié)構(gòu)

這一步我們唯一需要關(guān)注的就是每一層在網(wǎng)絡(luò)結(jié)構(gòu)里的下標(biāo)了,比如 _layers[0] 就是整個(gè)網(wǎng)絡(luò)的第 1 個(gè)卷積層模塊,而 _layers[0]._layers[0] 是這個(gè)模塊內(nèi)的二維卷積層,_layers[0]._layers[1] 是這個(gè)模塊內(nèi)的批標(biāo)準(zhǔn)化層。因?yàn)?torch.nn.Sequential 的行為和 list 一樣,因此它們的順序是確定不變的,取下標(biāo)是非常安全的操作。

3.對(duì)照參數(shù)名逐一賦值

????????經(jīng)過(guò)前面兩步之后,應(yīng)該對(duì) TensorFlow 預(yù)訓(xùn)練模型 和 Pytorch 定義的模型結(jié)構(gòu) 之間的對(duì)應(yīng)關(guān)系應(yīng)該有所印象了,下面需要將它們嚴(yán)格的對(duì)應(yīng)起來(lái),以便預(yù)訓(xùn)練參數(shù)轉(zhuǎn)化。

????????首先,看第一個(gè)卷積模塊,它包含一個(gè)卷積層、批標(biāo)準(zhǔn)化層和一個(gè)激活函數(shù)層,其中只有前兩者是有訓(xùn)練參數(shù)的。而且,根據(jù)第一步,我們知道對(duì)應(yīng)的 TensorFlow 模型這一個(gè)模塊的命名空間是:MobilenetV3/Conv/,因此如果我聲明了

import mobilenet_v3

model = mobilenet_v3.large()

large 模型,那么對(duì)應(yīng)的第 1 個(gè)卷積模塊的二維卷積層是 model._layers[0]._layers[0],批標(biāo)準(zhǔn)化層是 model._layers[0]._layers[1]。它們所含有的參數(shù)如下:

model._layers[0]._layers[0].weight
model._layers[0]._layers[1].bias:
model._layers[0]._layers[1].weight
model._layers[0]._layers[1].running_mean
model._layers[0]._layers[1].running_var

即卷積層的權(quán)重參數(shù)(對(duì)于 slim.conv2d(),如果指定了正規(guī)化函數(shù),即關(guān)鍵字參數(shù) normalizer_fn 不為 None,那么這個(gè)卷積層是沒(méi)有偏置項(xiàng)的;反之,則有,除非將偏置的初始化函數(shù) biases_initializer 設(shè)為 None),和批標(biāo)準(zhǔn)化層的 4 個(gè)參數(shù):
\gamma \frac{x - \mu}{\sigma} + \beta. \\
很容易的,你可以從 mobilenet_v3_large.json 里找到對(duì)應(yīng)的 TensorFlow 變量名:

conversion_map_for_root_block = {
    model._layers[0]._layers[0].weight: 
        'MobilenetV3/Conv/weights',
    model._layers[0]._layers[1].bias: 
        'MobilenetV3/Conv/BatchNorm/beta',
    model._layers[0]._layers[1].weight:
        'MobilenetV3/Conv/BatchNorm/gamma',
    model._layers[0]._layers[1].running_mean: 
        'MobilenetV3/Conv/BatchNorm/moving_mean',
    model._layers[0]._layers[1].running_var: 
        'MobilenetV3/Conv/BatchNorm/moving_variance',
}

然后用函數(shù) tf.train.load_variable,按照 TensorFlow 的變量名從預(yù)訓(xùn)練模型中取出變量的名字賦值給對(duì)應(yīng)的 Pytorch 變量,比如:

checkpoint_path = 'xxx/v3-large_224_1.0_float/ema/model-540000'

tf_param = tf.train.load_variable(checkpoint_path, 'MobilenetV3/Conv/weights')
tf_param = np.transpose(tf_param, (3, 2, 0, 1))
model._layers[0]._layers[0].weight.data = torch.from_numpy(tf_param)

就將第 1 個(gè)卷積層的參數(shù)轉(zhuǎn)化好了。這里,唯一需要注意的是,TensorFlow 權(quán)重的順序是 [kernel_size, kernel_size, in_channels, out_channels],而 Pytorch 的順序是 [out_channels, in_channels, kernel_size, kernel_size],因此要將它們的順序調(diào)整到一致。

????????其它參數(shù)完全按照一樣的方式轉(zhuǎn)化即可。完整的轉(zhuǎn)化代碼請(qǐng)見 converter.py。

????????以上過(guò)程結(jié)束之后,我們來(lái)轉(zhuǎn)化幾個(gè)模型

1.large 模型

????????執(zhí)行(tf_checkpoint_path 參數(shù)指定 TensorFlow 預(yù)訓(xùn)練模型參數(shù)的保存路徑):

python3 tf_weights_to_pth.py --tf_checkpoint_path xxx/v3-large_224_1.0_float/ema/model-540000

將在當(dāng)前項(xiàng)目路徑下生成一個(gè) pretrained_models 文件夾,里面保存了轉(zhuǎn)化后的模型:mobilenet_v3_large.pth,同時(shí)將輸出測(cè)試圖片(熊貓圖片):

panda.jpg

的分類結(jié)果:

large 模型 TensorFlow 原預(yù)訓(xùn)練模型和轉(zhuǎn)化的 Pytorch 模型對(duì)熊貓圖片的識(shí)別結(jié)果

可以看到兩者的結(jié)果是一模一樣的。類似的,再指定另一張測(cè)試圖片(貓圖片),執(zhí)行以下命令(image_path 參數(shù)指定測(cè)試圖片的路徑):

python3 tf_weights_to_pth.py --tf_checkpoint_path xxx/v3-large_224_1.0_float/ema/model-540000 \
    --image_path ./test/cat.jpg
cat.jpg

就可以看到對(duì)貓的分類結(jié)果:

large 模型 TensorFlow 原預(yù)訓(xùn)練模型和轉(zhuǎn)化的 Pytorch 模型對(duì)貓圖片的識(shí)別結(jié)果

顯然,TensorFlow 官方和本文實(shí)現(xiàn)的 Pytorch 模型的預(yù)測(cè)結(jié)果也是一模一樣的。

2.small 模型(depth_multiplier = 0.75)

執(zhí)行(output_name 指定轉(zhuǎn)化來(lái)的模型的保存名字,depth_multiplier 指定卷積層的通道數(shù)乘子,model_name 指定轉(zhuǎn)化的模型名):

python3 tf_weights_to_pth.py --tf_checkpoint_path xxx/v3-small_224_0.75_float/ema/model-497500 \
    --output_name mobilenet_v3_small_0.75.pth --depth_multiplier 0.75 --model_name small

得到熊貓圖片的分類結(jié)果:

small-dm=0.75 模型 TensorFlow 原預(yù)訓(xùn)練模型和轉(zhuǎn)化的 Pytorch 模型對(duì)熊貓圖片的識(shí)別結(jié)果

也得到一模一樣的結(jié)果,說(shuō)明轉(zhuǎn)化參數(shù)是正確的。

????????當(dāng)前支持參數(shù)轉(zhuǎn)化的預(yù)訓(xùn)練模型如下:

本文所有支持參數(shù)轉(zhuǎn)化的預(yù)訓(xùn)練模型

對(duì)應(yīng)的模型名(由 model_name 參數(shù)指定)分別為:large, small, large_minimalistic, small_minimalistic,如果 dm=0.75,請(qǐng)指定參數(shù) depth_multiplier。你可以逐一轉(zhuǎn)化并驗(yàn)證本文定義的 MobileNet V3 模型的正確性,不出意外應(yīng)該是準(zhǔn)確的(作者未轉(zhuǎn)化 8-bit 的預(yù)訓(xùn)練模型)。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容