
第三章 基于Pytorch的深度學(xué)習(xí)開發(fā)
前面章節(jié)我們已經(jīng)了解tensor及其操作,這章主要就是學(xué)習(xí)如何用Pytorch進(jìn)行基礎(chǔ)深度學(xué)習(xí)應(yīng)用開發(fā)。
本章主要內(nèi)容
- 創(chuàng)建一深度學(xué)習(xí)模型
- 通過普通模式進(jìn)行訓(xùn)練
- 測試模型效果
- 同時(shí)對超參數(shù)進(jìn)行微調(diào)以提高模型精度和速度
- 部署模型到生產(chǎn)環(huán)境中
在整個過程中的每一步,都將提供相應(yīng)的代碼參考及一些小建議。
在后面的篇章中,本書也會提供其它更加復(fù)雜模型來闡述一個完整的深度學(xué)習(xí)結(jié)構(gòu)中各個環(huán)節(jié)的技巧及意義,如自主化,優(yōu)化,加速,分布式訓(xùn)練及高效部署方案等等。
本章主要以一個基本的神經(jīng)網(wǎng)絡(luò)模型結(jié)構(gòu)展開討論的,即達(dá)到以小見大的作用。
整體流程
雖然每個人所構(gòu)建的模型結(jié)構(gòu)各不相同,但整個流程卻是一樣的。即不管是監(jiān)督學(xué)習(xí),非監(jiān)督學(xué)習(xí)或是半監(jiān)督學(xué)習(xí),整個流程還都是 訓(xùn)練,測試,部署。
- 在訓(xùn)練過程中,每次epoch后,我們都用驗(yàn)證數(shù)據(jù)(validation data)去驗(yàn)證一下我們的模型并微調(diào)超參數(shù),使模型能達(dá)到一個比較好的泛化能力。
- 最后通過測試數(shù)據(jù)(unseen data)進(jìn)行模型評價(jià),驗(yàn)證模型是否具有泛化性。
- 深度學(xué)習(xí)模型開發(fā)的最后一步就是模型的部署,包括服務(wù)器或智能終端部署。

數(shù)據(jù)預(yù)處理(Data Preparation)
-
數(shù)據(jù)加載 Data Loading
Pytorch內(nèi)置了數(shù)據(jù)處理的幾個類及工具,諸如Dataset,DataLoader和Sampler類。
- Dataset類提供怎么從文件或數(shù)據(jù)源里獲取和預(yù)處理數(shù)據(jù)的方法
- Sampler類提供了如何進(jìn)行數(shù)據(jù)采樣并形成批量數(shù)據(jù)
- DataLoader類聯(lián)合了Dataset和Sampler進(jìn)行數(shù)據(jù)分批次的迭代提取
Pytorch的一些包,如Torchvision和Torchtext等也都提供了一些不錯數(shù)據(jù)加載處理方法。
torchvison.datasets就提供了很多的子類用于加載一些不錯的圖片數(shù)據(jù)(CIFAR-10,MNIST等)。
示例:CIFAR10數(shù)據(jù)加載及預(yù)覽
from torchvision.datasets import CIFAR10
from PIL import Image
train_data = CIFAR10(root='./train/', train=True, download=True)
print(train_data)
"""
Dataset CIFAR10
Number of datapoints: 50000
Root location: ./train/
Split: Train
"""
print(len(train_data), train_data.data.shape, train_data.classes, train_data.class_to_idx)
print(train_data[0])
"""
50000
(50000, 32, 32, 3)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
(<PIL.Image.Image image mode=RGB size=32x32 at 0x2577B5677F0>, 6)
"""
test_data = CIFAR10(root='./test/', train=False, download=True)
print(test_data)
"""
Dataset CIFAR10
Number of datapoints: 10000
Root location: ./test/
Split: Test
"""
print(len(test_data))
print(test_data.data.shape)
"""
10000
(10000, 32, 32, 3)
"""
查看一下圖片數(shù)據(jù)

