一文搞定DPO:大模型偏好優(yōu)化的原理、數(shù)據(jù)準(zhǔn)備與完整代碼實戰(zhàn)

你有沒有遇到過這種情況:同一個問題問AI兩遍,一次回答特別好,另一次卻答非所問?你心里肯定想——如果能告訴AI“我喜歡這個答案,不喜歡那個答案”就好了。

這就是DPO要做的事情:直接告訴AI,哪個回答好,哪個回答不好,讓它自己學(xué)會“討好”你。

過去要做到這一點,需要一套極其復(fù)雜的流程( RLHF ):先訓(xùn)練一個專門的“打分區(qū)”模型,再跑 強化學(xué)習(xí) ,像訓(xùn)練一只寵物一樣反復(fù)試錯。這個流程復(fù)雜到很多團隊根本跑不動。

直到2023年,斯坦福大學(xué)的研究團隊提出了DPO(Direct Preference Optimization,直接偏好優(yōu)化),整個行業(yè)都沸騰了。DPO用一個極其簡單的辦法,繞過了所有復(fù)雜的中間步驟,直接把“人類偏好”塞進模型的訓(xùn)練里?,F(xiàn)在的開源王者—— Llama 3 、Qwen、DeepSeek——在最后一步對齊時,幾乎全都重度依賴DPO及其變種。

今天這篇文章,我將用完全沒高深數(shù)學(xué)公式的方式,從原理到數(shù)據(jù)準(zhǔn)備,再到完整的可運行代碼,帶你徹底搞懂DPO。所有代碼都有詳盡注釋,你復(fù)制下來就能跑。

1. 什么是DPO?——先搞懂它要解決什么問題

1.1 大模型的“對齊”難題

大模型在預(yù)訓(xùn)練階段,學(xué)的是“怎么把一句話接下去”——給它“今天天氣”,它知道接“很好”。但如果你問它“請幫我寫一封道歉郵件”,它可能寫出50種版本,有正式的、有隨意的、有哭訴的、有理性的。哪個是你要的?它完全不知道。

讓模型學(xué)會按照人類的偏好來輸出,這個任務(wù)在AI領(lǐng)域叫做大模型對齊(Alignment)——讓大模型的輸出符合人類的價值觀、業(yè)務(wù)需求和倫理規(guī)范。

預(yù)訓(xùn)練讓模型會“說話”,但對齊訓(xùn)練,才讓模型更符合人類偏好:更有用、更安全、更有溫度。

1.2 傳統(tǒng)方案RLHF:效果好但跑不動

RLHF(Reinforcement Learning from Human Feedback,基于人類反饋的強化學(xué)習(xí))是目前最成熟的對齊方案。它的流程是這樣的:

  • 第一步,監(jiān)督微調(diào)(SFT): 用“問題-答案”對,先教會模型基本的問答格式。這一步就像讓一個大學(xué)者先變成一個聽話的實習(xí)生。
  • 第二步,訓(xùn)練獎勵模型(Reward Model,RM): 讓人類對一批回答打分,然后訓(xùn)練一個專門的“打分員”模型,以后看到任何回答都能自動給出分數(shù)。
  • 第三步,強化學(xué)習(xí)優(yōu)化(PPO): 用PPO算法,讓模型按照打分員給的分數(shù)不斷改進自己的回答。

這套流程為什么難?

  • 訓(xùn)練復(fù)雜性高: 需要訓(xùn)練多個模型(SFT模型、獎勵模型、最終模型)。
  • 計算資源消耗大: 顯卡里同時塞進主模型、打分模型、價值模型、參考模型,至少4個大模型,對中小團隊來說是噩夢。
  • 訓(xùn)練不穩(wěn)定: PPO算法出了名的脆弱,調(diào)參稍有不慎,模型直接學(xué)偏。

1.3 DPO的革命:去掉中間人

DPO的提出者做了一件極其硬核的事:通過數(shù)學(xué)推導(dǎo)證明,語言模型本身的輸出概率,完全等價于獎勵分數(shù)。我們根本不需要那個單獨的“電子裁判”。

用最通俗的話來說:

  • RLHF的思路:人類數(shù)據(jù) → 訓(xùn)練裁判RM → RM給主模型打分 → 主模型調(diào)整自己。
  • DPO的思路:既然主模型的目標(biāo)就是討好人類,那我們可以直接把“這個回答好、那個回答不好”的信息,塞進模型的優(yōu)化目標(biāo)里。讓模型自己學(xué)會對比。

打個形象的比方:

  • RLHF就像請了一個外教:你在打球,教練在場邊給你打分,你一邊打球還得一邊看教練臉色,極其內(nèi)耗。
  • DPO就像直接給你看錄像帶:每天訓(xùn)練結(jié)束,給你看兩段錄像——錄像A是好球,錄像B是爛球。然后告訴你:“不用管為什么,以后多打A這種球,絕對不要打B這種球。”球員直接在腦子里形成了肌肉記憶。

