深度學習長文|使用 JAX 進行 AI 模型訓練

引言

在人工智能模型的開發(fā)旅程中,選擇正確的機器學習開發(fā)框架是一項至關重要的決策。歷史上,眾多庫都曾競相爭奪“人工智能開發(fā)者首選框架”這一令人垂涎的稱號。(你是否還記得 Caffe 和 Theano?)在過去的幾年里,TensorFlow 以其對高效率、基于圖的計算的重視,似乎已經成為了領頭羊(這是根據作者對學術論文提及次數(shù)和社區(qū)支持力度的觀察得出的結論)。而在近十年的轉折點上,PyTorch 以其對用戶友好的 Python 風格接口的強調,似乎已經穩(wěn)坐了霸主之位。但是,近年來,一個新興的競爭者迅速崛起,其受歡迎程度已經到了不容忽視的地步。JAX 以其對提升人工智能模型訓練和推理性能的追求,同時不犧牲用戶體驗,正逐步向頂尖位置發(fā)起挑戰(zhàn)。

本文中,我們將對這個新興框架進行評估,展示其應用,并分享我們對其優(yōu)勢和不足的一些個人見解。雖然我們的焦點將集中在人工智能模型的訓練上,但也應當注意,JAX 在人工智能/機器學習領域乃至更廣的范圍內都有著廣泛的應用。目前,已有多個高級機器學習庫基于 JAX 構建。在本文中,我們將使用 Flax,據本文撰寫時的觀察,它似乎是最受歡迎的選擇。

JAX 幕后花絮 — XLA 編譯

JAX 的強大之處在于它利用了 XLA 編譯技術。JAX 所展現(xiàn)出的卓越運行性能,歸功于 XLA 提供的硬件特定優(yōu)化。而許多與 JAX 緊密相關的功能,比如即時編譯(JIT)和“函數(shù)式編程”范式,實際上都是 XLA 的衍生物。實際上,XLA 編譯并非 JAX 獨有,TensorFlow 和 PyTorch 也都提供了使用 XLA 的選項。不過,與其它流行框架相比,JAX 從設計之初就全面擁抱了 XLA。這使得 JIT 編譯、自動微分、向量化、并行化、分片處理以及其他特性與 XLA 庫的底層設計和實現(xiàn)緊密相連,這些特性都值得我們高度尊重。

XLA JIT 編譯器會對模型的計算圖進行全面分析,將連續(xù)的張量操作合并為單一內核,剔除冗余的圖組件,并生成最適合底層硬件加速器的機器代碼。這不僅減少了每次訓練步驟所需的總體機器級操作數(shù),也降低了主機與加速器之間的通信開銷,減少了內存占用,提高了專用加速器引擎的利用率。

除了運行時性能的優(yōu)化,XLA 的另一個關鍵特性是其可擴展的基礎設施,它允許擴展對更多 AI 加速器的支持。XLA 是 OpenXLA 項目的一部分,由 ML 領域的多個參與者共同開發(fā)。

依賴 XLA 也帶來了一些局限性和潛在問題。特別是,許多 AI 模型,包括那些具有動態(tài)張量形狀的模型,在 XLA 中可能無法達到最佳運行效果。需要特別注意避免圖斷裂和重新編譯的問題。同時,你也應該考慮到這對你的代碼調試可能帶來的影響。

JAX 實際應用

在本節(jié)內容中,我們將展示如何在 JAX 環(huán)境下利用單個 GPU 來訓練一個簡單的人工智能模型,并對它與 PyTorch 的性能進行對比。目前,存在許多提供多種機器學習框架后端支持的高級機器學習開發(fā)平臺,這使我們能夠對 JAX 的性能進行橫向比較。

本節(jié)中,我們將利用 HuggingFace 的 Transformers 庫,該庫為許多常見的基于 Transformer 架構的模型提供了 PyTorch 和 JAX 的實現(xiàn)版本。具體來說,我們將定義一個基于 Vision Transformer(ViT)的圖像分類模型,分別使用 PyTorch 的 ViTForImageClassification 和 JAX 的 FlaxViTForImageClassification 模塊來實現(xiàn)。

