DARTS: Differentiable Architecture Search

深度學(xué)習(xí)可以自動(dòng)學(xué)習(xí)出有用的特征,脫離了對(duì)特征工程的依賴,在圖像、語音等任務(wù)上取得了超越其他算法的結(jié)果。這種成功很大程度上得益于新神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)的出現(xiàn),如ResNet、Inception、DenseNet等。但設(shè)計(jì)出高性能的神經(jīng)網(wǎng)絡(luò)需要大量的專業(yè)知識(shí)與反復(fù)試驗(yàn),成本極高,限制了神經(jīng)網(wǎng)絡(luò)在很多問題上的應(yīng)用。神經(jīng)結(jié)構(gòu)搜索(Neural Architecture Search,簡(jiǎn)稱NAS)是一種自動(dòng)設(shè)計(jì)神經(jīng)網(wǎng)絡(luò)的技術(shù),可以通過算法根據(jù)樣本集自動(dòng)設(shè)計(jì)出高性能的網(wǎng)絡(luò)結(jié)構(gòu),在某些任務(wù)上甚至可以媲美人類專家的水準(zhǔn),甚至發(fā)現(xiàn)某些人類之前未曾提出的網(wǎng)絡(luò)結(jié)構(gòu),這可以有效的降低神經(jīng)網(wǎng)絡(luò)的使用和實(shí)現(xiàn)成本。NAS的原理是給定一個(gè)稱為搜索空間的候選神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)集合,用某種策略從中搜索出最優(yōu)網(wǎng)絡(luò)結(jié)構(gòu)。神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)的優(yōu)劣即性能用某些指標(biāo)如精度、速度來度量,稱為性能評(píng)估。

摘要

本文提出通過可微的方式來解決神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)搜索的問題。前人的網(wǎng)絡(luò)搜索方法,要么是基于強(qiáng)化學(xué)習(xí) 的,要么是基于遺傳算法的,都是非常耗時(shí)的,最近的幾個(gè)算法表示他們的計(jì)算時(shí)間可能需要:1800 GPU days 以及 3150 GPU days,上述算法都是在不可微或者離散的空間中求解,本文的方法基于架構(gòu)表示的連續(xù)空間,允許使用梯度下降有效地搜索架構(gòu)。在CIFAR-10,ImageNet,Penn Treebank和WikiText-2上進(jìn)行了大量實(shí)驗(yàn),表明本文的算法擅長于發(fā)現(xiàn)用于圖像分類的高性能卷積結(jié)構(gòu)和用于語言建模的循環(huán)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),同時(shí)比現(xiàn)有技術(shù)的非微分搜索技術(shù)要快幾個(gè)數(shù)量級(jí)。

介紹

發(fā)現(xiàn)最先進(jìn)的神經(jīng)網(wǎng)絡(luò)架構(gòu)需要人類專家的大量工作。最近,人們?cè)絹碓接信d趣開發(fā)自動(dòng)化算法解決神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)的設(shè)計(jì)。自動(dòng)搜索架構(gòu)在諸如圖像分類和目標(biāo)檢測(cè)任務(wù)上有著優(yōu)異的性能。

盡管具有卓越的性能,但現(xiàn)有的最佳架構(gòu)搜索算法在計(jì)算上要求很高。在CIFAR-10和ImageNet上獲得最好的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),采用增強(qiáng)學(xué)習(xí)需要2000 GPU days,采用遺傳算法需要3150 GPU days。也有一些人提出了加速方法:例如強(qiáng)加搜索空間的特定結(jié)構(gòu),對(duì)每個(gè)單獨(dú)架構(gòu)的權(quán)重或性能預(yù)測(cè),以及跨體系結(jié)構(gòu)的權(quán)重共享,但可擴(kuò)展性的根本問題依然存在。主流方法(RL, evolution, MCTS, SMBO,Bayesian optimization)低效率的原因在于:他們將結(jié)構(gòu)搜索這個(gè)任務(wù)當(dāng)做是一個(gè)離散領(lǐng)域的黑盒優(yōu)化問題,從而導(dǎo)致需要評(píng)價(jià)大量的結(jié)構(gòu)。

