快速上手Pytorch


這篇文章需要大家對(duì)深度學(xué)習(xí)里的神經(jīng)網(wǎng)絡(luò)訓(xùn)練有一定的基礎(chǔ),我以前訓(xùn)練網(wǎng)絡(luò)一直都是用的TensorFlow,后面需要把模型和數(shù)據(jù)遷移到Pytorch平臺(tái)上去,發(fā)現(xiàn)很多里面有很多知識(shí)點(diǎn)需要注意,寫(xiě)這篇文章一方面是給自己做個(gè)筆記,總結(jié)下自己的經(jīng)驗(yàn),另一方面是為了方便想要快速上手Pytorch的同學(xué)。這篇文章主要內(nèi)容有:

  • Tensorflow的PlayGround
  • Pytorch介紹和安裝
  • Torch和Torchvision里的常用包
  • Variable、Tensor、Numpy之間的關(guān)系
  • CPU與GPU
  • 示例--GAN生成MINIST數(shù)據(jù)

Tensorflow的PlayGround

PlayGround是一個(gè)在線演示、實(shí)驗(yàn)的神經(jīng)網(wǎng)絡(luò)平臺(tái),是一個(gè)入門(mén)神經(jīng)網(wǎng)絡(luò)非常直觀的網(wǎng)站。這個(gè)圖形化平臺(tái)非常強(qiáng)大,將神經(jīng)網(wǎng)絡(luò)的訓(xùn)練過(guò)程直接可視化。假若有的同學(xué)剛剛想入門(mén)深度學(xué)習(xí)這一領(lǐng)域,可以去看看:
PlayGround地址:http://playground.tensorflow.org
這里也有一篇PlayGround介紹寫(xiě)的非常詳細(xì)的文章:
參考地址:https://finthon.com/tensorflow-playground-nn/

Pytorch介紹和安裝


2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。Pytorch和Torch底層實(shí)現(xiàn)都用的是C語(yǔ)言,但是Torch的調(diào)用需要掌握Lua語(yǔ)言,相比而言使用Python的人更多,根本不是一個(gè)數(shù)量級(jí),所以Pytorch基于Torch做了些底層修改、優(yōu)化并且支持Python語(yǔ)言調(diào)用。
它是一個(gè)基于Python的可續(xù)計(jì)算包,目標(biāo)用戶有兩類:

  1. 使用GPU來(lái)運(yùn)算numpy
  2. 一個(gè)深度學(xué)習(xí)平臺(tái),提供最大的靈活型和速度

如何安裝Pytorch呢?

  • 基礎(chǔ)環(huán)境
    一臺(tái)PC設(shè)備、一張高性能NVIDIA顯卡(可選)、Ubuntu系統(tǒng)
  • 安裝步驟
  1. Anaconda(可選)和Python
  2. 顯卡驅(qū)動(dòng)和CUDA
  3. 運(yùn)行Pytorch的安裝命令

Torch和Torchvision里的常用包

Torch

  • torch:張量相關(guān)的運(yùn)算,例如創(chuàng)建、索引、切片、連接轉(zhuǎn)置、加減乘除

  • torch.nn:包含搭建網(wǎng)絡(luò)層的模塊(Modules)和一系列的loss函數(shù),例如全連接、卷積、池化、BN批處理、dropout、CrossEntropyLoss、MSELoss

  • torch.nn.functional:常用的激活函數(shù)reluleaky_relu、sigmoid

  • torch.autograd:提供Tensor所有操作的自動(dòng)求導(dǎo)方法

  • torch.optim:各種參數(shù)優(yōu)化方法,例如SGD、AdaGrad、RMSProp、Adam

  • torch.nn.init:可以用它更改nn.Module的默認(rèn)參數(shù)初始化方式

  • torch.utils.data:用于加載數(shù)據(jù)

Torchvision

  • torchvision.datasets:常用數(shù)據(jù)集,MNIST、COCOCIFAR10、Imagenet

  • torchvision.models:常用模型,AlextNet、VGGResNet、DenseNet

  • torchvision.transforms:圖片相關(guān)處理,裁剪、尺寸縮放歸一化

  • torchvision.utils:將給定的Tensor保存成image文件

Variable、Tensor、Numpy之間的關(guān)系

  • Numpy
    NumPy是Python語(yǔ)言的一個(gè)擴(kuò)充程序庫(kù)。支持高級(jí)大量的維度數(shù)組與矩陣運(yùn)算,此外也針對(duì)數(shù)組運(yùn)算提供大量的數(shù)學(xué)函數(shù)庫(kù)。
    例子:
