本文主要是Pytorch2.0 的小實(shí)驗(yàn),在MacBookPro 上體驗(yàn)一下等優(yōu)化改進(jìn)后的Transformer Self Attention的性能,具體的有 FlashAttention、Memory-Efficient Attention、CausalSelfAttention 等。主要是torch.compile(model) 和 scaled_dot_product_attention的使用。
相關(guān)代碼已上傳GitHub:https://github.com/chensaics/Pytorch2DL

Pytorch2.0版本來了,帶來了很多的新技術(shù)。今天創(chuàng)建了Pytorch2DL倉(cāng)庫(kù),主要是使用Jupyter Notebook 結(jié)合Pytorch2做一些深度學(xué)習(xí)的示例。
Pytorch2.0 技術(shù)亮點(diǎn)

- torch.compile
包裝并返回編譯后的模型
- Accelerated Transformers
我們可以通過調(diào)用新的scaled_dot_product_attention() 函數(shù)直接使用縮放點(diǎn)積注意力 (SPDA)內(nèi)核。以前我們想要加速訓(xùn)練,要使用第三方庫(kù),比如 Flash Attention、xFormers等,現(xiàn)在都被原生支持到框架中了,具體的是在 torch.nn.MultiheadAttention 和 TransformerEncoderLayer 中。
下一節(jié)我們使用上下文管理器顯示調(diào)度不同的內(nèi)核做性能對(duì)比。
- Metal Performance Shaders (MPS后端)
在Mac上也能享受GPU加速的PyTorch訓(xùn)練哦!
在Windows和Linux上使用GPU還是CPU,我們通常加一句:
device = "cuda" if torch.cuda.is_available() else "cpu"
在Mac上:
device = torch.device("mps")
我結(jié)合MPS和scaled_dot_product_attention做一個(gè)示例:

- 其他新技術(shù)
TensorParallel、DTensor、2D parallel、TorchDynamo、AOTAutograd、PrimTorch和TorchInductor
TorchDynamo是借助Python Frame Evaluation Hooks能安全地獲取PyTorch程序;
AOTAutograd重載PyTorch autograd engine,作為一個(gè) tracing autodiff,用于生成超前的backward trace。
PrimTorch簡(jiǎn)化了編寫 PyTorch 功能或后端的流程。將 2000+ PyTorch 算子歸納為約 250 個(gè) primitive operator 閉集 (closed set)。
TorchInductor一個(gè)深度學(xué)習(xí)編譯器,可以為多個(gè)加速器和后端生成 fast code。
性能實(shí)驗(yàn)
目前有三種支持scaled_dot_product_attention的:
- FlashAttention
- Memory-Efficient Attention
- PyTorch C++ 公式實(shí)現(xiàn) (MATH)
他們可以通過這幾個(gè)函數(shù)啟用禁用:
enable_flash_sdp(): 啟用或禁用FlashAttention.
enable_mem_efficient_sdp(): 啟用或禁用 Memory-Efficient Attention.
enable_math_sdp(): 啟用或禁用 PyTorch C++ implementation.
我在Mac上做了一個(gè) scaled_dot_product_attention 結(jié)合 sdp_kernel() 上下文管理器來顯式調(diào)度(指定、啟用/禁用)其中一個(gè)融合內(nèi)核運(yùn)行 的實(shí)驗(yàn):
import torch
import torch.nn as nn
import torch.nn.functional as F
from rich import print
from torch.backends.cuda import sdp_kernel
from enum import IntEnum
import torch.utils.benchmark as benchmark
# Windows和Linux上使用GPU
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Mac 上使用 GPU加速:
# device = torch.device("mps")
device = "mps" if torch.backends.mps.is_built() else "cpu"
# 超參數(shù)定義
batch_size = 64
max_sequence_len = 256
num_heads = 32
embed_dimension = 32
dtype = torch.float16
# 模擬 q k v
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
# 定義一個(gè)計(jì)時(shí)器:
def torch_timer(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# torch.backends.cuda中也實(shí)現(xiàn)了,這里拿出了為了好理解backend_map是啥
class SDPBackend(IntEnum):
r"""
Enum class for the scaled dot product attention backends.
"""
ERROR = -1
MATH = 0
FLASH_ATTENTION = 1
EFFICIENT_ATTENTION = 2
# 使用上下文管理器context manager來
# 其他三種方案,字典映射
backend_map = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True}
}
# 基本版,不指定
print(f"基本對(duì)照方案 運(yùn)行時(shí)間: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# 基本對(duì)照方案 運(yùn)行時(shí)間: 17542.618 microseconds
with sdp_kernel(**backend_map[SDPBackend.MATH]):
print(f"math 運(yùn)行時(shí)間: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# math 運(yùn)行時(shí)間: 18869.076 microseconds
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(f"flash attention 運(yùn)行時(shí)間: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported")
# flash attention 運(yùn)行時(shí)間: 42313.492 microseconds
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(f"Memory efficient 運(yùn)行時(shí)間: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported")
# Memory efficient 運(yùn)行時(shí)間: 42347.333 microseconds
因果自注意力
nanoGPT
中使用了因果自注意力,就是如果我們Pytorch版本>=2.0,torch.nn.functional有 scaled_dot_product_attention 的功能,那么我們就使用它。
接下來,我利用了 scaled_dot_product_attention 和 torch.compile(model) 做一個(gè)性能試驗(yàn)。
這個(gè)是 CausalSelfAttention 模塊的代碼:
class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, dropout:float=0.0):
super().__init__()
assert embed_dimension % num_heads == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
# output projection
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
# regularization
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.num_heads = num_heads
self.embed_dimension = embed_dimension
self.dropout = dropout
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size))
.view(1, 1, block_size, block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (embed_dimension)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k ,v = self.c_attn(x).split(self.embed_dimension, dim=2)
k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
其他部分的代碼:
import torch
import torch.nn as nn
import torch.nn.functional as F
from rich import print
import torch.utils.benchmark as benchmark
import math
# Windows和Linux上使用GPU
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Mac 上使用 GPU加速:
# device = torch.device("mps")
device = "mps" if torch.backends.mps.is_built() else "cpu"
# 設(shè)置超參數(shù):
batch_size = 32
max_sequence_len = 128
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
block_size = 1024
dtype = torch.float16
# 定義計(jì)時(shí)器:
def torch_timer(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# 實(shí)例化我們上面的 CausalSelfAttention 類
model = CausalSelfAttention(num_heads=num_heads,
embed_dimension=embed_dimension,
bias=False,
dropout=0.1).to("mps").to(dtype).eval() # mps / cuda
print(model)
# 模擬數(shù)據(jù)
x = torch.rand(batch_size,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype)
print(f"原始model 運(yùn)行時(shí)間: {torch_timer(model, x):.3f} microseconds")
# 原始model 運(yùn)行時(shí)間: 9169.492 microseconds
# 編譯模型
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.verbose=True
compiled_model = torch.compile(model)
compiled_model(x)
print(f"compiled model 運(yùn)行時(shí)間: {torch_timer(compiled_model, x):.3f} microseconds")
# compiled model 運(yùn)行時(shí)間: 6786.322 microseconds
CausalSelfAttention 結(jié)構(gòu)參數(shù):

從打印的結(jié)果可以看出,torch.compile(model)加速了很多,提高了25%呢!
本次的分享就到這里了,Pytorch 2.x版本的新性能還是讓人很興奮的!能提升大模型訓(xùn)練和推理速度、占用更少算力資源!