筆記:神經(jīng)網(wǎng)絡(luò)中的“訓(xùn)練”+PyTorch構(gòu)建簡(jiǎn)單的二分類模型

一、訓(xùn)練(Training)

流程概覽:數(shù)據(jù)準(zhǔn)備 → 前向傳播 → 損失計(jì)算 → 反向傳播 → 優(yōu)化器更新參數(shù) → 下一批次數(shù)據(jù) → 循環(huán)多次直到模型收斂或達(dá)到最大 Epoch

1. 定義

訓(xùn)練是通過數(shù)據(jù)不斷調(diào)整模型參數(shù),使模型能夠更準(zhǔn)確地完成任務(wù)(分類、回歸等)。

2. 核心流程

  1. 數(shù)據(jù)準(zhǔn)備

    • 數(shù)據(jù)劃分:訓(xùn)練集(Train)、驗(yàn)證集(Validation)、測(cè)試集(Test) 課本-作業(yè)-考試
    • 數(shù)據(jù)處理:歸一化、標(biāo)準(zhǔn)化、數(shù)據(jù)增強(qiáng)
  2. 前向傳播(Forward Propagation)

    • 輸入數(shù)據(jù)經(jīng)過網(wǎng)絡(luò)層計(jì)算,得到模型預(yù)測(cè)值 \hat{y} 做練習(xí)題先寫答案
    • 激活函數(shù)提供非線性能力
  3. 計(jì)算損失(Loss Calculation)

    • 使用損失函數(shù)衡量模型預(yù)測(cè)值 \hat{y} 與真實(shí)標(biāo)簽 y 之間的誤差 查看標(biāo)準(zhǔn)答案,看看錯(cuò)多少
  4. 反向傳播(Backpropagation)

    • 模型根據(jù)損失調(diào)整內(nèi)部參數(shù)。反思
  5. 優(yōu)化器更新參數(shù)(Optimizer Update)

    • 使用優(yōu)化器調(diào)整權(quán)重,使損失 L(\hat{y}, y) 最小化 優(yōu)化器是“學(xué)習(xí)方法”,解析里的思路
  6. 迭代訓(xùn)練 (Iteration)

數(shù)據(jù)分批(batch)處理,每批數(shù)據(jù)重復(fù)上面過程。每批數(shù)據(jù)做一次前向傳播 → 計(jì)算損失 → 反向傳播 → 參數(shù)更新(Iteration
Batch → 數(shù)據(jù)塊
Iteration → 每塊數(shù)據(jù)訓(xùn)練一次、更新一次參數(shù)
Epoch → 完整訓(xùn)練整個(gè)數(shù)據(jù)集一次

  • Epoch :訓(xùn)練的次數(shù)。所有批次(完整數(shù)據(jù)集)都訓(xùn)練完一次 = 1 Epoch
  • 多次 Epoch 直到損失收斂或達(dá)到預(yù)設(shè)條件
  • 批次太大或太小都會(huì)影響學(xué)習(xí)效果。

二、損失函數(shù)(Loss Function)

1. 定義

損失函數(shù)衡量模型預(yù)測(cè)值 \hat{y} 與真實(shí)標(biāo)簽 y 之間的誤差,訓(xùn)練目標(biāo)是最小化損失。

2. 作用

  • 指導(dǎo)模型學(xué)習(xí)方向 : 告訴模型哪部分預(yù)測(cè)錯(cuò)誤,應(yīng)該如何調(diào)整參數(shù)
  • 衡量模型性能 :損失越小,模型預(yù)測(cè)越準(zhǔn)確
  • 定義優(yōu)化目標(biāo) :訓(xùn)練過程中不斷最小化損失

3. 常見類型

(1)回歸任務(wù)

損失函數(shù) 公式 特點(diǎn) 通俗理解
均方誤差(MSE) L(\hat{y}, y) = \frac{1}{n} \sum_{i=1}^{n} (\hat{y}_i - y_i)^2 對(duì)異常值敏感 把每個(gè)預(yù)測(cè)誤差平方,懲罰大的錯(cuò)誤更多
平均絕對(duì)誤差(MAE) L(\hat{y}, y) = \frac{1}{n} \sum_{i=1}^{n} \lvert \hat{y}_i - y_i \rvert 穩(wěn)健,不受異常值影響 計(jì)算每個(gè)預(yù)測(cè)誤差的絕對(duì)值,忽略方向
Huber Loss L_\delta(\hat{y}, y) = \begin{cases} \frac{1}{2} (\hat{y}-y)^2, & \lvert \hat{y}-y \rvert \le \delta \\ \delta \lvert \hat{y}-y \rvert - \frac{1}{2} \delta^2, & \lvert \hat{y}-y \rvert > \delta \end{cases} 綜合 MSE 和 MAE 優(yōu)點(diǎn) 對(duì)小錯(cuò)誤敏感,對(duì)大錯(cuò)誤不過分懲罰

(2)分類任務(wù)

損失函數(shù) 公式 特點(diǎn) 通俗理解
交叉熵(Cross Entropy) L(\hat{y}, y) = -\sum_{i} y_i \log(\hat{y}_i) 最常用分類損失 判斷預(yù)測(cè)概率與真實(shí)類別差異,概率越接近真實(shí)值損失越小
KL散度(KL Divergence) D_{KL}(P|Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)} 測(cè)概率分布差異 衡量?jī)蓚€(gè)概率分布的差異
Focal Loss 對(duì)難分類樣本加權(quán) 適合類別不平衡 對(duì)模型容易錯(cuò)的樣本給更高權(quán)重,讓模型重點(diǎn)學(xué)習(xí)難點(diǎn)

