用PyTorch實現(xiàn)MNIST手寫數(shù)字識別

image

MNIST可以說是機器學習入門的hello word了!導師一般第一個就讓你研究MNIST,研究透了,也算基本入門了。好的,今天就來扯一扯學一學。

image

在本文中,我們將在PyTorch中構建一個簡單的卷積神經(jīng)網(wǎng)絡,并使用MNIST數(shù)據(jù)集訓練它識別手寫數(shù)字。在MNIST數(shù)據(jù)集上訓練分類器可以看作是圖像識別的“hello world”。

MNIST包含70,000張手寫數(shù)字圖像: 60,000張用于培訓,10,000張用于測試。圖像是灰度的,28x28像素的,并且居中的,以減少預處理和加快運行。

設置環(huán)境

在本文中,我們將使用PyTorch訓練一個卷積神經(jīng)網(wǎng)絡來識別MNIST的手寫數(shù)字。PyTorch是一個非常流行的深度學習框架,比如Tensorflow、CNTK和caffe2。但是與其他框架不同的是,PyTorch具有動態(tài)執(zhí)行圖,這意味著計算圖是動態(tài)創(chuàng)建的。

先去官網(wǎng)上根據(jù)指南在PC上裝好PyTorch環(huán)境,然后引入庫。
import torch

準備數(shù)據(jù)集

導入就緒后,我們可以繼續(xù)準備將要使用的數(shù)據(jù)。但在那之前,我們將定義超參數(shù),我們將使用的實驗。在這里,epoch的數(shù)量定義了我們將循環(huán)整個訓練數(shù)據(jù)集的次數(shù),而learning_rate和momentum是我們稍后將使用的優(yōu)化器的超參數(shù)。
n_epochs = 3
對于可重復的實驗,我們必須為任何使用隨機數(shù)產(chǎn)生的東西設置隨機種子——如numpy和random! 

現(xiàn)在我們還需要數(shù)據(jù)集的dataloader。這就是TorchVision發(fā)揮作用的地方。它讓我們用一種方便的方式來加載MNIST數(shù)據(jù)集。我們將使用batch_size=64進行訓練,并使用size=1000對這個數(shù)據(jù)集進行測試。下面的Normalize()轉換使用的值0.1307和0.3081是MNIST數(shù)據(jù)集的全局平均值和標準偏差,這里我們將它們作為給定值。

TorchVision提供了許多方便的轉換,比如裁剪或標準化。
train_loader = torch.utils.data.DataLoader(
運行上面的程序后,會自動將數(shù)據(jù)集下載到目錄下的data文件夾。下載過程可能有點煩,經(jīng)??ㄗ〔粍?,只能多來幾遍。完成后就是這樣了:
image
除了數(shù)據(jù)集和批處理大小之外,PyTorch的DataLoader還包含一些有趣的選項。例如,我們可以使用num_workers > 1來使用子進程異步加載數(shù)據(jù),或者使用固定RAM(通過pin_memory)來加速RAM到GPU的傳輸。但是因為這些在我們使用GPU時很重要,我們可以在這里省略它們。

現(xiàn)在讓我們看一些例子。我們將為此使用test_loader。

讓我們看看一批測試數(shù)據(jù)由什么組成。
examples = enumerate(test_loader)

example_targets是圖片實際對應的數(shù)字標簽:

image

一批測試數(shù)據(jù)是一個形狀張量:

image
這意味著我們有1000個例子的28x28像素的灰度(即沒有rgb通道)。

我們可以使用matplotlib來繪制其中的一些
import matplotlib.pyplot as plt
image
好的,經(jīng)過一些訓練,這些應該不難識別。

構建網(wǎng)絡

現(xiàn)在讓我們開始建立我們的網(wǎng)絡。我們將使用兩個2d卷積層,然后是兩個全連接(或線性)層。作為激活函數(shù),我們將選擇整流線性單元(簡稱ReLUs),作為正則化的手段,我們將使用兩個dropout層。在PyTorch中,構建網(wǎng)絡的一個好方法是為我們希望構建的網(wǎng)絡創(chuàng)建一個新類。讓我們在這里導入一些子模塊,以獲得更具可讀性的代碼。
import torch.nn as nn
class Net(nn.Module):
**具體各部分的含義,在下面詳細講!**

廣義地說,我們可以想到torch.nn層中包含可訓練的參數(shù),而torch.nn.functional就是純粹的功能性。forward()傳遞定義了使用給定的層和函數(shù)計算輸出的方式。為了便于調(diào)試,在前向傳遞中打印出張量是完全可以的。在試驗更復雜的模型時,這就派上用場了。請注意,前向傳遞可以使用成員變量甚至數(shù)據(jù)本身來確定執(zhí)行路徑——它還可以使用多個參數(shù)!

現(xiàn)在讓我們初始化網(wǎng)絡和優(yōu)化器。
network = Net()
注意:如果我們使用GPU進行訓練,我們也應該使用例如network.cuda()將網(wǎng)絡參數(shù)發(fā)送給GPU。在將網(wǎng)絡參數(shù)傳遞給優(yōu)化器之前,將它們傳輸?shù)竭m當?shù)脑O備是很重要的,否則優(yōu)化器將無法以正確的方式跟蹤它們。

