使用MxNet新接口Gluon提供的預訓練模型進行微調(diào)

1. 導入各種包

from mxnet import gluon
import mxnet as mx
from mxnet.gluon import nn
from mxnet import ndarray as nd
import matplotlib.pyplot as plt
import cv2
from mxnet import image
from mxnet import autograd

2. 導入數(shù)據(jù)

我使用cifar10這個數(shù)據(jù)集,使用gluon自帶的模塊下載到本地并且為了配合后面的網(wǎng)絡,我將大小調(diào)整到224*224

def transform(data, label):
    data = image.imresize(data, 224, 224)
    return data.astype('float32'), label.astype('float32')
cifar10_train = gluon.data.vision.CIFAR10(root='./',train=True, transform=transform)
cifar10_test = gluon.data.vision.CIFAR10(root='./',train=False, transform=transform)
batch_size = 64
train_data = gluon.data.DataLoader(cifar10_train, batch_size, shuffle=True)
test_data = gluon.data.DataLoader(cifar10_test, batch_size, shuffle=False)

3. 加載預訓練模型

gluon提供的很多預訓練模型,我選擇一個簡單的模型AlexNet
首先下載AlexNet模型和模型參數(shù)
使用下面的代碼會獲取AlexNet的模型并且加載預訓練好的模型參數(shù),但是鑒于網(wǎng)絡的原因,我提前下好了

alexnet = mx.gluon.model_zoo.vision.alexnet(pretrained=True)#如果pretrained值為True,則會下載預訓練參數(shù),否則是空模型

獲取模型并從本地加載參數(shù)

alexnet = mx.gluon.model_zoo.vision.alexnet()
alexnet.load_params('alexnet-44335d1f.params',ctx=mx.gpu())

看下AlexNet網(wǎng)絡結構,發(fā)現(xiàn)分為兩部分,features,classifier,而features正好是需要的

print(alexnet)
AlexNet(
  (features): HybridSequential(
    (0): Conv2D(64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False)
    (2): Conv2D(192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False)
    (4): Conv2D(384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): Conv2D(256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): Conv2D(256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False)
    (8): Flatten
  )
  (classifier): HybridSequential(
    (0): Dense(4096, Activation(relu))
    (1): Dropout(p = 0.5)
    (2): Dense(4096, Activation(relu))
    (3): Dropout(p = 0.5)
    (4): Dense(1000, linear)
  )
)

4. 組合新的網(wǎng)絡

截取想要的features,并且固定參數(shù)。這樣防止訓練的時候把預訓練好的參數(shù)給搞壞了

featuresnet = alexnet.features
for _, w in featuresnet.collect_params().items():
    w.grad_req = 'null'

自己定義后面的網(wǎng)絡,因為數(shù)據(jù)集是10類,就把最后的輸出從1000改成了10。

def Classifier():
    net = nn.HybridSequential()
    net.add(nn.Dense(4096, activation="relu"))
    net.add(nn.Dropout(.5))
    net.add(nn.Dense(4096, activation="relu"))
    net.add(nn.Dropout(.5))
    net.add(nn.Dense(10))
    return net

接著需要把兩部分組合起來,并且對第二部分機進行初始化

net = nn.HybridSequential()
with net.name_scope():
    net.add(featuresnet)
    net.add(Classifier())
    net[1].collect_params().initialize(init=mx.init.Xavier(),ctx=mx.gpu())
net.hybridize()

5. 訓練

最后就是訓練了,看看效果如何

#定義準確率函數(shù)
def accuracy(output, label):
    return nd.mean(output.argmax(axis=1)==label).asscalar()
def evaluate_accuracy(data_iterator, net, ctx=mx.gpu()):
    acc = 0.
    for data, label in data_iterator:
        data = data.transpose([0,3,1,2])
        data = data/255
        output = net(data.as_in_context(ctx))
        acc += accuracy(output, label.as_in_context(ctx))
    return acc / len(data_iterator)
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(
    net.collect_params(), 'sgd', {'learning_rate': 0.01})
for epoch in range(1):
    train_loss = 0.
    train_acc = 0.
    test_acc = 0.
    for data, label in train_data:
        label = label.as_in_context(mx.gpu())
        data = data.transpose([0,3,1,2])
        data = data/255
        with autograd.record():
            output = net(data.as_in_context(mx.gpu()))
            loss = softmax_cross_entropy(output, label)
        loss.backward()
        trainer.step(batch_size)

        train_loss += nd.mean(loss).asscalar()
        train_acc += accuracy(output, label)
    test_acc = evaluate_accuracy(test_data, net)
    print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
        epoch, train_loss/len(train_data), 
        train_acc/len(train_data),test_acc))
Epoch 0. Loss: 1.249197, Train acc 0.558764, Test acc 0.696756
最后編輯于
?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

  • Android 自定義View的各種姿勢1 Activity的顯示之ViewRootImpl詳解 Activity...
    passiontim閱讀 179,112評論 25 709
  • 首頁 資訊 文章 資源 小組 相親 登錄 注冊 首頁 最新文章 IT 職場 前端 后端 移動端 數(shù)據(jù)庫 運維 其他...
    Helen_Cat閱讀 4,153評論 1 10
  • 聲明:作者翻譯論文僅為學習,如有侵權請聯(lián)系作者刪除博文,謝謝! 翻譯論文匯總:https://github.com...
    SnailTyan閱讀 12,749評論 1 27
  • 王同學的第三十篇亂寫 犯了錯誤不可怕,自己意識到了,改了就好了,可怕的是知道這么做不好,還知道應該如何去做,但仍舊...
    王雪梅2017閱讀 145評論 0 0
  • 你總能在平常生活中看到某些人或事,你可以在這些人或事身上學到些什么,你要多動動你的腦袋想想,如果發(fā)生在你身上你應該...
    小木易楊閱讀 241評論 0 0

友情鏈接更多精彩內(nèi)容