論文閱讀筆記:(ICLR 2023)EMERGENT WORLD REPRESENTATIONS: EXPLORING A SEQUENCE MODEL TRAINED ON A SYNTHET...

論文閱讀筆記:(ICLR 2023)EMERGENT WORLD REPRESENTATIONS: EXPLORING A SEQUENCE MODEL TRAINED ON A SYNTHETIC TASK

ps:world representation(世界表現(xiàn))是游戲AI中常見的一種術(shù)語,它表示將實(shí)際的地圖和路線表示為節(jié)點(diǎn)&連線構(gòu)造的圖結(jié)構(gòu)。

世界模型:

世界模型的概念(https://zhuanlan.zhihu.com/p/661965660)

摘要:自己的理解:語言模型能力強(qiáng)大但能力來源不明晰->在綜合環(huán)境(黑白棋,博弈場景)下研究->在沒有先驗(yàn)知識前提下找到了棋盤的內(nèi)部表示證據(jù)->反過來用內(nèi)部表示引入來控制輸出。

Part1:Introduction

序列建模任務(wù)(sequence modeling task)為什么可以通過“next token”來得到一些驚人的結(jié)果?關(guān)于這個問題,有很多爭執(zhí):

從哲學(xué)到數(shù)學(xué)方面的爭執(zhí)范圍增加下,產(chǎn)生了對模型性能的疑問:序列模型是否只是“記住了”表面數(shù)據(jù)?而不是去理解其中的因果關(guān)系?

但是又有線索表明,LM能夠構(gòu)建屬可解釋的世界模型。然而,相關(guān)研究并沒有去探索內(nèi)部模型的解釋,因此產(chǎn)生motivation。

本文貢獻(xiàn):

(1)我們?yōu)?GPT 變體中的新興世界模型提供了證據(jù),該模型經(jīng)過訓(xùn)練可以在奧賽羅中產(chǎn)生合法的棋步;?

(2)我們比較了線性和非線性探測方法的性能,發(fā)現(xiàn)非線性探測在這方面更優(yōu)越;?

(3)我們提出了一種干預(yù)技術(shù),表明在某些情況下,可以使用涌現(xiàn)世界模型來控制網(wǎng)絡(luò)的行為;?

(4)我們展示了如何使用探針來生成潛在顯著性圖,以闡明模型的預(yù)測。

Part2 訓(xùn)練一個游戲LM

一、準(zhǔn)備

①游戲選擇:黑白棋(游戲樹夠大避免單純記憶、游戲規(guī)則簡單)

②策略:不提供游戲及棋盤結(jié)構(gòu)的先驗(yàn)知識,僅允許GPT變體觀察游戲記錄。

③數(shù)據(jù)集:

? ? a.Championship:從兩個渠道收集“奧賽羅(Othello)錦標(biāo)賽”比賽數(shù)據(jù)(7,605+132, 921 ),按照8:2分訓(xùn)練集和測試集?!坝胁呗缘?、雙方為了獲勝的”

? ? b.Synthetic:從游戲樹的葉子節(jié)點(diǎn)進(jìn)行均勻采樣的結(jié)果,得到20,000,000(訓(xùn)練集)+3,796,010(測試集)?!盁o策略的、約等于亂下的”

④棋盤建模:用60個單詞來表示除了中間四個以外的格子。

⑤訓(xùn)練模型底座:8-layer GPT model+512維隱藏空間 =>以得到 Othello GPT


二、訓(xùn)練過程:利用了因果掩碼

S1:對于每一個部分游戲(partial game){y_t}_{t=0}^{T-1} 。首先輸入由60個向量組成的可訓(xùn)練詞嵌入(隨機(jī)初始化權(quán)重),每個向量對應(yīng)一個詞,得到初始化特征{x_{t}^0 }_{t=0}^{T-1}。

