寫此文原因
網(wǎng)上其實有不少關(guān)于pytorch自定義數(shù)據(jù)集的tutorial,但是之所以要寫這個,是因為我發(fā)現(xiàn)他們并沒有結(jié)合一兩個的神經(jīng)網(wǎng)絡(luò)來講解。所以我覺得再寫一個tutorial講解關(guān)于如何讀取任意的數(shù)據(jù)集,并且讓某個網(wǎng)絡(luò)訓(xùn)練該數(shù)據(jù)集還是有必要的。
在初學(xué)pytorch的時候,我們一般使用的是pytorch自帶的一些數(shù)據(jù)集,比如 (代碼參考1)
from torchvision.datasets.mnist import MNIST
...
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
....
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
引入MNIST數(shù)據(jù)集。最初始的訓(xùn)練網(wǎng)絡(luò)是Lenet-5識別MNIST里面的數(shù)字。這就導(dǎo)致當(dāng)你面對很多JPG, PNG的格式的torchvision.datasets里沒有的圖像時,不知道怎么讀取他們。這篇文章會帶領(lǐng)大家讀取自定義的數(shù)據(jù)集并訓(xùn)練他們。
最后的lenet5代碼自定義數(shù)據(jù)集的實現(xiàn)請在我的github下載
https://github.com/zhaozhongch/Pytorch_Lenet5_CustomDataset
內(nèi)容
下面我們從網(wǎng)上下載PNG格式的MNIST數(shù)據(jù)集。
git clone https://github.com/myleott/mnist_png.git
cd mnist_png
tar -xvf mnist_png.tar.gz #解壓文件夾
解壓之后在minst_png/mnist_png文件夾里你會看到testing和training兩個文件夾,進(jìn)入testing你會看到10個文件夾分別儲存數(shù)字為0~9的圖片。下面我們簡單實現(xiàn)Lenet-5網(wǎng)絡(luò)來識別圖片中的數(shù)字。
Lenet5網(wǎng)絡(luò)如下圖

