pytorch訓(xùn)練優(yōu)化-自動(dòng)混合精度訓(xùn)練(AMP)

Pytorch 版本:1.6及以上的版本,支持CUDA
GPU版本:支持 Tensor core的 CUDA(Volta、Turing、Ampere),在較早版本的GPU(Kepler、Maxwell、Pascal)提升一般

PyTorch 通常在 32 位浮點(diǎn)數(shù)據(jù) (FP32) 上進(jìn)行訓(xùn)練,如果你創(chuàng)建一個(gè)Tensor, 默認(rèn)類(lèi)型都是 torch.FloatTensor (32-bit floating point)。

NVIDIA 的工程師開(kāi)發(fā)了混合精度訓(xùn)練(AMP),讓少量操作在 FP32 中的訓(xùn)練,而大部分網(wǎng)絡(luò)在 FP16 中運(yùn)行,因此可以節(jié)省時(shí)間和內(nèi)存。

torch.cuda.amp 提供了混合精度的便捷方法,其中某些操作使用 FP32 ,其他操作使用 FP16。神經(jīng)網(wǎng)絡(luò)訓(xùn)練過(guò)程中的運(yùn)算主要可以分三類(lèi):

  • 可以受益于 FP16 速度提升的數(shù)學(xué)函數(shù)。包括矩陣乘法(線性層)和卷積。
  • 對(duì)于 16 位精度可能不夠的函數(shù),輸入應(yīng)采用 FP32。例如減法。
  • 其他操作,可以在 FP16 中運(yùn)行的函數(shù),但在 FP16 中加速并不顯著,因此它們的 FP32 -> FP16 轉(zhuǎn)換不值得。

混合精度訓(xùn)練將每個(gè)操作與其適當(dāng)?shù)臄?shù)據(jù)類(lèi)型相匹配,這可以減少網(wǎng)絡(luò)的運(yùn)行時(shí)間和內(nèi)存占用。

bfloat16 vs. float16
bfloat16 是一種z專(zhuān)門(mén)用于深度學(xué)習(xí)的 16 位浮點(diǎn)格式,由 1 個(gè)符號(hào)位、8 個(gè)指數(shù)位和 7 個(gè)尾數(shù)位組成。而行業(yè)標(biāo)準(zhǔn) IEEE 16 位浮點(diǎn)是1 個(gè)符號(hào)位、5 個(gè)指數(shù)位和 10 個(gè)尾數(shù)位。
實(shí)驗(yàn)表明使用 bfloat16 可以提高訓(xùn)練效率,因?yàn)樯疃葘W(xué)習(xí)模型通常對(duì)指數(shù)變化更加敏感,而16位使用內(nèi)存更少。
bfloat16 的指數(shù)位和 float32 一樣,在訓(xùn)練過(guò)程中不容易出現(xiàn)下溢,也就不容易出現(xiàn) NaN 或者 Inf 之類(lèi)的錯(cuò)誤。
使用 bfloat16: dtype=torch.bfloat16

一、一般的訓(xùn)練流程

通常自動(dòng)混合精度訓(xùn)練會(huì)同時(shí)使用 torch.autocasttorch.cuda.amp.GradScaler。

假設(shè)我們已經(jīng)定義好了一個(gè)模型, 并寫(xiě)好了其他相關(guān)代碼(懶得寫(xiě)出來(lái)了)。

1. torch.autocast
torch.autocast 實(shí)例作為上下文管理器,允許腳本區(qū)域以混合精度運(yùn)行。
在這些區(qū)域中,CUDA 操作將以 autocast 選擇的 dtype 運(yùn)行,以提高性能,同時(shí)保持準(zhǔn)確性。

autocast應(yīng)該只封裝前向和 loss 計(jì)算, 在 backward() 前退出 autocast,反向計(jì)算時(shí)數(shù)據(jù)類(lèi)型和前向的數(shù)據(jù)類(lèi)型一致。

訓(xùn)練部分的代碼:

for epoch in range(epochs): 
    for input, target in zip(data, targets):
        # 在 ``autocast`` 下進(jìn)行前向
        with torch.autocast(device_type=device, dtype=torch.float16):
            output = net(input)
            # output is float16 because linear layers ``autocast`` to float16.
          
            loss = loss_fn(output, target)
            # loss is float32 because ``mse_loss`` layers ``autocast`` to float32.
           
        # 在 backward() 前退出``autocast``
        # 不建議在“autocast”下進(jìn)行反向傳遞
        # Backward 在相應(yīng)前向操作選擇的相同“dtype”“autocast”中運(yùn)行。
        loss.backward()
        opt.step()
        opt.zero_grad() # set_to_none=True here can modestly improve performance

可以 autocast 到 FP16 的 CUDA 操作:
__matmul__, addbmm, addmm, addmv, addr, baddbmm, bmm, chain_matmul, multi_dot, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, GRUCell, linear, LSTMCell, matmul, mm, mv, prelu, RNNCell