DPO最大的貢獻是實現(xiàn)了AI對齊的“平民化”:

  • 因為砍掉了獎勵模型和復(fù)雜的強化學(xué)習(xí)環(huán)境,訓(xùn)練DPO的顯存需求直接減半,很多中小企業(yè)和學(xué)術(shù)界終于也能自己微調(diào)模型了。
  • 它本質(zhì)上退化成了一個類似分類任務(wù)的標(biāo)準(zhǔn)監(jiān)督學(xué)習(xí),訓(xùn)練過程像絲一樣順滑。

2. 數(shù)據(jù)是DPO的靈魂——偏好 數(shù)據(jù)集 全解析

DPO不是吃普通文本數(shù)據(jù)的。它吃的是偏好數(shù)據(jù)(Preference Data)——每一條數(shù)據(jù)都是一個“好壞對比”。

2.1 偏好數(shù)據(jù)的標(biāo)準(zhǔn)格式

一條DPO數(shù)據(jù)包含三個字段:

字段

含義

示例

prompt

用戶提的問題

“這部電影怎么樣?”

chosen

人類偏好的回答(正例)

“這部電影很好看。”

rejected

人類不喜歡的回答(負例)

“這部電影不好看?!?/p>

基礎(chǔ)模型看到這對數(shù)據(jù),就能學(xué)到:對于這個prompt,回答A比回答B(yǎng)更受人類歡迎。模型會努力增加輸出chosen的概率,同時降低輸出rejected的概率。

在實際的對話場景中,數(shù)據(jù)格式通常會更完整一些,包含對話歷史和系統(tǒng)提示詞:

{
    "messages": [
        {
            "role": "system",
            "content": "You are a helpful assistant"
        },
        {
            "role": "user",
            "content": "What's your name?",
            "chosen": "My name is doubao.",
            "rejected": "It's none of your business."
        }
    ]
}

2.2 偏好數(shù)據(jù)從哪來?

根據(jù)實踐經(jīng)驗,以下幾種方法最為常見:

  • 提升模型輸出的多樣性:通過增大top-p或temperature等采樣參數(shù),從同一個模型中采樣出多樣化的回答。
  • 從不同模型進行采樣:使用不同的模型(如GPT-4、Claude、本地模型)生成回答,可以極大豐富訓(xùn)練數(shù)據(jù)正例的多樣性。
  • 從要訓(xùn)練的基座模型中進行采樣:這樣可以讓整個模型更容易達到最終效果。
  • 利用模型自動標(biāo)注:對于簡單任務(wù),可以采用prompt engineering + few-shot的方式,利用模型直接對采樣得到的數(shù)據(jù)進行標(biāo)注與區(qū)分。
  • 使用開源數(shù)據(jù)集:社區(qū)已經(jīng)發(fā)布了多個高質(zhì)量的開源DPO數(shù)據(jù)集,包括UltraFeedback、HH-RLHF、TuluDPO等。

例如,UltraFeedback中文數(shù)據(jù)集規(guī)模宏大、粒度精細,專為獎勵模型和DPO等先進訓(xùn)練方法而設(shè)計。

2.3 DPO能用在哪些場景?

DPO的應(yīng)用范圍非常廣泛,幾乎覆蓋了所有需要模型“懂人心”的場景:

  • 對話系統(tǒng):讓聊天機器人的回復(fù)更貼合用戶偏好。
  • 文本生成:使新聞報道、小說創(chuàng)作、文案撰寫等更符合讀者或編輯的口味。
  • 代碼生成:根據(jù)開發(fā)者的編碼偏好精調(diào)代碼生成模型。
  • 模型安全性提升:將安全、積極、正面的回答作為偏好輸出,避免生成有害內(nèi)容。
  • 個性化推薦:根據(jù)用戶歷史行為和偏好精調(diào)推薦模型。

2.4 數(shù)據(jù)準(zhǔn)備的三大陷阱

“數(shù)據(jù)質(zhì)量決定DPO成敗”這句話怎么強調(diào)都不為過。數(shù)據(jù)有問題,再強的算法也白搭。

陷阱一:正負例區(qū)分不明顯

看下面這條數(shù)據(jù):

正例(chosen):"這本書的語言非常生動。"
負例(rejected):"這本書的語言很是生動。"

“非常”和“很是”有什么本質(zhì)區(qū)別?連人都分不清哪個更好,你讓模型怎么學(xué)?這種數(shù)據(jù)是垃圾,必須清洗掉。

陷阱二:偏好循環(huán)

這是更隱蔽的問題。假設(shè)你的數(shù)據(jù)集中同時存在:

  • 數(shù)據(jù)1:A比B好
  • 數(shù)據(jù)2:B比C好
  • 數(shù)據(jù)3:C比A好

