從零開始吃透 Seq2Seq:手把手用 PyTorch 搭建中英翻譯模型

從原理到代碼逐行解析,看完你就能自己寫一個能用的翻譯機器人


1. 序言:Seq2Seq 到底是干嘛的?

Seq2Seq(Sequence-to-Sequence,序列到序列)模型,顧名思義,它的核心任務就是把一個序列映射成另一個序列。這類模型廣泛應用于機器翻譯、文本摘要、對話生成等任務。

想象一個場景:你打開谷歌翻譯,輸入“我喜歡你”,點擊翻譯,輸出“I like you”。這背后就是 Seq2Seq 模型在工作。

那么,Seq2Seq 模型和傳統(tǒng)的分類模型有什么區(qū)別?

  • 文本分類:輸入一段話,輸出一個固定的標簽(比如“正面”或“負面”)。輸出只有幾種可能。
  • 序列標注:輸入一句話,給每個詞打一個標簽(比如“人名”、“地名”)。輸出長度等于輸入長度。
  • Seq2Seq:輸入一個序列,輸出另一個序列,輸入和輸出的長度可以不同。這才符合現(xiàn)實世界的需求——中文“我喜歡你”是3個字,英文“I like you”是3個詞,但有些句子的長度比例并不是1:1。

1.1 什么是“序列”?

在 NLP 領域,“序列”通常指一個由單詞、字符或子詞組成的有序列表。一句話就是一個序列,每個字/詞按順序排列。Seq2Seq 模型的核心就是理解輸入序列的順序結構,并生成與之對應的輸出序列。

1.2 Seq2Seq 的歷史與重要性

2014年,Google 的研究團隊發(fā)表了論文《Sequence to Sequence Learning with Neural Networks》,首次提出了基于 LSTM 的 Seq2Seq 架構,并在 WMT-14 英德翻譯任務上達到了 20.6 的 BLEU 分數(shù),大幅超越當時的統(tǒng)計機器翻譯系統(tǒng)。這一突破標志著神經(jīng)機器翻譯(NMT)時代的正式到來。

從技術演進的角度看,Seq2Seq 模型的意義不僅在于它解決了翻譯問題,更在于它確立了一個通用的端到端學習框架——不需要手工設計復雜的特征工程和對齊規(guī)則,模型能夠自動學習從輸入到輸出的映射關系。


2. 核心架構:編碼器-解碼器框架

2.1 整體結構概覽

Seq2Seq 模型的核心由兩個主要部分組成:編碼器(Encoder)?和解碼器(Decoder),二者通過隱藏狀態(tài)(Hidden State)?進行信息傳遞。這兩個組件通常由 RNN 及其變體(LSTM、GRU)構成。

2.2 編碼器(Encoder):聽懂輸入

編碼器的任務:逐個處理輸入序列中的元素,并將整個序列的信息壓縮成一個固定長度的上下文向量(Context Vector,也叫語義向量或“思想向量”)。

編碼器在輸入序列上的處理過程可以這樣理解:

  1. 初始化一個隱藏狀態(tài) h0(通常為零向量)

  2. 輸入序列的第一個詞 x1 送入 RNN,產(chǎn)生隱藏狀態(tài) h1

  3. 輸入序列的第二個詞 x2h1 一起送入 RNN,產(chǎn)生隱藏狀態(tài) h2

  4. 重復此過程,直到處理完最后一個詞 xn,得到最終隱藏狀態(tài) hn

  5. 這個 hn 就是上下文向量?c,它包含了整個輸入序列的信息

用數(shù)學語言來表達:

ht=f(xt,ht?1)(t=1,2,...,n)ht=f(xt,ht?1)(t=1,2,...,n)

c=hnc=hn

其中 htt 時刻的隱藏狀態(tài),f 是 RNN 單元(如 GRU)的函數(shù),c 就是編碼器輸出的上下文向量。

2.3 解碼器(Decoder):生成輸出

解碼器的任務:接收編碼器傳遞過來的上下文向量 c,并基于它逐步生成目標序列。

解碼器的初始隱藏狀態(tài) s0 被設置為編碼器輸出的上下文向量 c

s0=cs0=c

然后,解碼器在每個時間步執(zhí)行以下操作:

  1. 接收當前步的輸入 yt?1(第一步輸入是特殊的起始標記?<sos>)

  2. 結合當前隱藏狀態(tài) st?1,生成新的隱藏狀態(tài) st

  3. st 通過一個線性層映射到詞表大小,得到每個詞的概率分布

  4. 從概率分布中選擇一個詞作為輸出 yt

  5. 將這個輸出作為下一步的輸入,重復直到生成結束標記?<eos>

用數(shù)學語言表達解碼過程:

st=g(yt?1,st?1,c)st=g(yt?1,st?1,c)

yt=argmax(Wst+b)yt=argmax(Wst+b)

其中 g 是解碼器的 RNN 單元,yt?1 是上一個時間步輸出的詞,st 是當前隱藏狀態(tài),c 是來自編碼器的上下文向量。

2.4 一個完整例子的流程

以翻譯“我喜歡你” → “I like you”為例:

編碼器階段:

  • 輸入:"我" → 隱藏狀態(tài) h1

  • 輸入:"喜" → 結合 h1 → 隱藏狀態(tài) h2

  • 輸入:"歡" → 結合 h2 → 隱藏狀態(tài) h3

  • 輸入:"你" → 結合 h3 → 最終隱藏狀態(tài) h4

解碼器階段(推理時):

  • 初始隱藏狀態(tài) = h4,輸入 = <sos> → 輸出 "I"

  • 隱藏狀態(tài)更新,輸入 = "I" → 輸出 "like"

  • 隱藏狀態(tài)更新,輸入 = "like" → 輸出 "you"

  • 隱藏狀態(tài)更新,輸入 = "you" → 輸出 <eos>(停止)


3. 模型訓練的核心技巧:Teacher Forcing

3.1 什么是 Teacher Forcing?

在訓練 Seq2Seq 模型時,我們使用一種名為Teacher Forcing(教師強制)?的巧妙技術。

為什么需要 Teacher Forcing?

回想一下,解碼器在生成過程中,每一步的輸入是上一步的輸出。在訓練初期,模型對什么都一竅不通,它的預測結果基本是隨機的。如果讓模型用自己的錯誤預測作為下一步的輸入,那么錯誤會不斷累積放大——前面猜錯了一個詞,后面的整個句子可能都變得亂七八糟。這就像一個學騎自行車的人,沒人扶著,一開始就瘋狂摔跤,很難進步。

Teacher Forcing 的核心思想很簡單:在訓練過程中,不喂給解碼器它自己(可能錯誤)的預測,而是直接喂給它真實的目標詞。

具體來說,假設目標句子是 <sos> I like you <eos>:

  • 第一步:輸入 <sos>,真實輸出應該是 "I"。模型輸出 "I"(可能猜對,也可能猜錯)

  • 第二步:不管模型第一步猜的是什么,第二步的輸入直接用真實詞 "I",而不是用模型猜的詞

  • 第三步:輸入用真實詞 "like",而不是模型第二步猜的詞

  • ...以此類推

這種訓練方式被稱為“教師強制”,因為每一步都有一個“教師”(真實目標詞)強制告訴模型“下一步應該輸入什么”,將模型拉回正確的軌道上。

3.2 Teacher Forcing 為什么有效?

Teacher Forcing 帶來了兩個明顯的好處:

① 訓練更快更穩(wěn)定

因為每一步輸入的都是正確詞,模型不會因為前面的錯誤而“跑偏”。梯度傳播更平滑,模型收斂速度顯著提升。

② 誤差不會累積

傳統(tǒng)的自回歸訓練(用預測作輸入)存在嚴重的誤差累積問題——一個時間步的小錯誤會在后續(xù)時間步被不斷放大。Teacher Forcing 從根本上避免了這個問題。

3.3 訓練 vs 推理:兩種完全不同的模式

訓練階段:使用 Teacher Forcing

解碼器輸入: <sos> → 真實詞 I → 真實詞 like → 真實詞 you → ...
解碼器輸出: 預測 I → 預測 like → 預測 you → 預測 <eos> → ...
損失計算: 對比預測值和真實值

推理階段:使用自回歸生成

解碼器輸入: <sos> → 預測 I → 預測 like → 預測 you → ...
解碼器輸出: 預測 I → 預測 like → 預測 you → 預測 <eos> → ...