autocast 到 FP32 的 CUDA 操作:
__pow__, __rdiv__, __rpow__, __rtruediv__, acos, asin, binary_cross_entropy_with_logits, cosh, cosine_embedding_loss, cdist, cosine_similarity, cross_entropy, cumprod, cumsum, dist, erfinv, exp, expm1, group_norm, hinge_embedding_loss, kl_div, l1_loss, layer_norm, log, log_softmax, log10, log1p, log2, margin_ranking_loss, mse_loss, multilabel_margin_loss, multi_margin_loss, nll_loss, norm, normalize, pdist, poisson_nll_loss, pow, prod, reciprocal, rsqrt, sinh, smooth_l1_loss, soft_margin_loss, softmax, softmin, softplus, sum, renorm, tan, triplet_margin_loss

應(yīng)該優(yōu)先選擇 binary_cross_entropy_with_logits 而不是 binary_cross_entropy,因?yàn)?

torch.nn.function.binary_cross_entropy()(以及包裝它的torch.nn.BCELoss)的向后傳遞可以產(chǎn)生無(wú)法在 FP16 中表示的梯度。在啟用 autocast 的區(qū)域中,前向輸入可能是 FP16,這意味著反向梯度必須可以用 FP16 表示。因此,binary_cross_entropy 和 BCELoss 在啟用 autocast 的區(qū)域中會(huì)引發(fā)錯(cuò)誤。
可以使用 torch.nn.function.binary_cross_entropy_with_logits()torch.nn.BCEWithLogitsLoss 來(lái)代替。

2. GradScaler

梯度縮放(gradient scaling)有助于防止在使用混合精度進(jìn)行訓(xùn)練時(shí),出現(xiàn)梯度下溢,也就是在 FP16 下過(guò)小的梯度值會(huì)變成 0,因此相應(yīng)參數(shù)的更新將丟失。同樣的道理,如果網(wǎng)絡(luò)中有過(guò)小的值,比如防止出現(xiàn)除零而加入的 eps 值如果過(guò)?。ū热?1e-8),也會(huì)導(dǎo)致除零錯(cuò)誤出現(xiàn)。

為了防止下溢,梯度縮放將網(wǎng)絡(luò)的損失乘以比例因子,并對(duì)縮放后的損失調(diào)用向后傳遞。然后通過(guò)網(wǎng)絡(luò)向后流動(dòng)的梯度按相同的因子縮放。換句話說(shuō),梯度值具有較大的幅度,因此它們不會(huì)刷新為零。

每個(gè)參數(shù)的梯度(.grad 屬性)應(yīng)該在優(yōu)化器更新參數(shù)之前取消縮放,因此縮放因子不會(huì)干擾學(xué)習(xí)率。

torch.cuda.amp.GradScaler 可以執(zhí)行梯度縮放步驟。

scaler = torch.cuda.amp.GradScaler()

1+2: Automatic Mixed Precision

use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        # scaler.scale(loss) 縮放梯度,然后進(jìn)行反向計(jì)算
        scaler.scale(loss).backward()
        # scaler.step() 首先取消化器分配參數(shù)梯度的縮放,如果梯度中不包括 ``inf`` ``NaN``,就運(yùn)行 optimizer.step() 
        # 否則會(huì)跳過(guò) optimizer.step() 
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance

檢查 loss scale
訓(xùn)練過(guò)程中檢查 scale,避免掉到0.

scaler = torch.cuda.amp.GradScaler()
current_loss_scale = scaler.get_scale()
if step % log_iter == 0:
   print('scale:', current_loss_scale)

保存和加載
如果 checkpoint 是在沒(méi)有 Amp 的情況下保存的,并且你想要使用 Amp 恢復(fù)訓(xùn)練,直接從checkpoint 加載模型和優(yōu)化器狀態(tài),然后用新創(chuàng)建的 GradScaler。
如果checkpoint是通過(guò)使用 Amp 創(chuàng)建的,并且想要在不使用 Amp 的情況下恢復(fù)訓(xùn)練,可以直接從checkpoint 加載模型和優(yōu)化器狀態(tài),忽略保存的 scaler 。

# 保存
checkpoint = {"model": net.state_dict(),
              "optimizer": opt.state_dict(),
              "scaler": scaler.state_dict()}
# Write checkpoint as desired, e.g.,
# torch.save(checkpoint, "filename")

# 加載
dev = torch.cuda.current_device()
checkpoint = torch.load("filename",
                        map_location = lambda storage, loc: storage.cuda(dev))
net.load_state_dict(checkpoint["model"])
opt.load_state_dict(checkpoint["optimizer"])
scaler.load_state_dict(checkpoint["scaler"])

二、 多個(gè)XX

