Pytorch-lightning入門實例

本文將通過Colab平臺及MINIST數(shù)據(jù)集指導(dǎo)你了解Pytorch-lightning的核心組成。

注意:任何的深度學習、機器學習的Pytorch工程都可以轉(zhuǎn)變?yōu)閘ightning結(jié)構(gòu)

從MNIST到自動編碼器

安裝Lightning

雖然說安裝Lightning非常的容易,但還是建議大家在本地通過conda來安裝Lightning

conda activate my_env
pip install pytorch-lightning

當你運行在在Google Colab上時,需要執(zhí)行

!pip install pytorch-lightning
截屏2020-12-27 上午10.29.06
截屏2020-12-27 上午10.29.19

當出現(xiàn)上述內(nèi)容是,代表Lightning已經(jīng)安裝完成。

或者你也可以使用conda命令來安裝

conda install pytorch-lightning -c conda-forge

The research

The Model

Lightning由以下核心部分組成:

  • The model
  • The optimizers
  • The train/val/test steps

我們通過Model引入這一部分,下面我們將會設(shè)計一個三層的神經(jīng)網(wǎng)絡(luò)模型

import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class LitMNIST(LightningModule):

  def __init__(self):
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, height)
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 256)
    self.layer_3 = torch.nn.Linear(256, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)
    x = F.relu(x)
    x = self.layer_3(x)

    x = F.log_softmax(x, dim=1)
    return x

可以觀察到我們的LitMNIST類并不是像Pytorch中繼承于torch.nn.Module類,而是Pytorch-lightning中的LightningModule類。該類與Pytorch中的torch.nn.Module類并沒有太大的區(qū)別,只是增加了一些功能,我們可以像在Pytorch中使用torch.nn.Module類一樣使用它。例如:

net = LitMNIST()
x = torch.randn(1, 1, 28, 28)
out = net(x)
截屏2020-12-27 上午10.39.35

下一步我們添加訓練過程training_step,它繼承于LightningModule,其中包含了所有訓練過程中的邏輯內(nèi)容

class LitMNIST(LightningModule):

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
Data

LIghtning運行在dataloders上,以下是數(shù)據(jù)加載部分的代碼:

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms

# transforms
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))])

# data
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_train = DataLoader(mnist_train, batch_size=64)

我比較推薦采用下面的方式加載數(shù)據(jù):

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        # download only
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

    def setup(self, stage):
        # transform
        transform=transforms.Compose([transforms.ToTensor()])
        mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
        mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = mnist_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

使用上面的DataModule可以更加方便的分項數(shù)據(jù)定義

# use an MNIST dataset
mnist_dm = MNISTDatamodule()
model = LitModel(num_classes=mnist_dm.num_classes)
trainer.fit(model, mnist_dm)

# or other datasets with the same model
imagenet_dm = ImagenetDatamodule()
model = LitModel(num_classes=imagenet_dm.num_classes)
trainer.fit(model, imagenet_dm)
Optimizer

在Pytorch中,我們通過下列方式來Optimizer代碼:

from torch.optim import Adam
optimizer = Adam(LitMNIST().parameters(), lr=1e-3)

在Lightning中,方法類似,但是我們會將其包含在configure_optimizers()方法中

class LitMNIST(LightningModule):

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)
Training step

我們將其寫在training_step()`方法中

class LitMNIST(LightningModule):

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
Training

目前為止,我們已經(jīng)完成了四個主要的代碼部分

import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class LitMNIST(LightningModule):

    def __init__(self):
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        x = F.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

最后執(zhí)行下列代碼訓練我們的數(shù)據(jù):

from pytorch_lightning import Trainer
dm = MNISTDataModule()
model = LitMNIST()
trainer = Trainer(gpus=8)
trainer.fit(model, dm)

訓練時間比較長

小結(jié)

我在博客中的每篇文章都是我一字一句敲出來的,轉(zhuǎn)載的文章我也注明了出處,表示對原作者的尊重。同時也希望大家都能尊重我的付出。

最后,也希望大家關(guān)注我的個人博客HD Blog

謝謝~

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容