- 數(shù)據(jù)轉(zhuǎn)換 Data Transforms
一般情況下,我們都需要將原始數(shù)據(jù)進(jìn)行轉(zhuǎn)換后,才能輸出到Pytorch的處理流程中,即轉(zhuǎn)換為torch形式數(shù)據(jù)。
這些轉(zhuǎn)換操作基本上都是通過transforms包里提供的各類方法進(jìn)行一系列處理操作。
示例:對CIFAR10數(shù)據(jù)進(jìn)行一些預(yù)處理操作
from torchvision.datasets import CIFAR10
from torchvision import transforms
# 定義一系列transform算子,隨機(jī)裁剪,水平翻轉(zhuǎn)等
train_transforms = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010)
)
])
# 通過參數(shù)tranform指定數(shù)據(jù)預(yù)處理操作
train_data = CIFAR10(root="./train/",
train=True,
download=True,
transform=train_transforms
)
print(train_data)
"""
Dataset CIFAR10
Number of datapoints: 50000
Root location: ./train/
Split: Train
StandardTransform
Transform: Compose(
RandomCrop(size=(32, 32), padding=4)
RandomHorizontalFlip(p=0.5)
ToTensor()
Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
)
"""
# torch.Size([3, 32, 32])
print(train_data[0][0].size())
print(train_data[0])
"""
(tensor([[[-2.4291, -2.4291, -2.4291, ..., -2.4291, -2.4291, -2.4291],
[-2.4291, -2.4291, -2.4291, ..., -2.4291, -2.4291, -2.4291],
[-2.4291, -2.4291, -2.4291, ..., -2.4291, -2.4291, -2.4291],
...,
[ 1.1959, 0.4981, 0.0522, ..., 1.1765, -0.1610, -2.4291],
[ 1.3122, 0.8276, 0.4981, ..., -0.0253, -1.0527, -2.4291],
[ 1.4673, 1.1765, 0.9051, ..., -1.3435, -1.7894, -2.4291]],
[[-2.4183, -2.4183, -2.4183, ..., -2.4183, -2.4183, -2.4183],
[-2.4183, -2.4183, -2.4183, ..., -2.4183, -2.4183, -2.4183],
[-2.4183, -2.4183, -2.4183, ..., -2.4183, -2.4183, -2.4183],
...,
[ 0.1188, -0.4516, -0.8646, ..., 0.6301, -0.7269, -2.4183],
[ 0.2564, -0.0189, -0.2352, ..., -0.5892, -1.4742, -2.4183],
[ 0.5318, 0.4924, 0.3154, ..., -1.8479, -2.0446, -2.4183]],
[[-2.2214, -2.2214, -2.2214, ..., -2.2214, -2.2214, -2.2214],
[-2.2214, -2.2214, -2.2214, ..., -2.2214, -2.2214, -2.2214],
[-2.2214, -2.2214, -2.2214, ..., -2.2214, -2.2214, -2.2214],
...,
[-1.7141, -1.7336, -1.5580, ..., -0.4460, -1.2849, -2.2214],
[-1.9092, -1.8507, -1.5385, ..., -1.2654, -1.7141, -2.2214],
[-1.7922, -1.7531, -1.6751, ..., -2.0458, -2.0458, -2.2214]]]), 6)
"""
# 對于測試數(shù)據(jù)同樣需要做相應(yīng)的處理,至少也要將圖像數(shù)據(jù)轉(zhuǎn)換為Tensor形式
test_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
test_data = CIFAR10(
root="./test/",
train=False,
transform=test_transforms)
print(test_data)
"""
Dataset CIFAR10
Number of datapoints: 10000
Root location: ./test/
Split: Test
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
)
"""
對比預(yù)處理前后照片效果圖