本文另辟蹊徑,提出DARTS(Differentiable Architecture Search)。不再搜索一組離散的候選架構(gòu),而是將搜索空間放松至連續(xù),從而可以通過梯度下降的方法對(duì)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)在驗(yàn)證集上的性能進(jìn)行優(yōu)化。因?yàn)榛谔荻鹊膬?yōu)化,與低效的黑盒搜索不同,使得DARTS使用比現(xiàn)有技術(shù)數(shù)量級(jí)較少的計(jì)算資源實(shí)現(xiàn)同具競(jìng)爭(zhēng)力的表現(xiàn)。它也優(yōu)于另一種最近的高效架構(gòu)搜索方法:ENAS,值得注意的是,DARTS比許多現(xiàn)有方法簡(jiǎn)單,因?yàn)樗簧婕叭魏蝐ontrollers、hypernetworks,或表現(xiàn)預(yù)測(cè)因子。

在一個(gè)連續(xù)的領(lǐng)域內(nèi)搜索體系結(jié)構(gòu)的想法并不新鮮,本文與先前的工作有幾個(gè)主要區(qū)別在于:

1、雖然之前的工作試圖對(duì)結(jié)構(gòu)的特定方面進(jìn)行微調(diào),如卷積網(wǎng)絡(luò)中的卷積核形狀或分支模式,但是DARTS能夠在豐富的搜索空間內(nèi)發(fā)現(xiàn)具有復(fù)雜圖形拓?fù)涞母咝阅芗軜?gòu)。

2、DARTS不限于任何特定架構(gòu),能夠搜索卷積神經(jīng)網(wǎng)絡(luò)和循環(huán)神經(jīng)網(wǎng)絡(luò)。

在實(shí)驗(yàn)中,DARTS能設(shè)計(jì)一個(gè)在CIFAR-10分類數(shù)據(jù)集上測(cè)試誤差為2.76\pm 0.09%,參數(shù)量為3.3M的卷積單元,與regularized evolution的方法相比,表現(xiàn)相當(dāng),但是后者使用的計(jì)算資源是前者的三倍。同樣的卷積單元在ImageNet數(shù)據(jù)集上實(shí)現(xiàn)了26.7%的Top_1 Error,同樣與強(qiáng)化學(xué)習(xí)最好的方法相當(dāng)。在語言模型任務(wù)上,DARTS發(fā)現(xiàn)了一個(gè)循環(huán)單元,在Penn Treebank上達(dá)到了55.7的測(cè)試 perplexity,由于現(xiàn)有的LSTM模型和所有自動(dòng)結(jié)構(gòu)搜索(NAS&ENAS)得到的模型。

總結(jié)起來,貢獻(xiàn)有四:

1、引入了一種適用于卷積和循環(huán)結(jié)構(gòu)的可微分網(wǎng)絡(luò)結(jié)構(gòu)搜索的新算法;

2、實(shí)驗(yàn)表明本文的方法具有很強(qiáng)的競(jìng)爭(zhēng)力(CIFAR-10&PTB);

3、實(shí)現(xiàn)了卓越的結(jié)構(gòu)搜索效率(4個(gè)GPU:1天內(nèi)CIFAR-10誤差2.83%; 6小時(shí)內(nèi)PTB誤差56.1),這歸因于使用基于梯度的優(yōu)化而非非微分搜索技術(shù);

4、證明 DARTS 在 CIFAR-10 和 PTB 上學(xué)習(xí)的網(wǎng)絡(luò)結(jié)構(gòu)可以遷移到 ImageNet 和 WikiText-2 上

可微結(jié)構(gòu)搜索

在第一部分,用一般形式描述搜索空間,其中結(jié)構(gòu)(或其中的單元cell)的計(jì)算過程被表示為有向無環(huán)圖。在第二部分,為搜索空間引入一個(gè)簡(jiǎn)單的連續(xù)松弛方案,使得架構(gòu)及其權(quán)重的聯(lián)合優(yōu)化目標(biāo)可微。在第三部分,提出了一種近似技術(shù),使算法在計(jì)算上可行和高效。

搜索空間

