我有個(gè)同事的簽名是"大模型是真的大",大模型(如GPT、LLaMA)之所以被稱(chēng)為“大”,不僅因?yàn)樗鼈兊膮?shù)量高達(dá)數(shù)十億甚至上萬(wàn)億,更因?yàn)樗鼈冃枰獜?qiáng)大的計(jì)算資源來(lái)完成訓(xùn)練。這些資源遠(yuǎn)超單機(jī)能承載的范圍,因此,大模型的訓(xùn)練離不開(kāi)分布式訓(xùn)練技術(shù)。
分布式訓(xùn)練概述
分布式訓(xùn)練通過(guò)將訓(xùn)練任務(wù)分?jǐn)偟蕉鄠€(gè)計(jì)算設(shè)備(如GPU、NPU或CPU)上,以加速訓(xùn)練過(guò)程。其主要目標(biāo)包括:
- 提高計(jì)算效率:減少訓(xùn)練所需時(shí)間。
- 擴(kuò)展模型規(guī)模:支持更大的模型和更復(fù)雜的數(shù)據(jù)集。(單個(gè)設(shè)備的顯存不足以支撐大模型)
- 高效利用資源:通過(guò)并行計(jì)算,充分利用硬件能力。
在實(shí)際操作中,分布式訓(xùn)練需要解決以下關(guān)鍵問(wèn)題:
- 如何將數(shù)據(jù)分配到多個(gè)設(shè)備上?
- 如何在多個(gè)設(shè)備之間共享和同步模型參數(shù)?
- 如何保證訓(xùn)練的準(zhǔn)確性和效率?
分布式訓(xùn)練的關(guān)鍵概念
數(shù)據(jù)并行(Data Parallelism)
數(shù)據(jù)并行是最常見(jiàn)的分布式訓(xùn)練方法。其核心思想是將數(shù)據(jù)切分成多個(gè)小批次(mini-batches),并將這些小批次分發(fā)到不同的設(shè)備上進(jìn)行計(jì)算。
流程:
- 每個(gè)設(shè)備(如GPU)獲取不同的小批次數(shù)據(jù)。
- 每個(gè)設(shè)備獨(dú)立計(jì)算其對(duì)應(yīng)數(shù)據(jù)的小批次的梯度。
- 匯總所有設(shè)備的梯度,并更新全局模型參數(shù)。
- 廣播更新后的參數(shù)給所有設(shè)備。
優(yōu)點(diǎn):實(shí)現(xiàn)簡(jiǎn)單,對(duì)數(shù)據(jù)量較大的場(chǎng)景效果顯著。
缺點(diǎn):對(duì)模型規(guī)模的擴(kuò)展性有限,設(shè)備間通信開(kāi)銷(xiāo)較大。
以下是一個(gè)示意圖:

模型并行(Model Parallelism)
當(dāng)模型參數(shù)過(guò)大,單個(gè)設(shè)備無(wú)法容納時(shí),可以采用模型并行。模型并行將模型拆分成多個(gè)部分,并分配到不同設(shè)備上。
流程:
- 將模型的不同層(或塊)分配到不同的設(shè)備。
- 每個(gè)設(shè)備只計(jì)算屬于自己的那部分模型。
- 設(shè)備間通過(guò)通信共享中間結(jié)果。
優(yōu)點(diǎn):適用于超大模型。
缺點(diǎn):實(shí)現(xiàn)復(fù)雜,設(shè)備間的同步通信成本較高。
以下是模型并行的示意圖:

混合并行(Hybrid Parallelism)
混合并行結(jié)合了數(shù)據(jù)并行和模型并行的優(yōu)點(diǎn)。在這種方法中,既對(duì)數(shù)據(jù)進(jìn)行切分,又將模型分割到多個(gè)設(shè)備上。
優(yōu)點(diǎn):充分利用硬件資源,適用于超大規(guī)模的訓(xùn)練任務(wù)。
缺點(diǎn):實(shí)現(xiàn)和調(diào)試更復(fù)雜。
以下是混合并行的示意圖:

Pipeline 并行(Pipeline Parallelism)
Pipeline 并行是一種特殊的模型并行,它將模型的不同層分配到不同設(shè)備上,但同時(shí)允許多個(gè)小批次的數(shù)據(jù)在流水線(xiàn)中流動(dòng)。
優(yōu)點(diǎn):提高了設(shè)備利用率,減少了空閑時(shí)間。
缺點(diǎn):需要處理流水線(xiàn)中的梯度同步問(wèn)題。
以下是Pipeline并行的示意圖:

開(kāi)源框架支持的分布式訓(xùn)練方法
目前主流的深度學(xué)習(xí)框架都支持分布式訓(xùn)練,比如 PyTorch、TensorFlow、DeepSpeed 和 Hugging Face Transformers。以下是一些常用的工具和方法。
1. PyTorch 的分布式訓(xùn)練
PyTorch 提供了torch.distributed模塊,用于實(shí)現(xiàn)分布式訓(xùn)練。
以下是一個(gè)簡(jiǎn)單的數(shù)據(jù)并行示例:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# 初始化分布式環(huán)境
dist.init_process_group("nccl", rank=0, world_size=1)
# 模型和數(shù)據(jù)
model = torch.nn.Linear(10, 1).to("cuda:0")
ddp_model = DDP(model, device_ids=[0])
data = torch.randn(20, 10).to("cuda:0")
target = torch.randn(20, 1).to("cuda:0")
# 損失函數(shù)和優(yōu)化器
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)
# 訓(xùn)練
outputs = ddp_model(data)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
2. DeepSpeed 的零冗余優(yōu)化
DeepSpeed 是一種高效的分布式訓(xùn)練框架,特別適用于超大規(guī)模模型。其核心特性包括 ZeRO(Zero Redundancy Optimizer)。
ZeRO 的關(guān)鍵在于分布式地存儲(chǔ)優(yōu)化器狀態(tài)、梯度和參數(shù),從而顯著降低每個(gè)設(shè)備的內(nèi)存需求。
使用 DeepSpeed 訓(xùn)練模型的示例:
import deepspeed
# 定義模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = torch.nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
model = MyModel()
# 配置 DeepSpeed
ds_config = {
"train_batch_size": 8,
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 2
}
}
# 初始化
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
config_params=ds_config
)
# 訓(xùn)練
data = torch.randn(8, 10).to(model_engine.local_rank)
target = torch.randn(8, 1).to(model_engine.local_rank)
loss = torch.nn.MSELoss()(model_engine(data), target)
model_engine.backward(loss)
model_engine.step()
3. Hugging Face Transformers 的 Trainer
Hugging Face Transformers 提供了一個(gè)開(kāi)箱即用的Trainer類(lèi),支持分布式訓(xùn)練。以下是一個(gè)訓(xùn)練 GPT 模型的示例:
from transformers import Trainer, TrainingArguments, GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained("gpt2")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=10,
save_total_limit=2,
fp16=True,
deepspeed="./ds_config.json", # 支持 DeepSpeed
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=my_dataset,
)
trainer.train()
Ray:分布式訓(xùn)練中的“全能工具”
在深度學(xué)習(xí)的大規(guī)模分布式訓(xùn)練中,Ray 是一個(gè)不可忽視的工具。它不僅是一個(gè)分布式計(jì)算框架,還通過(guò)高層封裝,提供了許多強(qiáng)大的工具和庫(kù),如 Ray Train、Ray Tune 和 Ray Serve,幫助開(kāi)發(fā)者快速構(gòu)建和管理分布式應(yīng)用。所以單獨(dú)寫(xiě)一章。
什么是 Ray?
Ray 是一個(gè)通用的分布式計(jì)算框架,核心目標(biāo)是讓開(kāi)發(fā)者可以輕松實(shí)現(xiàn)分布式程序。它支持各種場(chǎng)景,從機(jī)器學(xué)習(xí)訓(xùn)練、超參數(shù)調(diào)優(yōu)到大規(guī)模數(shù)據(jù)處理和在線(xiàn)推理。
Ray 的核心特點(diǎn):
- 簡(jiǎn)單易用:開(kāi)發(fā)者只需用 Python 編寫(xiě)代碼,Ray 會(huì)自動(dòng)幫你處理分布式調(diào)度。
- 擴(kuò)展性強(qiáng):可以在單機(jī)上調(diào)試,部署到多節(jié)點(diǎn)集群時(shí)只需簡(jiǎn)單調(diào)整。
- 高效的資源管理:支持動(dòng)態(tài)資源分配和任務(wù)調(diào)度。
- 組件豐富:Ray 包含多個(gè)高層庫(kù),如 Ray Train、Ray Tune 和 Ray Serve,分別對(duì)應(yīng)訓(xùn)練、超參數(shù)調(diào)優(yōu)和在線(xiàn)推理。
Ray 的核心概念
在使用 Ray 時(shí),需要理解以下幾個(gè)核心概念:
-
Task(任務(wù))
一個(gè) Ray Task 是一個(gè)可以異步運(yùn)行的函數(shù)。它會(huì)自動(dòng)分配到集群中的空閑計(jì)算資源。示例代碼:
import ray ray.init() # 初始化 Ray @ray.remote def slow_function(x): import time time.sleep(1) return x ** 2 futures = [slow_function.remote(i) for i in range(10)] results = ray.get(futures) print(results)解釋:上述代碼中,
slow_function被聲明為遠(yuǎn)程任務(wù)(@ray.remote),會(huì)并行執(zhí)行在集群中的不同節(jié)點(diǎn)上。 -
Actor(角色)
Actor 是 Ray 中的有狀態(tài)任務(wù),可以用來(lái)保存中間狀態(tài)。例如,深度學(xué)習(xí)的模型實(shí)例可以作為一個(gè) Actor 存在。示例代碼:
@ray.remote class Counter: def __init__(self): self.count = 0 def increment(self): self.count += 1 return self.count counter = Counter.remote() print(ray.get(counter.increment.remote())) # 輸出 1 print(ray.get(counter.increment.remote())) # 輸出 2 Cluster(集群)
Ray 的集群管理非常靈活,你可以在本地運(yùn)行單節(jié)點(diǎn),也可以擴(kuò)展到上千節(jié)點(diǎn)的分布式集群。
Ray Train:用于分布式訓(xùn)練
Ray Train 是 Ray 為分布式訓(xùn)練任務(wù)設(shè)計(jì)的高層庫(kù)。它支持各種深度學(xué)習(xí)框架(如 PyTorch 和 TensorFlow),并通過(guò)高效的資源管理和分布式調(diào)度簡(jiǎn)化訓(xùn)練過(guò)程。
核心功能
- 自動(dòng)分布式支持:數(shù)據(jù)并行訓(xùn)練。
- 易于集成:與現(xiàn)有的 PyTorch 或 TensorFlow 代碼無(wú)縫對(duì)接。
- 靈活擴(kuò)展:支持 CPU/GPU 混合環(huán)境。
以下是一個(gè)使用 Ray Train 進(jìn)行分布式訓(xùn)練的 PyTorch 示例:
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import torch
import torch.nn as nn
import torch.optim as optim
# 定義簡(jiǎn)單的模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 定義訓(xùn)練函數(shù)
def train_loop_per_worker(config):
model = MyModel()
optimizer = optim.SGD(model.parameters(), lr=config["lr"])
loss_fn = nn.MSELoss()
# 模擬數(shù)據(jù)
data = torch.randn(100, 10)
target = torch.randn(100, 1)
for _ in range(5): # 模擬訓(xùn)練
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
print(f"Loss: {loss.item()}")
# 使用 Ray Train 進(jìn)行分布式訓(xùn)練
ray.init()
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=4), # 使用 4 個(gè) worker
train_loop_config={"lr": 0.01}, # 傳遞訓(xùn)練超參數(shù)
)
trainer.fit()
Ray Tune:超參數(shù)調(diào)優(yōu)
大模型訓(xùn)練中,找到最佳超參數(shù)(如學(xué)習(xí)率、batch size)非常重要。Ray 提供了 Ray Tune,這是一個(gè)分布式超參數(shù)調(diào)優(yōu)框架,支持多種搜索算法和調(diào)度策略。
示例代碼:
from ray import tune
from ray.tune.schedulers import ASHAScheduler
def train(config):
import torch
model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])
data = torch.randn(100, 10)
target = torch.randn(100, 1)
loss_fn = torch.nn.MSELoss()
for epoch in range(10):
optimizer.zero_grad()
loss = loss_fn(model(data), target)
loss.backward()
optimizer.step()
tune.report(loss=loss.item()) # 上報(bào)結(jié)果給 Ray Tune
search_space = {"lr": tune.grid_search([0.01, 0.1, 1.0])}
scheduler = ASHAScheduler()
tune.run(
train,
config=search_space,
scheduler=scheduler,
num_samples=3
)
Ray Serve:分布式推理
訓(xùn)練完成后,大模型的推理服務(wù)同樣需要分布式支持。Ray 的 Serve 模塊提供了高效的分布式推理能力。
以下是一個(gè)簡(jiǎn)單的 Ray Serve 示例:
from ray import serve
import ray
ray.init()
serve.start()
@serve.deployment
def predict(request):
return {"message": "Hello from Ray Serve!"}
predict.deploy()
import requests
response = requests.get("http://127.0.0.1:8000/predict")
print(response.json())
為什么選擇 Ray?
Ray 是分布式訓(xùn)練和部署的一站式解決方案:
- 如果你想高效地訓(xùn)練大模型,Ray Train 提供了數(shù)據(jù)并行和資源調(diào)度能力。
- 如果你需要優(yōu)化超參數(shù),Ray Tune 可以讓你輕松實(shí)現(xiàn)大規(guī)模調(diào)優(yōu)。
- 如果你需要部署分布式推理服務(wù),Ray Serve 是理想選擇。
相比其他工具(如 PyTorch DDP、DeepSpeed),Ray 的優(yōu)勢(shì)在于更廣的應(yīng)用場(chǎng)景和更高的靈活性。