(3)特殊應(yīng)用

場(chǎng)景 常用損失函數(shù) 通俗理解
序列生成 交叉熵、CTC Loss 處理時(shí)間序列或文本預(yù)測(cè)任務(wù)
圖像生成(GAN) 對(duì)抗損失(Adversarial Loss) 模型學(xué)會(huì)生成逼真的圖像
自編碼器 重構(gòu)誤差(Reconstruction Loss) 模型學(xué)習(xí)如何壓縮和還原數(shù)據(jù)
表征學(xué)習(xí) 對(duì)比損失(Contrastive Loss)、Triplet Loss 學(xué)習(xí)數(shù)據(jù)的向量表示,讓相似樣本更接近

4. 梯度(Gradient)

  • 梯度是損失函數(shù)對(duì)模型參數(shù)的導(dǎo)數(shù),表示參數(shù)改變時(shí)損失函數(shù)變化的方向和大小。

4.1 通俗理解

梯度就像地圖上的坡度,告訴你“往哪兒走能下坡最快”,梯度指向下坡最快的方向 ,也就是讓損失函數(shù)變小的方向。梯度為 0 → 到達(dá)山谷底(最優(yōu)點(diǎn))

  • 指導(dǎo)優(yōu)化器更新參數(shù) :梯度大 → 更新步伐大 ; 梯度小 → 更新步伐小

4.2 公式

假設(shè)損失函數(shù)為 L(W),參數(shù)為 W,梯度為:
\nabla_W L(W) = \frac{\partial L(W)}{\partial W}


5. 使用小技巧

  • 任務(wù)匹配:回歸用 MSE/MAE,分類用交叉熵
  • 類別不平衡:用加權(quán)交叉熵或 Focal Loss
  • 多任務(wù)學(xué)習(xí):組合多種損失函數(shù),并調(diào)節(jié)權(quán)重
  • 梯度穩(wěn)定:避免梯度過大或過小,可用 Log 變換、梯度裁剪或平滑處理

三、優(yōu)化器(Optimizer)

1. 定義

優(yōu)化器控制模型參數(shù)的更新方式,通過梯度下降最小化損失函數(shù) L(\hat{y}, y)。


2. 常用優(yōu)化器