根據(jù)前人的工作,搜索一個(gè)計(jì)算單元(cell)作為最終架構(gòu)的構(gòu)建塊,學(xué)習(xí)的單元可以堆疊以形成卷積網(wǎng)絡(luò),或者遞歸地連接以形成循環(huán)網(wǎng)絡(luò)。單元是N個(gè)有序節(jié)點(diǎn)構(gòu)成的有向無環(huán)圖,每個(gè)節(jié)點(diǎn)x^{(i)} 都是一個(gè)Latent Representation(如卷積神經(jīng)網(wǎng)絡(luò)中的特征圖),每個(gè)有向邊(i,j)是對(duì)節(jié)點(diǎn)x^{(i)} 的某種運(yùn)算操作o^{(i,j)},假設(shè)每個(gè)單元有兩個(gè)輸入節(jié)點(diǎn)和一個(gè)輸出節(jié)點(diǎn),對(duì)于卷積單元來說,輸入節(jié)點(diǎn)被定義為前兩層的單元輸出;對(duì)于循環(huán)神經(jīng)網(wǎng)絡(luò),輸入節(jié)點(diǎn)被定義為當(dāng)前時(shí)間步的輸入以及上一時(shí)間步中的隱藏狀態(tài)。通過對(duì)所有中間節(jié)點(diǎn)應(yīng)用reduction操作(例如concatenation)來獲得單元的輸出。

每個(gè)中間節(jié)點(diǎn)都是基于所有它之前的節(jié)點(diǎn)進(jìn)行計(jì)算的:x^{j}=\sum_{i<j}o^{(i,j)}(x^{(i)})

特殊的零操作指示兩個(gè)節(jié)點(diǎn)之間缺少連接。 因此,學(xué)習(xí)單元的任務(wù)減少了學(xué)習(xí)其邊緣的操作。

連續(xù)的放松和優(yōu)化

O表示候選操作的集合(如卷積,池化,或者零操作),每個(gè)操作都代表應(yīng)用于x^{(i)} 的一些函數(shù)o(\cdot ),為了令搜索空間連續(xù),將特定操作的分類選擇放寬到所有可能操作的softmax:


其中一對(duì)節(jié)點(diǎn)(i,j)的操作混合權(quán)重由向量\alpha ^{(i,j)}(維數(shù)為\vert O \vert )參數(shù)化表示,結(jié)構(gòu)搜索任務(wù)簡(jiǎn)化為學(xué)習(xí)一系列連續(xù)的變量\alpha =(\alpha ^{(i,j)} ),在搜索的最后,通過將每個(gè)混合操作替換為最可能的操作o^{(i,j)}=argmax_{o\in O} \alpha _{o}^{(i,j)} ,下面將用\alpha 來表示結(jié)構(gòu)的編碼。

在放松得到連續(xù)的優(yōu)化空間之后,我們的目標(biāo)是同時(shí)學(xué)習(xí)網(wǎng)絡(luò)結(jié)構(gòu)\alpha 和所有混合操作的權(quán)重w(例如卷積神經(jīng)網(wǎng)絡(luò)卷積核的權(quán)重),在強(qiáng)化學(xué)習(xí)和遺傳算法中,驗(yàn)證集上的表現(xiàn)被視為獎(jiǎng)勵(lì)或者擬合程度。DARTS用梯度下降來優(yōu)化驗(yàn)證損失。

DARTS的算法流程

在上圖(a)中,每條邊上的操作初始化為未知的;圖(b)中,將候選操作的混合放置于每條邊以此來進(jìn)行放松使得搜索空間連續(xù);圖(c)中,通過求解雙層優(yōu)化問題聯(lián)合優(yōu)化混合概率和網(wǎng)絡(luò)權(quán)重;圖(d)中,從學(xué)習(xí)的混合概率中得到最終的結(jié)構(gòu)。

L_{train} ,L_{val} 分別代表訓(xùn)練和驗(yàn)證損失,上述損失不僅由結(jié)構(gòu)\alpha 定義,也由網(wǎng)絡(luò)中的權(quán)重矩陣w定義,網(wǎng)絡(luò)搜索的目標(biāo)是找到某一結(jié)構(gòu)\alpha ^*使得驗(yàn)證損失L_{val}(w^*,\alpha ^* ) 最小化,意即\alpha ^* =argmin_\alpha  L_{val}(w^*,\alpha  ) ,與結(jié)構(gòu)相關(guān)的權(quán)重w^*通過最小化訓(xùn)練損失得到:w^* =argmin_w L_{train}(w,\alpha ^* ) 。

這是一個(gè)雙層優(yōu)化問題,\alpha 是上層變量,w是下層變量:

min_{\alpha }  L_{val}(w^*(\alpha ),\alpha  )

s.t.   \quad w^*(\alpha ) =argmin_{w}L_{train}(w,\alpha )

嵌套公式也出現(xiàn)在基于梯度的超參數(shù)優(yōu)化中,這在某種意義上是相關(guān)的,即架構(gòu)\alpha 可以被視為一種特殊類型的超參數(shù),盡管它的維度遠(yuǎn)遠(yuǎn)高于標(biāo)量值超參數(shù),例如學(xué)習(xí)率,并且它更難以優(yōu)化。

