DDP

DistributedDataParallel config

Import

from torch.utils.data.distributed import DistributedSampler
# import torch.distributed as dist

main

if __name__ == '__main__':
    ...
    parser.add_argument("--local_rank", type=int, default=0)
    ...

train

def train(args):
    # 初始化,設(shè)置通信方式
    torch.distributed.init_process_group(backend="nccl")
    # 設(shè)置當(dāng)前進(jìn)程的GPU. local_rank為設(shè)備編號
    local_rank = args.local_rank
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    ...
    dataset = ...
    # 由于每個(gè)進(jìn)程獨(dú)立進(jìn)行,需要額外設(shè)置sampler
    sampler = DistributedSampler(dataset)
    # batch_siz設(shè)置. 總的batch_size為batch_size*num_gpu
    # 使用sampler之后,不能設(shè)置shuffle為true
    train_loader = DataLoader(train_data, batch_size=6,
                               num_workers=8,sampler = sampler)
    ...
    # model在各個(gè)進(jìn)程的初始化參數(shù)需要相同,可以設(shè)置相同的種子
    model = ...
    # model移入當(dāng)前進(jìn)程的GPU
    model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(\
            model,device_ids=[local_rank],\
            output_device=local_rank,find_unused_parameters=True)
    ...
    for epoch in range(args.start_ep,20):
        # 設(shè)置這個(gè)地方可以使每個(gè)epoch的batch隨機(jī)  
        sampler.set_epoch(epoch)
        ...
        for step, batch_x in enumerate(train_loader):
            batch_x.to(device)
        ...

同步輸出

使用DDP時(shí)每個(gè)進(jìn)程都會(huì)輸出信息,eg: print, log,并且輸出可能不一致,要同步各個(gè)進(jìn)程之間的信息,統(tǒng)一輸出,可以用以下代碼

if step % 100 == 0:
    # 同步數(shù)據(jù)。只能同步tensor類型的數(shù)據(jù)
    # 同步數(shù)據(jù)時(shí)默認(rèn)對各個(gè)進(jìn)程的數(shù)據(jù)求和
    torch.distributed.all_reduce(loss)# 同步loss
    ...
    # 現(xiàn)在各個(gè)進(jìn)程之間同步數(shù)據(jù)之后,再選擇在一個(gè)進(jìn)程里進(jìn)行輸出、保存
    if torch.distributed.get_rank()==0:                
        ...
    else:
        pass

Warning

如果出現(xiàn)這個(gè)錯(cuò)誤,是因?yàn)镚PU上有不參與計(jì)算loss的Variable。如果找不到這些參數(shù),可以在計(jì)算loss之后,反傳梯度,打印每個(gè)parameters()的.grad,如果為None,就是沒有參與計(jì)算的參數(shù)。

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can ena
ble unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participat
e in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` fun
ction. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容