下面的代碼示例展示了模型的定義過程。

import torch
import jax, flax, optax
import jax.numpy as jnp

def get_model(use_jax=False):
    from transformers import ViTConfig

    if use_jax:
        from transformers import FlaxViTForImageClassification as ViTModel
    else:
        from transformers import ViTForImageClassification as ViTModel

    vit_config = ViTConfig(
        num_labels = 1000,
        _attn_implementation = 'eager'  # this disables flash attention
    )
    
    return ViTModel(vit_config)

請注意,我們決定不使用 "flash-attention" 功能,因為據我們所知,這項優(yōu)化目前只適用于 PyTorch 模型(至少在本文撰寫時是這樣)。

鑒于本文關注的是運行時性能,我們選擇在一個隨機生成的數(shù)據集上訓練我們的模型。我們利用了 JAX 支持 PyTorch 數(shù)據加載器的特性:

def get_data_loader(batch_size, use_jax=False):
    from torch.utils.data import Dataset, DataLoader, default_collate

    # create dataset of random image and label data
    class FakeDataset(Dataset):
        def __len__(self):
            return 1000000

        def __getitem__(self, index):
            if use_jax: # use nhwc
                rand_image = torch.randn([224, 224, 3], dtype=torch.float32)
            else: # use nchw
                rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
            label = torch.tensor(data=[index % 1000], dtype=torch.int64)
            return rand_image, label

    ds = FakeDataset()
    
    if use_jax:  # convert torch tensors to numpy arrays
        def numpy_collate(batch):
            from jax.tree_util import tree_map
            import jax.numpy as jnp
            return tree_map(jnp.asarray, default_collate(batch))
        collate_fn = numpy_collate
    else:
        collate_fn = default_collate
 
    ds = FakeDataset()
    dl = DataLoader(ds, batch_size=batch_size,
                    collate_fn=collate_fn)
    return dl

接下來,我們定義 PyTorch 和 JAX 訓練循環(huán)。 JAX 訓練循環(huán)依賴于 Flax TrainState 對象,其定義遵循在 Flax 中訓練 ML 模型的基本教程:

@jax.jit
def train_step_jax(train_state, batch):
    with jax.default_matmul_precision('tensorfloat32'):
        def forward(params):
            logits = train_state.apply_fn({'params': params}, batch[0])
            loss = optax.softmax_cross_entropy(
                logits=logits.logits, labels=batch[1]).mean()
            return loss

        grad_fn = jax.grad(forward)
        grads = grad_fn(train_state.params)
        train_state = train_state.apply_gradients(grads=grads)
        return train_state

def train_step_torch(batch, model, optimizer, loss_fn, device):
    inputs = batch[0].to(device=device, non_blocking=True)
    label = batch[1].squeeze(-1).to(device=device, non_blocking=True)
    outputs = model(inputs)
    loss = loss_fn(outputs.logits, label)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

現(xiàn)在讓我們把所有東西放在一起。在下面的腳本中,我們包含了使用 PyTorch 基于圖形的 JIT 編譯選項的控件,使用 torch.compile 和 torch_xla:

def train(batch_size, mode, compile_model):
    print(f"Mode: {mode} \n"
          f"Batch size: {batch_size} \n"
          f"Compile model: {compile_model}")

    # init model and data loader
    use_jax = mode == 'jax'
    use_torch_xla = mode == 'torch_xla'
    model = get_model(use_jax)
    train_loader = get_data_loader(batch_size, use_jax)

    if use_jax:
        # init jax settings
        from flax.training import train_state
        params = model.module.init(jax.random.key(0), 
                                   jnp.ones([1, 224, 224, 3]))['params']
        optimizer = optax.sgd(learning_rate=1e-3)
        state = train_state.TrainState.create(apply_fn=model.module.apply,
                                              params=params, tx=optimizer)
    else:
        if use_torch_xla:
            import torch_xla
            import torch_xla.core.xla_model as xm
            import torch_xla.distributed.parallel_loader as pl
            torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
                use_full_mat_mul_precision=False)
       
            device = xm.xla_device()
            backend = 'openxla'
        
            # wrap data loader
            train_loader = pl.MpDeviceLoader(train_loader, device)
        else:
            device = torch.device('cuda')
            backend = 'inductor'
    
        model = model.to(device)
        if compile_model:
            model = torch.compile(model, backend=backend)
        model.train()
        optimizer = torch.optim.SGD(model.parameters())
        loss_fn = torch.nn.CrossEntropyLoss()

    import time
    t0 = time.perf_counter()
    summ = 0
    count = 0

    for step, data in enumerate(train_loader):
        if use_jax:
            state = train_step_jax(state, data)
        else:
            train_step_torch(data, model, optimizer, loss_fn, device)

        # capture step time
        batch_time = time.perf_counter() - t0
        if step > 10:  # skip first steps
            summ += batch_time
        count += 1
        t0 = time.perf_counter()
        if step > 50:
            break

    print(f'average step time: {summ / count}')


if __name__ == '__main__':
    import argparse
    torch.set_float32_matmul_precision('high')
    
    parser = argparse.ArgumentParser(description='Toy Training Script.')
    parser.add_argument('--batch-size', type=int, default=32,
                        help='input batch size for training (default: 2)')
    parser.add_argument('--mode', choices=['pytorch', 'jax', 'torch_xla'],
                        default='jax',
                        help='choose training mode')
    parser.add_argument('--compile-model', action='store_true', default=False,
                        help='whether to apply torch.compile to the model')
    args = parser.parse_args()

    train(**vars(args))

性能基準測試

在進行基準測試對比分析時,我們務必要非常謹慎和嚴格,仔細審視測試的執(zhí)行方式。這一點在人工智能模型開發(fā)領域尤為重要,因為如果基于不準確的數(shù)據做出決策,可能會導致極其嚴重的后果。在評估訓練模型的運行時性能時,有幾個關鍵因素可能會極大地影響我們的測量結果,例如浮點數(shù)的精度、矩陣乘法的精度、數(shù)據加載方式,以及是否采用了 flash/fused 注意力機制等。舉例來說,如果 PyTorch 默認的矩陣乘法精度是 float32,而 JAX 使用的是 tensorfloat32,那么單純比較它們的性能可能不會給我們帶來太多有價值的信息。這些精度設置可以通過相應的 API 進行調整,例如使用 jax.default_matmul_precision 和 torch.set_float32_matmul_precision。在我們的腳本中,我們已經盡力去識別并排除這些可能的問題,但我們無法保證我們的嘗試一定能夠完全成功。

測試結果

我們在 Google Cloud 的兩臺虛擬機上執(zhí)行了訓練腳本,一臺配置為 g2-standard-16(配備了一塊 NVIDIA L4 GPU),另一臺是 a2-highgpu-1g(配備了一塊 NVIDIA A100 GPU)。無論哪種情況,我們都選用了專為深度學習定制的虛擬機鏡像(common-cu121-v20240514-ubuntu-2204-py310),并預裝了 PyTorch(版本 2.3.0)、PyTorch/XLA(版本 2.3.0)、JAX(版本 0.4.28)、Flax(版本 0.8.4)、Optax(版本 0.2.2)以及 HuggingFace 的 Transformers 庫(版本 4.41.1)。

以下表格匯總了多項實驗的運行時間數(shù)據。需要提醒的是,模型架構和運行環(huán)境的不同可能會導致性能比較結果有顯著差異。同時,代碼中的一些細微調整也可能對這些結果產生顯著影響。

盡管 JAX 在 L4 GPU 上展現(xiàn)出了明顯超越其他選項的性能,但在 A100 GPU 上,它與 PyTorch/XLA 的表現(xiàn)卻旗鼓相當。這種情況并不出人意料,因為它們共享了 XLA 后端。理論上,JAX 生成的任何 XLA(高級線性優(yōu)化)圖都應該能夠被 PyTorch/XLA 同樣實現(xiàn)。在這兩種平臺上,torch.compile 功能的表現(xiàn)都不盡如人意??紤]到我們選擇了全精度浮點數(shù)進行計算,這種情況在一定程度上是可以預見的。

