研究線性模型訓(xùn)練中損失變化的規(guī)律和最優(yōu)學(xué)習(xí)率的影響

探究一維線性模型訓(xùn)練中,測試損失隨訓(xùn)練步數(shù)變化的縮放定律及其最優(yōu)學(xué)習(xí)率影響,并研究多維線性模型訓(xùn)練的縮放定律,確定參數(shù)以符合特定損失衰減模式。

研究大模型的縮放定律對減少其訓(xùn)練開銷至關(guān)重要,即最終的測試損失如何隨著訓(xùn)練步數(shù)和模型大小的變化而變化?本題中,我們研究了訓(xùn)練線性模型時的縮放定律。

  1. 在本小問中,考慮使用梯度下降學(xué)習(xí)一個一維線性模型的情況。
  • 定義數(shù)據(jù)分布\mathcal{D}為一個\mathbb{R}^2上的分布,每個數(shù)據(jù)是一個數(shù)對(x, y),分別代表輸入和輸出,并服從分布x\sim N(0, 1),y\sim N(3x, 1)。

  • 用梯度下降算法學(xué)習(xí)線性模型f_{w}(x)=w \cdot x,其中w, x\in\mathbb{R}。初始化ω_0=0并進(jìn)行多步迭代。每次迭代時,從\mathcal{D}中采樣(x_t,y_t),然后更新w_tw_{t+1}\leftarrow w_t-\eta\nabla l_t(w_t),其中l_t(w)=\frac{1}{2}(f_w(x_t)-y_t)^2是平方損失函數(shù),\eta>0是學(xué)習(xí)率。

設(shè)學(xué)習(xí)率\eta\in(0,\frac{1}{3}],那么T≥0步迭代之后的測試損失的期望

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{w_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{w_T}(x)-y)^2]

是多少?

  1. 現(xiàn)在我們在第一小問的設(shè)定下,考慮學(xué)習(xí)率\eta被調(diào)到最優(yōu)的情況,求函數(shù)g(T),使得當(dāng)T\rightarrow+\infty時,以下條件成立:

\left|\underset{η\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O(\frac{(\log T)^2}{T^2})

  1. 一個常常被觀測到的實(shí)驗(yàn)現(xiàn)象是大語言模型的預(yù)訓(xùn)練過程大致遵循Chinchilla縮放定律:

\overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C,

其中\overline{\mathcal{L}}_{N,T}是在經(jīng)過T步訓(xùn)練后具有N個參數(shù)的模型的測試損失的期望,AB,a,β,C是常數(shù)。現(xiàn)在我們舉一個訓(xùn)練多維線性模型的例子,使其也遵循類似的縮放定律。

  • 固定a>0,b≥1,每個數(shù)據(jù)(x_{\cdot},y)由一個輸入和輸出組成,其中輸入x_{\cdot}是一個無限維向量(可看作一個序列),輸出y滿足y\in\mathbb{R}。定義數(shù)據(jù)分布\mathcal{D}如下。首先,從Zipf分布中采樣k,\Pr[k=i]\propto i^{-(a+1)}\quad(i\geq 1)。令j:=[k^b],然后,從mathcal{N}(0,1)中采樣得到x_{\cdot}的第j個坐標(biāo)x_j,并令其余坐標(biāo)為0。最后,y\sim N(3x_j,1)。這樣得到的(x_{\cdot},y)的分即數(shù)據(jù)分布\mathcal{D}。

  • 我們研究一個僅關(guān)注前N個輸入坐標(biāo)的線性模型。定義函數(shù)\phi_N(xx_{\cdot})=(x_1,...,x_N)。我們研究的線性模型具有參數(shù)\mathbf{w}\in\mathbb{R}^N,輸出為f_{\mathbf{w}}(x)=(\mathbf{w},\phi_N(x_{\cdot}))。

  • 我們使用梯度下降算法學(xué)習(xí)該線性模型。初始化\mathbf{w}_0=0并進(jìn)行多步迭代。每次迭代時,從\mathcal{D}中采樣(x_{t,\cdot},y_t),然后更新\mathbf{w}_t\mathbf{w}_{t+1}\gets \mathbf{w}_t-\eta\nabla l_t(\mathbf{w}_t),其中l_t(\mathbf{w})=\frac{1}{2}(f_\mathbf{w}(x_{t,\cdot})-y_t)^2。

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{\mathbf{w}_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{\mathbf{w}_T}(x)-y)^2]為以學(xué)習(xí)率\eta\in(0,\frac{1}{3}]對其有N個參數(shù)的線性模型進(jìn)行T≥0步訓(xùn)練后的測試損失的期望。

請求出α,β,C,使得\forall\gamma>0,\forall c>0,當(dāng)T=N^{c+o(1)}N足夠大時,以下條件成立:

\epsilon(N,T):=\frac{\inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}-C}{\frac{A}{N^\alpha}+\frac{B}{T^\beta}},

(\log N+\log T)^{-γ}\leq \epsilon(N,T)\leq(\log N+\log T)^γ。即\inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}=\tilde{\Theta}(N^{-\alpha}+T^{-\beta})+C,其中\tilde{\Theta}表示忽略任何關(guān)于\log N\log T的多項(xiàng)式。

解:

  1. 首先,我們來計算測試損失的期望\overline{\mathcal{L}}_{\eta,T}。

由于xy是獨(dú)立的隨機(jī)變量,且y的條件分布是N(3x, 1),我們可以寫出測試損失的期望為:

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(w_T x - y)^2]

由于y=3x+\epsilon,其中\epsilon\sim N(0, 1)且獨(dú)立于x,我們可以將y替換為3x+\epsilon

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{x,\epsilon}[\frac{1}{2}(w_T x - (3x+\epsilon))^2]

展開并利用\mathbb{E}[\epsilon^2]=1\mathbb{E}[x^2]=1(因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=x%5Csim%20N(0%2C%201)" alt="x\sim N(0, 1)" mathimg="1">):

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_x[\frac{1}{2}(w_T^2 x^2 - 6w_T x^2 + 9x^2 + \epsilon^2 - 6w_T x \epsilon + 3w_T^2 x^2)]

由于\epsilonx是獨(dú)立的,我們可以分別計算期望:

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9)\mathbb{E}[x^2] + \frac{1}{2}\mathbb{E}[\epsilon^2]

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9) + \frac{1}{2}

現(xiàn)在我們需要計算w_T的期望值。由于w_t的更新規(guī)則是w_{t+1}=w_t-\eta\nabla l_t(w_t),我們有:

\nabla l_t(w_t) = w_t x_t - y_t = w_t x_t - (3x_t + \epsilon)

因此,更新規(guī)則變?yōu)椋?/p>

w_{t+1} = w_t - \eta(w_t x_t - 3x_t - \epsilon)

取期望并利用\mathbb{E}[x_t]=0\mathbb{E}[\epsilon]=0

\mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - \eta(3\mathbb{E}[x_t^2])

由于x_t^2的期望是1,我們有:

\mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - 3\eta

由于w_0=0,我們可以遞歸地計算w_T

\mathbb{E}[w_T] = -3\eta T

\mathbb{E}[w_T]代入測試損失的期望中:

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}((-3\eta T)^2 - 6(-3\eta T) + 9) + \frac{1}{2}

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(9\eta^2 T^2 + 18\eta T + 9) + \frac{1}{2}

\overline{\mathcal{L}}_{\eta,T}=\frac{9\eta^2 T^2 + 18\eta T + 10}{2}

  1. 接下來,我們需要找到g(T)。

首先,我們需要最小化\overline{\mathcal{L}}_{\eta,T}關(guān)于\eta。我們可以通過設(shè)置\frac{d\overline{\mathcal{L}}_{\eta,T}}{d\eta}=0來找到最優(yōu)的學(xué)習(xí)率\eta^*

\fracu0z1t8os{d\eta}(\frac{9\eta^2 T^2 + 18\eta T + 10}{2})=9\eta T^2 + 18T=0

解得:

\eta^* = \frac{2}{3T}

\eta^*代入\overline{\mathcal{L}}_{\eta,T}中,我們得到最小化測試損失的表達(dá)式:

\overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{2}{3T})^2 T^2 + 18(\frac{2}{3T}) T + 10}{2}

\overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{4}{9T^2}) T^2 + 18(\frac{2}{3T}) T + 10}{2}

\overline{\mathcal{L}}_{\eta^*,T}=\frac{4 + 12 + 10}{2}

\overline{\mathcal{L}}_{\eta^*,T}=\frac{26}{2}

\overline{\mathcal{L}}_{\eta^*,T}=13

現(xiàn)在,我們需要找到g(T),使得當(dāng)T\rightarrow+\infty時,以下條件成立:

\left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right)

由于我們已經(jīng)找到了最優(yōu)的學(xué)習(xí)率\eta^*,我們可以將\overline{\mathcal{L}}_{\eta^*,T}視為\mathcal{I}_{n,T}的下界。因此,我們需要找到一個函數(shù)g(T),使得當(dāng)T趨向于無窮大時,\overline{\mathcal{L}}_{\eta^*,T}g(T)之間的差異滿足上述條件。

考慮到\overline{\mathcal{L}}_{\eta^*,T}是一個常數(shù)13,我們可以推斷g(T)應(yīng)該也是一個常數(shù),因?yàn)闇y試損失的期望在最優(yōu)學(xué)習(xí)率下不隨T變化。因此,我們可以選擇g(T)=13。

現(xiàn)在,我們需要驗(yàn)證這個選擇是否滿足條件:

\left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right)