優(yōu)化器 更新公式(簡(jiǎn)化) 特點(diǎn)
SGD(隨機(jī)梯度下降) W \gets W - \eta \frac{\partial L(\hat{y}, y)}{\partial W} 簡(jiǎn)單,收斂慢,易陷入局部最優(yōu)
Momentum(動(dòng)量法) v = \beta v + \eta \frac{\partial L}{\partial W}, \quad W \gets W - v 利用慣性加速收斂,減少震蕩
AdaGrad W \gets W - \frac{\eta}{\sqrt{G+\epsilon}} \frac{\partial L}{\partial W} 自適應(yīng)學(xué)習(xí)率,頻繁更新參數(shù)較小,適合稀疏數(shù)據(jù)
RMSProp E[g^2]_t = \rho E[g^2]_{t-1} + (1-\rho) g_t^2, \quad W \gets W - \frac{\eta}{\sqrt{E[g^2]_t+\epsilon}} g_t 解決 AdaGrad 學(xué)習(xí)率下降過快問題,深度學(xué)習(xí)常用
Adam 結(jié)合 Momentum 和 RMSProp 高效,適合大多數(shù)場(chǎng)景,默認(rèn)選擇

3. 超參數(shù)關(guān)鍵點(diǎn)

  • 學(xué)習(xí)率(Learning Rate, \eta

    • 控制參數(shù)每次更新的步長(zhǎng)
    • 太大 → 參數(shù)跳過最優(yōu)點(diǎn),訓(xùn)練可能震蕩或發(fā)散
    • 太小 → 收斂慢,訓(xùn)練時(shí)間長(zhǎng)
  • Batch Size(批大小)

    • 每次訓(xùn)練使用的樣本數(shù)量
    • 小 batch → 梯度噪聲大,更新不穩(wěn)定,但可節(jié)省顯存
    • 大 batch → 梯度穩(wěn)定,訓(xùn)練更平滑,但顯存占用高
  • Epoch(訓(xùn)練輪數(shù))

    • 模型完整看一遍訓(xùn)練集的次數(shù)
    • Too few → 欠擬合,模型沒學(xué)夠
    • Too many → 過擬合,模型記住訓(xùn)練集過多細(xì)節(jié),泛化能力差

超參數(shù)是訓(xùn)練性能的關(guān)鍵調(diào)節(jié)器。通常結(jié)合 驗(yàn)證集損失曲線早停法(Early Stopping) 來選擇合適的學(xué)習(xí)率、Batch Size 和 Epoch。


四、訓(xùn)練中的常見問題

問題 原因 表現(xiàn) 解決方法
過擬合 模型容量過大、訓(xùn)練數(shù)據(jù)不足 訓(xùn)練集損失低,驗(yàn)證集損失高 數(shù)據(jù)增強(qiáng)、正則化、Dropout
欠擬合 模型容量不足、訓(xùn)練不夠或特征不夠 訓(xùn)練集和驗(yàn)證集損失都高 增加模型容量、更多數(shù)據(jù)、訓(xùn)練更久
梯度消失 網(wǎng)絡(luò)太深、激活函數(shù)飽和、權(quán)重初始化不合理 參數(shù)幾乎不更新,訓(xùn)練停滯 ReLU 激活、殘差連接、BatchNorm、梯度裁剪、權(quán)重初始化
梯度爆炸 網(wǎng)絡(luò)太深、權(quán)重過大 參數(shù)更新過大,訓(xùn)練發(fā)散 梯度裁剪、權(quán)重初始化、調(diào)整網(wǎng)絡(luò)結(jié)構(gòu)
學(xué)習(xí)率不合適 學(xué)習(xí)率設(shè)置太大或太小 收斂慢或訓(xùn)練發(fā)散 調(diào)整學(xué)習(xí)率,使用學(xué)習(xí)率調(diào)度器(Scheduler)

五、訓(xùn)練策略與技巧-簡(jiǎn)介

1. 學(xué)習(xí)率調(diào)度(Learning Rate Scheduling)

控制訓(xùn)練過程中參數(shù)更新步長(zhǎng)的變化:

  • 固定學(xué)習(xí)率:保持不變,簡(jiǎn)單但可能收斂慢
  • 指數(shù)衰減(Exponential Decay):學(xué)習(xí)率隨訓(xùn)練輪數(shù)逐步減小
  • 余弦退火(Cosine Annealing):學(xué)習(xí)率按余弦曲線下降,提高訓(xùn)練穩(wěn)定性
  • 自適應(yīng)學(xué)習(xí)率:如 Adam、AdaGrad 等自動(dòng)調(diào)整每個(gè)參數(shù)的步長(zhǎng)

2. 正則化(Regularization)

防止模型過擬合,提高泛化能力:

  • L1/L2 正則化:對(duì)參數(shù)加約束,防止過大
  • Dropout:隨機(jī)屏蔽一部分神經(jīng)元,增強(qiáng)模型魯棒性
  • 數(shù)據(jù)增強(qiáng)(Data Augmentation):擴(kuò)充訓(xùn)練數(shù)據(jù),增加樣本多樣性

3. 梯度問題處理(Gradient Issues)

  • 梯度消失:使用 ReLU 激活、殘差網(wǎng)絡(luò)(ResNet)、BatchNorm
  • 梯度爆炸:使用梯度裁剪(Gradient Clipping)

4. 早停法(Early Stopping)

  • 當(dāng)驗(yàn)證集損失連續(xù)若干輪不下降時(shí),停止訓(xùn)練
  • 避免過擬合,節(jié)省訓(xùn)練時(shí)間

六、訓(xùn)練監(jiān)控與評(píng)估

  • 損失曲線:訓(xùn)練損失和驗(yàn)證損失變化,判斷過擬合或欠擬合
  • 指標(biāo)監(jiān)控:分類(Accuracy、F1、AUC)、回歸(MSE、MAE、R2)
  • 模型保存與加載:Checkpoints,便于斷點(diǎn)訓(xùn)練或推理使用

"""
使用 PyTorch 構(gòu)建并訓(xùn)練一個(gè)簡(jiǎn)單的二分類神經(jīng)網(wǎng)絡(luò)

- 網(wǎng)絡(luò)結(jié)構(gòu)包括輸入層、隱藏層和輸出層,使用了 ReLU 激活函數(shù)和 Sigmoid 激活函數(shù)。
- 采用了均方誤差損失函數(shù)和隨機(jī)梯度下降優(yōu)化器。
- 訓(xùn)練過程是通過前向傳播、計(jì)算損失、反向傳播和參數(shù)更新來逐步調(diào)整模型參數(shù)。

代碼邏輯:
1. 定義超參數(shù):
   - n_in: 輸入特征維度 = 10
   - n_h: 隱藏層神經(jīng)元數(shù) = 5
   - n_out: 輸出層大小 = 1(二分類結(jié)果)
   - batch_size: 每批數(shù)據(jù)樣本數(shù) = 10

2. 構(gòu)造數(shù)據(jù):
   - 輸入 x: 隨機(jī)生成 (10,10) 張量,表示 10 個(gè)樣本,每個(gè)樣本有 10 個(gè)特征
   - 標(biāo)簽 y: (10,1) 張量,取值為 0 或 1

3. 定義模型:
   - 結(jié)構(gòu):Linear(10→5) → ReLU → Linear(5→1) → Sigmoid
   - Sigmoid 將輸出壓縮到 (0,1),適合二分類

4. 定義損失函數(shù)與優(yōu)化器:
   - 損失函數(shù):MSELoss(均方誤差,預(yù)測(cè)值與真實(shí)值的平方誤差平均)
   - 優(yōu)化器:SGD,學(xué)習(xí)率 0.01

5. 訓(xùn)練過程:
   - 循環(huán) 50 個(gè) epoch
   - 前向傳播:計(jì)算預(yù)測(cè)值 y_pred
   - 計(jì)算損失:loss = MSE(y_pred, y)
   - 反向傳播:loss.backward()
   - 參數(shù)更新:optimizer.step()
   - 打印每個(gè) epoch 的損失值

"""

# 導(dǎo)入 PyTorch 庫(kù)
import torch
import torch.nn as nn
import matplotlib.pyplot as plt  # 可視化損失曲線

# ==========================
# 1. 定義超參數(shù)
# ==========================
n_in, n_h, n_out, batch_size = 10, 5, 1, 10
learning_rate = 0.01
num_epochs = 50

# ==========================
# 2. 構(gòu)造輸入數(shù)據(jù)和目標(biāo)數(shù)據(jù)
# ==========================
x = torch.randn(batch_size, n_in)  # 隨機(jī)生成輸入數(shù)據(jù)
y = torch.tensor([
    [1.0], [0.0], [0.0], 
    [1.0], [1.0], [1.0], 
    [0.0], [0.0], [1.0], [1.0]
])  # 構(gòu)造標(biāo)簽數(shù)據(jù)

# ==========================
# 3. 定義模型結(jié)構(gòu)
# ==========================
model = nn.Sequential(
   nn.Linear(n_in, n_h),  # 輸入層 → 隱藏層
   nn.ReLU(),             # 激活函數(shù)
   nn.Linear(n_h, n_out), # 隱藏層 → 輸出層
   nn.Sigmoid()           # 輸出壓縮到 (0,1),適合二分類
)

# ==========================
# 4. 定義損失函數(shù)和優(yōu)化器
# ==========================
criterion = nn.MSELoss()  
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# ==========================
# 5. 模型訓(xùn)練循環(huán)
# ==========================
losses = []  # 保存每輪 loss

for epoch in range(num_epochs):
    y_pred = model(x)            # 前向傳播
    loss = criterion(y_pred, y)  # 計(jì)算損失
    losses.append(loss.item())   # 保存損失

    optimizer.zero_grad()        # 清零梯度
    loss.backward()              # 反向傳播
    optimizer.step()             # 參數(shù)更新

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')

# --------------------------
# 可視化訓(xùn)練損失曲線
# --------------------------
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs+1), losses, marker='o', label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.grid(True)
plt.legend()
plt.show()