算法1:DARTS:

建立一個(gè)混合操作,由邊(i,j)的向量\alpha ^{(i,j)}參數(shù)化得到:

如果沒有收斂:

就以下式更新結(jié)構(gòu)\alpha (理論上,權(quán)重應(yīng)該已達(dá)最優(yōu),此處為近似表達(dá))

梯度下降更新\alpha

再以下式更新權(quán)重w(對(duì)于確定的結(jié)構(gòu)\alpha ,更新權(quán)重)

梯度下降更新w

收斂之后由學(xué)習(xí)到的\alpha 得到最后的網(wǎng)絡(luò)結(jié)構(gòu)。

估計(jì)結(jié)構(gòu)梯度

由于昂貴的內(nèi)部?jī)?yōu)化,準(zhǔn)確地評(píng)估架構(gòu)梯度可能是令人望而卻步的,因此提出一個(gè)簡(jiǎn)單的梯度估計(jì):

梯度估計(jì)

或者表示為:

step  \quad k,給定當(dāng)前的結(jié)構(gòu) \alpha _{k-1} ,我們通過朝向降低訓(xùn)練損失的方向去移動(dòng)w_{k-1} 來得到 w_{k} 。然后,保持權(quán)重w_{k} 不變,去更新網(wǎng)絡(luò)結(jié)構(gòu),使其可以最小化驗(yàn)證集損失:

驗(yàn)證集損失

w代表算法使用的當(dāng)前權(quán)重,\xi 代表內(nèi)部?jī)?yōu)化的學(xué)習(xí)率,用權(quán)重w一個(gè)訓(xùn)練步驟的更新來估計(jì)最優(yōu)權(quán)重w^*(\alpha ),而不是訓(xùn)練直到收斂以完全求解內(nèi)部?jī)?yōu)化問題。如果w已經(jīng)達(dá)到局部最優(yōu),上式將會(huì)退化為:

退化表達(dá)

雖然我們目前還不能指出該優(yōu)化算法的收斂性,但在實(shí)踐中它能夠通過合適的\xi 達(dá)到一個(gè)固定的點(diǎn)。當(dāng)動(dòng)量被用于權(quán)重優(yōu)化時(shí),上式中的一步展開學(xué)習(xí)目標(biāo)被相應(yīng)地修改,并且我們的所有分析仍然適用。

應(yīng)用鏈?zhǔn)椒▌t:

上述梯度估計(jì)為:

鏈?zhǔn)椒▌t展開式

由于鏈?zhǔn)椒▌t展開后的第二項(xiàng)將涉及到矩陣內(nèi)積運(yùn)算,計(jì)算非常昂貴,因此采用有限差分近似來降低復(fù)雜度:

有限差分近似

其中:

推導(dǎo)

評(píng)估該有限差分僅需要兩次前向傳播即可得到w,兩次反向傳播,就可以得到\alpha ,運(yùn)算復(fù)雜度大大的降低了:由O(\vert w \vert \vert \alpha  \vert )降低為O(\vert w \vert +\vert \alpha  \vert )

一階估計(jì):

當(dāng)內(nèi)部?jī)?yōu)化學(xué)習(xí)率\xi =0時(shí),二階導(dǎo)數(shù)將消失,此時(shí),關(guān)于結(jié)構(gòu)\alpha 的梯度將由下式給出:

一階情況

意即假設(shè)當(dāng)前步的權(quán)重w就是w^*(\alpha ) ,這會(huì)帶來速度的提升,但是性能會(huì)下降。

迭代算法的學(xué)習(xí)表現(xiàn)

當(dāng)L_{val} (w,\alpha )=\alpha w-2\alpha +1L_{train}(w,\alpha ) =\alpha ^2-2\alpha w+w^2時(shí),從(\alpha ^{(0)},w^{(0)})=(2,-2)開始優(yōu)化,兩層優(yōu)化問題的理論最優(yōu)解為(\alpha ^*,w^*)=(1,1),紅色虛線表示精確滿足雙層優(yōu)化數(shù)學(xué)描述的可行組,上圖表明一個(gè)合適的學(xué)習(xí)率\xi 將有助于收斂到較好的局部最優(yōu)點(diǎn)。

推導(dǎo)離散結(jié)構(gòu)