這種“訓練時老師扶著,推理時自己走”的設計,是 Seq2Seq 模型能夠成功訓練的關鍵。


4. 數(shù)據(jù)預處理:從原始文本到模型能吃的數(shù)字

計算機不認識“我”、“你”這些漢字,也不認識“I”、“l(fā)ike”這些英文單詞。模型能理解的只有數(shù)字。因此,在訓練之前,我們需要把所有文字轉(zhuǎn)換成數(shù)字——這個過程就是數(shù)據(jù)預處理。

4.1 分詞:把句子切成最小單位

分詞(Tokenization)?是把一個句子切分成一個個“最小語義單元”的過程。

對于中文和英文,分詞策略是不同的:

  • 中文

    :按字粒度切分?!拔蚁矚g你?!?→ ["我","喜","歡","你","。"]。為什么按字分?因為中文的詞之間沒有空格,自動分詞容易出錯,而按字分簡單可靠。

  • 英文

    :按詞粒度切分,使用 NLTK 庫。“I like you.” → ["I","like","you","."]。英文天然有空格分隔,按詞分可以減少序列長度。

項目使用的數(shù)據(jù)集來自阿里云天池,共 29,155 對中英文平行語句,TSV 格式,每行包含英文和中文兩列。

數(shù)據(jù)集下載地址:https://pan.baidu.com/s/1As2fpzjOn4HSQZPNeIHHDw?pwd=es3y

4.2 詞表構建:給每個詞一個編號

我們把訓練集中所有出現(xiàn)過的字/詞收集起來,每個分配一個唯一的整數(shù)編號(索引)。例如:

"我" → 5
"喜" → 6
"歡" → 7
"你" → 8
"。" → 9
...

中英文的詞表是分開構建的,因為兩種語言的詞匯完全不同。詞表大小可以通過 max_size 參數(shù)限制,取頻率最高的那些詞,超過限制的丟棄(這被稱為“截斷”),可以控制模型參數(shù)量和訓練速度。

4.3 特殊標記:為什么需要它們?

除了普通詞匯,詞表中還必須包含四個特殊標記

標記

全稱

用途

<pad>

Padding

填充符:把所有句子統(tǒng)一到相同長度,填充在句子末尾。不參與損失計算。

<unk>

Unknown

未知詞:當遇到詞表中不存在的詞時(如生僻字、拼寫錯誤),用它代替。

<sos>

Start of Sentence

開始標記:告訴解碼器“現(xiàn)在開始生成新句子”。只加在目標句子的開頭。

<eos>

End of Sentence

結束標記:告訴解碼器“生成結束”。模型學會在輸出這個標記時停止生成。

為什么需要 <sos> 和 <eos>?因為解碼器的生成過程必須有明確的起點和終點。沒有 <sos>,解碼器不知道什么時候開始生成;沒有 <eos>,模型會無限地生成下去,不知道什么時候停止。

4.4 統(tǒng)一長度:填充與截斷

RNN 要求一個 batch 內(nèi)的所有序列長度相同(才能以矩陣形式并行計算)。因此我們需要設定一個最大長度?SEQ_LEN(本項目設為 30)。

  • 截斷

    :長度超過 SEQ_LEN 的句子,直接截取前 SEQ_LEN 個詞。

  • 填充

    :長度不足 SEQ_LEN 的句子,在末尾補充 <pad> 符號。

對于編碼器的輸入(中文)?:不加 <sos> 和 <eos>,直接編碼后填充。
對于解碼器的目標(英文)?:需要先添加 <sos> 和 <eos>,再編碼和填充。

4.5 完整代碼實現(xiàn)

config.py:配置文件

所有超參數(shù)集中管理,方便調(diào)整。

# config.py
from pathlib import Path
 
# ---------- 路徑配置 ----------
BASE_DIR = Path(__file__).parent.parent
RAW_DATA_DIR = BASE_DIR / 'data' / 'raw' ? ? ? ? ? ? ?# 原始數(shù)據(jù)存放位置
PROCESSED_DATA_DIR = BASE_DIR / 'data' / 'processed' ?# 預處理后的數(shù)據(jù)
MODELS_DIR = BASE_DIR / 'models' ? ? ? ? ? ? ? ? ? ? ?# 保存模型參數(shù)
LOGS_DIR = BASE_DIR / 'logs' ? ? ? ? ? ? ? ? ? ? ? ? ?# TensorBoard 日志
 
# ---------- 數(shù)據(jù)參數(shù) ----------
MAX_VOCAB_SIZE = 10000 ? ? ? ?# 詞表最大大小(保留頻率最高的詞)
SEQ_LEN = 30 ? ? ? ? ? ? ? ? ?# 序列統(tǒng)一長度(短則補,長則截)
 
# ---------- 模型參數(shù) ----------
EMBEDDING_DIM = 128 ? ? ? ? ? # 詞向量維度(每個詞用128個浮點數(shù)表示)
ENCODER_HIDDEN_DIM = 512 ? ? ?# 編碼器 GRU 隱藏維度(單向)
DECODER_HIDDEN_DIM = 1024 ? ? # 解碼器隱藏維度 = ENCODER_HIDDEN_DIM * 2
ENCODER_LAYERS = 1 ? ? ? ? ? ?# 編碼器層數(shù)(可以加深,1層足夠演示)
 
# ---------- 訓練參數(shù) ----------
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EPOCHS = 30

tokenizer.py:分詞器實現(xiàn)

分詞器負責文本到數(shù)字的轉(zhuǎn)換,是整個預處理的核心。

# tokenizer.py
import json
from collections import Counter
from typing import List, Optional
import nltk
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer
 
# 確保 nltk 的分詞數(shù)據(jù)已下載
try:
? ? nltk.data.find('tokenizers/punkt')
except LookupError:
? ? nltk.download('punkt')
 
 
class BaseTokenizer:
? ? """
? ? 分詞器基類,提供通用的編碼/解碼方法和詞表構建邏輯。
? ? 中文分詞器和英文分詞器都繼承自此類。
? ? """
 
? ? # 四個特殊標記
? ? PAD_TOKEN = '<pad>' ? # 填充符
? ? UNK_TOKEN = '<unk>' ? # 未知詞
? ? SOS_TOKEN = '<sos>' ? # 句子開始標記
? ? EOS_TOKEN = '<eos>' ? # 句子結束標記
 
? ? def __init__(self, vocab: List[str]):
? ? ? ? """
? ? ? ? 初始化分詞器
? ? ? ? 參數(shù):
? ? ? ? ? ? vocab: 詞表列表,索引順序就是詞表中詞的順序
? ? ? ? """
? ? ? ? self.vocab = vocab
? ? ? ? self.vocab_size = len(vocab)
? ? ? ? # 建立雙向映射:詞→索引 和 索引→詞
? ? ? ? self.stoi = {word: idx for idx, word in enumerate(vocab)}
? ? ? ? self.itos = {idx: word for word, idx in self.stoi.items()}
 
? ? ? ? # 緩存特殊標記的索引,方便使用
? ? ? ? self.pad_idx = self.stoi[self.PAD_TOKEN]
? ? ? ? self.unk_idx = self.stoi[self.UNK_TOKEN]
? ? ? ? self.sos_idx = self.stoi[self.SOS_TOKEN]
? ? ? ? self.eos_idx = self.stoi[self.EOS_TOKEN]
 
? ? @classmethod
? ? def build_vocab(cls, sentences: List[str], max_size: Optional[int] = None, min_freq: int = 1):
? ? ? ? """
? ? ? ? 從句子列表構建詞表(特殊標記自動添加)
? ? ? ? 參數(shù):
? ? ? ? ? ? sentences: 句子列表
? ? ? ? ? ? max_size: 詞表最大大小(保留頻率最高的詞)
? ? ? ? ? ? min_freq: 最小出現(xiàn)頻率,低于此頻率的詞被丟棄
? ? ? ? 返回:
? ? ? ? ? ? 詞表列表,前4個位置是特殊標記
? ? ? ? """
? ? ? ? # 統(tǒng)計所有詞的出現(xiàn)頻率
? ? ? ? counter = Counter()
? ? ? ? for sent in sentences:
? ? ? ? ? ? tokens = cls.tokenize(sent)
? ? ? ? ? ? counter.update(tokens)
 
