1. 只保存參數(shù)
推薦
1.1 示例
- 訓練過程: main.py
# file: main.py
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.h1 = torch.nn.Linear(1, 10)
self.h2 = torch.nn.Linear(10, 1)
def forward(self, x):
x = F.relu(self.h1(x))
x = self.h2(x)
return x
def prepare_data():
torch.manual_seed(1) # 保證每次生成的隨機數(shù)相同
x = torch.linspace(-1, 1, 50)
x = torch.unsqueeze(x, 1)
y = x ** 2 + 0.2 * torch.rand(x.size())
return x, y
if __name__ == "__main__":
# 1. 數(shù)據(jù)準備
x,y=prepare_data()
plt.scatter(x.numpy(), y.numpy())
plt.show()
# 2. 網(wǎng)絡(luò)搭建
net = Net()
# 3. 訓練
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
loss_F = torch.nn.MSELoss()
for iter in range(100):
pred = net(x)
loss = loss_F(pred, y)
optimizer.zero_grad()
loss.backward()
print(loss.detach().numpy())
optimizer.step()
# 只保存網(wǎng)絡(luò)狀態(tài)
torch.save(net.state_dict(), "./net_param.pkl")
- 恢復過程: test.py
from main import Net, prepare_data
import torch
import matplotlib.pyplot as plt
if __name__ == "__main__":
net = Net()
x, y = prepare_data()
plt.scatter(x.numpy(), y.numpy())
plt.show()
# load是加載成dict形式
net.load_state_dict(torch.load("net_param.pkl"))
loss_F = torch.nn.MSELoss()
pred = net(x)
loss = loss_F(pred, y) # loos值與訓練最后一次迭代的loss值相同
print(loss.detach().numpy())
1.2 好處
- 可以定義新的類。在
test.py中可以定義新的class,forward可以有不同的方式。只要有相同名字的參數(shù),都可以load成功
class New_Net(torch.nn.Module): # class名字也修改了
def __init__(self):
super().__init__()
self.h1 = torch.nn.Linear(1, 10)
self.h2 = torch.nn.Linear(10, 1)
def forward(self, x):
x = F.tanh(self.h1(x)) # 修改了激活函數(shù)
x = self.h2(x)
return x
- 加載更快
2. 保存網(wǎng)絡(luò)結(jié)構(gòu)和參數(shù)
2.1 示例
- main.py
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.h1 = torch.nn.Linear(1, 10)
self.h2 = torch.nn.Linear(10, 1)
def forward(self, x):
x = F.relu(self.h1(x))
x = self.h2(x)
return x
def prepare_data():
torch.manual_seed(1) # 保證每次生成的隨機數(shù)相同
x = torch.linspace(-1, 1, 50)
x = torch.unsqueeze(x, 1)
y = x ** 2 + 0.2 * torch.rand(x.size())
return x, y
if __name__ == "__main__":
# 1. 數(shù)據(jù)準備
x,y=prepare_data()
plt.scatter(x.numpy(), y.numpy())
plt.show()
# 2. 網(wǎng)絡(luò)搭建
net = Net()
# 3. 訓練
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
loss_F = torch.nn.MSELoss()
for iter in range(100):
pred = net(x)
loss = loss_F(pred, y)
optimizer.zero_grad()
loss.backward()
print(loss.detach().numpy())
optimizer.step()
# 只保存網(wǎng)絡(luò)狀態(tài)
torch.save(net, "./net.pkl") #直接保存net,而不是net.state_dict()
- test.py
from main import Net, prepare_data
import torch
import matplotlib.pyplot as plt
if __name__ == "__main__":
x, y = prepare_data()
plt.scatter(x.numpy(), y.numpy())
plt.show()
# load是加載成dict形式
net = torch.load("net.pkl")
loss_F = torch.nn.MSELoss()
pred = net(x)
loss = loss_F(pred, y)
print(loss.detach().numpy())
2.2 弊端
- 與特定class綁定了。即: 雖然
test.py中的net是通過load的來的, 但是還是需要import訓練時候的那個類Net(否則會報錯) - 不靈活。因為結(jié)構(gòu)被定死了,不能定義新的層等等。