Neural Ordinary Differential Equations 神經(jīng)常微分方程

0 摘要

我們引入了一個新的深度神經(jīng)網(wǎng)絡(luò)模型家族. 我們沒有用非連續(xù)的隱藏層, 而是用神經(jīng)網(wǎng)絡(luò)把隱狀態(tài)的導(dǎo)數(shù)參數(shù)化. 網(wǎng)絡(luò)的輸出是通過黑盒微分方程求解器來計算的. 這些連續(xù)層的網(wǎng)絡(luò)的內(nèi)存消耗是穩(wěn)定不變的, 針對每個輸入來設(shè)計估計方法的話, 就能做計算精度和計算速度的權(quán)衡. 我們通過連續(xù)層的ResNet和連續(xù)時間的隱變量模型展示了這些特性. 我們還構(gòu)造了連續(xù)正則化流, 該生成模型可以直接用極大似然來訓(xùn)練, 不需要對數(shù)據(jù)維度進(jìn)行分區(qū)或者排序. 我們展示了訓(xùn)練過程其實不需要了解ODE求解器內(nèi)部的實現(xiàn), 也能對ODE求解器的反向進(jìn)行計算. 這就允許我們構(gòu)造大規(guī)模模型, 并進(jìn)行端到端的訓(xùn)練.

1 引言

像ResNet, RNN解碼器, 正則化流, 它們都組合了隱狀態(tài)的一系列變換, 構(gòu)建出一個復(fù)雜的變換, 如下:

image.png

其中t屬于[0…T], ht屬于Rd, 這些迭代更新可以看作是連續(xù)變換的歐拉離散化. 當(dāng)t趨于0, step趨于無窮時, 可以得到如下的常微分方程(ODE, ordinary differential equation):

image.png

給定h(0), 我們可以把h(T) 作為該方程在T時刻的解. 該解可以用黑盒ODE求解器計算得到, 求解器還能根據(jù)需要的精度自行決定在何處對f進(jìn)行擬合. 圖1對比這一過程:

image.png

圖1

左: ResNet定義了一系列非連續(xù)的有限轉(zhuǎn)換.

右: ODE網(wǎng)絡(luò)定義了一個向量場, 可以對隱狀態(tài)進(jìn)行連續(xù)的轉(zhuǎn)換.

黑點表示估計點.

定義 使用ODE的模型有如下好處:

內(nèi)存優(yōu)化:

在第2節(jié), 我們展示了如何在不涉及ODE求解器黑盒內(nèi)部操作的情況下, 對任意ODE求解過程求反向, 得到標(biāo)量損失的梯度. 不儲存任何前向計算結(jié)果, 就可以讓我們在內(nèi)存占用不變的情況下訓(xùn)練任意深度的模型. 這就解決了深度神經(jīng)網(wǎng)絡(luò)模型訓(xùn)練的主要瓶頸---模型深度.

自適應(yīng)計算法

歐拉法求ODE是比較古老的方法了, 現(xiàn)代ODE求解器可以做到根據(jù)誤差精度要求來調(diào)整求解過程, 監(jiān)控誤差來獲得需要的精度. 這就可以根據(jù)問題復(fù)雜度來調(diào)整模型估值的消耗. 在模型訓(xùn)練結(jié)束后, 還能降低計算精度來滿足程序?qū)崟r性的要求.

可拓展和可逆的標(biāo)準(zhǔn)化流

連續(xù)變換帶了一個意想不到的好處, 變量方程式的變化更加容易計算了. 在第4節(jié), 我們提出這個結(jié)論并組建了一個可逆的密度模型, 該模型可以避免正則化流中單單元的瓶頸, 可以直接用極大似然來進(jìn)行訓(xùn)練.

連續(xù)時間序列模型

RNN需要離散的觀測和發(fā)射間隔, 而定義連續(xù)的模型可以接收任意時間得到的數(shù)據(jù). 此種模型的構(gòu)建和展示詳見第5節(jié).

2 ODE求解器的反向自動微分

訓(xùn)練連續(xù)層網(wǎng)絡(luò)的主要問題就是對ODE求解器的反向微分(也叫反向傳播). 直接根據(jù)求解器內(nèi)部操作來求微分的內(nèi)存占用過大, 并且會引入額外的誤差.