那么為什么要使用 JAX?

  • 性能優(yōu)化

JAX 訓練的一個主要吸引力在于 JIT 編譯可能帶來的運行時性能提升。然而,隨著 PyTorch 新增的 JIT 編譯功能(PyTorch/XLA)以及更進一步的 torch.compile 選項,JAX 的這一優(yōu)勢可能遭到質疑。實際上,考慮到 PyTorch 背后龐大的開發(fā)者社區(qū),以及 PyTorch 所原生支持而 JAX/FLAX 尚未涵蓋的眾多特性(例如自動混合精度、先進的注意力機制層,至少在本文撰寫時),有人可能會強烈主張沒有必要投入時間去掌握 JAX。除了可能的性能提升之外,還有一些其他的動力因素:

  • XLA友好性

與 PyTorch 后來通過 PyTorch/XLA 實現(xiàn)的“函數(shù)化”不同,JAX 從設計之初就內嵌了 XLA 的支持。這表明在 PyTorch/XLA 中可能顯得復雜或混亂的操作,在 JAX 中可以更加簡潔優(yōu)雅地實現(xiàn)。例如,在訓練過程中混合使用 JIT 和非 JIT 函數(shù),在 JAX 中是直接可行的,而在 PyTorch/XLA 中可能需要一些巧妙的技巧。

正如之前提到的,理論上,PyTorch/XLA 和 TensorFlow 都能夠生成與 JAX 相同的 XLA(高級線性優(yōu)化)圖,從而實現(xiàn)同等的性能。然而,在實際操作中,生成的圖的優(yōu)劣取決于框架實現(xiàn)如何轉化為 XLA 代碼。更高效的轉換將帶來更佳的運行時性能。由于 JAX 原生支持 XLA,它可能在與其他框架的競爭中占據優(yōu)勢。

JAX 對 XLA 的友好性使其對專用 AI 加速器的開發(fā)人員尤其有吸引力,例如 Google Cloud TPU、Intel Gaudi 和 AWS Trainium 芯片,這些加速器通常被稱為“XLA 設備”。特別是在 TPU 上進行訓練的團隊可能會發(fā)現(xiàn) JAX 的支持生態(tài)系統(tǒng)比 PyTorch/XLA 更先進。

  • 高級特性

近年來,JAX 中發(fā)布了許多高級功能,遠遠早于同行。例如,SPMD 是一種先進的設備并行技術,提供最先進的模型分片機會,幾年前在 JAX 中引入,最近才被轉移到 PyTorch。另一個例子是 Pallas(終于)能夠為 XLA 設備構建自定義內核。

開源模型

隨著 JAX 框架的日益普及,越來越多的開源 AI 模型正在 JAX 中發(fā)布。一些經典的例子是 Google 的開源 MaxText (LLM) 和 AlphaFold v2(蛋白質結構預測)模型。要充分利用此類模型,您需要學習 JAX,或者承擔將其移植到另一種語言的重要任務。

總結

本文我們深入探討了正在崛起的 JAX 機器學習開發(fā)框架。我們闡述了它依托于 XLA 編譯器,并在一個示例中演示了其應用。雖然 JAX 常因其快速的運行時執(zhí)行速度而備受矚目,但 PyTorch 的 JIT 編譯功能(包括 torch.compile 和 PyTorch/XLA)同樣具備性能優(yōu)化的巨大潛力。每種選擇的性能表現(xiàn),將極大程度上依賴于模型的具體細節(jié)和運行環(huán)境。

值得注意的是,每個機器學習開發(fā)框架都可能擁有其獨到的特性(例如,截至本文撰寫時,JAX 的 SPMD 自動分片和 PyTorch 的 SDPA 注意力機制),這些特性可能在性能比較中起到關鍵作用。因此,選擇最佳框架的決定因素可能是你的模型能夠多大程度上利用這些特性。

本文由mdnice多平臺發(fā)布

?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容