pytorch 批數(shù)據(jù)訓(xùn)練

import torch
import torch.utils.data as Data

BATCH_SIZE = 8

x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

torch_dataset = Data.TensorDataset(data_tensor=x,target_tensor=y)
loader = Data.DataLoader(
    dataset = torch_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = 2,
)

for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):
        # training
        print('Epoch: ',epoch,
              '| Step: ',step,
              '| batch x: ',batch_x.numpy(),
              '| batch y: ',batch_y.numpy())

result:

Epoch:  0 | Step:  0 | batch x:  [  9.   5.   4.   8.  10.   1.   3.   6.] | batch y:  [  2.   6.   7.   3.   1.  10.   8.   5.]
Epoch:  0 | Step:  1 | batch x:  [ 2.  7.] | batch y:  [ 9.  4.]
Epoch:  1 | Step:  0 | batch x:  [  3.   7.   8.   4.  10.   2.   9.   6.] | batch y:  [ 8.  4.  3.  7.  1.  9.  2.  5.]
Epoch:  1 | Step:  1 | batch x:  [ 5.  1.] | batch y:  [  6.  10.]
Epoch:  2 | Step:  0 | batch x:  [ 1.  8.  2.  7.  3.  5.  6.  4.] | batch y:  [ 10.   3.   9.   4.   8.   6.   5.   7.]
Epoch:  2 | Step:  1 | batch x:  [  9.  10.] | batch y:  [ 2.  1.]
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容