天池-街景字符編碼識別-task4

模型復(fù)雜度并非越高越好,要避免過擬合。訓(xùn)練集和驗證集選擇,數(shù)據(jù)分布盡量保持一致。

劃分的方式: hold-out validation, k-fold cross validation和bootstrapping.

hold-out劃分簡單,適用數(shù)據(jù)量大的情況,但因為只得到了一份驗證集,有可能導(dǎo)致模型在驗證集上過擬合。

*cross validation優(yōu)點(diǎn)是驗證集精度比較可靠,訓(xùn)練K次可以得到K個有多樣性差異的模型;CV驗證的缺點(diǎn)是需要訓(xùn)練K次,不適合數(shù)據(jù)量很大的情況。

bootstrap適用數(shù)據(jù)較少的情況

train_loader = torch.utils.data.DataLoader(

? ? train_dataset,

? ? batch_size=10,

? ? shuffle=True,

? ? num_workers=10,

)

? ?

val_loader = torch.utils.data.DataLoader(

? ? val_dataset,

? ? batch_size=10,

? ? shuffle=False,

? ? num_workers=10,

)


model = SVHN_Model1()

criterion = nn.CrossEntropyLoss (size_average=False)

optimizer = torch.optim.Adam(model.parameters(), 0.001)

best_loss = 1000.0

for epoch in range(20):

? ? print('Epoch: ', epoch)


? ? train(train_loader, model, criterion, optimizer, epoch)

? ? val_loss = validate(val_loader, model, criterion)

? ?

? ? # 記錄下驗證集精度

? ? if val_loss < best_loss:

? ? ? ? best_loss = val_loss

? ? ? ? torch.save(model.state_dict(), './model.pt')

def train(train_loader, model, criterion, optimizer, epoch):

? ? # 切換模型為訓(xùn)練模式

? ? model.train()


? ? for i, (input, target) in enumerate(train_loader):

? ? ? ? c0, c1, c2, c3, c4, c5 = model(data[0])

? ? ? ? loss = criterion(c0, data[1][:, 0]) + \

? ? ? ? ? ? ? ? criterion(c1, data[1][:, 1]) + \

? ? ? ? ? ? ? ? criterion(c2, data[1][:, 2]) + \

? ? ? ? ? ? ? ? criterion(c3, data[1][:, 3]) + \

? ? ? ? ? ? ? ? criterion(c4, data[1][:, 4]) + \

? ? ? ? ? ? ? ? criterion(c5, data[1][:, 5])

? ? ? ? loss /= 6

? ? ? ? optimizer.zero_grad()

? ? ? ? loss.backward()

? ? ? ? optimizer.step()

def validate(val_loader, model, criterion):

? ? # 切換模型為預(yù)測模型

? ? model.eval()

? ? val_loss = []


? ? # 不記錄模型梯度信息

? ? with torch.no_grad():

? ? ? ? for i, (input, target) in enumerate(val_loader):

? ? ? ? ? ? c0, c1, c2, c3, c4, c5 = model(data[0])

? ? ? ? ? ? loss = criterion(c0, data[1][:, 0]) + \

? ? ? ? ? ? ? ? ? ? criterion(c1, data[1][:, 1]) + \

? ? ? ? ? ? ? ? ? ? criterion(c2, data[1][:, 2]) + \

? ? ? ? ? ? ? ? ? ? criterion(c3, data[1][:, 3]) + \

? ? ? ? ? ? ? ? ? ? criterion(c4, data[1][:, 4]) + \

? ? ? ? ? ? ? ? ? ? criterion(c5, data[1][:, 5])

? ? ? ? ? ? loss /= 6

? ? ? ? ? ? val_loss.append(loss.item())

? ? return np.mean(val_loss)

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容