學(xué)習(xí)速率的設(shè)置

先上一段代碼

from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

import time
%matplotlib inline
from IPython import display
# Hyper-parameters
input_size = 1
output_size = 1
learning_rate = 0.001

def load_data(filename):
    xys=[]
    with open(filename,'r') as f:
        for line in f:
            xys.append(map(float, line.strip().split()))
        xs, ys = zip(*xys)#解壓,返回二維矩陣式
        return np.asarray(xs), np.asarray(ys)
x_t, y_t=load_data(r'train.txt')
x=[]
y=[]
for i in range(len(x_t)):
    x.append([x_t[i]])
    y.append([y_t[i]])

x_train = np.array(x, dtype=np.float32)
y_train = np.array(y, dtype=np.float32)

plt.scatter(x_train,y_train)
plt.show()
output_1_0.png
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(input_size, output_size) # One in and one out

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
model = Model()
# model = nn.Linear(input_size, output_size)

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.00001)  
for epoch in range(50):
    inputs=torch.from_numpy(x_train)
    targets=torch.from_numpy(y_train)
    
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Zero gradients
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    
    if (epoch+1) % 5 == 0:
        predicted = model(torch.from_numpy(x_train)).detach().numpy()
        plt.plot(x_train, y_train, 'ro', label='Original data')
        plt.plot(x_train, predicted, label='Fitted line')
        plt.legend()
        plt.show()
        display.clear_output(wait=True)
        plt.pause(1)
output_4_0.png

這玩意調(diào)了一整天,就因?yàn)閷W(xué)習(xí)速率,剛開始設(shè)置的只有0.001,本來感覺夠小了,做出來圖發(fā)現(xiàn)擬合線不斷上下跳躍,剛開始還以為是數(shù)據(jù)讀取的問題,換了隨機(jī)數(shù)據(jù)發(fā)現(xiàn)正常。試著調(diào)小學(xué)習(xí)速率,擬合線才出來。第一次遇到……

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容