基礎(chǔ)

一、數(shù)據(jù)導(dǎo)入部分

torch.utils.data.Dataset,這是一個抽象類,在pytorch中所有和數(shù)據(jù)相關(guān)的類都要繼承這個類來實現(xiàn)。

torchvision.datasets.ImageFolder接口實現(xiàn)數(shù)據(jù)導(dǎo)入。torchvision.datasets.ImageFolder會返回一個列表(比如image_datasets[‘train’]或者image_datasets[‘val]),列表中的每個值都是一個tuple,每個tuple包含圖像和標(biāo)簽信息。列表list是不能作為模型輸入的,因此在PyTorch中需要用另一個類來封裝list,那就是:

torch.utils.data.DataLoader,它可以將list類型的輸入數(shù)據(jù)封裝成Tensor數(shù)據(jù)格式,以備模型使用。這里是對圖像和標(biāo)簽分別封裝成一個Tensor。

data_transforms是一個字典。主要是進(jìn)行一些圖像預(yù)處理,比如resize、crop等。實現(xiàn)的時候采用的是torchvision.transforms模塊,比如torchvision.transforms.Compose是用來管理所有transforms操作的,torchvision.transforms.RandomSizedCrop是做crop的。需要注意的是對于torchvision.transforms.RandomSizedCrop和transforms.RandomHorizontalFlip()等,輸入對象都是PILImage,也就是用python的PIL庫讀進(jìn)來的圖像內(nèi)容,而transforms.Normalize([0.5, 0.5,0.4], [0.2, 0.2,0.5])的作用對象需要是一個Tensor,因此在transforms.Normalize([0.5, 0.5, 0.4],[0.2, 0.2, 0.5])之前有一個

transforms.ToTensor()就是用來生成Tensor的。另外transforms.Scale(256)其實就是resize操作,目前已經(jīng)被transforms.Resize類取代了。

將Tensor數(shù)據(jù)類型封裝成Variable數(shù)據(jù)類型后就可以作為模型的輸入了,用torch.autograd.Variable將Tensor封裝成模型真正可以用的Variable數(shù)據(jù)類型。Variable可以看成是tensor的一種包裝,其不僅包含了tensor的內(nèi)容,還包含了梯度等信息。

二、模塊導(dǎo)入

torchvision.models用來導(dǎo)入模塊

torch.nn模塊來定義網(wǎng)絡(luò)的所有層,比如卷積、降采樣、損失層等等

torch.optim模塊定義優(yōu)化函數(shù)

三、訓(xùn)練

在每個epoch開始時都要更新學(xué)習(xí)率:scheduler.step()

設(shè)置模型狀態(tài)為訓(xùn)練狀態(tài):model.train(True)

先將網(wǎng)絡(luò)中的所有梯度置0:optimizer.zero_grad()

網(wǎng)絡(luò)的前向傳播:outputs =

model(inputs)

然后將輸出的outputs和原來導(dǎo)入的labels作為loss函數(shù)的輸入就可以得到損失了:loss =

criterion(outputs, labels)

輸出的outputs也是torch.autograd.Variable格式,得到輸出后(網(wǎng)絡(luò)的全連接層的輸出)還希望能到到模型預(yù)測該樣本屬于哪個類別的信息,這里采用torch.max。torch.max()的第一個輸入是tensor格式,所以用outputs.data而不是outputs作為輸入;第二個參數(shù)1是代表dim的意思,也就是取每一行的最大值,其實就是我們常見的取概率最大的那個index;第三個參數(shù)loss也是torch.autograd.Variable格式。

?_, preds =

torch.max(outputs.data, 1)

計算得到loss后就要回傳損失。要注意的是這是在訓(xùn)練的時候才會有的操作,測試時候只有forward過程。

loss.backward()

回傳損失過程中會計算梯度,然后需要根據(jù)這些梯度更新參數(shù),optimizer.step()就是用來更新參數(shù)的。optimizer.step()后,你就可以從optimizer.param_groups[0][‘params’]里面看到各個層的梯度和權(quán)值信息。

optimizer.step()

這樣一個batch數(shù)據(jù)的訓(xùn)練就結(jié)束了!不斷重復(fù)這樣的訓(xùn)練過程。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 四月低廣州的雨季來了,今年的雨季好像比往年來的晚一些,從四月底就一直開啟下雨模式,隨著雨水增,氣溫回升,水跟熱攪渾...
    中醫(yī)二羊閱讀 2,184評論 7 18
  • 舞臺上,你是朕的皇后,下了臺,我是你的愛人。 陸佳佳心里明鏡兒似的,自己的演員生涯終于要進(jìn)入一個新階段了,她就要紅...
    我叫楊大山閱讀 262評論 0 1
  • 在一個微信群里意外獲得要舉辦一個免費的互聯(lián)網(wǎng)思維的沙龍,看下時間就是明天。作為一個長期關(guān)注互聯(lián)網(wǎng)領(lǐng)域的少年,你沒猜...
    陳靜人閱讀 160評論 0 0

友情鏈接更多精彩內(nèi)容