# --------------------------
# 可視化預(yù)測(cè)結(jié)果與實(shí)際目標(biāo)值對(duì)比
# --------------------------
y_pred_final = model(x).detach().numpy()  # 將模型最終預(yù)測(cè)值從 tensor 轉(zhuǎn)為 numpy 數(shù)組
y_actual = y.numpy()                      # 將真實(shí)標(biāo)簽 tensor 轉(zhuǎn)為 numpy 數(shù)組

plt.figure(figsize=(8, 5))
# 繪制實(shí)際值
plt.plot(range(1, batch_size + 1), y_actual, 'o-', label='Actual', color='blue')
# 繪制預(yù)測(cè)值
plt.plot(range(1, batch_size + 1), y_pred_final, 'x--', label='Predicted', color='red')

plt.xlabel('Sample Index')  # 橫軸:樣本編號(hào)
plt.ylabel('Value')         # 縱軸:值(0~1)
plt.title('Actual vs Predicted Values')  # 圖標(biāo)題
plt.legend()  # 顯示圖例
plt.grid()    # 添加網(wǎng)格線
plt.show()
1

2
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • """1.個(gè)性化消息: 將用戶的姓名存到一個(gè)變量中,并向該用戶顯示一條消息。顯示的消息應(yīng)非常簡(jiǎn)單,如“Hello ...
    她即我命閱讀 5,714評(píng)論 0 6
  • 為了讓我有一個(gè)更快速、更精彩、更輝煌的成長(zhǎng),我將開始這段刻骨銘心的自我蛻變之旅!從今天開始,我將每天堅(jiān)持閱...
    李薇帆閱讀 2,274評(píng)論 1 4
  • 似乎最近一直都在路上,每次出來走的時(shí)候感受都會(huì)很不一樣。 1、感恩一直遇到好心人,很幸運(yùn)。在路上總是...
    時(shí)間里的花Lily閱讀 1,778評(píng)論 1 3
  • 1、expected an indented block 冒號(hào)后面是要寫上一定的內(nèi)容的(新手容易遺忘這一點(diǎn)); 縮...
    庵下桃花仙閱讀 1,145評(píng)論 1 2
  • 一、工具箱(多種工具共用一個(gè)快捷鍵的可同時(shí)按【Shift】加此快捷鍵選取)矩形、橢圓選框工具 【M】移動(dòng)工具 【V...
    墨雅丫閱讀 1,749評(píng)論 0 0

友情鏈接更多精彩內(nèi)容