大模型的工作原理:分布式訓(xùn)練入門(mén)

我有個(gè)同事的簽名是"大模型是真的大",大模型(如GPT、LLaMA)之所以被稱(chēng)為“大”,不僅因?yàn)樗鼈兊膮?shù)量高達(dá)數(shù)十億甚至上萬(wàn)億,更因?yàn)樗鼈冃枰獜?qiáng)大的計(jì)算資源來(lái)完成訓(xùn)練。這些資源遠(yuǎn)超單機(jī)能承載的范圍,因此,大模型的訓(xùn)練離不開(kāi)分布式訓(xùn)練技術(shù)。

分布式訓(xùn)練概述

分布式訓(xùn)練通過(guò)將訓(xùn)練任務(wù)分?jǐn)偟蕉鄠€(gè)計(jì)算設(shè)備(如GPU、NPU或CPU)上,以加速訓(xùn)練過(guò)程。其主要目標(biāo)包括:

  1. 提高計(jì)算效率:減少訓(xùn)練所需時(shí)間。
  2. 擴(kuò)展模型規(guī)模:支持更大的模型和更復(fù)雜的數(shù)據(jù)集。(單個(gè)設(shè)備的顯存不足以支撐大模型)
  3. 高效利用資源:通過(guò)并行計(jì)算,充分利用硬件能力。

在實(shí)際操作中,分布式訓(xùn)練需要解決以下關(guān)鍵問(wèn)題:

  • 如何將數(shù)據(jù)分配到多個(gè)設(shè)備上?
  • 如何在多個(gè)設(shè)備之間共享和同步模型參數(shù)?
  • 如何保證訓(xùn)練的準(zhǔn)確性和效率?

分布式訓(xùn)練的關(guān)鍵概念

數(shù)據(jù)并行(Data Parallelism)

數(shù)據(jù)并行是最常見(jiàn)的分布式訓(xùn)練方法。其核心思想是將數(shù)據(jù)切分成多個(gè)小批次(mini-batches),并將這些小批次分發(fā)到不同的設(shè)備上進(jìn)行計(jì)算。

流程

  1. 每個(gè)設(shè)備(如GPU)獲取不同的小批次數(shù)據(jù)。
  2. 每個(gè)設(shè)備獨(dú)立計(jì)算其對(duì)應(yīng)數(shù)據(jù)的小批次的梯度。
  3. 匯總所有設(shè)備的梯度,并更新全局模型參數(shù)。
  4. 廣播更新后的參數(shù)給所有設(shè)備。

優(yōu)點(diǎn):實(shí)現(xiàn)簡(jiǎn)單,對(duì)數(shù)據(jù)量較大的場(chǎng)景效果顯著。
缺點(diǎn):對(duì)模型規(guī)模的擴(kuò)展性有限,設(shè)備間通信開(kāi)銷(xiāo)較大。

以下是一個(gè)示意圖:

image.png

模型并行(Model Parallelism)

當(dāng)模型參數(shù)過(guò)大,單個(gè)設(shè)備無(wú)法容納時(shí),可以采用模型并行。模型并行將模型拆分成多個(gè)部分,并分配到不同設(shè)備上。

流程

  1. 將模型的不同層(或塊)分配到不同的設(shè)備。
  2. 每個(gè)設(shè)備只計(jì)算屬于自己的那部分模型。
  3. 設(shè)備間通過(guò)通信共享中間結(jié)果。

優(yōu)點(diǎn):適用于超大模型。
缺點(diǎn):實(shí)現(xiàn)復(fù)雜,設(shè)備間的同步通信成本較高。

以下是模型并行的示意圖:

image.png

混合并行(Hybrid Parallelism)

混合并行結(jié)合了數(shù)據(jù)并行和模型并行的優(yōu)點(diǎn)。在這種方法中,既對(duì)數(shù)據(jù)進(jìn)行切分,又將模型分割到多個(gè)設(shè)備上。

優(yōu)點(diǎn):充分利用硬件資源,適用于超大規(guī)模的訓(xùn)練任務(wù)。
缺點(diǎn):實(shí)現(xiàn)和調(diào)試更復(fù)雜。

以下是混合并行的示意圖:

image.png

Pipeline 并行(Pipeline Parallelism)

Pipeline 并行是一種特殊的模型并行,它將模型的不同層分配到不同設(shè)備上,但同時(shí)允許多個(gè)小批次的數(shù)據(jù)在流水線(xiàn)中流動(dòng)。

優(yōu)點(diǎn):提高了設(shè)備利用率,減少了空閑時(shí)間。
缺點(diǎn):需要處理流水線(xiàn)中的梯度同步問(wèn)題。

以下是Pipeline并行的示意圖:

image.png

開(kāi)源框架支持的分布式訓(xùn)練方法

目前主流的深度學(xué)習(xí)框架都支持分布式訓(xùn)練,比如 PyTorch、TensorFlow、DeepSpeed 和 Hugging Face Transformers。以下是一些常用的工具和方法。