S2:在訓(xùn)練過程中,第l層第t個標(biāo)記x_{t}^{l} 只可見x_{\leq t}^{l-1}(也就是前一層的前t個token),并最終得到編碼x_{T-1}^8,并利用線性分類器其預(yù)測y_T。

S3:訓(xùn)練時利用梯度下降算法,以最小化真實(shí)移動和預(yù)測之間的交叉熵?fù)p失。

三、實(shí)驗(yàn)結(jié)果:

提供給GPT一個驗(yàn)證集下的部分棋局,利用檢測top-1是否合法來計(jì)算錯誤率。結(jié)果上,在Synthetic上訓(xùn)練的錯誤率0.1%,在Championship上訓(xùn)練的錯誤率5.17%,未經(jīng)訓(xùn)練的則是93.29%。

這說明Othello-GPT在預(yù)測上做的比原先更好。但也不排除是它記住了所有可能。

進(jìn)一步實(shí)驗(yàn):對Synthetic數(shù)據(jù)集,將開局可能的4個動作(C5,D6,E3,F(xiàn)4)四個步驟下的數(shù)據(jù)集刪掉C5節(jié)點(diǎn)。[構(gòu)造傾斜數(shù)據(jù)集],用該數(shù)據(jù)集訓(xùn)練模型。但是在這種情況下,Othello GPT依然保持住在0.02%的錯誤率(沒有大幅度提升),這說明Othello GPT不是“死記硬背”順序的。

PART3 利用探針探索內(nèi)部展示

1.探針:一個分類器/回歸器,輸入是待探測網(wǎng)絡(luò)的內(nèi)部激活,輸出是對于某個特征的預(yù)測結(jié)果。如果這個探針在訓(xùn)練后能夠準(zhǔn)確預(yù)估某個特征,就說明這個特征的表示在網(wǎng)絡(luò)的激活中。

2.探測目標(biāo):對棋盤狀態(tài)的表示=>探針預(yù)測目標(biāo):預(yù)測棋盤的某個位置是空/持有黑子/持有白子。(動機(jī):函數(shù)好寫)

輸入:x_{t}^l (研究不同層l結(jié)果)

輸出:p\theta (x_{t}^l )(一個三向分類概率分布)

參照概率:52.95%(在驗(yàn)證集中猜測所有圖塊為空的結(jié)果)=>參考錯誤率:47.05%

3.實(shí)驗(yàn)結(jié)果:

①使用線性探頭的結(jié)果:略有提高,性能較差

表示:p_\theta (x_{t}^l )=softmax(Wx_{t}^l),

其中?\theta =W \in R^{F\times 3}

baseline:在隨機(jī)初始化網(wǎng)絡(luò)上訓(xùn)練的探針

表 1:在不同層的不同數(shù)據(jù)集上訓(xùn)練的隨機(jī) Othello-GPT 和 Othello-GPT 上的線性探針的錯誤率 (%)(xi 表示第 i 層之后的內(nèi)部表示)。

②使用非線性探頭的結(jié)果:錯誤率較低

表示:p_\theta (x_{t}^l )=softmax(W_1ReLU(W_2x_{t}^l))

表 2:在不同層的不同數(shù)據(jù)集上訓(xùn)練的隨機(jī) Othello-GPT 和 Othello-GPT 上的非線性探針的錯誤率 (%)。標(biāo)準(zhǔn)差在附錄 H 中報(bào)告。

補(bǔ)充:Othello-GPT是怎么對棋盤建模的?

為每個格子訓(xùn)練的探針本質(zhì)保留了該格子的原型向量知識,認(rèn)為其具有“概念向量”。

利用PCA將維度降到3,連接在棋盤上直接臨近的兩個點(diǎn),同時,如果它們是水平臨近的,就用橙色線來連線。

圖片中,非線性情況下Championship和Synthetic的建模是一種“球形抹布”結(jié)構(gòu),作者認(rèn)為這就是Othello GPT對棋盤的建模方法(我的理解是:因?yàn)槌壬牟糠郑?/p>