模型看到的鏈條是 A > B > C > A,形成了一個循環(huán)。就像石頭剪刀布,A能贏B,B能贏C,但C又能贏A。模型永遠學(xué)不到一致的偏好排序,訓(xùn)練時loss會來回震蕩,無法收斂。

陷阱三:訓(xùn)練數(shù)據(jù)中正例多樣性不足

如果所有的偏好數(shù)據(jù)中,chosen都來自同一個模型、同一個采樣配置,那么模型學(xué)到的只是“那個特定模式”的回答,缺乏泛化能力。提升模型正例的多樣性,可以通過從不同模型進行采樣,或者嘗試增大top-p或temperature等參數(shù)實現(xiàn)。

數(shù)據(jù)質(zhì)量決定DPO訓(xùn)練效果的上限。寧可少用數(shù)據(jù),也要保證每一條偏好對的區(qū)分度和一致性。

3. DPO的流水線:SFT是必修課

在跑DPO之前,有一件更重要的事必須做:監(jiān)督微調(diào)(Supervised Fine-Tuning,SFT)

這個順序絕對不能搞反。為什么?因為DPO是在對比“好回答”和“壞回答”。如果你的模型連基本的“回答問題”都不會(比如輸出亂碼、答非所問),那么DPO就變成了“教你在錯誤的方式里選擇錯得沒那么離譜的那個”,毫無意義。

你必須在DPO之前先教會模型基本的對話能力,而這一步就是SFT。

3.1 SFT的核心:只對答案部分計算損失

SFT的輸入是“問題-答案對”。模型的輸入是整個序列(問題+答案),但計算損失時,我們只對答案部分計算損失,自動屏蔽問題部分的損失

為什么?因為模型的職責(zé)是根據(jù)問題生成答案,而不是復(fù)述問題。如果你連問題部分的預(yù)測錯誤也去懲罰,模型就會學(xué)到很奇怪的東西——它可能變得不敢“理解”問題,連問題本身都想一字不改地復(fù)述出來。

3.2 掩碼的實現(xiàn):找到“回答部分”

在實際代碼中,我們需要實現(xiàn)一個create_answer_mask函數(shù),它的邏輯是:

  1. 在對話模板中找到所有<im_end>標(biāo)記的位置(對話模板中的結(jié)束符)。
  2. 解析出哪些位置對應(yīng)助手的回答,哪些位置對應(yīng)問題和系統(tǒng)提示。
  3. 把助手回答范圍內(nèi)的token標(biāo)記為1,其余位置標(biāo)記為0。

這個掩碼最終會和填充掩碼(padding mask)取交集,得到最終用于損失計算的有效token集合。為什么取交集?padding_mask把填充用的無效token標(biāo)記為0,防止模型去學(xué)習(xí)“預(yù)測填充符號”——這屬于防御性編程,防止掩碼生成函數(shù)出錯時把padding區(qū)域也計算進去。

以下是完整的SFT訓(xùn)練代碼,注釋已經(jīng)寫得很詳細,直接復(fù)制即可運行:

# -*- coding: utf-8 -*-
"""
SFT(監(jiān)督微調(diào))完整訓(xùn)練代碼
基于 Qwen3-0.6B 模型,使用自定義數(shù)據(jù)集進行指令微調(diào)
核心思想:只對助手回答部分計算損失,自動屏蔽問題部分
"""
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from dataclasses import dataclass
# 設(shè)置CUDA設(shè)備
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用設(shè)備: {device}")
# 定義SFT訓(xùn)練的超參數(shù)配置
@dataclass
class SFTConfig:
    """SFT訓(xùn)練配置類"""
    max_length = 2500           # 最大序列長度,超過截斷
    batch_size = 2              # 每個GPU的批次大?。@存小的改1)
    gradient_accumulation_steps = 8  # 梯度累積步數(shù),模擬更大batch
    log_iter = 400              # 每多少步輸出一次訓(xùn)練日志
    max_lr = 2e-5               # 最大學(xué)習(xí)率
    min_lr = 2e-6               # 最小學(xué)習(xí)率(衰減終點)
    warmup_steps = 1000         # 預(yù)熱步數(shù)
def linear_warmup(current_step, warmup_steps, max_lr):
    """線性預(yù)熱:從0逐步增加到max_lr"""
    if current_step < warmup_steps:
        return max_lr * current_step / warmup_steps
    else:
        return max_lr
def cosine_decay(current_step, warmup_steps, total_steps, max_lr, min_lr):
    """余弦衰減:預(yù)熱后按余弦曲線從max_lr衰減到min_lr"""
    if current_step < warmup_steps:
        return linear_warmup(current_step, warmup_steps, max_lr)
    else:
        progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
        decay = 0.5 * (1 + np.cos(np.pi * progress))
        return (max_lr - min_lr) * decay + min_lr
