Datawhale深入淺出pytorchtask02筆記

pytorch主要組成模塊分為

數(shù)據(jù)讀入,模型構(gòu)建,模型初始化,損失函數(shù),優(yōu)化器,訓(xùn)練,評估,可視化等

數(shù)據(jù)讀入:分兩種,

一,pytorch自帶數(shù)據(jù)集

自帶方法

通過torchvision中的datasets來獲取

二,自定義方法

自己動手

通過集成pytorch自帶的Dataset類和Dataloader來設(shè)立自己的數(shù)據(jù)集

Dataset中主要為__init__,__len__,__getitem__三個函數(shù)

__init__用于傳入外部參數(shù),定義樣本集

__len__用過返回數(shù)據(jù)集的樣本個數(shù)

__getitem__用于逐個讀出樣本集合中的元素并最終返回訓(xùn)練和驗證所用的數(shù)據(jù)集

Dataloader用于按批次讀入數(shù)據(jù)

主要參數(shù):batch_size(每批次的樣本數(shù))

num_workers(多少個進(jìn)程)

shuffle(是否打亂)

drop_last(未滿批次數(shù)的樣本去除)

模型構(gòu)建主要依賴于pytorch中的nn.Module

集成nn.Module的類,在類中有__init__和forward函數(shù)

前者負(fù)責(zé)定義模型中的所有層,后者定義這些層如何向前傳播

一個簡單的cnn模型

Alexnet

初始化基于pytorch.nn.init

相關(guān)函數(shù)

通常會自己去封裝一個initialize_weights函數(shù)來進(jìn)行初始化

大概:判斷是什么層,如何初始化

初始化函數(shù)的封裝

pytorch提供了很多種損失函數(shù),種類太多日后再慢慢分析

訓(xùn)練和評估模型通過定義train和val函數(shù)來實(shí)現(xiàn),

訓(xùn)練:送入數(shù)據(jù),梯度置零,送入模型,計算損失函數(shù),反向傳播,用優(yōu)化器更新參數(shù)

測試的不同點(diǎn)在于:

預(yù)先設(shè)置torch.no_grad,不需要把優(yōu)化器梯度置零,不需要反向傳播和更新參數(shù)

圖像分類的訓(xùn)練和評估函數(shù)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容