為了在離散架構(gòu)中形成每個(gè)節(jié)點(diǎn),我們保留從所有先前節(jié)點(diǎn)收集的所有非零候選操作中的前k個(gè)強(qiáng)度最高的操作(來自不同節(jié)點(diǎn)),強(qiáng)度定義為:

強(qiáng)度定義

對(duì)于卷積單元,設(shè)置k=2,對(duì)于循環(huán)單元,設(shè)置k=1。

意即:為每個(gè)中間節(jié)點(diǎn)保留k個(gè)最強(qiáng)的前導(dǎo),其中邊的強(qiáng)度定義為上式;通過采用argmax將每個(gè)混合操作替換為最可能的操作。

要將零操作去除的原因是:首先,我們需要每個(gè)節(jié)點(diǎn)恰好有k個(gè)非零入射邊,以便與現(xiàn)有模型進(jìn)行公平比較;其次,零操作的強(qiáng)度是不確定的,因?yàn)樵黾恿悴僮鞯膌ogits僅影響結(jié)果節(jié)點(diǎn)表示的規(guī)模,并且由于存在批量規(guī)范化而不影響最終的分類結(jié)果。

實(shí)驗(yàn)和結(jié)果

在CIFAR-10和PTB數(shù)據(jù)集上的實(shí)驗(yàn)包含兩個(gè)部分:結(jié)構(gòu)搜索和結(jié)構(gòu)評(píng)估,在第一階段,使用DARTS搜索單元體系結(jié)構(gòu),并根據(jù)其驗(yàn)證性能確定最佳單元;在第二階段,使用這些單元構(gòu)建更大的架構(gòu),從頭開始訓(xùn)練并在測(cè)試集上評(píng)估它們的性能。同時(shí)還分別通過在ImageNet和WikiText-2(WT2)上評(píng)估它們來研究在CIFAR-10和PTB上學(xué)習(xí)的最佳單元的可遷移性。

結(jié)構(gòu)搜索

CIFAR-10上的結(jié)構(gòu)搜索

操作集合O中包含3\times 35\times 5可分離卷積,3\times 35\times 5的空洞可分離卷積,3\times 3最大池化,3\times 3平均池化,恒等映射和零操作。所有步長均為1,填充卷積特征圖以保持其空間分辨率。使用ReLU-Conv-BN順序進(jìn)行卷積運(yùn)算,每個(gè)可分離卷積總是應(yīng)用兩次。

一個(gè)卷積單元包含7個(gè)節(jié)點(diǎn),其中輸出節(jié)點(diǎn)被定義為所有中間節(jié)點(diǎn)(排除輸入節(jié)點(diǎn))的深度級(jí)聯(lián)(concat)。然后通過將多個(gè)單元堆疊在一起形成網(wǎng)絡(luò)。單元k的第一和第二個(gè)節(jié)點(diǎn)設(shè)置為單元k-1和單元k-2的輸出,必要時(shí)插入1\times 1卷積。在網(wǎng)絡(luò)深度1/32/3處的單元是降采樣單元,與輸入節(jié)點(diǎn)相鄰的所有操作的步長設(shè)置為2。因此網(wǎng)絡(luò)表示為(\alpha _{normal} ,\alpha _{reduce}),其中所有正常單元共享\alpha_{normal},所有下采樣單元共享\alpha_{reduce}

PTB上的結(jié)構(gòu)搜索

所有可能的操作集合中包含:線性變換+tanh ,線性變換+relu,線性變換+sigmoid,恒等映射,零操作。

循環(huán)單元包含12個(gè)節(jié)點(diǎn),通過對(duì)兩個(gè)輸入節(jié)點(diǎn)進(jìn)行線性變換來獲得第一中間節(jié)點(diǎn),結(jié)果相加后通過tanh激活函數(shù)。并通過一個(gè)旁路來增強(qiáng)每個(gè)操作,單元輸出定義為所有中間節(jié)點(diǎn)的平均值,在進(jìn)行結(jié)構(gòu)搜索時(shí)啟用BN,以防止出現(xiàn)梯度爆炸,在測(cè)試網(wǎng)絡(luò)結(jié)構(gòu)時(shí)禁用BN。循環(huán)網(wǎng)絡(luò)只包含一個(gè)單元,即不會(huì)在循環(huán)結(jié)構(gòu)中假設(shè)任何重復(fù)模式。

Normal Cell Learned on CIFAR -10
Reduce Cell Learned on CIFAR-10
Recurrent Cell Learned on PTB