模型訓練

是時候建立我們的訓練循環(huán)了。首先,我們要確保我們的網(wǎng)絡處于訓練模式。然后,每個epoch對所有訓練數(shù)據(jù)進行一次迭代。加載單獨批次由DataLoader處理。

首先,我們需要使用optimizer.zero_grad()手動將梯度設置為零,因為PyTorch在默認情況下會累積梯度。然后,我們生成網(wǎng)絡的輸出(前向傳遞),并計算輸出與真值標簽之間的負對數(shù)概率損失?,F(xiàn)在,我們收集一組新的梯度,并使用optimizer.step()將其傳播回每個網(wǎng)絡參數(shù)。有關PyTorch自動漸變系統(tǒng)內(nèi)部工作方式的詳細信息,請參閱autograd的官方文檔(強烈推薦)。

我們還將使用一些打印輸出來跟蹤進度。為了在以后創(chuàng)建一個良好的培訓曲線,我們還創(chuàng)建了兩個列表來節(jié)省培訓和測試損失。在x軸上,我們希望顯示網(wǎng)絡在培訓期間看到的培訓示例的數(shù)量。
train_losses = []
在開始訓練之前,我們將運行一次測試循環(huán),看看僅使用隨機初始化的網(wǎng)絡參數(shù)可以獲得多大的精度/損失。你能猜出我們的準確度是多少嗎?
def train(epoch):
image
神經(jīng)網(wǎng)絡模塊以及優(yōu)化器能夠使用.state_dict()保存和加載它們的內(nèi)部狀態(tài)。這樣,如果需要,我們就可以繼續(xù)從以前保存的狀態(tài)dict中進行訓練——只需調(diào)用.load_state_dict(state_dict)。

現(xiàn)在進入測試循環(huán)。在這里,我們總結了測試損失,并跟蹤正確分類的數(shù)字來計算網(wǎng)絡的精度。
def test():
image
使用上下文管理器no_grad(),我們可以避免將生成網(wǎng)絡輸出的計算結果存儲在計算圖中。
是時候開始訓練了!我們將在循環(huán)遍歷n_epochs之前手動添加test()調(diào)用,以使用隨機初始化的參數(shù)來評估我們的模型。
test()

震驚了,我的電腦!!

image

運行結果:

image

評估模型的性能

就是這樣。僅僅經(jīng)過3個階段的訓練,我們已經(jīng)能夠達到測試集97%的準確率!我們開始使用隨機初始化的參數(shù),正如預期的那樣,在開始訓練之前,測試集的準確率只有10%左右。

我們來畫一下訓練曲線。
test()
image
從訓練曲線來看,看起來我們甚至可以繼續(xù)訓練幾個epoch!

但在此之前,讓我們再看看幾個例子,正如我們之前所做的,并比較模型的輸出。
examples = enumerate(test_loader)
image
我們的模型對這些例子的預測似乎是正確的!

檢查點的持續(xù)訓練

現(xiàn)在讓我們繼續(xù)對網(wǎng)絡進行訓練,或者看看如何從第一次培訓運行時保存的state_dicts中繼續(xù)進行訓練。我們將初始化一組新的網(wǎng)絡和優(yōu)化器。
continued_network = Net()
使用.load_state_dict(),我們現(xiàn)在可以加載網(wǎng)絡的內(nèi)部狀態(tài),并在最后一次保存它們時優(yōu)化它們。
network_state_dict = torch.load('model.pth')
同樣,運行一個訓練循環(huán)應該立即恢復我們之前的訓練。為了檢查這一點,我們只需使用與前面相同的列表來跟蹤損失值。由于我們?yōu)樗吹降挠柧毷纠臄?shù)量構建測試計數(shù)器的方式,我們必須在這里手動添加它。
for i in range(4,9):
image
太棒了!我們再次看到測試集的準確性從一個epoch到另一個epoch有了(運行更慢的,慢的多了)提高。讓我們用圖像來進一步檢查訓練進度。
fig = plt.figure()

(我沒跑出了,但)

這看起來仍然像一個相當平滑的學習曲線,就像我們最初要訓練8個epoch!請記住,我們只是將值添加到從第5個紅點開始的相同列表中。

由此我們可以得出兩個結論:

    1\. 從檢查點內(nèi)部狀態(tài)繼續(xù)按預期工作。
    2\. 我們似乎仍然沒有遇到過擬合問題!看起來我們的dropout層做了一個很好的規(guī)范模型。

總結

總之,我們使用PyTorch和TorchVision構建了一個新環(huán)境,并使用它從MNIST數(shù)據(jù)集中對手寫數(shù)字進行分類,希望使用PyTorch開發(fā)出一個良好的直覺。對于進一步的信息,官方的PyTorch文檔確實寫得很好,論壇也很活躍!
?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容