? ? ? ? # 按頻率從高到低排序,取前 max_size 個
? ? ? ? most_common = counter.most_common(max_size)
 
? ? ? ? # 構建詞表:先加特殊標記,再加普通詞
? ? ? ? vocab = [cls.PAD_TOKEN, cls.UNK_TOKEN, cls.SOS_TOKEN, cls.EOS_TOKEN]
? ? ? ? for word, freq in most_common:
? ? ? ? ? ? if freq >= min_freq:
? ? ? ? ? ? ? ? vocab.append(word)
? ? ? ? return vocab
 
? ? @staticmethod
? ? def tokenize(sentence: str) -> List[str]:
? ? ? ? """分詞方法,子類必須實現(xiàn)"""
? ? ? ? raise NotImplementedError
 
? ? @staticmethod
? ? def detokenize(tokens: List[str]) -> str:
? ? ? ? """去分詞方法(把 token 列表還原成字符串),子類必須實現(xiàn)"""
? ? ? ? raise NotImplementedError
 
? ? def encode(self, sentence: str, max_len: int, add_sos_eos: bool = False) -> List[int]:
? ? ? ? """
? ? ? ? 把原始句子轉(zhuǎn)換成索引列表,并統(tǒng)一長度
? ? ? ? 參數(shù):
? ? ? ? ? ? sentence: 原始字符串
? ? ? ? ? ? max_len: 目標長度
? ? ? ? ? ? add_sos_eos: 是否添加 <sos> 和 <eos> 標記
? ? ? ? 返回:
? ? ? ? ? ? 長度為 max_len 的整數(shù)列表
? ? ? ? """
? ? ? ? # 1. 分詞
? ? ? ? tokens = self.tokenize(sentence)
? ? ? ? # 2. 轉(zhuǎn)索引,遇到不在詞表中的詞用 unk_idx
? ? ? ? indices = [self.stoi.get(token, self.unk_idx) for token in tokens]
? ? ? ? # 3. 如果需要加開始/結束標記
? ? ? ? if add_sos_eos:
? ? ? ? ? ? indices = [self.sos_idx] + indices + [self.eos_idx]
? ? ? ? # 4. 截斷或填充到 max_len
? ? ? ? if len(indices) > max_len:
? ? ? ? ? ? indices = indices[:max_len]
? ? ? ? else:
? ? ? ? ? ? indices = indices + [self.pad_idx] * (max_len - len(indices))
? ? ? ? return indices
 
? ? def decode(self, indices: List[int], skip_special: bool = True) -> str:
? ? ? ? """
? ? ? ? 把索引列表變回字符串
? ? ? ? 參數(shù):
? ? ? ? ? ? indices: 索引列表
? ? ? ? ? ? skip_special: 是否跳過特殊標記(<pad>, <sos>, <eos>)
? ? ? ? 返回:
? ? ? ? ? ? 還原后的字符串
? ? ? ? """
? ? ? ? if skip_special:
? ? ? ? ? ? special_set = {self.pad_idx, self.sos_idx, self.eos_idx}
? ? ? ? ? ? tokens = [self.itos[i] for i in indices if i not in special_set]
? ? ? ? else:
? ? ? ? ? ? tokens = [self.itos[i] for i in indices]
? ? ? ? return self.detokenize(tokens)
 
 
class ChineseTokenizer(BaseTokenizer):
? ? """中文分詞器:按字符切分"""
 
? ? @staticmethod
? ? def tokenize(sentence: str) -> List[str]:
? ? ? ? # 去除空格,然后按字符拆分
? ? ? ? return [ch for ch in sentence.strip() if ch != ' ']
 
? ? @staticmethod
? ? def detokenize(tokens: List[str]) -> str:
? ? ? ? return ''.join(tokens)
 
 
class EnglishTokenizer(BaseTokenizer):
? ? """英文分詞器:使用 NLTK 的 word_tokenize 按詞切分"""
 
? ? @staticmethod
? ? def tokenize(sentence: str) -> List[str]:
? ? ? ? return word_tokenize(sentence.strip())
 
? ? @staticmethod
? ? def detokenize(tokens: List[str]) -> str:
? ? ? ? return TreebankWordDetokenizer().detokenize(tokens)

process.py:數(shù)據(jù)預處理主程序

# process.py
import pandas as pd
from sklearn.model_selection import train_test_split
from pathlib import Path
import json
import config
from tokenizer import ChineseTokenizer, EnglishTokenizer
 
def main():
? ? # 設置路徑
? ? raw_path = Path(config.RAW_DATA_DIR) / 'cmn.txt'
? ? processed_dir = Path(config.PROCESSED_DATA_DIR)
? ? processed_dir.mkdir(parents=True, exist_ok=True)
 
? ? # 1. 讀取原始數(shù)據(jù)(TSV 格式,兩列:英文、中文)
? ? print("讀取原始數(shù)據(jù)...")
? ? df = pd.read_csv(raw_path, sep='\t', header=None, names=['en', 'zh'])
? ? df = df.dropna() ?# 去掉空行
? ? df = df[df['en'].str.strip().ne('') & df['zh'].str.strip().ne('')]
? ? print(f"總共 {len(df)} 對句子")
 
? ? # 2. 劃分訓練集和測試集(8:2 比例)
? ? train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
? ? print(f"訓練集 {len(train_df)} 對,測試集 {len(test_df)} 對")
 
? ? # 3. 構建詞表(只用訓練集)
? ? print("構建中文詞表...")
? ? zh_vocab = ChineseTokenizer.build_vocab(
? ? ? ? train_df['zh'].tolist(),?
? ? ? ? max_size=config.MAX_VOCAB_SIZE
? ? )
? ? print(f"中文詞表大小: {len(zh_vocab)}")
 
? ? print("構建英文詞表...")
? ? en_vocab = EnglishTokenizer.build_vocab(
? ? ? ? train_df['en'].tolist(),?
? ? ? ? max_size=config.MAX_VOCAB_SIZE
? ? )
? ? print(f"英文詞表大小: {len(en_vocab)}")
 
? ? # 保存詞表到文件(方便以后加載,不用重復構建)
? ? with open(processed_dir / 'zh_vocab.json', 'w', encoding='utf-8') as f:
? ? ? ? json.dump(zh_vocab, f, ensure_ascii=False)
? ? with open(processed_dir / 'en_vocab.json', 'w', encoding='utf-8') as f:
? ? ? ? json.dump(en_vocab, f, ensure_ascii=False)
 
? ? # 初始化分詞器
? ? zh_tokenizer = ChineseTokenizer(zh_vocab)
? ? en_tokenizer = EnglishTokenizer(en_vocab)
 
? ? # 4. 編碼訓練集并保存
? ? print("處理訓練集...")
? ? train_records = []
? ? for _, row in train_df.iterrows():
? ? ? ? # 中文:不加 sos/eos,只編碼后填充
? ? ? ? zh_ids = zh_tokenizer.encode(row['zh'], max_len=config.SEQ_LEN, add_sos_eos=False)
? ? ? ? # 英文:需要加 sos/eos,因為解碼器需要它們來開始和結束生成
? ? ? ? en_ids = en_tokenizer.encode(row['en'], max_len=config.SEQ_LEN, add_sos_eos=True)
? ? ? ? train_records.append({'zh': zh_ids, 'en': en_ids})
 
? ? with open(processed_dir / 'train.json', 'w', encoding='utf-8') as f:
? ? ? ? for rec in train_records:
? ? ? ? ? ? f.write(json.dumps(rec, ensure_ascii=False) + '\n')
 
? ? # 5. 編碼測試集并保存
? ? print("處理測試集...")
? ? test_records = []
? ? for _, row in test_df.iterrows():
? ? ? ? zh_ids = zh_tokenizer.encode(row['zh'], max_len=config.SEQ_LEN, add_sos_eos=False)
? ? ? ? en_ids = en_tokenizer.encode(row['en'], max_len=config.SEQ_LEN, add_sos_eos=True)
? ? ? ? test_records.append({'zh': zh_ids, 'en': en_ids})
 
? ? with open(processed_dir / 'test.json', 'w', encoding='utf-8') as f:
? ? ? ? for rec in test_records:
? ? ? ? ? ? f.write(json.dumps(rec, ensure_ascii=False) + '\n')
 
? ? print("預處理完成!")
 
if __name__ == '__main__':
? ? main()