1. PyTorch 的分布式訓(xùn)練

PyTorch 提供了torch.distributed模塊,用于實(shí)現(xiàn)分布式訓(xùn)練。

以下是一個(gè)簡(jiǎn)單的數(shù)據(jù)并行示例:

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# 初始化分布式環(huán)境
dist.init_process_group("nccl", rank=0, world_size=1)

# 模型和數(shù)據(jù)
model = torch.nn.Linear(10, 1).to("cuda:0")
ddp_model = DDP(model, device_ids=[0])
data = torch.randn(20, 10).to("cuda:0")
target = torch.randn(20, 1).to("cuda:0")

# 損失函數(shù)和優(yōu)化器
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)

# 訓(xùn)練
outputs = ddp_model(data)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()

2. DeepSpeed 的零冗余優(yōu)化

DeepSpeed 是一種高效的分布式訓(xùn)練框架,特別適用于超大規(guī)模模型。其核心特性包括 ZeRO(Zero Redundancy Optimizer)。

ZeRO 的關(guān)鍵在于分布式地存儲(chǔ)優(yōu)化器狀態(tài)、梯度和參數(shù),從而顯著降低每個(gè)設(shè)備的內(nèi)存需求。

使用 DeepSpeed 訓(xùn)練模型的示例:

import deepspeed

# 定義模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = MyModel()

# 配置 DeepSpeed
ds_config = {
    "train_batch_size": 8,
    "fp16": {
        "enabled": True
    },
    "zero_optimization": {
        "stage": 2
    }
}

# 初始化
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config_params=ds_config
)

# 訓(xùn)練
data = torch.randn(8, 10).to(model_engine.local_rank)
target = torch.randn(8, 1).to(model_engine.local_rank)
loss = torch.nn.MSELoss()(model_engine(data), target)
model_engine.backward(loss)
model_engine.step()

3. Hugging Face Transformers 的 Trainer

Hugging Face Transformers 提供了一個(gè)開(kāi)箱即用的Trainer類(lèi),支持分布式訓(xùn)練。以下是一個(gè)訓(xùn)練 GPT 模型的示例:

from transformers import Trainer, TrainingArguments, GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    save_steps=10,
    save_total_limit=2,
    fp16=True,
    deepspeed="./ds_config.json",  # 支持 DeepSpeed
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=my_dataset,
)

trainer.train()

Ray:分布式訓(xùn)練中的“全能工具”

在深度學(xué)習(xí)的大規(guī)模分布式訓(xùn)練中,Ray 是一個(gè)不可忽視的工具。它不僅是一個(gè)分布式計(jì)算框架,還通過(guò)高層封裝,提供了許多強(qiáng)大的工具和庫(kù),如 Ray Train、Ray Tune 和 Ray Serve,幫助開(kāi)發(fā)者快速構(gòu)建和管理分布式應(yīng)用。所以單獨(dú)寫(xiě)一章。

什么是 Ray?

Ray 是一個(gè)通用的分布式計(jì)算框架,核心目標(biāo)是讓開(kāi)發(fā)者可以輕松實(shí)現(xiàn)分布式程序。它支持各種場(chǎng)景,從機(jī)器學(xué)習(xí)訓(xùn)練、超參數(shù)調(diào)優(yōu)到大規(guī)模數(shù)據(jù)處理和在線(xiàn)推理。

Ray 的核心特點(diǎn):

  1. 簡(jiǎn)單易用:開(kāi)發(fā)者只需用 Python 編寫(xiě)代碼,Ray 會(huì)自動(dòng)幫你處理分布式調(diào)度。
  2. 擴(kuò)展性強(qiáng):可以在單機(jī)上調(diào)試,部署到多節(jié)點(diǎn)集群時(shí)只需簡(jiǎn)單調(diào)整。
  3. 高效的資源管理:支持動(dòng)態(tài)資源分配和任務(wù)調(diào)度。
  4. 組件豐富:Ray 包含多個(gè)高層庫(kù),如 Ray Train、Ray Tune 和 Ray Serve,分別對(duì)應(yīng)訓(xùn)練、超參數(shù)調(diào)優(yōu)和在線(xiàn)推理。

Ray 的核心概念