由于\mathcal{I}_{n,T}的最小值是13,我們有:

\left|13-13\right|=0

顯然,0=O\left(\frac{(\log T)^2}{T^2}\right),因?yàn)楫?dāng)T趨向于無窮大時,\frac{(\log T)^2}{T^2}趨向于0。因此,我們的選擇g(T)=13是正確的。

綜上所述,g(T)=13滿足題目中的條件。

3.為了解決這個問題,我們需要推導(dǎo)出多維線性模型在給定數(shù)據(jù)分布下的縮放定律。根據(jù)題目描述,我們有一個線性模型,其參數(shù)遵循特定的縮放定律。我們將通過以下步驟來解決這個問題:

步驟 1: 理解數(shù)據(jù)分布

數(shù)據(jù)分布 \mathcal{D} 是通過 Zipf 分布來選擇輸入向量的非零坐標(biāo),然后根據(jù)該坐標(biāo)的值來生成輸出 y。這意味著大部分的數(shù)據(jù)集中在較少的非零坐標(biāo)上。

步驟 2: 定義損失函數(shù)

損失函數(shù) \overline{\mathcal{L}}_{\eta,T} 是在給定學(xué)習(xí)率 \eta 和訓(xùn)練步數(shù) T 后,模型參數(shù) \mathbf{w} 的測試損失的期望。

步驟 3: 推導(dǎo)縮放定律

我們需要找到 \alpha,\beta,和 C 使得損失函數(shù)符合 \overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C 的形式。

對于 \alpha 的推導(dǎo):

  • 參數(shù) N 表示模型考慮的輸入向量的維度。由于數(shù)據(jù)分布的特性,大部分的權(quán)重不會接收到有效的梯度更新,因?yàn)樗鼈儗?yīng)的輸入坐標(biāo)為零。因此,增加 N 的數(shù)量不會顯著改善模型的性能,但也不會損害它,因?yàn)橹挥猩贁?shù)權(quán)重會被更新。

  • Zipf 分布的特性意味著非零坐標(biāo)的數(shù)量隨著 N 的增加而減少。因此,我們可以預(yù)期 \alpha 大于 0,但小于 1,因?yàn)樵黾泳S度對于模型性能的提升是有上限的。

對于 \beta 的推導(dǎo):

  • 參數(shù) T 表示訓(xùn)練步數(shù)。隨著訓(xùn)練步數(shù)的增加,模型將獲得更多的機(jī)會來更新其權(quán)重,從而減少損失。因此,我們可以預(yù)期 \beta 大于 0。

  • 由于數(shù)據(jù)分布的特性,并不是每一步都會對所有權(quán)重進(jìn)行有效更新。因此,\beta 可能不會是 1,而是小于 1 的某個值。

對于 C 的推導(dǎo):

  • 常數(shù) C 表示當(dāng) NT 趨于無窮大時,測試損失的最低值。這是由于數(shù)據(jù)本身的噪聲和模型的能力限制導(dǎo)致的。

步驟 4: 確定 \alpha,\beta,和 C

為了確定 \alpha,\beta,和 C,我們需要進(jìn)行以下分析:

  • 對于 \alpha:考慮到只有少數(shù)權(quán)重會被更新,我們可以假設(shè) \alpha 在 0 和 1 之間。更具體地,由于 Zipf 分布的特性,我們可以假設(shè) \alpha 接近于 1,但小于 1,因?yàn)殡S著 N 的增加,額外維度的邊際貢獻(xiàn)會減少。一個合理的猜測是 \alpha = \frac{1}。

  • 對于 \beta:考慮到每一步并不是對所有權(quán)重都進(jìn)行有效更新,我們可以假設(shè) \beta 小于 1。一個合理的猜測是 \beta = \frac{1}{2},這是因?yàn)橥ǔG闆r下,梯度下降的收斂速度與步數(shù)的平方根成反比。

  • 對于 C:這是數(shù)據(jù)噪聲和模型表達(dá)能力限制的結(jié)果。在沒有更多信息的情況下,我們無法精確確定 C,但可以假設(shè)它是一個正數(shù)。

步驟 5: 驗(yàn)證條件

我們需要驗(yàn)證 \epsilon(N,T) 的條件是否成立。這通常涉及到對 \overline{\mathcal{L}}_{N,T} 進(jìn)行詳細(xì)的分析,并證明它符合給定的縮放形式。這通常需要數(shù)學(xué)上的證明和/或?qū)嶒?yàn)驗(yàn)證。

綜上所述,我們可以假設(shè) \alpha = \frac{1},\beta = \frac{1}{2},C 是一個正數(shù)。然而,為了得到精確的值,我們需要更深入的分析和實(shí)驗(yàn)數(shù)據(jù)。在實(shí)際應(yīng)用中,這些參數(shù)通常是通過實(shí)驗(yàn)來確定的。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容