DistributedDataParallel config
Import
from torch.utils.data.distributed import DistributedSampler
# import torch.distributed as dist
main
if __name__ == '__main__':
...
parser.add_argument("--local_rank", type=int, default=0)
...
train
def train(args):
# 初始化,設(shè)置通信方式
torch.distributed.init_process_group(backend="nccl")
# 設(shè)置當(dāng)前進(jìn)程的GPU. local_rank為設(shè)備編號
local_rank = args.local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
...
dataset = ...
# 由于每個(gè)進(jìn)程獨(dú)立進(jìn)行,需要額外設(shè)置sampler
sampler = DistributedSampler(dataset)
# batch_siz設(shè)置. 總的batch_size為batch_size*num_gpu
# 使用sampler之后,不能設(shè)置shuffle為true
train_loader = DataLoader(train_data, batch_size=6,
num_workers=8,sampler = sampler)
...
# model在各個(gè)進(jìn)程的初始化參數(shù)需要相同,可以設(shè)置相同的種子
model = ...
# model移入當(dāng)前進(jìn)程的GPU
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(\
model,device_ids=[local_rank],\
output_device=local_rank,find_unused_parameters=True)
...
for epoch in range(args.start_ep,20):
# 設(shè)置這個(gè)地方可以使每個(gè)epoch的batch隨機(jī)
sampler.set_epoch(epoch)
...
for step, batch_x in enumerate(train_loader):
batch_x.to(device)
...
同步輸出
使用DDP時(shí)每個(gè)進(jìn)程都會(huì)輸出信息,eg: print, log,并且輸出可能不一致,要同步各個(gè)進(jìn)程之間的信息,統(tǒng)一輸出,可以用以下代碼
if step % 100 == 0:
# 同步數(shù)據(jù)。只能同步tensor類型的數(shù)據(jù)
# 同步數(shù)據(jù)時(shí)默認(rèn)對各個(gè)進(jìn)程的數(shù)據(jù)求和
torch.distributed.all_reduce(loss)# 同步loss
...
# 現(xiàn)在各個(gè)進(jìn)程之間同步數(shù)據(jù)之后,再選擇在一個(gè)進(jìn)程里進(jìn)行輸出、保存
if torch.distributed.get_rank()==0:
...
else:
pass
Warning
如果出現(xiàn)這個(gè)錯(cuò)誤,是因?yàn)镚PU上有不參與計(jì)算loss的Variable。如果找不到這些參數(shù),可以在計(jì)算loss之后,反傳梯度,打印每個(gè)parameters()的.grad,如果為None,就是沒有參與計(jì)算的參數(shù)。
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can ena
ble unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participat
e in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` fun
ction. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).