Higher Library

Higher是FAIR開源的一個元學(xué)習(xí)框架,主要針對gradient-based meta-learning。在gradient-based meta-learning算法中,經(jīng)常涉及兩層優(yōu)化(Bi-level Optimization/Nested Optimization),以gradient-based hyper-params optimization為例

  • 第一層/底層 Inner Loop是training,在給定超參\varphi的情況下,優(yōu)化模型參數(shù)\theta
  • 第二層/頂層 Outer Loop是meta-training, 優(yōu)化超參\varphi

在每一步模型更新時,都要進行上述兩層優(yōu)化子步驟。讀者可能覺得這兩層優(yōu)化看起來平平無奇,與普通的模型訓(xùn)練沒什么區(qū)別。確實如此,但如何級聯(lián)這兩層優(yōu)化決定了最終算法的效果。Inner Loop需要“準(zhǔn)備”些東西以供Outer Loop使用。Higher庫的相關(guān)論文總結(jié)了此類Inner Loop- Outer Loop級聯(lián)的算法。

圖1 Higher庫的算法框架

如圖1所示,輸入為模型參數(shù)\theta_t,元參數(shù)\varphi_{\tau},I為meta-params更新次數(shù),J為inner loop展開的步數(shù)(朝前看的步數(shù),number of steps looking ahead)。如果I = J = 0,那么

2-6行描述虛擬更新

第2行:得到此時的超參\varphi^{opt}_0, \varphi^{loss}_0
第3行: 復(fù)制得到虛擬模型\theta_0' = \theta_t, 復(fù)制得到虛擬優(yōu)化器opt'_0 = opt_t
第4行: inner loop
第5行: 計算虛擬梯度,得到梯度G_0 = \nabla_{\theta_0'} l_{t+0}^{train}(\theta_0', \varphi_0^{loss}),保留梯度圖狀態(tài)(不清空梯度zero_grad)。
第6行: 虛擬更新, \theta_{1}' = opt'_0(\theta_0',\varphi_0^{opt}, G_0)
第8行: A_0 初始化

我們來看下它的用法吧。

model = MyModel()
opt = torch.optim.Adam(model.parameters())

# When you want to branch from the current state of your model and unroll
# optimization, follow this example. This context manager gets a snapshot of the
# current version of the model and optimizer at the point where you want to
# start unrolling and create a functional version `fmodel` which executes the
# forward pass of `model` with implicit fast weights which can be read by doing
# `fmodel.parameters()`, and a differentiable optimizer `diffopt` which ensures
# that at each step, gradient of `fmodel.parameters()` with regard to initial
# fast weights `fmodel.parameters(time=0)` (or any other part of the unrolled
# model history) is defined.

with higher.innerloop_ctx(model, opt) as (fmodel, diffopt):
    for xs, ys in data:
        logits = fmodel(xs)  # modified `params` can also be passed as a kwarg
        loss = loss_function(logits, ys)  # no need to call loss.backwards()
        diffopt.step(loss)  # note that `step` must take `loss` as an argument!
        # The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
        # these new parameters, as an alternative to getting them from
        # `fmodel.fast_params` or `fmodel.parameters()` after calling
        # `diffopt.step`.

        # At this point, or at any point in the iteration, you can take the
        # gradient of `fmodel.parameters()` (or equivalently
        # `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
        # `fmodel.init_fast_params`). i.e. `fast_params` will always have
        # `grad_fn` as an attribute, and be part of the gradient tape.

    # At the end of your inner loop you can obtain these e.g. ...
    grad_of_grads = torch.autograd.grad(
        meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容