我們把ODE求解器當(dāng)做黑盒, 用”伴隨靈敏度法”(adjoint sensitivity method)來求梯度. 這種計算法是通過計算另一個參數(shù)化的ODE來實現(xiàn)的. 這種方法的復(fù)雜度會根據(jù)問題的規(guī)模線性變化, 內(nèi)存占用也很低, 并且可以顯式的控制計算精度.

假設(shè)標(biāo)量的損失函數(shù)為L, 輸入是ODE求解器的結(jié)果:

image.png

為最小化L, 就需要求L對θ的梯度, 第一步就是要求L在每一個時刻對隱狀態(tài)z(t)的梯度. 這部分被稱為”伴隨”:

image.png

它也是一個ODE, 可以視作瞬時的鏈?zhǔn)椒▌t:

image.png

這樣, 再調(diào)一次求解器就可以解出
image.png

. 這個求解是反向進(jìn)行的, 初始狀態(tài)是
image.png

解這個ODE就需要知道從t0到t1軌跡上的所有z(t). 所以在求伴隨的過程中需要把z(t)也一并解出, 就可以在中間的軌跡上使用z(t)的值來求a(t)了.

計算L對θ的偏導(dǎo)則需要求第三個積分式:

image.png

這個式子需要知道z(t)和a(t)的值.

image.png

image.png

這兩個向量-jacobian 乘積可以通過一次自動微分直接得到, 時間消耗跟對f的估值差不多. 只要把初始狀態(tài), 伴隨和另一個偏導(dǎo) concat 到一個向量中, 所有求解z,a和
image.png

的積分, 都可以通過調(diào)用一次ODE求解器計算得出. 如下算法1的偽代碼:
image.png

大多數(shù)的ODE求解器都可以輸出中間計算結(jié)果z(t), 當(dāng)loss取決于這些中間狀態(tài)時, 反向偏導(dǎo)的計算也必須拆成一系列的求解. 如圖2所示:


image.png

圖2: ODE求解器的反向過程.

伴隨敏感度法求反向是分時刻實時求解的. 參數(shù)化的系統(tǒng)包括了初始狀態(tài)以及l(fā)oss對狀態(tài)的靈敏度. 如果損失直接依賴于多個時刻的隱狀態(tài)的觀測, 伴隨狀態(tài)也必須在loss對觀測的偏導(dǎo)方向上更新.

在每個觀測處, 伴隨都必須跟著偏導(dǎo)
image.png

的方向調(diào)整.

在附錄C中給出了L關(guān)于t0, t1偏導(dǎo)的解法. 附錄B中給出上面公式的詳細(xì)推導(dǎo)過程. 附錄D給出了上述算法scipy實現(xiàn), 這部分代碼也支持更高階的微分.

https://github.com/rtqichen/torchdiffeq中還給出了pytorch版本的實現(xiàn).

3 用ODE來取代ResNet進(jìn)行有監(jiān)督的訓(xùn)練

本節(jié)嘗試用神經(jīng)ODE進(jìn)行有監(jiān)督訓(xùn)練.

軟件: (作者說自己選取了某某ODE求解器, 還用一個第三方框架實現(xiàn)了求反向, 但是在pytorch版代碼中這些都對不上)

模型結(jié)構(gòu): 使用了一個小的殘差網(wǎng)絡(luò), 對輸入進(jìn)行了2次下采樣, 然后疊了6個標(biāo)準(zhǔn)殘差鏈接層, 這6個殘差連接層替換成ODE求解器模塊. 還測試了一下同樣結(jié)構(gòu), 但是反向直接用鏈?zhǔn)椒▌t求解的網(wǎng)絡(luò), 記為RK-Net. 各網(wǎng)絡(luò)的表現(xiàn)如下:

image.png

可以看到, ODE網(wǎng)絡(luò)和RK網(wǎng)絡(luò)可以達(dá)到和ResNet相同的性能.

ODE****網(wǎng)絡(luò)的誤差控制: ODE求解器可以保證計算誤差在真實解的某個誤差限內(nèi). 更改這個誤差限會改變網(wǎng)絡(luò)的性能表現(xiàn). 圖3a展示了誤差是可控的. 圖3b展示了前向計算時間是跟著函數(shù)估值次數(shù)成比例增加的. 所以降低誤差限可以在計算速度和精度之間做取舍. 你可以在訓(xùn)練時用高精度, 但是在推理時用低精度來加快速度..

image.png

圖3c表明: 反向計算的消耗只有前向計算的一半左右. 這就表明, 伴隨法不但節(jié)省內(nèi)存, 還比直接求反向更加高效.