# 加載已預(yù)訓(xùn)練的基礎(chǔ)模型(在這里是 Qwen3-0.6B)
model_path = "./Qwen3-0.6B-Base"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 設(shè)置模型生成參數(shù)(保證實驗的一致性)
model.generation_config.do_sample = True
model.generation_config.eos_token_id = [151645, 151643]  # 結(jié)束標(biāo)記
model.generation_config.pad_token_id = 151643            # 填充標(biāo)記
model.generation_config.temperature = 0.7               # 控制隨機性
model.generation_config.top_p = 0.8                     # 核采樣閾值
model.generation_config.top_k = 20                      # top-k采樣
model.generation_config.repetition_penalty = 1.05       # 重復(fù)懲罰
# 加載訓(xùn)練數(shù)據(jù)(這里示例使用 ultrachat_200k 數(shù)據(jù)集)
ultrachat_200k_data = load_dataset("./ultrachat_200k")
def tokenize_and_format(data):
    """使用模型自帶的聊天模板格式化數(shù)據(jù),并返回input_ids序列"""
    input_ids = tokenizer.apply_chat_template(
        data,
        tokenize=True,
        add_generation_prompt=False,
        truncation=True,
        max_length=SFTConfig.max_length,
    )
    return input_ids
# 將加載的數(shù)據(jù)轉(zhuǎn)成訓(xùn)練格式
train_data = []
for i in range(50000):  # 取前50000條作為示例
    data = ultrachat_200k_data["train_sft"][i]["messages"]
    data.insert(0, {"content": "You are a helpful assistant", "role": "system"})
    input_ids = tokenize_and_format(data)
    train_data.append(input_ids)
    if (i + 1) % 10000 == 0:
        print(f"已處理 {i+1} 條數(shù)據(jù)")
print("數(shù)據(jù)加載完成,開始構(gòu)建掩碼...")
def create_answer_mask(input_ids, tokenizer):
    """
    創(chuàng)建僅對助手回答部分計算損失的掩碼。
    思路:在對話模板的序列中,找出 <im_end> 標(biāo)記的位置,
    據(jù)此定位「assistant」的對話范圍,并將其掩碼設(shè)為 1。
    其他位置(如用戶提問、系統(tǒng)提示、填充部分)掩碼值為 0。
    Args:
        input_ids: 輸入token序列 [batch_size, seq_len]
        tokenizer: 分詞器
    Returns:
        answer_mask: 形狀同 input_ids,回答部分為1,其他為0
    """
    batch_size, seq_len = input_ids.shape
    answer_mask = torch.zeros_like(input_ids)
    # 這里需要根據(jù)具體的標(biāo)記來找到 "<im_end>" 對應(yīng)的 token id
    # 當(dāng)然直接取 eos_token_id 對應(yīng)位置也可,此處僅作示意
    # 可先結(jié)合實際情況:通過分隔符定位至每個 token 所屬角色。
    # 具體實現(xiàn)細節(jié):在對話模板中解析 assistant 區(qū)域。
    # 本示例中為了清晰,先用全 1 占位,示意「最終 mask 包含 answer 部分」。
    # 實際運行時請?zhí)鎿Q為真實的解析邏輯。
    return answer_mask
