在介紹softmax回歸的實(shí)現(xiàn)前我們先引入一個多類圖像分類數(shù)據(jù)集。它將在后面的章節(jié)中被多次使用,以方便我們觀察比較算法之間在模型精度和計算效率上的區(qū)別。圖像分類數(shù)據(jù)集中最常用的是手寫數(shù)字識別數(shù)據(jù)集MNIST[1]。但大部分模型在MNIST上的分類精度都超過了95%。為了更直觀地觀察算法之間的差異,我們將使用一個圖像內(nèi)容更加復(fù)雜的數(shù)據(jù)集Fashion-MNIST[2](這個數(shù)據(jù)集也比較小,只有幾十M,沒有GPU的電腦也能吃得消)。
本節(jié)我們將使用torchvision包,它是服務(wù)于PyTorch深度學(xué)習(xí)框架的,主要用來構(gòu)建計算機(jī)視覺模型。torchvision主要由以下幾部分構(gòu)成:
torchvision.datasets: 一些加載數(shù)據(jù)的函數(shù)及常用的數(shù)據(jù)集接口;
torchvision.models: 包含常用的模型結(jié)構(gòu)(含預(yù)訓(xùn)練模型),例如AlexNet、VGG、ResNet等;
torchvision.transforms: 常用的圖片變換,例如裁剪、旋轉(zhuǎn)等;
torchvision.utils: 其他的一些有用的方法。
3.5.1 獲取數(shù)據(jù)集
首先導(dǎo)入本節(jié)需要的包或模塊。
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..") # 為了導(dǎo)入上層目錄的d2lzh_pytorch
import d2lzh_pytorch as d2l
在Anaconda 中安裝
conda install matplotlib #安裝畫圖包
conda install tqdm #安裝tadm包 是在文件中需要
conda install -c pytorch torchtext #安裝torchtext文件
下面,我們通過torchvision的torchvision.datasets來下載這個數(shù)據(jù)集。第一次調(diào)用時會自動從網(wǎng)上獲取數(shù)據(jù)。我們通過參數(shù)train來指定獲取訓(xùn)練數(shù)據(jù)集或測試數(shù)據(jù)集(testing data set)。測試數(shù)據(jù)集也叫測試集(testing set),只用來評價模型的表現(xiàn),并不用來訓(xùn)練模型。
另外我們還指定了參數(shù)transform = transforms.ToTensor()使所有數(shù)據(jù)轉(zhuǎn)換為Tensor,如果不進(jìn)行轉(zhuǎn)換則返回的是PIL圖片。transforms.ToTensor()將尺寸為 (H x W x C) 且數(shù)據(jù)位于[0, 255]的PIL圖片或者數(shù)據(jù)類型為np.uint8的NumPy數(shù)組轉(zhuǎn)換為尺寸為(C x H x W)且數(shù)據(jù)類型為torch.float32且位于[0.0, 1.0]的Tensor。
注意: 由于像素值為0到255的整數(shù),所以剛好是uint8所能表示的范圍,包括transforms.ToTensor()在內(nèi)的一些關(guān)于圖片的函數(shù)就默認(rèn)輸入的是uint8型,若不是,可能不會報錯但可能得不到想要的結(jié)果。所以,如果用像素值(0-255整數(shù))表示圖片數(shù)據(jù),那么一律將其類型設(shè)置成uint8,避免不必要的bug。 本人就被這點(diǎn)坑過,詳見ta的這個博客2.2.4節(jié)。