4.6 自定義 Dataset 與 DataLoader

預處理后的數(shù)據(jù)保存為 JSON Lines 格式(每行一個 JSON 對象),我們需要一個 PyTorch Dataset 來加載它們。

# dataset.py
import torch
from torch.utils.data import Dataset, DataLoader
import json
from pathlib import Path
import config
 
class TranslationDataset(Dataset):
? ? """
? ? 翻譯數(shù)據(jù)集類
? ? 加載預處理后的 JSON 文件,返回中文輸入和英文目標
? ? """
? ? def __init__(self, json_path):
? ? ? ? self.data = []
? ? ? ? with open(json_path, 'r', encoding='utf-8') as f:
? ? ? ? ? ? for line in f:
? ? ? ? ? ? ? ? self.data.append(json.loads(line))
 
? ? def __len__(self):
? ? ? ? return len(self.data)
 
? ? def __getitem__(self, idx):
? ? ? ? item = self.data[idx]
? ? ? ? # 中文輸入(不加 sos/eos)
? ? ? ? src = torch.tensor(item['zh'], dtype=torch.long)
? ? ? ? # 英文目標(已加 sos 和 eos)
? ? ? ? tgt = torch.tensor(item['en'], dtype=torch.long)
? ? ? ? return src, tgt
 
def get_dataloader(train=True, batch_size=None):
? ? """
? ? 獲取 DataLoader
? ? 參數(shù):
? ? ? ? train: True 返回訓練集 DataLoader,F(xiàn)alse 返回測試集 DataLoader
? ? ? ? batch_size: 批次大小,默認使用 config.BATCH_SIZE
? ? """
? ? if batch_size is None:
? ? ? ? batch_size = config.BATCH_SIZE
? ? data_dir = Path(config.PROCESSED_DATA_DIR)
? ? if train:
? ? ? ? path = data_dir / 'train.json'
? ? else:
? ? ? ? path = data_dir / 'test.json'
? ? dataset = TranslationDataset(path)
? ? return DataLoader(dataset, batch_size=batch_size, shuffle=train, drop_last=train)

提示:drop_last=True 是為了確保每個 batch 大小一致,避免最后一個不完整的 batch 導致訓練出錯。如果訓練集大小不是 batch_size 的整數(shù)倍,最后一個不完整的 batch 會被丟棄。


5. 模型實現(xiàn):逐行代碼詳解

5.1 編碼器(Encoder)完整代碼

編碼器包含兩個主要部分:嵌入層(Embedding Layer)?和雙向 GRU 層。

# model.py
import torch
import torch.nn as nn
import config
 
class Encoder(nn.Module):
? ? """
? ? 編碼器:將輸入序列(中文)編碼成上下文向量
? ? """
? ? def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=1):
? ? ? ? super().__init__()
? ? ? ? self.hidden_dim = hidden_dim
? ? ? ? self.num_layers = num_layers
 
? ? ? ? # ---------- 嵌入層 ----------
? ? ? ? # 作用:把每個字的編號(如 12)變成一個稠密向量(如 128 個浮點數(shù))
? ? ? ? # 為什么需要嵌入層?因為整數(shù)編號沒有語義信息,"我"=12 和 "你"=13 之間的數(shù)值差 1 沒有任何意義。
? ? ? ? # 嵌入層讓模型自己學習每個詞的向量表示,意思相近的詞向量也會相近。
? ? ? ? self.embedding = nn.Embedding(vocab_size, embedding_dim)
 
? ? ? ? # ---------- 雙向 GRU ----------
? ? ? ? # GRU(Gated Recurrent Unit)是 RNN 的改進版,比基本 RNN 效果好,比 LSTM 簡單。
? ? ? ? # bidirectional=True:同時從左到右和從右到左處理序列,讓每個位置都能看到完整的上下文。
? ? ? ? # batch_first=True:輸入形狀為 (batch, seq, feature) 而不是 (seq, batch, feature)
? ? ? ? self.gru = nn.GRU(
? ? ? ? ? ? input_size=embedding_dim,
? ? ? ? ? ? hidden_size=hidden_dim,
? ? ? ? ? ? num_layers=num_layers,
? ? ? ? ? ? batch_first=True,
? ? ? ? ? ? bidirectional=True
? ? ? ? )
 
? ? def forward(self, x):
? ? ? ? """
? ? ? ? 前向傳播
? ? ? ? 參數(shù):
? ? ? ? ? ? x: 輸入張量,形狀 (batch_size, seq_len),里面是每個字的編號
? ? ? ? 返回:
? ? ? ? ? ? hidden: 最終隱藏狀態(tài),形狀 (num_layers*2, batch_size, hidden_dim)
? ? ? ? ? ? ? ? ? ?因為是雙向,num_layers*2 = 2
? ? ? ? """
? ? ? ? # 步驟1:嵌入層轉(zhuǎn)換
? ? ? ? # 輸入 x: (batch_size, seq_len)
? ? ? ? # 輸出 embedded: (batch_size, seq_len, embedding_dim)
? ? ? ? embedded = self.embedding(x)
 
? ? ? ? # 步驟2:雙向 GRU 處理
? ? ? ? # output: 每個時間步的輸出,形狀 (batch_size, seq_len, hidden_dim*2)
? ? ? ? # hidden: 最后一個時間步的隱藏狀態(tài)(雙向),形狀 (2, batch_size, hidden_dim)
? ? ? ? output, hidden = self.gru(embedded)
 
? ? ? ? # 返回 hidden,這就是編碼器的輸出(上下文向量)
? ? ? ? # 注意:這里的 hidden 包含前向和后向兩個方向的最后狀態(tài)
? ? ? ? return hidden

編碼器的關鍵點

  • 嵌入層

    :nn.Embedding(vocab_size, embedding_dim) 創(chuàng)建了一個可訓練的查找表。給定一個整數(shù)索引(詞的編號),它返回對應的向量。這些向量一開始是隨機的,在訓練過程中不斷調(diào)整,最終使得語義相近的詞向量在向量空間中距離更近。

  • 雙向 GRU

    :bidirectional=True 讓 GRU 同時從兩個方向處理序列。輸出維度會翻倍(hidden_dim*2),因為前向和后向的隱藏狀態(tài)拼接在一起。這樣做的好處是:每個位置的隱藏狀態(tài)都能同時看到左邊和右邊的上下文信息,理解更準確。

  • 為什么返回的是 hidden 而不是 output

    :在傳統(tǒng)的 Seq2Seq 中,編碼器的最后一個隱藏狀態(tài)被當作整個輸入序列的語義摘要。這個摘要向量就是傳遞給解碼器的上下文向量。不過嚴格來說,在雙向 GRU 中,我們需要把前向和后向的最后狀態(tài)拼接起來才是完整的上下文向量——這個拼接操作在訓練循環(huán)中完成。

5.2 解碼器(Decoder)完整代碼

解碼器包含三層:嵌入層單向 GRU、全連接層(線性層)。

class Decoder(nn.Module):
? ? """
? ? 解碼器:根據(jù)上下文向量逐步生成目標序列(英文)
? ? """
? ? def __init__(self, vocab_size, embedding_dim, hidden_dim):
? ? ? ? super().__init__()
 
? ? ? ? # ---------- 嵌入層 ----------
? ? ? ? # 和編碼器的嵌入層作用相同,把英文詞的編號轉(zhuǎn)成向量
? ? ? ? self.embedding = nn.Embedding(vocab_size, embedding_dim)
 
? ? ? ? # ---------- 單向 GRU ----------
? ? ? ? # 注意:解碼器是單向的,因為生成時只能看到已經(jīng)生成的詞,不能看到未來的詞
? ? ? ? # 這是自回歸生成的自然要求
? ? ? ? self.gru = nn.GRU(
? ? ? ? ? ? input_size=embedding_dim,
? ? ? ? ? ? hidden_size=hidden_dim,
? ? ? ? ? ? batch_first=True
? ? ? ? )
 
? ? ? ? # ---------- 全連接層 ----------
? ? ? ? # 把 GRU 輸出的向量映射成詞表大小的概率分布
? ? ? ? # 例如 hidden_dim=1024,vocab_size=10000,則線性層將 1024 維向量映射成 10000 維
? ? ? ? # 這 10000 個數(shù)字經(jīng)過 softmax 后就是每個詞的概率
? ? ? ? self.fc = nn.Linear(hidden_dim, vocab_size)
 
