問(wèn)題
深度學(xué)習(xí)煉丹師們大都在面對(duì)某項(xiàng)任務(wù)是都會(huì)在github上搜索SOTA的模型實(shí)現(xiàn),clone下來(lái),嘗試魔改一番以適應(yīng)當(dāng)前任務(wù),評(píng)測(cè)指標(biāo)達(dá)標(biāo)可能就準(zhǔn)備上線(xiàn)了,然后遇到下一個(gè)任務(wù)就再來(lái)一遍。這樣會(huì)遇到兩個(gè)問(wèn)題:
- github的模型實(shí)現(xiàn)往往都是基于貢獻(xiàn)者的喜好選擇自微分框架,不同任務(wù)之間不能共用模塊,例如優(yōu)化器,學(xué)習(xí)率策略等,說(shuō)的玄幻點(diǎn)兒就是沒(méi)有技術(shù)沉淀。
- 多人協(xié)作更是不太可能,只能各自為戰(zhàn),由于沒(méi)有統(tǒng)一的構(gòu)建準(zhǔn)則,有了bug也只能自己硬gang了。
介紹
tensor2tensor為以上兩個(gè)問(wèn)題提供了很好的解決方案。Attention Is All You Need所提出的Transformer的官方實(shí)現(xiàn)就是基于tensor2tensor(T2T)的。T2T將一個(gè)深度學(xué)習(xí)任務(wù)抽象成一個(gè)T2TExperiment,其中包括Problem、T2TModel、hparams。
-
Problem主要負(fù)責(zé)預(yù)處理原始數(shù)據(jù)和輸入、輸出的數(shù)據(jù)格式定義。一方面根據(jù)tf.Example協(xié)議將原始數(shù)據(jù)轉(zhuǎn)寫(xiě)成TFRecord,通過(guò)problem.input_fn為模型train、eval階段提供dataset。另一方面,利用problem.hparams中預(yù)設(shè)的Modality轉(zhuǎn)化模型輸入與輸出和計(jì)算損失值。 -
T2TModel通過(guò)bottom、body、top構(gòu)建模型的核心運(yùn)算,其中bottom和top是Modality轉(zhuǎn)化數(shù)據(jù)階段。 -
hparams主要設(shè)置模型的超參,包括層數(shù)、層寬、優(yōu)化器、學(xué)習(xí)率策略等,common_hparams提供了一個(gè)基本配置。 -
T2TExperiment是對(duì)tf.estimator.Estimator一個(gè)封裝,根據(jù)模型的不同階段(train, eval, predict)通過(guò)T2TModel.make_estimator_model_fn中獲取不同的tf.estimator.EstimatorSpec,所以模型的訓(xùn)練是借助Estimator的train方法。create_run_config設(shè)置了模型訓(xùn)練的參數(shù),包括訓(xùn)練步數(shù),分布式訓(xùn)練策略、早聽(tīng)策略、模型保存策略等。
其中T2T是通過(guò)工廠(chǎng)模式管理Problem、T2TModel和hparams的,自定義的模塊可以借助registry注冊(cè)到相應(yīng)工廠(chǎng)。同樣的方式之后在基于pytorch的fairseq中也得到了應(yīng)用。
優(yōu)點(diǎn)
- 解耦深度學(xué)習(xí)任務(wù),每個(gè)階段只需要關(guān)注特定問(wèn)題。例如針對(duì)一個(gè)新的任務(wù),只需要構(gòu)建相應(yīng)的
Problem,復(fù)用已有的Model就可以直接訓(xùn)練了。 - 通過(guò)繼承
T2T中抽象的很多基類(lèi),例如Problem(Text2TextProblem、Text2ClassProblem),Model(Transformer、Resnet),來(lái)快速構(gòu)建自定義的任務(wù)。 - 自動(dòng)管理模型訓(xùn)練中的可視化監(jiān)控、驗(yàn)證集指標(biāo)、混合精度訓(xùn)練等,加快任務(wù)版本迭代,將有限的精力用在構(gòu)建模型主干上。