從圖中我們可以發(fā)現(xiàn)前后數(shù)據(jù)發(fā)生了一些變化,可能對于我們來說看不懂,但經(jīng)過這樣的處理后,模型的預(yù)測效率有不錯的提升。
%matplotlib inline
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10
from torchvision import transforms
# 加載數(shù)據(jù),不做任何預(yù)處理
train_data_ori = CIFAR10(root='./train/', train=True, download=True)
# 定義一系列transform算子,隨機(jī)裁剪,水平翻轉(zhuǎn)等
train_transforms = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010)
)
])
# 通過參數(shù)tranform指定數(shù)據(jù)預(yù)處理操作
train_data_trans = CIFAR10(root="./train/",
train=True,
download=True,
transform=train_transforms
)
data_ori, label_ori = train_data_ori[0]
data_trans, label_trans = train_data_trans[0]
# transform前后圖片對比
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.imshow(data_ori)
plt.subplot(1, 2, 2)
# 將tensor轉(zhuǎn)換為PIL Image格式
image_trans = transforms.ToPILImage()(data_trans) # 自動轉(zhuǎn)換為0-255
plt.imshow(image_trans)
plt.tight_layout()
plt.show()
- 數(shù)據(jù)批量化 Data Batching
因?yàn)樵趯?shí)際模型訓(xùn)練中,受限于我們的的資源條件等因素,也不可能一下子把所有數(shù)據(jù)全部加載到內(nèi)存中進(jìn)行計(jì)算的,通常情況下都是會對訓(xùn)練數(shù)據(jù)進(jìn)行批量化處理,每批次進(jìn)行迭代訓(xùn)練模型,即在有限資源條件下快速訓(xùn)練模型效果。
這樣做不僅達(dá)到高效的模型訓(xùn)練效果,同時(shí)也能充分發(fā)揮GPUs的并行計(jì)算優(yōu)勢。
同樣數(shù)據(jù)批處理也十分簡單的,只要通過torch.utils.data.DataLoader類即可實(shí)現(xiàn)。
示例:對訓(xùn)練數(shù)據(jù)進(jìn)行批量化處理
下面的例子是有放回的批量化處理,每批次有16個樣本進(jìn)行模型訓(xùn)練計(jì)算。經(jīng)過DataLoader處理過的數(shù)據(jù)相當(dāng)于一個迭代器,可以通過next()及iter()方法進(jìn)行數(shù)據(jù)獲取next(iter(train_loader))。
train_loader = torch.utils.data.DataLoader(
train_data_trans,
batch_size=16,
shuffle=True
)
data_batch, labels_batch = next(iter(train_loader))
# 每批次的數(shù)據(jù)情況,16個 3*32*32
print(data_batch.size())
print(labels_batch.size())
"""
torch.Size([16, 3, 32, 32])
torch.Size([16])
"""
普通應(yīng)用中數(shù)據(jù)預(yù)處理(General Data Preparation)
torch.utils.data
前面的例子我們知道了圖片數(shù)據(jù)是如何通過torchvision這個包進(jìn)行加載,轉(zhuǎn)換和批處理。但對于我們實(shí)際應(yīng)用中的數(shù)據(jù),我們是要怎么來做這些預(yù)處理呢?Pytorch提供了一個內(nèi)置模塊來幫我們完成這些工作,即torch.utils.data。
Pytorch提供了對數(shù)據(jù)進(jìn)行映射和迭代類型的數(shù)據(jù)集類torch.utils.data.Dataset,通過繼承其并重寫相關(guān)函數(shù)方法實(shí)現(xiàn)數(shù)據(jù)的加載,處理及返回等預(yù)處理操作。
-
Dataset類
在實(shí)際應(yīng)用中,子類需要重寫getitem(),len()這兩個方法。- getitem(),通過一給定key取得樣本數(shù)據(jù)(數(shù)據(jù)及標(biāo)簽值)
- len(),返回?cái)?shù)據(jù)size
-
Sampler類
提供數(shù)據(jù)采樣器的方法,且這些采樣器一般不直接使用,而是直接內(nèi)嵌在數(shù)據(jù)加載器中,作為一參數(shù)配置一起使用的。
sampler DataLoader類
Dataset類返回包含數(shù)據(jù)和相關(guān)信息的數(shù)據(jù)對象,Sampler類以特定方式或隨機(jī)返回實(shí)際數(shù)據(jù)本身,DataLoader類就是把Dataset和Sampler類聯(lián)合起來構(gòu)建出一數(shù)據(jù)迭代器返回?cái)?shù)據(jù)。
torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None)
在正常情況下,dataset,batch_size,shuffle和sampler這幾個參數(shù)比較常用到,其他一些參數(shù)主要是在特殊場景下使用,主要還是要根據(jù)實(shí)際應(yīng)用來選擇。期中num_workers主要是利用cpu的多核技術(shù)來并行處理生成數(shù)據(jù),提高效率。
If you write your own dataset class, all you need to do is call the built-in DataLoader to generate an iterable for your data. There is no need to create a dataloader class from scratch.
