一個完整的神經(jīng)網(wǎng)絡(luò)訓(xùn)練總體流程:
1.定義神經(jīng)網(wǎng)絡(luò)
2.輸入數(shù)據(jù)進(jìn)行迭代
3.損失函數(shù)計算損失
4.梯度反向傳播
5.更新網(wǎng)絡(luò)權(quán)重參數(shù)
定義神經(jīng)網(wǎng)絡(luò)
import torch
import torch.nn.functional as F
#第一種定義方式
class Net(torch.nn.Module):
def __init__(self,n_feature,n_hidden,n_output):
super(Net,self).__init__()
self.hidden = torch.nn.Linear(n_feature,n_hidden)
self.out = torch.nn.Linear(n_hidden,n_output)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.out(x)
return x
net1 = Net(n_feature=2,n_hidden=10,n_output=2)
print(net)
#第二種定義方式
net2 = torch.nn.Sequential(
torch.nn.Linear(2,10),
torch.nn.ReLU(),
torch.nn.Linear(10,2)
)
print(net2)
輸入數(shù)據(jù)進(jìn)行迭代
n_data = torch.ones(100,2)
x0 = torch.normal(2*n_data,1)
y0 = torch.zeros(100)
x1 = torch.normal(-2*n_data,1)
y1 = torch.ones(100)
x = torch.cat((x0,x1),0).type(torch.FloatTensor)
y = torch.cat((y0,y1),).type(torch.LongTensor)
x,y = Variable(x),Variable(y)
out=net(x)
損失函數(shù)計算損失
損失函數(shù)包括L1損失函數(shù)、MSE損失函數(shù)、交叉熵?fù)p失函數(shù)等
loss = torch.nn.CrossEntropyLoss()
loss(out,y)
梯度下降反向傳播
optim = torch.nn.optim.SGD(net.parameters(),lr=0.02)
optim.zero_grad()
loss.backward()
#更新參數(shù)
optim.step()
參考鏈接:https://github.com/MorvanZhou/PyTorch-Tutorial/tree/master/tutorial-contents