pytorch主要組成模塊分為
數(shù)據(jù)讀入,模型構(gòu)建,模型初始化,損失函數(shù),優(yōu)化器,訓(xùn)練,評估,可視化等
數(shù)據(jù)讀入:分兩種,
一,pytorch自帶數(shù)據(jù)集

通過torchvision中的datasets來獲取
二,自定義方法

通過集成pytorch自帶的Dataset類和Dataloader來設(shè)立自己的數(shù)據(jù)集
Dataset中主要為__init__,__len__,__getitem__三個函數(shù)
__init__用于傳入外部參數(shù),定義樣本集
__len__用過返回數(shù)據(jù)集的樣本個數(shù)
__getitem__用于逐個讀出樣本集合中的元素并最終返回訓(xùn)練和驗證所用的數(shù)據(jù)集
Dataloader用于按批次讀入數(shù)據(jù)
主要參數(shù):batch_size(每批次的樣本數(shù))
num_workers(多少個進(jìn)程)
shuffle(是否打亂)
drop_last(未滿批次數(shù)的樣本去除)
模型構(gòu)建主要依賴于pytorch中的nn.Module
集成nn.Module的類,在類中有__init__和forward函數(shù)
前者負(fù)責(zé)定義模型中的所有層,后者定義這些層如何向前傳播


初始化基于pytorch.nn.init

通常會自己去封裝一個initialize_weights函數(shù)來進(jìn)行初始化
大概:判斷是什么層,如何初始化

pytorch提供了很多種損失函數(shù),種類太多日后再慢慢分析
訓(xùn)練和評估模型通過定義train和val函數(shù)來實(shí)現(xiàn),
訓(xùn)練:送入數(shù)據(jù),梯度置零,送入模型,計算損失函數(shù),反向傳播,用優(yōu)化器更新參數(shù)
測試的不同點(diǎn)在于:
預(yù)先設(shè)置torch.no_grad,不需要把優(yōu)化器梯度置零,不需要反向傳播和更新參數(shù)