網(wǎng)絡(luò)深度: 在ODE中不太好直接定義網(wǎng)絡(luò)層數(shù)這個概念. 有點類似的是隱狀態(tài)方程估值所需的次數(shù), 這依賴于ODE求解器的輸入和初始狀態(tài). 圖3d展示了訓(xùn)練過程中估值次數(shù)的增加, 這對應(yīng)了模型復(fù)雜度的增長.

4 連續(xù)正則化流

還有一個模型也出現(xiàn)了類似式1的非連續(xù)型方程, 那就是正則化流(NF, normalization flows)和NICE framework. 這些模式使用變量代換定理來計算可逆變換之后的概率密度.


image.png

經(jīng)典的正則化流模型: planar normalization flows的公式如下:

image.png

一般來說, 使用變量代換公式的瓶頸是計算雅克比矩陣
image.png

, 它的計算復(fù)雜度要么是z維度的立方, 要么是隱藏單元數(shù)量的立方. 最近的研究都是在NF模型的表達(dá)能力和計算復(fù)雜度做取舍.

令人驚訝的是, 我們把非連續(xù)的模型公式, 用第3節(jié)同樣的思路來轉(zhuǎn)換成連續(xù)模型可以減少計算量.

定理1: 變量瞬時變化

設(shè)z(t)是一個有限連續(xù)隨機變量,概率p(z(t))依賴于時間. 則下式是z(t)隨時間連續(xù)變化的微分方程:

image.png

假設(shè)f在z上均勻Lipschitz連續(xù),在t上連續(xù),那么對數(shù)概率密度的變化也遵循微分方程:


image.png

證明見附錄A. 與式6的log計算不同, 本式只需要計算跡(trace)的操作. 另外, 不像標(biāo)準(zhǔn)的NF模型, 本式不要求f是可逆的, 因為如果滿足唯一性,那么整個轉(zhuǎn)換自然就是可逆的.

應(yīng)用變量瞬時變化定理,我們可以看一下planar normalization flows的連續(xù)模擬版本:

image.png

給定一個初始分布p(z(0),我們可以從p(z(T))中采樣,并通過求解這組ODE來評估其概率密度。

使用多個線性成本的隱藏單元

當(dāng)det(行列式)不是線性方程時, 跡的方程還是線性的, 并且滿足:

image.png

這樣我們的方程就可以由一系列的求和得到, 概率密度的微分方程也是一個求和:

image.png

這意味著我們可以很簡便的評估多隱藏單元的流模型,其成本僅與隱藏單元M的數(shù)量呈線性關(guān)系。使用標(biāo)準(zhǔn)的NF模型評估這種“寬”層的成本是O(M3),這意味著標(biāo)準(zhǔn)NF體系結(jié)構(gòu)的多個層只使用單個隱藏單元.

依賴于時間的動態(tài)方程

我們可以將流的參數(shù)指定為t的函數(shù),使微分方程f(z(t)、t)隨t而變化。這種參數(shù)化的方法是一種超網(wǎng)絡(luò). 我們還為每個隱藏層引入了門機制:

image.png

其中:
image.png

, 是一個神經(jīng)網(wǎng)絡(luò), 可以學(xué)習(xí)到何時使用fn. 我們把該模型稱之為連續(xù)正則化流(CNF, continuous normalizing flows)

4.1 CNF試驗

我們首先比較連續(xù)的和離散的planar正則化流在學(xué)習(xí)樣本從一個已知的分布。我們證明了一個具有M個隱藏單元的連續(xù) planar CNF至少可以與一個具有K層(M = K)的離散 planar NF具有同樣的擬合能力,某些情況下CNF的擬合能力甚至更強.

擬合概率密度

設(shè)置一個前述的CNF, 用adam優(yōu)化器訓(xùn)練10000個step. 對應(yīng)的NF使用RMSprop訓(xùn)練500000個step. 此任務(wù)中損失函數(shù)為KL (q(x)||p(x)), 最小化這個損失函數(shù), 來用q(x)擬合目標(biāo)概率分布p(x). 圖4表明, CNF可以得到更低的損失.

[圖片上傳失敗...(image-7d47a5-1616472352555)]

極大似然訓(xùn)練

CNF一個有用的特性是: 計算反向轉(zhuǎn)換和正向的成本差不多, 這一點是NF模型做不到的. 這樣在用CNF模型做概率密度估計任務(wù)時, 我們可以通過極大似然估計來進(jìn)行訓(xùn)練 也就是最大化log(q(x))的期望值. 其中q是變量代換之后的函數(shù). 然后反向轉(zhuǎn)換CNF來從q(x)中進(jìn)行采樣.

該任務(wù)中, 我們使用64個隱藏單元的CNF和64層的NF來進(jìn)行對比. 圖5展示了最終的訓(xùn)練結(jié)果. 從最初的高斯分布, 到最終學(xué)到的分布, 每一個圖代表時間t的某一步. 有趣的是: 為了擬合兩個圓圈, CNF把planar 流 進(jìn)行了旋轉(zhuǎn), 這樣粒子會均分到兩個圓中. 跟 CNF的平滑可解釋相對的是, NF模型比較反直覺, 并且很難擬合雙月牙的概率分布(見圖5.b)

[圖片上傳失敗...(image-687aaa-1616472365302)]

5 生成式隱方程時間序列模型

將神經(jīng)網(wǎng)絡(luò)應(yīng)用于不規(guī)則采樣的數(shù)據(jù),如醫(yī)療記錄、網(wǎng)絡(luò)流量或神經(jīng)尖峰數(shù)據(jù)是困難的。 通常,觀測被放入固定持續(xù)時間的桶中,隱方程(變量?原文是dynamic)以同樣的方式進(jìn)行離散。如果存在數(shù)據(jù)缺失或隱變量定義不當(dāng)?shù)那闆r, 問題就比較困難. 數(shù)據(jù)缺失可以用數(shù)據(jù)填充和生成時間序列模型來進(jìn)行標(biāo)記. 還有一種方式是給RNN的輸入加時間戳信息.