? ? def forward(self, x, hidden):
? ? ? ? """
? ? ? ? 單步前向傳播(一個時間步)
? ? ? ? 參數(shù):
? ? ? ? ? ? x: 當前步的輸入,形狀 (batch_size, 1),是一個詞的編號
? ? ? ? ? ? hidden: 上一步的隱藏狀態(tài),形狀 (1, batch_size, hidden_dim)
? ? ? ? ? ? ? ? ? ?初始時來自編碼器的上下文向量
? ? ? ? 返回:
? ? ? ? ? ? output: 當前步的輸出,形狀 (batch_size, 1, vocab_size)
? ? ? ? ? ? ? ? ? ?代表每個候選詞的概率(logits,未經(jīng)過 softmax)
? ? ? ? ? ? hidden: 新的隱藏狀態(tài),用于下一步
? ? ? ? """
? ? ? ? # 步驟1:嵌入
? ? ? ? # x: (batch_size, 1) → embedded: (batch_size, 1, embedding_dim)
? ? ? ? embedded = self.embedding(x)
 
? ? ? ? # 步驟2:GRU 一步
? ? ? ? # 輸入 embedded 和上一步的 hidden,輸出當前步的 output 和新的 hidden
? ? ? ? # output: (batch_size, 1, hidden_dim)
? ? ? ? # hidden: (1, batch_size, hidden_dim)
? ? ? ? output, hidden = self.gru(embedded, hidden)
 
? ? ? ? # 步驟3:全連接層,映射到詞表大小
? ? ? ? # output: (batch_size, 1, hidden_dim) → (batch_size, 1, vocab_size)
? ? ? ? output = self.fc(output)
 
? ? ? ? return output, hidden

解碼器的關鍵點

  • 單向 GRU

    :與編碼器不同,解碼器只能用單向。因為在生成過程中,我們只能利用已經(jīng)生成的信息來預測下一個詞,不能“偷看”未來的詞。這是自回歸生成的核心約束。

  • 自回歸機制

    :解碼器在生成一個詞之后,把它作為輸入去生成下一個詞。這種“把自己上一時刻的輸出當作當前時刻的輸入”的方式,就是自回歸生成。推理階段是這樣,但訓練階段我們用了 Teacher Forcing。

  • 全連接層

    :GRU 的輸出是一個隱藏向量(維度 hidden_dim),這個向量包含了當前生成步驟的語義信息。全連接層把這個向量映射到詞表大小,得到一個分布。分布中概率最大的詞就是我們預測的下一個詞。

5.3 訓練循環(huán)完整代碼

訓練時需要把編碼器和解碼器串聯(lián)起來,并實現(xiàn) Teacher Forcing 策略。

# train.py
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import config
from dataset import get_dataloader
from model import Encoder, Decoder
from tokenizer import ChineseTokenizer, EnglishTokenizer
import json
from pathlib import Path
 
def load_tokenizers():
? ? """從保存的 json 文件加載中英文詞表和分詞器"""
? ? with open(Path(config.PROCESSED_DATA_DIR) / 'zh_vocab.json', 'r', encoding='utf-8') as f:
? ? ? ? zh_vocab = json.load(f)
? ? with open(Path(config.PROCESSED_DATA_DIR) / 'en_vocab.json', 'r', encoding='utf-8') as f:
? ? ? ? en_vocab = json.load(f)
? ? zh_tokenizer = ChineseTokenizer(zh_vocab)
? ? en_tokenizer = EnglishTokenizer(en_vocab)
? ? return zh_tokenizer, en_tokenizer
 
def train_one_epoch(encoder, decoder, dataloader, criterion, optimizer, device):
? ? """
? ? 訓練一個 epoch(所有訓練數(shù)據(jù)過一遍)
? ? 參數(shù):
? ? ? ? encoder: 編碼器模型
? ? ? ? decoder: 解碼器模型
? ? ? ? dataloader: 訓練數(shù)據(jù)加載器
? ? ? ? criterion: 損失函數(shù)
? ? ? ? optimizer: 優(yōu)化器
? ? ? ? device: 設備(cuda/cpu)
? ? 返回:
? ? ? ? 這個 epoch 的平均損失值
? ? """
? ? encoder.train() ?# 設置為訓練模式(啟用 dropout 等)
? ? decoder.train()
? ? total_loss = 0
 
? ? # tqdm 顯示進度條
? ? for src, tgt in tqdm(dataloader, desc="Training"):
? ? ? ? # src: (batch_size, SEQ_LEN) 中文輸入
? ? ? ? # tgt: (batch_size, SEQ_LEN) 英文目標(已包含 <sos> 和 <eos>)
? ? ? ? src = src.to(device)
? ? ? ? tgt = tgt.to(device)
 
? ? ? ? optimizer.zero_grad() ?# 梯度清零(否則會累加)
 
? ? ? ? # ---------- 編碼器 ----------
? ? ? ? encoder_hidden = encoder(src)
? ? ? ? # encoder_hidden: (2, batch_size, ENCODER_HIDDEN_DIM)
? ? ? ? # 2 代表兩個方向(前向和后向)
 
? ? ? ? # 將雙向的隱藏狀態(tài)拼接成上下文向量
? ? ? ? # encoder_hidden[-2]: 前向最后一層的隱藏狀態(tài) (batch_size, ENCODER_HIDDEN_DIM)
? ? ? ? # encoder_hidden[-1]: 后向最后一層的隱藏狀態(tài) (batch_size, ENCODER_HIDDEN_DIM)
? ? ? ? # context: (batch_size, ENCODER_HIDDEN_DIM*2) = (batch_size, 1024)
? ? ? ? forward_hidden = encoder_hidden[-2]
? ? ? ? backward_hidden = encoder_hidden[-1]
? ? ? ? context = torch.cat([forward_hidden, backward_hidden], dim=1)
 
? ? ? ? # ---------- 解碼器(Teacher Forcing)----------
? ? ? ? # 初始化解碼器的隱藏狀態(tài)(需要添加一層維度)
? ? ? ? decoder_hidden = context.unsqueeze(0) ?# (1, batch_size, 1024)
 
? ? ? ? # 解碼器的第一步輸入:<sos>,取目標句子的第一個詞
? ? ? ? decoder_input = tgt[:, 0:1] ?# (batch_size, 1)
 
? ? ? ? decoder_outputs = [] ?# 存放每個時間步的輸出
 
? ? ? ? # 從第1步到第 SEQ_LEN-1 步
? ? ? ? # 因為第0步是 <sos>,最后一步是 <eos>,我們需要預測從第1步到最后一步
? ? ? ? for step in range(1, config.SEQ_LEN):
? ? ? ? ? ? # 解碼器前向傳播
? ? ? ? ? ? decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
? ? ? ? ? ? # decoder_output: (batch_size, 1, vocab_size)
? ? ? ? ? ? decoder_outputs.append(decoder_output)
 
? ? ? ? ? ? # Teacher Forcing:下一步的輸入直接使用真實的目標詞
? ? ? ? ? ? # 不管 decoder 上一步預測了什么,我們都喂給它正確的詞
? ? ? ? ? ? decoder_input = tgt[:, step:step+1] ?# (batch_size, 1)
 
? ? ? ? # 將所有時間步的輸出拼接在一起
? ? ? ? # decoder_outputs 列表中有 SEQ_LEN-1 個元素,每個形狀 (batch, 1, vocab)
? ? ? ? # cat 后形狀: (batch, SEQ_LEN-1, vocab)
? ? ? ? decoder_outputs = torch.cat(decoder_outputs, dim=1)
 
? ? ? ? # 目標序列:去掉 <sos>,從第1個詞到最后一個詞(包括 <eos>)
? ? ? ? targets = tgt[:, 1:] ?# (batch_size, SEQ_LEN-1)
 
? ? ? ? # ---------- 計算損失 ----------
? ? ? ? # CrossEntropyLoss 要求輸入形狀 (N, C) 和 (N,)
? ? ? ? # 其中 N 是所有 token 的總數(shù),C 是詞表大小
? ? ? ? loss = criterion(
? ? ? ? ? ? decoder_outputs.reshape(-1, decoder_outputs.shape[-1]), ?# (batch*(SEQ_LEN-1), vocab)
? ? ? ? ? ? targets.reshape(-1) ?# (batch*(SEQ_LEN-1))
? ? ? ? )
 