>>> import numpy as np
>>> x=np.array([[1,2,3],[9,8,7],[6,5,4]])
  • Tensor
    PyTorch 提供一種類似 NumPy 的抽象方法來(lái)表征張量(或多維數(shù)組),它可以利用 GPU 來(lái)加速訓(xùn)練。


  • Variable


  1. PyTorch 張量的簡(jiǎn)單封裝
  2. 幫助建立計(jì)算圖
  3. Autograd(自動(dòng)微分庫(kù))的必要部分
  4. 將關(guān)于這些變量的梯度保存在 .grad 中
  • Tensor、Variable、Numpy之間相互轉(zhuǎn)化
  1. 將Numpy矩陣轉(zhuǎn)換為T(mén)ensor張量
    sub_ts = torch.from_numpy(sub_img)
  2. 將Tensor張量轉(zhuǎn)化為Numpy矩陣
    sub_np1 = sub_ts.numpy()
  3. 將Tensor轉(zhuǎn)換為Variable
    sub_va = Variable(sub_ts)
  4. 將Variable轉(zhuǎn)換為T(mén)ensor
    sub_np2 = sub_va.data

CPU與GPU

Pytorch支持CPU運(yùn)行,但是速度非常慢,一張好的NVIDIA顯卡能夠大大減少網(wǎng)絡(luò)訓(xùn)練時(shí)間,以我自己經(jīng)驗(yàn)來(lái)看,15年MacBook Pro 與戴爾工作站附加一張顯存11GB的1080ti顯卡相比,后者速度是前者速度的224倍,尤其訓(xùn)練復(fù)雜網(wǎng)絡(luò)一定要在GPU上跑。Pytorch中把數(shù)據(jù)和模型從CPU遷移到GPU非常簡(jiǎn)單:


直接對(duì)變量、張量、模型使用.cuda()即可把他們遷移到GPU上,反過(guò)來(lái)遷移到CPU上,使用.cpu()。
當(dāng)有多行顯卡時(shí),想充分利用它們,則可使用model = nn.DataParallel(model)命令:

常見(jiàn)問(wèn)題

  • 這里的不同位置包含GPU與CPU,還包含不同GPU之間
  • 不同位置的Variable之間不能直接相互運(yùn)算
  • 不同位置的Tensor直接不能直接相互運(yùn)算
  • 不同位置的Variable模型不能直接訓(xùn)練
  • 使用指定顯卡:.cuda(<顯卡號(hào)數(shù)>)

示例--GAN生成MINIST數(shù)據(jù)

最后看個(gè)實(shí)例,如何使用GAN網(wǎng)絡(luò)生成MINIST 數(shù)據(jù),主要內(nèi)容有:


MNIST數(shù)據(jù)集

MNIST數(shù)據(jù)集是一個(gè)手寫(xiě)體數(shù)據(jù)集,圖片大小都是28x28,包含0-9共10個(gè)數(shù)字,各種風(fēng)格:



下載好的數(shù)據(jù)集:


測(cè)試集t10k開(kāi)頭,訓(xùn)練集train開(kāi)頭,images是圖片,labels是標(biāo)簽

GAN網(wǎng)絡(luò)模型


輸入100長(zhǎng)度的噪聲向量,經(jīng)過(guò)一個(gè)全連接,兩個(gè)卷積層,一個(gè)下采樣之后生成成28x28大小的圖片,這一部分是生成器
生成的假圖片和MNIST里的真圖片經(jīng)過(guò)兩個(gè)卷積層下采樣之后,再次經(jīng)歷兩個(gè)全連接層后輸出一個(gè)1長(zhǎng)度的單位向量,1代表輸入圖片為真,0代表輸入圖片為假

GAN訓(xùn)練和Loss



訓(xùn)練判別器D時(shí),要使得V整體變大,訓(xùn)練生成器G時(shí),要使得V整體變小。
這是一個(gè)博弈的過(guò)程,就像制造假錢(qián)的犯罪團(tuán)伙和驗(yàn)鈔機(jī)的關(guān)系,犯罪團(tuán)伙需要努力提高技術(shù),讓驗(yàn)鈔機(jī)無(wú)法識(shí)別出來(lái)其制造的假幣,而驗(yàn)鈔機(jī)要能夠正確的分辨出真正的紙幣還有假幣。
理論上當(dāng)判別器D只有一半的概率0.5能識(shí)別出假圖片時(shí),就已經(jīng)收斂了,實(shí)際上達(dá)不到一半的概率,沒(méi)關(guān)系,使得假圖片概率盡量高就行了,最終看上去效果不錯(cuò)。
這是一張由生成器生成的假圖片,你能區(qū)分出來(lái)嗎?

可視化

可視化方式有兩種,一種是利用torchvision里面的包 torchvision.utils,另外一種是利用visdom插件,下面是二者的對(duì)比:


上面那張生成的假圖片就是利用torchvision.utils里的save_image函數(shù)來(lái)存儲(chǔ)在本地的。
而以下這張圖是利用visdom,在瀏覽器中查看到的效果:

visdom不光可以查看圖片,還可以查看loss變化曲線圖等各種功能。

具體的代碼實(shí)現(xiàn)去工程里查看,這里給出分享地址:
https://github.com/gcfrun/GAN_MNIST_Pytorch
mnist_data.py:數(shù)據(jù)輸入模塊
mnist_net.py:網(wǎng)絡(luò)模型模塊
mnist_loss.py:Loss計(jì)算模塊
mnist_train.py:迭代訓(xùn)練模塊
mnist_visual.py:可視化模塊

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容