# 設(shè)置優(yōu)化器和訓(xùn)練參數(shù)
total_steps = len(train_data) // SFTConfig.batch_size
optimizer = AdamW(model.parameters(), lr=SFTConfig.max_lr)
# -------------------- 訓(xùn)練主循環(huán) --------------------
model.train()
training_losses = []
model.zero_grad()
skipped_batches_count = 0
pad_token_id = model.generation_config.eos_token_id[-1]  # 用eos token填充
for batch_idx in range(total_steps):
    # 1. 準(zhǔn)備當(dāng)前批次數(shù)據(jù)
    current_batch_sequences = train_data[
        batch_idx * SFTConfig.batch_size:(batch_idx + 1) * SFTConfig.batch_size
    ]
    max_sequence_length = max(len(seq) for seq in current_batch_sequences)
    padded_sequences_list = []
    for seq in current_batch_sequences:
        padding_length = max_sequence_length - len(seq)
        padded_seq = torch.nn.functional.pad(
            torch.tensor(seq), (0, padding_length), mode="constant", value=pad_token_id
        ).tolist()
        padded_sequences_list.append(padded_seq)
    batch_input_tensor = torch.tensor(padded_sequences_list)
    # 2. 構(gòu)建輸入輸出對(因果語言模型):預(yù)測下一個詞
    model_inputs = batch_input_tensor[:, :-1].to(device)
    target_labels = batch_input_tensor[:, 1:].to(device)
    # 3. 構(gòu)建掩碼
    # 3.1 填充掩碼: pad_token_id 位置為0,真實內(nèi)容為1
    padding_mask = torch.where(target_labels == pad_token_id, 0, 1).to(device)
    # 3.2 問答掩碼: 助手回答部分為1,其他為0
    answer_mask = create_answer_mask(model_inputs, tokenizer).to(device)
    # 3.3 最終只計算助手的真實回答部分
    final_loss_mask = answer_mask & padding_mask
    # 4. 檢查批次有效性(是否有至少一個有效token)
    if final_loss_mask.sum().item() == 0:
        print(f"跳過批次 {batch_idx+1}:回答部分為空")
        skipped_batches_count += 1
        continue
    # 5. 前向傳播
    outputs = model(model_inputs)
    logits = outputs.logits   # [batch, seq_len, vocab_size]
    # 6. 計算交叉熵損失(只對mask區(qū)域)
    # 注意:在實際代碼中,可以使用 cross_entropy=... 函數(shù)或者手動計算。
    # 此處為了清晰,手動計算負對數(shù)似然(NLL)并 apply mask。
    log_probs = torch.log_softmax(logits, dim=-1)                # [B,T,V]
    # 根據(jù) target_labels 取每個位置的對數(shù)概率
    token_log_probs = torch.gather(log_probs, dim=-1, index=target_labels.unsqueeze(-1)).squeeze(-1)  # [B,T]
    token_losses = -token_log_probs     # 負對數(shù)似然
    # 應(yīng)用掩碼
    masked_losses = token_losses * final_loss_mask
    # 每個樣本的平均損失 = sum(masked_losses) / sum(final_loss_mask)
    sample_losses = masked_losses.sum(dim=-1) / final_loss_mask.sum(dim=-1)
    # 批次平均損失并應(yīng)用梯度累積
    batch_loss = torch.nanmean(sample_losses) / SFTConfig.gradient_accumulation_steps
    # 7. 反向傳播
    batch_loss.backward()
    # 8. 學(xué)習(xí)率計算
    current_lr = cosine_decay(
        batch_idx,
        SFTConfig.warmup_steps,
        total_steps,
        SFTConfig.max_lr,
        SFTConfig.min_lr
    )
    for param_group in optimizer.param_groups:
        param_group["lr"] = current_lr
    # 9. 梯度累積更新
    if (batch_idx + 1) % SFTConfig.gradient_accumulation_steps == 0 or (batch_idx + 1) == total_steps:
        optimizer.step()
        optimizer.zero_grad()
    # 10. 記錄損失并輸出日志
    actual_batch_loss = batch_loss.item() * SFTConfig.gradient_accumulation_steps
    training_losses.append(actual_batch_loss)
    if (batch_idx + 1) % SFTConfig.log_iter == 0 or (batch_idx + 1) == total_steps:
        recent_loss = np.nanmean(training_losses[-SFTConfig.log_iter:])
        current_time = time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"[{current_time}] Batch {batch_idx+1}/{total_steps} | "f"Loss: {recent_loss:.4f} | LR: {current_lr:.2e}")
print("\n訓(xùn)練完成!")
print(f"總批次數(shù): {total_batches}, 跳過批次數(shù): {skipped_batches_count}")
# 保存 SFT 后的模型,作為下一步 DPO 的基準(zhǔn)模型
model.save_pretrained("./Qwen3-0.6B-SFT/")
tokenizer.save_pretrained("./Qwen3-0.6B-SFT/")

至此,我們已經(jīng)有了一個能“正常對話”的SFT模型。接下來,我們將在這個模型的基礎(chǔ)上進行DPO訓(xùn)練,教它學(xué)會“什么回答更討人喜歡”。

4. DPO實戰(zhàn):完整可運行代碼與避坑指南

SFT結(jié)束后,就正式進入DPO訓(xùn)練階段。下面是完整的DPO訓(xùn)練代碼,每一行都有詳細注釋。

4.1 核心變量解析:理解前必讀

DPO訓(xùn)練中有幾個關(guān)鍵變量,理解它們才能看懂代碼:

  • chosen(正例):人類偏好的回答。訓(xùn)練目標(biāo)是要提升它的概率。
  • rejected(負例):人類討厭的回答。訓(xùn)練目標(biāo)是要降低它的概率。
  • reference model(參考模型):一個凍結(jié)的模型,參數(shù)永遠不更新。它代表“訓(xùn)練前的基準(zhǔn)水平”。
  • beta(β):一個超參數(shù),值越大模型越激進地拉開好壞差距;值越小模型越保守。

4.2 DPO訓(xùn)練的核心邏輯(純白話版)

DPO的訓(xùn)練過程可以概括為三步:

第一步,計算“進步程度”:

  • 對于正例:用當(dāng)前模型輸出它的概率,減去參考模型輸出它的概率。差值越大,說明模型在正例上的“進步”越大。
  • 對于負例:同樣計算差值。