結(jié)構(gòu)評(píng)估

為了確定最終評(píng)估的體系結(jié)構(gòu),使用不同的隨機(jī)初始化運(yùn)行DARTS四次,并根據(jù)其在短時(shí)間內(nèi)從頭開始訓(xùn)練獲得的驗(yàn)證性能選擇最佳單元(100 epochs on CIFAR-10 and 300 epochs on PTB)。(循環(huán)單元的優(yōu)化結(jié)果對(duì)初始化很敏感)

為了評(píng)估所選擇的架構(gòu),我們隨機(jī)初始化其權(quán)重(在搜索過程中學(xué)習(xí)的權(quán)重被丟棄),從頭開始訓(xùn)練,并在測(cè)試集上報(bào)告其性能。 注意,測(cè)試集從未用于架構(gòu)搜索。

CIFAR-10
PTB
ImageNet

結(jié)果分析

為了更好地理解雙層優(yōu)化的必要性,本文研究了一種簡(jiǎn)單的搜索策略,\alpha ,w在訓(xùn)練驗(yàn)證聯(lián)合數(shù)據(jù)集上使用協(xié)同過濾進(jìn)行聯(lián)合優(yōu)化,這樣最好的卷積單元測(cè)試錯(cuò)誤率為4.16\pm 0.16%,參數(shù)量為3.1M,比隨機(jī)搜索要差。在第二個(gè)實(shí)驗(yàn)中,在訓(xùn)練驗(yàn)證聯(lián)合數(shù)據(jù)集上使用SGD同時(shí)優(yōu)化\alpha, w,這樣最好的卷積單元測(cè)試錯(cuò)誤率為3.56\pm 0.10%,參數(shù)量為3.0M。猜想是因?yàn)檫@些啟發(fā)式方法會(huì)導(dǎo)致結(jié)構(gòu)\alpha 過度(類似于超參數(shù))擬合訓(xùn)練數(shù)據(jù),從而導(dǎo)致較差的泛化。 請(qǐng)注意,\alpha 未在DARTS中的訓(xùn)練集上直接優(yōu)化。

值得注意的是,隨機(jī)搜索在卷積模型和循環(huán)模型上表現(xiàn)都很良好,這反映了搜索空間設(shè)計(jì)的重要性。



實(shí)驗(yàn)細(xì)節(jié):

在CIFAR-10上的結(jié)構(gòu)搜索:

由于架構(gòu)在整個(gè)搜索過程中會(huì)有所不同,因此始終使用批量特定的統(tǒng)計(jì)信息進(jìn)行批量標(biāo)準(zhǔn)化而不是全局移動(dòng)平均值。在搜索過程中禁用所有批量標(biāo)準(zhǔn)化中可學(xué)習(xí)的仿射參數(shù),以避免重新調(diào)整候選操作的輸出。

為了進(jìn)行結(jié)構(gòu)搜索,將一半的CIFAR-10訓(xùn)練數(shù)據(jù)作為驗(yàn)證集,8個(gè)單元的小網(wǎng)絡(luò)以DARTS訓(xùn)練50輪,batch_size=64,初始通道數(shù)為16,優(yōu)化w,初始化學(xué)習(xí)率\eta _w=0.025,在沒有重啟的情況下按照余弦計(jì)劃退火到零,對(duì)結(jié)構(gòu)變量進(jìn)行零初始化(對(duì)于正常單元和降采樣單元中的\alpha),這意味著在所有可能的操作上都會(huì)有相同的注意力(在采用softmax之后)。在早期階段,這確保了每個(gè)候選操作中的權(quán)重以接收足夠的學(xué)習(xí)信號(hào)(更多探索)。使用Adam算法來優(yōu)化\alpha,初始化學(xué)習(xí)率\eta_{\alpha}=0.0003,結(jié)構(gòu)搜索在單GPU上需要一天。

在CIFAR-10上的結(jié)構(gòu)評(píng)估:

有20個(gè)單元的大網(wǎng)絡(luò)以batch_size=96訓(xùn)練600輪,通道數(shù)初始化為36,使用了一些增強(qiáng)如cutout,path dropout,auxiliary towers with weight 0.4,訓(xùn)練在單GPU上需要1.5天,由于即使設(shè)置完全相同,CIFAR結(jié)果也會(huì)出現(xiàn)很大的差異,因此報(bào)告了完整模型的10次獨(dú)立運(yùn)行的平均值和標(biāo)準(zhǔn)差。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容