? ? ? ? # 反向傳播
? ? ? ? loss.backward()
? ? ? ? optimizer.step()
 
? ? ? ? total_loss += loss.item()
 
? ? return total_loss / len(dataloader)
 
def main():
? ? device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
? ? print(f"Using device: {device}")
 
? ? # 加載分詞器和詞表
? ? zh_tokenizer, en_tokenizer = load_tokenizers()
 
? ? # 創(chuàng)建模型
? ? encoder = Encoder(
? ? ? ? vocab_size=zh_tokenizer.vocab_size,
? ? ? ? embedding_dim=config.EMBEDDING_DIM,
? ? ? ? hidden_dim=config.ENCODER_HIDDEN_DIM,
? ? ? ? num_layers=config.ENCODER_LAYERS
? ? ).to(device)
 
? ? decoder = Decoder(
? ? ? ? vocab_size=en_tokenizer.vocab_size,
? ? ? ? embedding_dim=config.EMBEDDING_DIM,
? ? ? ? hidden_dim=config.DECODER_HIDDEN_DIM
? ? ).to(device)
 
? ? # 損失函數(shù):交叉熵,忽略 <pad> 位置的損失
? ? criterion = nn.CrossEntropyLoss(ignore_index=en_tokenizer.pad_idx)
 
? ? # 優(yōu)化器:Adam,學習率 0.001
? ? optimizer = optim.Adam(
? ? ? ? list(encoder.parameters()) + list(decoder.parameters()),
? ? ? ? lr=config.LEARNING_RATE
? ? )
 
? ? # 加載數(shù)據(jù)
? ? train_loader = get_dataloader(train=True)
 
? ? # 訓練循環(huán)
? ? for epoch in range(1, config.EPOCHS + 1):
? ? ? ? print(f"\n========== Epoch {epoch} / {config.EPOCHS} ==========")
? ? ? ? avg_loss = train_one_epoch(encoder, decoder, train_loader, criterion, optimizer, device)
? ? ? ? print(f"Average Loss: {avg_loss:.4f}")
 
? ? ? ? # 每 5 個 epoch 保存一次模型
? ? ? ? if epoch % 5 == 0:
? ? ? ? ? ? torch.save(encoder.state_dict(), config.MODELS_DIR / f'encoder_epoch{epoch}.pt')
? ? ? ? ? ? torch.save(decoder.state_dict(), config.MODELS_DIR / f'decoder_epoch{epoch}.pt')
? ? ? ? ? ? print(f"Model saved at epoch {epoch}")
 
? ? # 保存最終模型
? ? torch.save(encoder.state_dict(), config.MODELS_DIR / 'encoder_final.pt')
? ? torch.save(decoder.state_dict(), config.MODELS_DIR / 'decoder_final.pt')
? ? print("Training completed!")
 
if __name__ == '__main__':
? ? main()

輸出結果:

========== 前面45輪省略。。。 ==========
========== EPOCH:45 ==========
訓練:: 100%|██████████| 365/365 [00:01<00:00, 193.33it/s]
本輪訓練損失: 0.06847477024548675
模型保存成功!
========== EPOCH:46 ==========
訓練:: 100%|██████████| 365/365 [00:01<00:00, 191.92it/s]
本輪訓練損失: 0.06932613608261494
========== EPOCH:47 ==========
訓練:: 100%|██████████| 365/365 [00:02<00:00, 179.68it/s]
本輪訓練損失: 0.07184863890176767
========== EPOCH:48 ==========
訓練:: 100%|██████████| 365/365 [00:02<00:00, 178.81it/s]
訓練:: ? 0%| ? ? ? ? ?| 0/365 [00:00<?, ?it/s]本輪訓練損失: 0.09403942644800226
========== EPOCH:49 ==========
訓練:: 100%|██████████| 365/365 [00:01<00:00, 190.23it/s]
本輪訓練損失: 0.12599861240142013
========== EPOCH:50 ==========
訓練:: 100%|██████████| 365/365 [00:02<00:00, 180.61it/s]
本輪訓練損失: 0.10180483621685472

訓練循環(huán)的關鍵點

  • 上下文向量的拼接

    :雙向 GRU 輸出的 hidden 形狀是 (2, batch, hidden_dim)。我們需要把前向(索引 -2)和后向(索引 -1)的最后狀態(tài)拼接起來,形成一個 (batch, hidden_dim*2) 的向量,這才是完整的上下文向量。

  • Teacher Forcing 的實現(xiàn)

    :在 for step in range(1, SEQ_LEN) 循環(huán)中,decoder_input 始終從真實目標 tgt 中取,而不是用上一步的預測。這是 Teacher Forcing 的核心。

  • 損失函數(shù)的 ignore_index

    :設置 ignore_index=en_tokenizer.pad_idx 后,損失函數(shù)會自動忽略所有 <pad> 位置,不會因為這些位置而懲罰模型。這是非常關鍵的,因為我們不希望模型去學習預測那些填充的無意義符號。

  • 梯度清零

    :每個 batch 開始前需要 optimizer.zero_grad(),否則梯度會累加到下一個 batch。

5.4 推理與翻譯完整代碼

訓練完成后,我們需要一個推理函數(shù)來實際翻譯新句子。推理時必須使用自回歸生成,不能使用 Teacher Forcing。

# predict.py
import torch
import config
from model import Encoder, Decoder
from tokenizer import ChineseTokenizer, EnglishTokenizer
import json
from pathlib import Path
 
def load_models(device):
? ? """加載訓練好的模型和分詞器"""
? ? # 加載詞表
? ? with open(Path(config.PROCESSED_DATA_DIR) / 'zh_vocab.json', 'r', encoding='utf-8') as f:
? ? ? ? zh_vocab = json.load(f)
? ? with open(Path(config.PROCESSED_DATA_DIR) / 'en_vocab.json', 'r', encoding='utf-8') as f:
? ? ? ? en_vocab = json.load(f)
 
? ? zh_tokenizer = ChineseTokenizer(zh_vocab)
? ? en_tokenizer = EnglishTokenizer(en_vocab)
 
? ? # 創(chuàng)建模型
? ? encoder = Encoder(
? ? ? ? vocab_size=zh_tokenizer.vocab_size,
? ? ? ? embedding_dim=config.EMBEDDING_DIM,
? ? ? ? hidden_dim=config.ENCODER_HIDDEN_DIM,
? ? ? ? num_layers=config.ENCODER_LAYERS
? ? ).to(device)
 
? ? decoder = Decoder(
? ? ? ? vocab_size=en_tokenizer.vocab_size,
? ? ? ? embedding_dim=config.EMBEDDING_DIM,
? ? ? ? hidden_dim=config.DECODER_HIDDEN_DIM
? ? ).to(device)
 
? ? # 加載訓練好的權重
? ? encoder.load_state_dict(torch.load(config.MODELS_DIR / 'encoder_final.pt', map_location=device))
? ? decoder.load_state_dict(torch.load(config.MODELS_DIR / 'decoder_final.pt', map_location=device))
 
? ? encoder.eval()
? ? decoder.eval()
 
? ? return encoder, decoder, zh_tokenizer, en_tokenizer
 
def translate(sentence, encoder, decoder, zh_tokenizer, en_tokenizer, device, max_len=None):
? ? """
? ? 翻譯單個中文句子
? ? 參數(shù):
? ? ? ? sentence: 原始中文句子,例如 "我喜歡你"
? ? ? ? encoder: 編碼器模型
? ? ? ? decoder: 解碼器模型
? ? ? ? zh_tokenizer: 中文分詞器
? ? ? ? en_tokenizer: 英文分詞器
? ? ? ? device: 設備
? ? ? ? max_len: 最大生成長度,默認使用 config.SEQ_LEN
? ? 返回:
? ? ? ? 英文翻譯句子
? ? """
? ? if max_len is None:
? ? ? ? max_len = config.SEQ_LEN
 
? ? with torch.no_grad(): ?# 推理時不需要計算梯度,節(jié)省內(nèi)存
? ? ? ? # ---------- 編碼 ----------
? ? ? ? # 將中文句子編碼成索引序列(不加 sos/eos)
? ? ? ? input_ids = zh_tokenizer.encode(sentence, max_len=config.SEQ_LEN, add_sos_eos=False)
? ? ? ? src_tensor = torch.tensor([input_ids], device=device) ?# (1, SEQ_LEN)
 