但是,它是否真的利用這個棋盤建模來幫助自己進(jìn)行建模呢?

Part4 干預(yù)實(shí)驗(yàn)驗(yàn)證指針

干預(yù)實(shí)驗(yàn):給定Othello-GPT一組激活,探針預(yù)測棋盤狀態(tài)B;然后修改激活,檢測出B’。如果B'符合預(yù)期,而非是和B一樣的結(jié)果,則認(rèn)為是相關(guān)的。

(1)干預(yù)

在這里的例子中,干預(yù)實(shí)驗(yàn)的目的在于修改某個部分的狀態(tài),從而使得預(yù)測的狀態(tài)從某個顏色top1越過邊界到達(dá)另一種顏色的top1(圖A)

而圖B展示了一種例子,左上是實(shí)際的情況,左下是原來的世界狀態(tài)下對于棋盤顏色的預(yù)測結(jié)果(一致);而在參數(shù)改動之后,模型預(yù)測的下一步落子發(fā)生了改變(廢話),且模型的世界狀態(tài)下棋盤上有一顆白子變成了黑子。

圖C則回答了"動哪些地方?”的問題。

其選擇一個初始層,修改其與后續(xù)層的激活。淺藍(lán)色表示未修改的激活;深藍(lán)色代表受干預(yù)影響的激活。

從預(yù)定義層開始,作者對時間上最后的標(biāo)記進(jìn)行干預(yù)。用干預(yù)后的內(nèi)部表示替換原始的內(nèi)部表示,并恢復(fù)下一層的計(jì)算。

部分錯誤信息得到糾正(淺藍(lán)色),但我們交替這種干預(yù)和計(jì)算過程,直到最后一層,從中進(jìn)行下一步預(yù)測。

如果只修改中間層的激活,高層的激活會直接受到干預(yù)的影響。

(簡單說:改變當(dāng)層最后一個->得到下一層->把下一層除了最后一個以外的全部恢復(fù)成原來的表示->繼續(xù)計(jì)算循環(huán))

具體干預(yù)的辦法:訓(xùn)練探針的方法依舊是采用探針預(yù)測的概率和棋盤狀態(tài)之間的交叉熵?fù)p失,但是這里不再優(yōu)化探針的權(quán)重\theta ,而是優(yōu)化x

優(yōu)化x的公式

解釋:使用梯度下降來修改最后一個timestamp的關(guān)鍵激活向量。

(2)實(shí)驗(yàn)

設(shè)計(jì)了一個評估的base。測試包括:游戲、目標(biāo)棋盤圖塊和目標(biāo)狀態(tài)組成。

實(shí)驗(yàn)過程:將部分博弈交給Othello-GPT執(zhí)行,并且進(jìn)行干預(yù)。具體來說,就是計(jì)算時提取模型激活,然后修改激活從而改變棋盤的表示,并將修改后的表示返回要求其進(jìn)行預(yù)測。

基準(zhǔn)集:1000個干預(yù)案例下的兩個子集:

? ? ①自然子集:所有游戲內(nèi)按照規(guī)則可以到達(dá)的地方集合。

? ? ②非自然子集:游戲內(nèi)按照規(guī)則不可以到達(dá)的地方集合(嚴(yán)格測試,遠(yuǎn)離訓(xùn)練分布中的任何內(nèi)容)

在這里,相當(dāng)于先進(jìn)行計(jì)算過程中激活的修改,然后再拿改變了之后的題目來詢問,因此,做過干預(yù)的結(jié)果會比沒做過干預(yù)的錯誤率更低。實(shí)驗(yàn)結(jié)果上,layer4取得了最低的錯誤率(性能最好),且干預(yù)實(shí)驗(yàn)在自然子集和非自然子集上都取得了效果。

Part5 干預(yù)實(shí)驗(yàn)驗(yàn)證指針

擺了,直接用匯報(bào)ppt里的內(nèi)容。

Over。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容