在使用 Ray 時(shí),需要理解以下幾個(gè)核心概念:

  1. Task(任務(wù))
    一個(gè) Ray Task 是一個(gè)可以異步運(yùn)行的函數(shù)。它會(huì)自動(dòng)分配到集群中的空閑計(jì)算資源。

    示例代碼:

    import ray
    
    ray.init()  # 初始化 Ray
    
    @ray.remote
    def slow_function(x):
        import time
        time.sleep(1)
        return x ** 2
    
    futures = [slow_function.remote(i) for i in range(10)]
    results = ray.get(futures)
    print(results)
    

    解釋:上述代碼中,slow_function 被聲明為遠(yuǎn)程任務(wù)(@ray.remote),會(huì)并行執(zhí)行在集群中的不同節(jié)點(diǎn)上。

  2. Actor(角色)
    Actor 是 Ray 中的有狀態(tài)任務(wù),可以用來(lái)保存中間狀態(tài)。例如,深度學(xué)習(xí)的模型實(shí)例可以作為一個(gè) Actor 存在。

    示例代碼:

    @ray.remote
    class Counter:
        def __init__(self):
            self.count = 0
    
        def increment(self):
            self.count += 1
            return self.count
    
    counter = Counter.remote()
    print(ray.get(counter.increment.remote()))  # 輸出 1
    print(ray.get(counter.increment.remote()))  # 輸出 2
    
  3. Cluster(集群)
    Ray 的集群管理非常靈活,你可以在本地運(yùn)行單節(jié)點(diǎn),也可以擴(kuò)展到上千節(jié)點(diǎn)的分布式集群。

Ray Train:用于分布式訓(xùn)練

Ray Train 是 Ray 為分布式訓(xùn)練任務(wù)設(shè)計(jì)的高層庫(kù)。它支持各種深度學(xué)習(xí)框架(如 PyTorch 和 TensorFlow),并通過(guò)高效的資源管理和分布式調(diào)度簡(jiǎn)化訓(xùn)練過(guò)程。

核心功能

  • 自動(dòng)分布式支持:數(shù)據(jù)并行訓(xùn)練。
  • 易于集成:與現(xiàn)有的 PyTorch 或 TensorFlow 代碼無(wú)縫對(duì)接。
  • 靈活擴(kuò)展:支持 CPU/GPU 混合環(huán)境。

以下是一個(gè)使用 Ray Train 進(jìn)行分布式訓(xùn)練的 PyTorch 示例:

import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import torch
import torch.nn as nn
import torch.optim as optim

# 定義簡(jiǎn)單的模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 定義訓(xùn)練函數(shù)
def train_loop_per_worker(config):
    model = MyModel()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"])
    loss_fn = nn.MSELoss()

    # 模擬數(shù)據(jù)
    data = torch.randn(100, 10)
    target = torch.randn(100, 1)

    for _ in range(5):  # 模擬訓(xùn)練
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        print(f"Loss: {loss.item()}")

# 使用 Ray Train 進(jìn)行分布式訓(xùn)練
ray.init()
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=4),  # 使用 4 個(gè) worker
    train_loop_config={"lr": 0.01},  # 傳遞訓(xùn)練超參數(shù)
)
trainer.fit()

Ray Tune:超參數(shù)調(diào)優(yōu)

大模型訓(xùn)練中,找到最佳超參數(shù)(如學(xué)習(xí)率、batch size)非常重要。Ray 提供了 Ray Tune,這是一個(gè)分布式超參數(shù)調(diào)優(yōu)框架,支持多種搜索算法和調(diào)度策略。

示例代碼:

from ray import tune
from ray.tune.schedulers import ASHAScheduler

def train(config):
    import torch
    model = torch.nn.Linear(10, 1)
    optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])
    data = torch.randn(100, 10)
    target = torch.randn(100, 1)
    loss_fn = torch.nn.MSELoss()
    for epoch in range(10):
        optimizer.zero_grad()
        loss = loss_fn(model(data), target)
        loss.backward()
        optimizer.step()
        tune.report(loss=loss.item())  # 上報(bào)結(jié)果給 Ray Tune

search_space = {"lr": tune.grid_search([0.01, 0.1, 1.0])}
scheduler = ASHAScheduler()
tune.run(
    train,
    config=search_space,
    scheduler=scheduler,
    num_samples=3
)

Ray Serve:分布式推理

訓(xùn)練完成后,大模型的推理服務(wù)同樣需要分布式支持。Ray 的 Serve 模塊提供了高效的分布式推理能力。

以下是一個(gè)簡(jiǎn)單的 Ray Serve 示例:

from ray import serve
import ray

ray.init()
serve.start()

@serve.deployment
def predict(request):
    return {"message": "Hello from Ray Serve!"}

predict.deploy()

import requests
response = requests.get("http://127.0.0.1:8000/predict")
print(response.json())

為什么選擇 Ray?

Ray 是分布式訓(xùn)練和部署的一站式解決方案:

  • 如果你想高效地訓(xùn)練大模型,Ray Train 提供了數(shù)據(jù)并行和資源調(diào)度能力。
  • 如果你需要優(yōu)化超參數(shù),Ray Tune 可以讓你輕松實(shí)現(xiàn)大規(guī)模調(diào)優(yōu)。
  • 如果你需要部署分布式推理服務(wù),Ray Serve 是理想選擇。

相比其他工具(如 PyTorch DDP、DeepSpeed),Ray 的優(yōu)勢(shì)在于更廣的應(yīng)用場(chǎng)景和更高的靈活性。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容