途中范例給的輸入圖片是32X32,實際我們上面的PNG圖片大小是28X28,網(wǎng)絡(luò)其他結(jié)構(gòu)依次減小即可。
輸入圖片1通道28X28,輸入給第一層
第一層卷積層,卷積核大小5X5,輸出圖像6通道,24X24,卷積之后接激勵函數(shù)ReLU
第二層池化層,使用平均池化,池化核大小2X2,輸出圖像6通道,12X12
第三層卷積層,卷積核大小還是5X5,輸出圖像16通道,大小8X8。之后再接ReLu
第四層再接2X2池化。輸出16通道,4X4大小圖片。
第五層全連接層,先把16X4X4的圖片"展平"為線性向量,再通過線性變換把圖片"展平"為120維的變量,接ReLu
第六層再把120維降為84維,接ReLu
第七層再降為10維(對應(yīng)0~9 10種數(shù)字可能性)輸出。
講解Lenet5并不是本文的重點,所以簡單的說了上面的網(wǎng)絡(luò)結(jié)構(gòu)后我們就給出網(wǎng)絡(luò)實現(xiàn),對于細(xì)節(jié)不熟悉的新手可以參考文章1。
根據(jù)上面的網(wǎng)絡(luò)結(jié)構(gòu),網(wǎng)絡(luò)在pytorch中的實現(xiàn)如下
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1,6,5)
self.pool = nn.AvgPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.linear1 = nn.Linear(16*4*4, 120)
self.linear2 = nn.Linear(120,84)
self.linear3 = nn.Linear(84,10)
def forward(self,x):
x = self.conv1(x)
x = F.relu(x)
x = self.pool(x)
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*4*4)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
理論上來說是很簡單的。
那么針對網(wǎng)絡(luò)對輸入數(shù)據(jù)的要求,我們應(yīng)該怎么把最開始下載的一堆圖片輸入進(jìn)去呢?這就要用到pytorch里的Dataset類了。
你需要定義一個類,繼承Dataset類,然后類里必須包含3個函數(shù)__init__,__len__,__getitem__,具體結(jié)構(gòu)如下
class ReadDataset(Dataset):
def __init__(self, 參數(shù)...):
def __len__(self, 參數(shù)...):
...
return 數(shù)據(jù)長度
def __getitem__(self, 參數(shù)...):
...
return 字典
__len__需要返回一個表示數(shù)據(jù)長度的整型量,__getitem__需要返回一個字典。ReadDataset這個類名是自定義的,繼承了Dataset即可。
接下來的過程,我們先簡單過一遍得到結(jié)果,再回看為什么這么做。
為了處理MNIST dataset,我們先把training文件夾里的圖像label讀取進(jìn)來
data_length = 60000
data_label = [-1] * data_length
prev_dir = './mnist_png/mnist_png/training/'
after_dir = '.png'
for id in range(10):
id_string = str(id)
for filename in glob(prev_dir + id_string +'/*.png'):
position = filename.replace(prev_dir+id_string+'/', '')
position = position.replace(after_dir, '')
data_label[int(position)] = id
這幾行代碼的作用,是把training文件夾里的10個文件夾里的共計60000張圖片放入到data_label里。舉個例子,圖片編號為21的圖,包含的數(shù)字是0(在training文件夾的0文件夾里),那么data_label[21] = 0。
接下來定義繼承了Dataset類的ReadDataset類,具體如下。
class ReadDataset(Dataset):
def __init__(self, imgs_dir, data_label):
self.imgs_dir = imgs_dir
self.ids = data_label
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
idx = self.ids[i]
imgs_file = self.imgs_dir+ str(idx) + '/' + str(i) + '.png'
img = Image.open(imgs_file).convert('L')
img = np.array(img)
img = img.reshape(1,28,28)
if img.max() > 1:
img = img / 255
return {'image': torch.from_numpy(img), 'label': torch.tensor(idx)}
可以看到,構(gòu)造函數(shù)__init__里我們有兩個參數(shù),一個是imgs_dir,圖像地址,另一個是我們之前創(chuàng)建的列表data_label,賦值給self.ids. __len__()僅僅是返回了data_label的長度。
有趣的是__getitem__函數(shù),我們看到這個函數(shù)的參數(shù)是i,傳入了i之后,我們首先根據(jù)ids找到它對應(yīng)的圖像里所標(biāo)識的數(shù)字,繼而根據(jù)
imgs_file = self.imgs_dir+ str(idx) + '/' + str(i) + '.png'
img = Image.open(imgs_file).convert('L')
找到圖像并轉(zhuǎn)化為黑白。之后再轉(zhuǎn)化為np,再reshape。原圖像讀進(jìn)來本來是28X28,但是根據(jù)網(wǎng)絡(luò)的要求,輸入需要是圖像通道數(shù)X圖像尺寸,黑白圖片通道為1,所以我們reshape為1X28X28。最后圖像的像素點的灰度值歸一化到0到1.因為我們要使用cross entropy代價函數(shù)來訓(xùn)練,根據(jù)官網(wǎng),要求cross entropy的矩陣輸入的值為0到1。返回的內(nèi)容格式必須是字典,我們這兒字典的內(nèi)容圖像和圖像內(nèi)對應(yīng)的數(shù)字(label)是
{'image': torch.from_numpy(img), 'label': torch.tensor(idx)}
這個getitem函數(shù)如果調(diào)用,最終達(dá)到的目的就是,假如我在代碼中輸入A = __getitem__(0),我就應(yīng)該能得到0.png對應(yīng)的那張圖像,獲取圖像的方式就是A['image'],獲取圖像是數(shù)字幾的方式是A['label']。
有了上面的內(nèi)容作為鋪墊,我們看看主函數(shù)里讀取數(shù)據(jù)的具體操作。首先有下面一行內(nèi)容
prev_dir = './mnist_png/mnist_png/testing/'
...
all_data = ReadDataset(prev_dir, data_label)
我們把prev_dir和之前得到的data_label作為參數(shù)傳入了ReadDataset并返回了all_data。有的人可能說,誒,我沒看到ReadDataset有返回值呀。這是因為這些寫在了Dataset這個類里,不然繼承它干什么呢。隨后,我們把這個返回值賦值給DataLoader,就可以定義從torchvision里自帶的MNIST dataset一樣的操作了。
test_loader = DataLoader(all_data, batch_size=batch_size, shuffle=True, num_workers=4)
定義好batch_size,num_workers,代價函數(shù)這些之后,我們就可以在訓(xùn)練的時候使用返回值test_loader了。
with torch.no_grad():
for data in test_loader:
images = data['image']
labels = data['label']
...
我們可以看到其實我們并沒有顯式地調(diào)用__getitem__函數(shù),而是通過data遍歷test_loader, data會自動根據(jù)ReadDataset里ids的長度,從1到ids.length來批量讀取圖像。如果你設(shè)置了batch_size等于4,那么for循環(huán)的第一次循環(huán),會調(diào)用__getitem__四次,data['image']就會返回__getitem__,return {'image':...}中image所對應(yīng)的內(nèi)容。
設(shè)置代價函數(shù)這些不是本文的內(nèi)容,就不細(xì)講了。具體的可參見github代碼。
可能大家看了上面的例子還是有些不明不白,因為雖然ReadDataset這個類的內(nèi)容就是定義三個函數(shù),但是這三個函數(shù)具體的內(nèi)容是什么,就需要根據(jù)實際情況確定了。我們上面的數(shù)據(jù)集的圖像是分別儲存在0~9個文件夾中,其他的數(shù)據(jù)可能不是這么儲存的,就需要想新的辦法獲得那個data_label列表。但是你的最終目的是很明白的,
1:getitem所返回的內(nèi)容,需要能輸入到網(wǎng)絡(luò)里,比如我們的
images = data['image']
...
outputs = net(images.float())
2: 根據(jù)0到ids的長度的indx,能遍歷你想要使用的所有圖像。
假想你顯式調(diào)用__getitem__(0) ,你需要能獲得名字為0.png或者0.jpg之類的圖像的內(nèi)容。
說這些不如多看兩個例子再自己實踐一下。上面的lenet5的例子之外,我在github里分別分開寫了CPU的方法和GPU的方法,當(dāng)然其實就一兩行代碼的事兒。不過考慮到這還是屬于接近新手范疇的tutorial,就分開寫了。
另外我還在github代碼里提供了稍微復(fù)雜的網(wǎng)絡(luò)UNET的實現(xiàn),UNET是用來做語義分割的網(wǎng)絡(luò),不熟悉的同學(xué)可以自行看下語義分割是什么blabla。在UNET的這個網(wǎng)絡(luò)里,我同樣是讀取的自定義的數(shù)據(jù)集而不是使用torchvision.dataset里帶的數(shù)據(jù)集。代碼放于github
https://github.com/zhaozhongch/Pytorch_UNET_MultiObjects
當(dāng)下用得最多的pytroch的UNET的實現(xiàn)還只是一個物體分割的(參考此處
),我順便拔高了一下實現(xiàn)多個物體的語義分割了,不過最后語義分割的效果圖不是非常好,因為懶得花時間去仔細(xì)fine tune了。但我相信作為tutorial級別的代碼,我覺得跑一遍熟悉一下網(wǎng)絡(luò)結(jié)構(gòu)怎么定義,怎么自定義數(shù)據(jù)集,已經(jīng)很夠了。覺得還不錯的可以給這倆倉庫點個小star哈哈。
關(guān)于這兩個網(wǎng)絡(luò)實現(xiàn)或者其他內(nèi)容不懂的同學(xué)歡迎私信。