? ? ? ? encoder_hidden = encoder(src_tensor)
? ? ? ? # encoder_hidden: (2, 1, ENCODER_HIDDEN_DIM)
 
? ? ? ? # 拼接上下文向量
? ? ? ? forward_hidden = encoder_hidden[-2] ? # (1, ENCODER_HIDDEN_DIM)
? ? ? ? backward_hidden = encoder_hidden[-1] ?# (1, ENCODER_HIDDEN_DIM)
? ? ? ? context = torch.cat([forward_hidden, backward_hidden], dim=1) ?# (1, 1024)
? ? ? ? decoder_hidden = context.unsqueeze(0) ?# (1, 1, 1024)
 
? ? ? ? # ---------- 解碼(自回歸生成)----------
? ? ? ? # 初始輸入:<sos> 標記
? ? ? ? decoder_input = torch.tensor([[en_tokenizer.sos_idx]], device=device) ?# (1, 1)
 
? ? ? ? generated_indices = [] ?# 存放生成的 token 索引
 
? ? ? ? for _ in range(max_len - 1): ?# 最多生成 max_len-1 個詞
? ? ? ? ? ? decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
? ? ? ? ? ? # decoder_output: (1, 1, vocab_size)
 
? ? ? ? ? ? # 貪心解碼:選擇概率最高的詞(argmax 找到最大值索引)
? ? ? ? ? ? next_token = decoder_output.argmax(dim=-1) ?# (1, 1)
? ? ? ? ? ? token_id = next_token.item()
 
? ? ? ? ? ? # 如果遇到結束符,停止生成
? ? ? ? ? ? if token_id == en_tokenizer.eos_idx:
? ? ? ? ? ? ? ? break
 
? ? ? ? ? ? generated_indices.append(token_id)
? ? ? ? ? ? # 自回歸:把當前輸出作為下一步的輸入
? ? ? ? ? ? decoder_input = next_token
 
? ? ? ? # 將索引序列解碼成英文句子
? ? ? ? english_sentence = en_tokenizer.decode(generated_indices, skip_special=True)
 
? ? return english_sentence
 
def interactive_translate():
? ? """交互式翻譯程序"""
? ? device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
? ? print(f"Loading models on {device}...")
 
? ? encoder, decoder, zh_tokenizer, en_tokenizer = load_models(device)
? ? print("Models loaded! Enter Chinese sentences to translate (type 'quit' to exit).")
 
? ? while True:
? ? ? ? user_input = input("\n中文: ").strip()
? ? ? ? if user_input.lower() in ['quit', 'exit', 'q']:
? ? ? ? ? ? print("Goodbye!")
? ? ? ? ? ? break
? ? ? ? if not user_input:
? ? ? ? ? ? print("請輸入內(nèi)容")
? ? ? ? ? ? continue
 
? ? ? ? try:
? ? ? ? ? ? result = translate(user_input, encoder, decoder, zh_tokenizer, en_tokenizer, device)
? ? ? ? ? ? print(f"英文: {result}")
? ? ? ? except Exception as e:
? ? ? ? ? ? print(f"翻譯出錯: {e}")
 
if __name__ == '__main__':
? ? interactive_translate()

推理代碼的關鍵點

  • torch.no_grad()

    :推理時不需要計算梯度,用這個上下文管理器可以節(jié)省大量內(nèi)存和計算。

  • 貪心解碼(Greedy Decoding)

    :每個時間步選擇概率最高的詞(argmax)。這是最簡單的解碼策略,優(yōu)點是速度快,缺點是可能不是全局最優(yōu)。

  • 自回歸生成

    :decoder_input = next_token 將當前預測作為下一步的輸入,這是與訓練階段最大的不同。

  • 結束條件

    :當生成 <eos> 標記或達到最大長度時停止生成。


6. 模型評估:BLEU 分數(shù)

訓練完模型,我們怎么知道它翻譯得好不好?總不能人工看幾千條翻譯結果吧。BLEU 分數(shù)就是解決這個問題而設計的。

6.1 BLEU 是什么?

BLEU(Bilingual Evaluation Understudy,雙語評估替補)?是一種用于自動評估機器翻譯文本質(zhì)量的算法,由 IBM 的研究團隊于 2002 年提出。它的核心思想很簡單:機器翻譯越接近專業(yè)人工翻譯,質(zhì)量就越好

BLEU 通過計算候選翻譯(模型生成)與參考翻譯(人工標準答案)之間的?n-gram 匹配數(shù)量來評估翻譯質(zhì)量。n-gram 是指由 n 個詞組成的連續(xù)序列:

6.2 BLEU 的計算方法

BLEU 的計算包含三個步驟:

① 改進的 n-gram 精度(Modified n-gram Precision)

對于 1-4 的每個 n,計算候選翻譯中出現(xiàn)在參考翻譯中的 n-gram 所占的比例。為避免重復詞被過度獎勵,每個 n-gram 的匹配次數(shù)被限制為它在參考翻譯中出現(xiàn)的最大次數(shù)。

② 幾何平均(Geometric Mean)

將 1-gram 到 4-gram 的精度取幾何平均。這樣做是為了確保單一類型 n-gram 的低精度會顯著拉低最終分數(shù)。

③ 簡短懲罰(Brevity Penalty)

如果候選翻譯的長度(c)顯著短于參考翻譯的長度(r),BLEU 會引入一個懲罰因子,防止模型通過生成過短的句子來獲得高精度。

最終 BLEU 分數(shù)范圍:0 到 1,通常乘以 100 以百分比表示。分數(shù)越接近 1(100%),翻譯質(zhì)量越高。

6.3 評估代碼實現(xiàn)

# evaluate.py
import torch
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm
import config
from dataset import get_dataloader
from predict import load_models, translate
import json
 
def evaluate(encoder, decoder, test_loader, zh_tokenizer, en_tokenizer, device):
? ? """
? ? 在測試集上評估模型
? ? 返回:
? ? ? ? BLEU-4 分數(shù)(0-1 之間)
? ? """
? ? references = [] ? # 參考翻譯列表
? ? hypotheses = [] ? # 模型生成的翻譯列表
 
? ? special_indices = {en_tokenizer.pad_idx, en_tokenizer.sos_idx, en_tokenizer.eos_idx}
 
? ? for src, tgt in tqdm(test_loader, desc="Evaluating"):
? ? ? ? src = src.to(device)
? ? ? ? # tgt: (batch_size, SEQ_LEN) 參考翻譯的索引
 
? ? ? ? # 批量翻譯(可以優(yōu)化為批量處理,這里為清晰起見逐句處理)
? ? ? ? for i in range(src.size(0)):
? ? ? ? ? ? # 獲取單句中文字符串(這里需要解碼 src,實際項目中可以在 dataset 中保留原始句子)
? ? ? ? ? ? # 為簡化,這里假設我們直接使用 translate 函數(shù)處理
? ? ? ? ? ? # 實際項目中建議實現(xiàn)批量推理
? ? ? ? ? ? pass
 
? ? ? ? # 收集參考翻譯(去掉特殊標記)
? ? ? ? for seq in tgt.tolist():
? ? ? ? ? ? ref = [idx for idx in seq if idx not in special_indices]
? ? ? ? ? ? references.append([ref]) ?# corpus_bleu 要求每個參考是列表的列表
 
? ? # 計算 BLEU-4 分數(shù)
? ? bleu_score = corpus_bleu(references, hypotheses)
? ? return bleu_score
 
def main():
? ? device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
? ? # 加載模型和分詞器
? ? encoder, decoder, zh_tokenizer, en_tokenizer = load_models(device)
 
? ? # 加載測試數(shù)據(jù)
? ? test_loader = get_dataloader(train=False)
 
? ? # 評估
? ? bleu = evaluate(encoder, decoder, test_loader, zh_tokenizer, en_tokenizer, device)
? ? print(f"\n========== Evaluation Results ==========")
? ? print(f"BLEU-4 Score: {bleu:.4f} ({bleu*100:.2f}%)")
? ? print(f"========================================")
 
if __name__ == '__main__':
? ? main()

提示:BLEU 分數(shù)雖然被廣泛使用,但也有一些局限性。例如,它不善于評估語義相似性——如果模型用了不同的詞表達了相同的意思(“蘋果”vs“這種水果”),BLEU 可能給出低分,而實際上翻譯是對的。因此在做實際項目時,可以結合人工評估或其他指標一起使用。