我們提出了一種連續(xù)時間,生成的方法來建模時間序列。我們的模型用一個隱軌跡來表示每個時間序列。每個軌跡都是由一個局部初始狀態(tài)zt0和跨所有時間序列共享的全局隱方程組來確定。給定觀測時間t0、t1、……tN和初始狀態(tài)zt0,ODE求解算器產(chǎn)生zt1,…ztN,描述每個觀測的潛在狀態(tài)。我們通過一個采樣程序正式地定義了這個生成模型:

image.png

函數(shù)f是一個時間無關(guān)的函數(shù),在當(dāng)前時間步長取z并輸出梯度:

image.png

我們用神經(jīng)網(wǎng)絡(luò)來參數(shù)化這個方程. 因為f是時間無關(guān)的, 給定隱狀態(tài)z(t), 整個隱軌跡就是唯一確定的. 推斷隱軌跡可以讓我們在時間上任意向前或后退做出預(yù)測

image.png

訓(xùn)練與預(yù)測

我們可以用觀測的序列將這個潛變量模型訓(xùn)練為變分自動編碼器. 我們的判別模型RNN倒序的接收時間序列數(shù)據(jù), 輸出q φ (z 0 |x 1 ,x 2 ,...,x N ). 詳見附錄E. 使用ODE來做生成模型, 我們就能在已知時間序列的情況下, 在任意時間點做出預(yù)測.

泊松過程似然

觀測本身就給出了一些隱狀態(tài)的信息, 比如說: 得病的人更傾向于做藥物測試. 事件發(fā)生率可以用隱方程來進(jìn)行參數(shù)化:

image.png

給定這個概率函數(shù),非均勻泊松過程給出了區(qū)間[tstart,tend]中獨立觀測的可能性:

image.png

我們可以使用另一個神經(jīng)網(wǎng)絡(luò)來參數(shù)化λ(·)。因此,我們可以調(diào)用一次ODE求解器就評估出隱軌跡和泊松過程概率值。圖7為該模型在數(shù)據(jù)集上學(xué)習(xí)到的事件發(fā)生率。


image.png

觀測時間上的泊松過程似然可以與數(shù)據(jù)似然相結(jié)合,共同模擬所有觀測和時間。

5.1 事件序列隱ODE試驗

我們研究了隱ODE模型的擬合和推斷時間序列的能力。該判別網(wǎng)絡(luò)是一個有25個隱藏單元的RNN。我們使用一個四維的隱空間。我們用一個具有20個隱藏單元的單隱藏層網(wǎng)絡(luò)來參數(shù)化函數(shù)f。解碼器是一個神經(jīng)網(wǎng)絡(luò), 只有一個隱藏層, 20個隱藏單元, 用于計算p(x t i |z t i )。我們的基線是一個有25個隱藏單元的RNN,用最小化負(fù)高斯對數(shù)似然為目標(biāo)函數(shù)訓(xùn)練。我們訓(xùn)練了這個RNN的第二個版本,其輸入與下一個觀測的時間差連接,以幫助RNN進(jìn)行不規(guī)則的觀測。