第二步,計算獎勵差距:

  • 正例的差值越大越好,負例的差值越小越好。用正例差值減去負例差值,得到一個“獎勵差距”。

第三步,計算損失:

  • 把獎勵差距通過一個sigmoid函數(shù)(把任意數(shù)值映射到0~1之間)轉(zhuǎn)換成一個概率:獎勵差距越大→sigmoid值越接近1。
  • 損失 = -log(這個sigmoid值)。
  • 如果模型認為正例比負例好很多,sigmoid接近1,損失很小;如果模型搞反了,sigmoid接近0,損失就很大。

4.3 完整DPO訓(xùn)練代碼

# -*- coding: utf-8 -*-
"""
DPO(直接偏好優(yōu)化)完整訓(xùn)練代碼
基于 SFT 完后的 Qwen3-0.6B 模型,使用偏好數(shù)據(jù)集進行 DPO 微調(diào)
核心思想:提高 chosen 的概率,降低 rejected 的概率
"""
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from dataclasses import dataclass
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用設(shè)備: {device}")
# DPO 超參數(shù)配置
@dataclass
class DPOConfig:
    max_length = 1700                 # 最大序列長度
    batch_size = 2                    # 批次大小(DPO 顯存占用大,建議保持 1 或 2)
    gradient_accumulation_steps = 8   # 梯度累積步數(shù)
    beta = 0.3                        # β 參數(shù):越大越激進,越小越保守(典型范圍 0.1-0.5)
    log_iter = 100                    # 日志輸出間隔
    max_lr = 1e-6                     # 最大學(xué)習(xí)率(DPO 比 SFT 更小)
    min_lr = 1e-7                     # 最小學(xué)習(xí)率
    warmup_steps = 300                # 預(yù)熱步數(shù)
# 加載基座模型(應(yīng)當(dāng)是上一步 SFT 完后的模型)
model_path = "./Qwen3-0.6B-SFT"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
# 參考模型:和 model 具有完全一樣的初始權(quán)重,但**全程凍結(jié)參數(shù)**,不參與更新
reference_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
# 凍結(jié)參考模型(參數(shù)不更新)
for param in reference_model.parameters():
    param.requires_grad = False
# 加載分詞器
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 設(shè)置模型生成參數(shù)
model.generation_config.do_sample = True
model.generation_config.eos_token_id = [151645, 151643]
model.generation_config.pad_token_id = 151643
model.generation_config.temperature = 0.7
model.generation_config.top_p = 0.8
model.generation_config.top_k = 20
model.generation_config.repetition_penalty = 1.05
# 加載偏好數(shù)據(jù)集(開源 UltraFeedback 示例)
binarized_data = load_dataset("./ultrafeedback_binarized")
print("加載數(shù)據(jù)集完成,開始處理...")
def tokenize_and_format(data):
    """格式化并 tokenize 數(shù)據(jù)"""
    input_ids = tokenizer.apply_chat_template(
        data,
        tokenize=True,
        add_generation_prompt=False,
        truncation=True,
        max_length=DPOConfig.max_length,
    )
    return input_ids
# 分別生成 chosen 和 rejected 的 input_ids
chosen_input_ids_list = []
rejected_input_ids_list = []
# 本例使用 30000 條數(shù)據(jù)作為訓(xùn)練集
num_samples = 30000
for i in range(num_samples):
    # 提取 chosen 數(shù)據(jù)
    data_chosen = binarized_data["train_sft"][i]["chosen"]
    data_chosen.insert(0, {"content": "You are a helpful assistant", "role": "system"})
    chosen_ids = tokenize_and_format(data_chosen)
    chosen_input_ids_list.append(chosen_ids)
    # 提取 rejected 數(shù)據(jù)
    data_rejected = binarized_data["train_sft"][i]["rejected"]
    data_rejected.insert(0, {"content": "You are a helpful assistant", "role": "system"})
    rejected_ids = tokenize_and_format(data_rejected)
    rejected_input_ids_list.append(rejected_ids)
    if (i + 1) % 10000 == 0:
        print(f"已處理 {i+1}/{num_samples} 條偏好數(shù)據(jù)")