輸出結果:

詞表加載成功!
模型加載成功!
評估結果:
BLEU 評分: ?0.18761972609375802

7. 傳統(tǒng) Seq2Seq 的局限性與進化之路

7.1 信息瓶頸問題

傳統(tǒng) Seq2Seq 模型的核心問題是信息瓶頸:編碼器必須把整個源句子壓縮成一個固定長度的上下文向量。對于短句子,這沒問題;但對于長句子,信息很容易在壓縮過程中丟失。

這就像讓你聽完一個 1000 字的演講后,只用一句話概括核心內(nèi)容。你可以概括大意,但細節(jié)一定會丟失。

7.2 注意力機制的引入

2015 年,Bahdanau 等人在 Seq2Seq 模型中引入了注意力機制(Attention Mechanism)?,解決了上述問題。注意力機制的核心思想是:允許解碼器在生成每個輸出詞時,動態(tài)地關注輸入序列的不同部分。

注意力機制的工作原理:

  1. 計算注意力分數(shù)

    :解碼器當前的隱藏狀態(tài) st 與編碼器在所有時間步的隱藏狀態(tài) h1,h2,...,hn 進行比較,計算相關性分數(shù) eti=a(st?1,hi)。

  2. 歸一化得到權重

    :使用 Softmax 將分數(shù)歸一化,得到一組權重 αti=exp?(eti)∑k=1nexp?(etk),權重之和為 1,代表了每個源詞對當前目標詞的重要性。

  3. 計算動態(tài)上下文向量

    :用權重對編碼器的所有隱藏狀態(tài)進行加權求和,得到新的動態(tài)上下文向量 ct=∑i=1nαtihi。

  4. 結合預測

    :將解碼器當前隱藏狀態(tài) st 和動態(tài)上下文向量 ct 結合起來,預測下一個詞。

注意力機制帶來約 10% 的 BLEU 分數(shù)提升,尤其在處理長句和復雜語法結構時效果顯著。

注意力機制有多種實現(xiàn)變體,最著名的是?Bahdanau 注意力(加性注意力,使用多層感知機計算分數(shù))和?Luong 注意力(乘性注意力,使用點積或雙線性變換)。兩者主要區(qū)別在于評分函數(shù)和對齊方式的不同,Luong 注意力在計算效率上更有優(yōu)勢。

7.3 從 Seq2Seq 到 Transformer

2017 年,Google 發(fā)表了著名的論文《Attention Is All You Need》,提出了?Transformer?架構。Transformer 在 Seq2Seq 的編碼器-解碼器框架基礎上,完全拋棄了 RNN,只使用自注意力機制(Self-Attention)來捕捉序列中的依賴關系。

Transformer 相比傳統(tǒng) Seq2Seq 的優(yōu)勢:

  • 并行計算

    :RNN 必須按順序處理序列,而 Transformer 可以同時處理整個序列,訓練速度大幅提升。

  • 長距離依賴

    :自注意力機制允許序列中的每個元素直接關注所有其他元素,不受距離限制,解決了 RNN 的梯度消失問題。

盡管 Transformer 已經(jīng)成為主流,但理解 Seq2Seq 的核心思想(編碼-解碼框架、Teacher Forcing、自回歸生成)仍然是掌握現(xiàn)代 NLP 生成模型的重要基礎。


8. 項目完整結構與運行指南

translation-seq2seq/
│
├── data/
│ ? ├── raw/
│ ? │ ? └── cmn.txt ? ? ? ? ? ? ?# 原始數(shù)據(jù)集(TSV 格式,英-中對照)
│ ? └── processed/ ? ? ? ? ? ? ? # 預處理后的文件
│ ? ? ? ├── train.jsonl ? ? ? ? ? # 訓練集(JSON Lines)
│ ? ? ? └── test.jsonl ? ? ? ? ? ?# 測試集(JSON Lines)
│
├── models/ ? ? ? ? ? ? ? ? ? ? ?# 保存訓練好的模型權重
│ ? ├── best_model.pt
│ ? ├── cn_vocab.txt
│ ? └── en_vocab.txt
│
├── logs/ ? ? ? ? ? ? ? ? ? ? ? ?# TensorBoard 日志
│
├── src/ ? ? ? ? ? ? ? ? ? ? ? ? # 源代碼
│ ? ├── config.py ? ? ? ? ? ? ? ?# 配置文件(所有超參數(shù))
│ ? ├── tokenizer.py ? ? ? ? ? ? # 中英文分詞器
│ ? ├── dataset.py ? ? ? ? ? ? ? # PyTorch Dataset
│ ? ├── process.py ? ? ? ? ? ? ? # 數(shù)據(jù)預處理腳本
│ ? ├── model.py ? ? ? ? ? ? ? ? # Encoder 和 Decoder 定義
│ ? ├── train.py ? ? ? ? ? ? ? ? # 訓練腳本
│ ? ├── predict.py ? ? ? ? ? ? ? # 交互式翻譯腳本
│ ? └── evaluate.py ? ? ? ? ? ? ?# BLEU 評估腳本

運行步驟

  1. 安裝依賴

  2. bash

  3. pip install torch pandas nltk scikit-learn tqdm tensorboard

  4. 準備數(shù)據(jù)

    :下載 cmn.txt 放到 data/raw/ 目錄下

  5. 數(shù)據(jù)預處理

  6. bash

  7. python src/process.py

  8. 訓練模型

  9. bash

  10. python src/train.py

  11. 交互翻譯

  12. bash

  13. python src/predict.py

  14. 評估模型

  15. bash

  16. python src/evaluate.py


9. 總結

經(jīng)過一整篇文章的學習,我們把 Seq2Seq 模型從原理到代碼完整過了一遍?,F(xiàn)在回顧一下核心知識點:

9.1 核心概念速查表

概念

一句話解釋

Seq2Seq

把一個序列映射成另一個序列的神經(jīng)網(wǎng)絡架構,輸入輸出長度可以不同

編碼器(Encoder)

用雙向 RNN 把輸入句子壓縮成語義向量

解碼器(Decoder)

用單向 RNN 從語義向量逐步生成目標句子

上下文向量

編碼器輸出的最后一個隱藏狀態(tài),代表整個輸入句子的語義

Teacher Forcing

訓練時用真實目標詞作為解碼器輸入,而不是模型自己的預測

自回歸生成

推理時用自己的上一步輸出作為下一步輸入

貪心解碼

每一步選概率最高的詞,簡單但可能不是全局最優(yōu)

BLEU

通過 n-gram 匹配度評估翻譯質(zhì)量的自動指標

注意力機制

允許解碼器在生成時動態(tài)關注源句的不同部分,解決信息瓶頸

Transformer

在 Seq2Seq 框架上用自注意力取代 RNN,實現(xiàn)并行計算

9.2 關于本文的代碼

本文提供的代碼是完整可運行的,你可以直接復制使用。需要注意幾點:

  • 確保所有文件放在正確的目錄結構下

  • 訓練可能需要一些時間(取決于硬件和 EPOCHS 設置)

  • 如果顯存不足,可以減小 BATCH_SIZE

完整代碼下載:https://pan.baidu.com/s/1-de-mxzjaMzwWSPcGbaZ4Q?pwd=ssah

9.3 進階學習建議

如果你已經(jīng)掌握了本文的內(nèi)容,可以嘗試以下擴展:

  1. 添加注意力機制

    :在解碼器中實現(xiàn) Bahdanau 或 Luong 注意力,觀察 BLEU 分數(shù)的提升。

  2. 換成 LSTM

    :用 LSTM 替換 GRU,對比兩種 RNN 變體的效果。

  3. 增加編碼器層數(shù)

    :設置 ENCODER_LAYERS = 2 或 3,觀察模型性能變化。

  4. 使用更大的數(shù)據(jù)集

    :嘗試 WMT 等大規(guī)模機器翻譯數(shù)據(jù)集,體驗真正的工業(yè)級翻譯。

  5. 學習 Transformer

    :基于本文的編碼器-解碼器框架理解,進一步學習 Transformer 的自注意力機制和位置編碼。

希望這篇文章能幫你真正理解 Seq2Seq。如果有任何問題或建議,歡迎交流討論!

?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容