Task04: 模型訓(xùn)練與驗證

0. 數(shù)據(jù)集搭建

  1. 訓(xùn)練集(Train Set):模型用于訓(xùn)練和調(diào)整模型參數(shù);
  2. 驗證集(Validation Set):用來驗證模型精度和調(diào)整模型超參數(shù);
  3. 測試集(Test Set):驗證模型的泛化能力。

在數(shù)據(jù)建模比賽中,一般三者都已經(jīng)分好了,即:訓(xùn)練集、驗證集發(fā)放數(shù)據(jù)和標(biāo)簽,而測試集僅發(fā)放數(shù)據(jù);而如果賽方?jīng)]有提前劃分驗證集,則需要參賽人員自行劃分,有以下劃分方法

  1. 留出法(Hold-Out):直接對數(shù)據(jù)隨機劃分成兩份,適用于數(shù)據(jù)量大的情況;
  2. 交叉驗證法(Cross Validation,CV):對數(shù)據(jù)劃分為若干折(等分),對每一折進行驗證時,其余折用于訓(xùn)練,驗證集精度以各份平均表示,適用于數(shù)據(jù)量一般的情況;
  3. 自助采樣法(BootStrap):有放回采樣獲得訓(xùn)練集和驗證集,適用于數(shù)據(jù)量較少的情況。

數(shù)據(jù)集劃分的原則是:每份數(shù)據(jù)的標(biāo)簽分布都能代表整體分布。

1. 模型訓(xùn)練與驗證

主要步驟如下:

  • 構(gòu)造訓(xùn)練集和驗證集;
  • 每輪進行訓(xùn)練和驗證,并根據(jù)最優(yōu)驗證集精度保存模型。

模型構(gòu)建、數(shù)據(jù)集構(gòu)建過程與上一task一致,這里不贅述。需要注意的是每次交替訓(xùn)練和驗證時要切換模型的狀態(tài):

def train(train_loader, model, criterion, optimizer, epoch):
    # 切換模型為訓(xùn)練模式
    model.train()

    for i, (input, target) in enumerate(train_loader):
        c0, c1, c2, c3, c4, c5 = model(data[0])
        loss = criterion(c0, data[1][:, 0]) + \
                criterion(c1, data[1][:, 1]) + \
                criterion(c2, data[1][:, 2]) + \
                criterion(c3, data[1][:, 3]) + \
                criterion(c4, data[1][:, 4]) + \
                criterion(c5, data[1][:, 5])
        loss /= 6
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def validate(val_loader, model, criterion):
    # 切換模型為預(yù)測模型
    model.eval()
    val_loss = []

    # 不記錄模型梯度信息
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            c0, c1, c2, c3, c4, c5 = model(data[0])
            loss = criterion(c0, data[1][:, 0]) + \
                    criterion(c1, data[1][:, 1]) + \
                    criterion(c2, data[1][:, 2]) + \
                    criterion(c3, data[1][:, 3]) + \
                    criterion(c4, data[1][:, 4]) + \
                    criterion(c5, data[1][:, 5])
            loss /= 6
            val_loss.append(loss.item())
    return np.mean(val_loss)

最后保存/加載最優(yōu)模型

torch.save(model_object.state_dict(), 'model.pt')
model.load_state_dict(torch.load(' model.pt'))

2. 調(diào)lian參dan trick

一般流程
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容