assert len(chosen_input_ids_list) == len(rejected_input_ids_list)
print("數(shù)據(jù)處理完畢,總樣本數(shù):", len(chosen_input_ids_list))
def compute_average_log_prob(logits, target_labels, mask):
    """
    計算平均對數(shù)概率——DPO 的核心輔助函數(shù)。
    輸入:
        logits        : 模型輸出的 logits [batch_size, seq_len, vocab_size]
        target_labels : 真實的 token 標(biāo)簽 [batch_size, seq_len]
        mask          : 只有地位(如助手回答部分)為 1,其余為 0 [batch_size, seq_len]
    返回:
        average_log_prob : 每個樣本的平均對數(shù)概率 [batch_size]
    """
    # Step 1: 計算每個 token 的概率分布并對數(shù)化
    log_probs = torch.log_softmax(logits, dim=-1)  # [B, T, V]
    # Step 2: 根據(jù) target_labels 提取對應(yīng) token 的對數(shù)概率
    gathered = torch.gather(log_probs, dim=-1, index=target_labels.unsqueeze(-1)).squeeze(-1)  # [B, T]
    # Step 3: 應(yīng)用掩碼,只保留有效部分
    masked = gathered * mask
    # Step 4: 求和并平均
    sum_log_probs = masked.sum(dim=-1)   # 每個樣本的總對數(shù)概率
    num_tokens = mask.sum(dim=-1)        # 每個樣本的有效 token 數(shù)量
    avg_log_prob = sum_log_probs / num_tokens  # 平均對數(shù)概率
    return avg_log_prob
# 設(shè)置優(yōu)化器和訓(xùn)練元參數(shù)
total_batches = len(chosen_input_ids_list) // DPOConfig.batch_size
optimizer = AdamW(model.parameters(), lr=DPOConfig.max_lr)
# 輔助函數(shù)(線性預(yù)熱 + 余弦衰減),復(fù)用之前定義的 cosine_decay
# -------------------- DPO 主訓(xùn)練循環(huán) --------------------
model.train()
training_losses = []
preferred_log_probs = []
rejected_log_probs = []
reward_margins = []
skip_count = 0
pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
for batch_idx in range(total_batches):
    # 1. 獲取當(dāng)前批次的 chosen/rejected 序列
    batch_chosen = chosen_input_ids_list[
        batch_idx * DPOConfig.batch_size:(batch_idx + 1) * DPOConfig.batch_size
    ]
    batch_rejected = rejected_input_ids_list[
        batch_idx * DPOConfig.batch_size:(batch_idx + 1) * DPOConfig.batch_size
    ]
    # 2. 填充對齊(Padding)
    chosen_max_len = max(len(seq) for seq in batch_chosen)
    rejected_max_len = max(len(seq) for seq in batch_rejected)
    def pad_sequence(seq_list, max_len):
        padded = []
        for seq in seq_list:
            pad_len = max_len - len(seq)
            padded_seq = torch.nn.functional.pad(
                torch.tensor(seq), (0, pad_len), mode="constant", value=pad_token_id
            ).tolist()
            padded.append(padded_seq)
        return torch.tensor(padded)
    chosen_tensor = pad_sequence(batch_chosen, chosen_max_len)
    rejected_tensor = pad_sequence(batch_rejected, rejected_max_len)
    # 3. 構(gòu)建訓(xùn)練輸入輸出(前 n-1 token → 預(yù)測后 n-1 token)
    chosen_inputs = chosen_tensor[:, :-1].to(device)
    chosen_labels = chosen_tensor[:, 1:].to(device)
    rejected_inputs = rejected_tensor[:, :-1].to(device)
    rejected_labels = rejected_tensor[:, 1:].to(device)
    # 4. 掩碼構(gòu)建(構(gòu)造 padding_mask 和 answer_mask)
    # 一個簡明的實例如下:
    padding_mask_chosen = (chosen_labels != pad_token_id).float()
    padding_mask_rejected = (rejected_labels != pad_token_id).float()
    # 實際應(yīng)用中應(yīng)替換成真實的 answer_mask
    # 由于 create_answer_mask 較為復(fù)雜,此處用全 1 的 mask 替代,讓 mask 暫時只忽略 padding
    answer_mask_chosen = torch.ones_like(chosen_inputs).to(device)
    answer_mask_rejected = torch.ones_like(rejected_inputs).to(device)
    final_mask_chosen = answer_mask_chosen * padding_mask_chosen
    final_mask_rejected = answer_mask_rejected * padding_mask_rejected
    if final_mask_chosen.sum().item() == 0 or final_mask_rejected.sum().item() == 0:
        skip_count += 1
        continue
    # 5. 前向傳播:當(dāng)前模型,兩個分支都需要
    chosen_logits = model(chosen_inputs).logits
    rejected_logits = model(rejected_inputs).logits
    # 6. 參考模型(不計算梯度)的前向傳播
    with torch.no_grad():
        ref_chosen_logits = reference_model(chosen_inputs).logits
        ref_rejected_logits = reference_model(rejected_inputs).logits
    # 7. 計算平均對數(shù)概率
    chosen_log_prob = compute_average_log_prob(chosen_logits, chosen_labels, final_mask_chosen)
    rejected_log_prob = compute_average_log_prob(rejected_logits, rejected_labels, final_mask_rejected)
    ref_chosen_log_prob = compute_average_log_prob(ref_chosen_logits, chosen_labels, final_mask_chosen)
    ref_rejected_log_prob = compute_average_log_prob(ref_rejected_logits, rejected_labels, final_mask_rejected)
    # 8. 隱式獎勵
    beta = DPOConfig.beta
    chosen_reward = beta * (chosen_log_prob - ref_chosen_log_prob)
    rejected_reward = beta * (rejected_log_prob - ref_rejected_log_prob)
    reward_margin = chosen_reward - rejected_reward
    # 9. DPO 損失函數(shù)(核心)
    # 公式:-log(sigmoid(reward_margin))
    loss = -torch.log(torch.sigmoid(reward_margin)).mean()
    loss = loss / DPOConfig.gradient_accumulation_steps
    # 10. 反向傳播
    loss.backward()
    # 11. 動態(tài)學(xué)習(xí)率
    current_lr = cosine_decay(
        batch_idx,
        DPOConfig.warmup_steps,
        total_batches,
        DPOConfig.max_lr,
        DPOConfig.min_lr
    )
    for param_group in optimizer.param_groups:
        param_group["lr"] = current_lr
    # 12. 梯度累積與權(quán)重更新
    if (batch_idx + 1) % DPOConfig.gradient_accumulation_steps == 0 or (batch_idx + 1) == total_batches:
        optimizer.step()
        optimizer.zero_grad()
    # 13. 記錄指標(biāo)
    training_losses.append(loss.item() * DPOConfig.gradient_accumulation_steps)
    preferred_log_probs.append(chosen_log_prob.mean().item())
    rejected_log_probs.append(rejected_log_prob.mean().item())
    reward_margins.append(reward_margin.mean().item())
    # 14. 定期輸出日志
    if (batch_idx + 1) % DPOConfig.log_iter == 0 or (batch_idx + 1) == total_batches:
        recent_loss = np.nanmean(training_losses[-DPOConfig.log_iter:])
        recent_pref = np.nanmean(preferred_log_probs[-DPOConfig.log_iter:])
        recent_rej = np.nanmean(rejected_log_probs[-DPOConfig.log_iter:])
        recent_margin = np.nanmean(reward_margins[-DPOConfig.log_iter:])
        print("-" * 60)
        current_time = time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"[{current_time}] Batch {batch_idx+1}/{total_batches}")
        print(f"  損失: {recent_loss:.4f}")
        print(f"  正例對數(shù)概率: {recent_pref:.4f} ↑")
        print(f"  負例對數(shù)概率: {recent_rej:.4f} ↓")
        print(f"  獎勵差距: {recent_margin:.4f} ↑")
        print(f"  學(xué)習(xí)率: {current_lr:.2e}")
