6. pytorch-保存與恢復

官方序列化教程

1. 只保存參數(shù)

推薦

1.1 示例

  • 訓練過程: main.py
# file: main.py
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.h1 = torch.nn.Linear(1, 10)
        self.h2 = torch.nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.h1(x)) 
        x = self.h2(x)
        return x

def prepare_data():
    torch.manual_seed(1)  # 保證每次生成的隨機數(shù)相同
    x = torch.linspace(-1, 1, 50)
    x = torch.unsqueeze(x, 1)
    y = x ** 2 + 0.2 * torch.rand(x.size())
    return x, y

if __name__ == "__main__":
    # 1. 數(shù)據(jù)準備
    x,y=prepare_data()
    plt.scatter(x.numpy(), y.numpy())
    plt.show()
    # 2. 網(wǎng)絡(luò)搭建
    net = Net()
    # 3. 訓練
    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    loss_F = torch.nn.MSELoss()
    for iter in range(100):
        pred = net(x)
        loss = loss_F(pred, y)
        optimizer.zero_grad()
        loss.backward()
        print(loss.detach().numpy())
        optimizer.step()
    # 只保存網(wǎng)絡(luò)狀態(tài)
    torch.save(net.state_dict(), "./net_param.pkl")
  • 恢復過程: test.py
from main import Net, prepare_data
import torch
import matplotlib.pyplot as plt

if __name__ == "__main__":
    net = Net()
    x, y = prepare_data()
    plt.scatter(x.numpy(), y.numpy())
    plt.show()
    # load是加載成dict形式
    net.load_state_dict(torch.load("net_param.pkl"))
    loss_F = torch.nn.MSELoss()
    pred = net(x)
    loss = loss_F(pred, y) # loos值與訓練最后一次迭代的loss值相同
    print(loss.detach().numpy())

1.2 好處

  • 可以定義新的類。在test.py中可以定義新的class, forward可以有不同的方式。只要有相同名字的參數(shù),都可以load成功
class New_Net(torch.nn.Module): # class名字也修改了
    def __init__(self):
        super().__init__()
        self.h1 = torch.nn.Linear(1, 10)
        self.h2 = torch.nn.Linear(10, 1)

    def forward(self, x):
        x = F.tanh(self.h1(x)) # 修改了激活函數(shù)
        x = self.h2(x)
        return x
  • 加載更快

2. 保存網(wǎng)絡(luò)結(jié)構(gòu)和參數(shù)

2.1 示例

  • main.py
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.h1 = torch.nn.Linear(1, 10)
        self.h2 = torch.nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.h1(x))
        x = self.h2(x)
        return x

def prepare_data():
    torch.manual_seed(1)  # 保證每次生成的隨機數(shù)相同
    x = torch.linspace(-1, 1, 50)
    x = torch.unsqueeze(x, 1)
    y = x ** 2 + 0.2 * torch.rand(x.size())
    return x, y

if __name__ == "__main__":
    # 1. 數(shù)據(jù)準備
    x,y=prepare_data()
    plt.scatter(x.numpy(), y.numpy())
    plt.show()
    # 2. 網(wǎng)絡(luò)搭建
    net = Net()
    # 3. 訓練
    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    loss_F = torch.nn.MSELoss()
    for iter in range(100):
        pred = net(x)
        loss = loss_F(pred, y)
        optimizer.zero_grad()
        loss.backward()
        print(loss.detach().numpy())
        optimizer.step()
    # 只保存網(wǎng)絡(luò)狀態(tài)
    torch.save(net, "./net.pkl") #直接保存net,而不是net.state_dict()
  • test.py
from main import Net, prepare_data
import torch
import matplotlib.pyplot as plt

if __name__ == "__main__":
    x, y = prepare_data()
    plt.scatter(x.numpy(), y.numpy())
    plt.show()
    # load是加載成dict形式
    net = torch.load("net.pkl")
    loss_F = torch.nn.MSELoss()
    pred = net(x)
    loss = loss_F(pred, y)
    print(loss.detach().numpy())

2.2 弊端

  • 與特定class綁定了。即: 雖然test.py中的net是通過load的來的, 但是還是需要import訓練時候的那個類Net(否則會報錯)
  • 不靈活。因為結(jié)構(gòu)被定死了,不能定義新的層等等。
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容