多個(gè) model,loss, optimizer
如果有多個(gè)損失,則必須分別對(duì)每個(gè)損失調(diào)用 scaler.scale。如果網(wǎng)絡(luò)有多個(gè)優(yōu)化器,可以分別對(duì)其中任何一個(gè)優(yōu)化器調(diào)用scaler.unscale_,并且必須對(duì)每個(gè)優(yōu)化器單獨(dú)調(diào)用 scaler.step。
但是,scaler.update 只能調(diào)用一次.

scaler = torch.cuda.amp.GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer0.zero_grad()
        optimizer1.zero_grad()
        with autocast(device_type='cuda', dtype=torch.float16):
            output0 = model0(input)
            output1 = model1(input)
            loss0 = loss_fn(2 * output0 + 3 * output1, target)
            loss1 = loss_fn(3 * output0 - 5 * output1, target)

        scaler.scale(loss0).backward()
        scaler.scale(loss1).backward()

        # You can choose which optimizers receive explicit unscaling, if you
        # want to inspect or modify the gradients of the params they own.
        scaler.unscale_(optimizer0)

        scaler.step(optimizer0)
        scaler.step(optimizer1)

        scaler.update()

多個(gè) GPU
autocast 狀態(tài)會(huì)在每個(gè)線程上傳播,不管是在單個(gè)進(jìn)程的多線程,還是每個(gè) GPU一個(gè)進(jìn)程。(和原來(lái)的沒(méi)什么區(qū)別)

model = MyModel()
dp_model = nn.DataParallel(model)

# Sets autocast in the main thread
with autocast(device_type='cuda', dtype=torch.float16):
    # dp_model's internal threads will autocast.
    output = dp_model(input)
    # loss_fn also autocast
    loss = loss_fn(output)

多個(gè)GPU一個(gè)進(jìn)程,這里 torch.nn.parallel.DistributedDataParallel可能會(huì)產(chǎn)生一個(gè)側(cè)線程來(lái)在每個(gè)設(shè)備上運(yùn)行前向傳遞,就像 torch.nn.DataParallel 一樣。修復(fù)方法是相同的:將自動(dòng)轉(zhuǎn)換作為模型前向方法的一部分應(yīng)用,以確保它在側(cè)線程中啟用。

MyModel(nn.Module):
    ...
    @autocast()
    def forward(self, input):
       ...

# Alternatively
MyModel(nn.Module):
    ...
    def forward(self, input):
        with autocast():
            ...

三、常見(jiàn)問(wèn)題

  1. 加速有限,可能的原因有:
  • 顯卡不支持
  • GPU飽和
  • FP32 -> FP16 的轉(zhuǎn)換消耗了過(guò)多時(shí)間,應(yīng)該避免多個(gè)小的 CUDA操作
  • 過(guò)多的CPU和GPU的通信
  • matmul 操作的尺寸應(yīng)該是 8 的倍數(shù)
  1. loss 是 inf/NaN
  • 如果網(wǎng)絡(luò)中有較小的數(shù)字,轉(zhuǎn)成 FP16 就會(huì)變成0,導(dǎo)致出現(xiàn)inf/NaN,先去掉 GradScaler 檢查前向過(guò)程中是不是會(huì)有這種問(wèn)題。
  • 如果前向過(guò)程中出現(xiàn) NaN,一般是前向過(guò)程中某些步驟蘊(yùn)含求和求平均的操作導(dǎo)致了上溢,找到這些可能出現(xiàn)上溢的地方,手動(dòng)固定為 FP32 就可以了。
  1. loss scale 掉到了0
    通常也是因?yàn)樯弦纾业缴弦绲膶庸潭ǖ?FP32 就可以了。

  2. 混合精度下 transformer 的位置編碼碰撞問(wèn)題
    目前廣泛采用的位置編碼算法比如 Rope 和 Alibi, 需要為每個(gè)位置生成一個(gè)整型的 position_id,在 float16/bfloat16 下浮點(diǎn)數(shù)精度不足,導(dǎo)致整數(shù)范圍超過(guò) 256 時(shí), bfloat16 無(wú)法準(zhǔn)確表示每個(gè)整數(shù),因此相鄰的若干個(gè) token 會(huì)共享一個(gè)位置編碼。
    解決思路也是保證 position_id 的精度在 FP32上就可以了。
    (在圖像里ViT上下文沒(méi)這么長(zhǎng)的就不用擔(dān)心這個(gè)問(wèn)題)

參考:

  1. Pytorch AMP 教程
  2. https://pytorch.org/docs/stable/notes/amp_examples.html
  3. https://pytorch.org/docs/stable/amp.html#autocast-op-reference
  4. Pytorch中混合精度訓(xùn)練的使用和debug
  5. Llama也中招,混合精度下位置編碼竟有大坑,百川智能給出修復(fù)方案
  6. To Bfloat or not to Bfloat? That is the Question!
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容