雙向螺旋數(shù)據(jù)集

我們生成了一個1000個二維螺旋的數(shù)據(jù)集,每個螺旋從一個不同的點開始,在100個相同間隔的時間步長采樣。 數(shù)據(jù)集包含兩種類型的螺旋:一半是順時針方向,另一半是逆時針方向。 為了模擬真實情況,我們在觀測中加入高斯噪聲。

具有不規(guī)則時間點的時間序列

為了生成不規(guī)則的時間戳,我們不替換的從每個軌跡隨機采樣 (n={30,50,100}). 訓(xùn)練數(shù)據(jù)之外, 我們展示了100個時間點的預(yù)測均方根誤差(RMSE)。 表2顯示,隱ODE預(yù)測時的RMSE明顯較低.

image.png

圖8展示了用下采樣的30個點來擬合螺旋的結(jié)果.

[圖片上傳失敗...(image-53050-1616472422476)]

隱ODE的重構(gòu)是通過對潛在軌跡的后驗采樣并將其解碼為數(shù)據(jù)空間得到的. 附錄F展示了更多不同數(shù)據(jù)點的情況. 我們發(fā)現(xiàn), 不管多少個點的下采樣, 不管有沒有高斯噪聲, 重建和推斷都和真實情況一致.

隱空間推斷

圖8c展示了隱軌跡投影到隱空間前2個維度的結(jié)果. 這是兩個軌跡群, 一個順時針一個逆時針. 圖9展示了: 初始狀態(tài)隱軌跡方程為順時針, 而后轉(zhuǎn)變?yōu)槟鏁r針, 這一轉(zhuǎn)變過程是非常連續(xù)的.


image.png

6 應(yīng)用范圍與限制.

Mini-Batch

Mini-Batch的使用不如標(biāo)準(zhǔn)神經(jīng)網(wǎng)絡(luò)那么直觀。我們?nèi)匀豢梢酝ㄟ^將每個batch的狀態(tài)連接在一起,創(chuàng)建維度D×K的ODE方程組,通過ODE求解器來計算。In some cases, controlling

error on all batch elements together might require evaluating the combined system K times more

often than if each system was solved individually(不太懂什么意思)。不過,在實踐中使用Mini-Batch時,計算量并沒有大幅增加.

唯一性

什么情況下連續(xù)方程有唯一解? 皮卡存在定理限定了, 當(dāng)微分方程Lipschitz連續(xù)并且z在t上連續(xù)時, 初值問題的解存在且唯一. 這就對我們使用的神經(jīng)網(wǎng)絡(luò)有所限制, 模型的權(quán)重有限, 且不能使用非Lipschitz連續(xù)的激活函數(shù), 比如tanh或者relu.

設(shè)置計算精度

模型允許用戶在計算精度和速度之間做trade-off, 需要用戶在訓(xùn)練的前向和反向中設(shè)置誤差限. 對于序列模型, 默認(rèn)值為1.5e-8. 在分類和概率密度擬合問題中, 不降低模型性能的情況下, 默認(rèn)值可設(shè)置為1e-3和1e-5.

重建前向軌跡

如果重建的軌跡偏離了原軌跡,則通過向后運行的方程來重建狀態(tài)軌跡會帶來額外的數(shù)值誤差。這個問題可以通過checkpoint來解決:將z的中間值存儲在前向過程中,并通過從這些點重新積分來重建精確的前向軌跡。不過在實際計算中這不是一個問題,多層CNF的反向可以恢復(fù)到初始狀態(tài).

7 相關(guān)工作

8 結(jié)語

我們探索了黑盒ODE求解器作為模型的一部分, 并用它開發(fā)了新模型可以用于時間序列問題, 監(jiān)督學(xué)習(xí)問題, 概率密度估計問題. 這些模型可以自適應(yīng)的進(jìn)行估值計算, 并且允許用戶顯式的在計算速度和精度之間做取舍. 最終, 我們提出了連續(xù)版本的變量代換模型, 命名為CNF, 該模型的層可以擴展到比較大的尺度.

9 注:

我沒有對附錄和參考文獻(xiàn)做翻譯, 這部分大家請下載論文原文查看: https://arxiv.org/pdf/1806.07366

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容