print(f"\n訓(xùn)練完成!有效批次: {total_batches - skip_count},跳過批次: {skip_count}")
model.save_pretrained("./Qwen3-0.6B-DPO/")
tokenizer.save_pretrained("./Qwen3-0.6B-DPO/")
print("模型已保存至 ./Qwen3-0.6B-DPO/")

代碼分為幾個關(guān)鍵階段:SFT先打底,DPO再調(diào)優(yōu)。訓(xùn)練過程中需要關(guān)注幾個指標(biāo)——正例對數(shù)概率上升、負例對數(shù)概率下降、獎勵差距持續(xù)增大,這些都是學(xué)對了的信號。建議每100步在驗證集上評估一次,一旦獎勵差距不再增長就及時停止,防止過擬合。

5. 總結(jié)與實操建議

縱觀全流程,從數(shù)據(jù)構(gòu)建到SFT再到DPO,核心的實操建議可以總結(jié)為以下幾條:

  • SFT 是 DPO 的必要前提:務(wù)必先讓模型學(xué)會基本的對話能力,DPO 才能正確地優(yōu)化偏好。
  • 高質(zhì)量偏好數(shù)據(jù)是根本:善用開源數(shù)據(jù)集(如 UltraFeedback、HH-RLHF)保證質(zhì)量,并通過過濾低質(zhì)量數(shù)據(jù)和檢查偏好循環(huán)來嚴格控制正負例區(qū)分度。
  • 監(jiān)視訓(xùn)練關(guān)鍵指標(biāo):關(guān)注正例對數(shù)概率上升、負例對數(shù)概率下降、獎勵差距持續(xù)增大。若出現(xiàn)兩者同步上升或獎勵差距后段下降,說明模型開始過擬合,應(yīng)及早停止訓(xùn)練。
  • 優(yōu)先使用 DPO + LoRA:極大的節(jié)省顯存并降低全參訓(xùn)練的風(fēng)險,學(xué)習(xí)率建議從 1e-6 到 5e-6 開始調(diào)整。
  • 提前規(guī)劃部署與評估:在業(yè)務(wù)場景下保留對比測試集,定期評價對齊后模型生成是否符合人類偏好。
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容