MxNet源碼解析(2) symbol

1. 前言

我們?cè)谟?xùn)練之前,先建立好一個(gè)圖,然后我們可以在這個(gè)圖上做我們想做的優(yōu)化,這種形式稱為Symbolic Programs。相對(duì)應(yīng)的是Imperative Programs,也就是每一句代碼都對(duì)應(yīng)著程序的執(zhí)行,在這種情況下,我們可以寫類似于下面的代碼:

a = 2
b= a + 1
d = np.zeros(10)
for i in range(d):
    d += np.zeros(10)

這在symbolic的方式下是做不到的,因?yàn)樵趂or循環(huán)開始時(shí),程序并不知道d的值,也就無(wú)法判斷循環(huán)的次數(shù)。
因此我們可以說(shuō),symbolic更高效,imperative更靈活。

MxNet是一個(gè)異步式的訓(xùn)練框架,它支持上面的兩種形式。我們可以使用NDArray來(lái)進(jìn)行imperative形式的程序編寫,也可以使用symbol來(lái)建立圖。

2. op

先來(lái)了解operator,不了解operator可能就很難理解源碼中占據(jù)了很大一部分的operator的定義。就是通過(guò)這些operator來(lái)將symbol連接成為了一個(gè)圖。

  • OpManager:?jiǎn)卫Y(jié)構(gòu)體,通過(guò)OpManager::Global()總會(huì)返回同一個(gè)結(jié)構(gòu)體。Op的構(gòu)造函數(shù)會(huì)將OpManagerop_counter加一,并且將自己的index_注冊(cè)為當(dāng)前的op_counter
  • add_alias:將別名注冊(cè)到`dmlc::Registry<Op>中
  • Get:根據(jù)name返回Op
  • GetAttrMap

2.1 op

  • name:名字
  • description:該op的描述
  • num_inputs:輸入的個(gè)數(shù)
  • num_outputs:輸出的個(gè)數(shù)
  • get_num_outputs, get_num_inputs:函數(shù),返回輸出,輸入的個(gè)數(shù)
  • attr_parser:函數(shù),用于方便返回該op的參數(shù)
  • Op& Op::describe(const std::string& descr):方法用于將輸入注冊(cè)到description變量中,并返回這個(gè)op,方便接著調(diào)用其他方法。

2.2 幾個(gè)宏

  • #define NNVM_REGISTER_VAR_DEF(OpName):定義OpName
  • #define NNVM_REGISTER_VAR_DEF(TagName):定義TagName
#define NNVM_REGISTER_OP(OpName) \
  DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName, __COUNTER__) = \
    ::dmlc::Register<::nnvm::op>::Get()->__REGISTER_OR_GET(#OpName)

注冊(cè)op,并返回該op

3. Node

Node是組成symbol的基本組件。
結(jié)構(gòu)體NodeEntry包含了:

  • node:指向node的指針
  • index:輸出的索引值
  • version:輸入的version

結(jié)構(gòu)體NodeAttrs包含了:

  • op: 指向operator的指針
  • name: node的名字
  • dict:attributes的字典

Node包含:

  • attrs:結(jié)構(gòu)體NodeAttrs成員,存儲(chǔ)了op, name, attributes等信息。
  • inputs:輸入,是一個(gè)元素為NodeEntry的向量
  • control_deps:保存了應(yīng)該在該node執(zhí)行之前執(zhí)行的node。
  • op():返回該Node的operator,就是返回attrs中保存的op
  • Create():類方法,靜態(tài)方法,用于新建一個(gè)Node,返回指向它的指針
  • num_outputs:如果是變量,輸出為1,否則返回op的輸出

幾個(gè)函數(shù)

定義在文件op_attr_types.h

  • FListinputNames:返回輸入的名字,默認(rèn)return {'data'}
  • FNumVisibleOutputs:用于隱藏一些輸出
  • FListOutputNames:返回輸出的名字
  • FMutateInputs:返回該node會(huì)改變的node的索引值
  • FInferNodeEntryAttr:推理出AttrType
  • FInferShape:推理shape,也就是上面的AttrTypeTshape
  • FInferType:推理類型
  • TIsBackward是否是反向傳播
  • FInplaceOption
  • FGradient:返回node的梯度節(jié)點(diǎn)
  • FSetInputVarAttrOnCompose:為輸入設(shè)置attribute
  • FCorrectLayout:推理layout
  • FInputGraph:返回輸入,解釋為圖而不是數(shù)據(jù)

這些函數(shù)是在定義具體的op時(shí),可以選擇注冊(cè)對(duì)應(yīng)的函數(shù)。

4. Symbol

Symbol是為了使用Node建立Graph。Symbol是我們能夠直接接觸的類,它定義了一系列方法用于更方便地構(gòu)建圖。在symbol的成員outputs中,定義了一組由NodeEntry組成的向量。

  • outputs:該symbol包含的輸出,是一個(gè)元素是NodeEntry的向量
  • Copy:返回一個(gè)深拷貝,方式是通過(guò)遍歷Node,每次訪問(wèn)到的Node保存起來(lái),再建立起node之間的連接,最后將head加入到outputs中。
  • Symbol operator[] (size_t index) const:返回第n個(gè)輸出。
  • ListInputs:返回輸入
  • ListInputNames:返回輸入的名字
  • Compose:組合symbol
  • operator ():調(diào)用compose,來(lái)組合symbol
  • AddControlDeps:加入控制,用于有向圖的構(gòu)建
  • GetInternals:返回一個(gè)symbol,它的輸出是原來(lái)symbol的輸出加上所有中間輸出和輸入
  • GetChildren
  • SetAttrs:設(shè)置attribution
  • GetAttrs
  • CreateFunctor:給定op和attrs,返回一個(gè)symbol
    我認(rèn)為symbol中比較重要的函數(shù)是compose,在調(diào)用的時(shí)候我們是通過(guò)調(diào)用symbol的操作符()函數(shù),也就是operator (),該函數(shù)將參數(shù)傳遞給Compose

5. Graph

Graph就是計(jì)算的時(shí)候使用的圖

  • outputs:和symboloutputs一樣,類型為std::vector<NodeEntry>
  • attrs:定義了圖的一些屬性
  • PostOrderDFSVisit:后序遍歷圖,給定參數(shù)head,進(jìn)行拓?fù)渑判?。算法?strong>貌似,就是拓?fù)渑判蛩惴ā?/li>
  • DFSVisit:調(diào)用PostOrderDFSVisit,對(duì)圖的head進(jìn)行拓?fù)渑判?。參?shù)為:const std::vector<NodeEntry>& heads, FVisit fvisit,其中head是反向傳播時(shí)的頭節(jié)點(diǎn),fvisit是訪問(wèn)時(shí)調(diào)用的函數(shù),該方法將fvisit(*n)作為訪問(wèn)節(jié)點(diǎn)時(shí)的函數(shù),[](GNode n)->Node*{return->get();}作為hash函數(shù),這個(gè)函數(shù)看簽名返回的是一個(gè)指向節(jié)點(diǎn)的指針。圖的節(jié)點(diǎn)入度計(jì)算如下:
[](GNode n)->uint32_t {
  if (!(*n)) return 0;
  return (*n)->input.size() + (*n)->control_deps.size();
}

節(jié)點(diǎn)輸入計(jì)算如下:

[](GNode n, uint32_t index)->GNode {
  if (index < (*n)->input.size()) {
    return &(*n)->input.at(index).node;
  } else {
  return &(*n)->contorl_deps.at(index - (*n)->inputs.size());
}

6. IndexedGraph

IndexedGraphGraph返回,

  • nodes_:成員變量,一個(gè)指向Node結(jié)構(gòu)體的向量,Node定義如下:
struct Node {
  const nnvm::Node* source;
  array_view<NodeEntry> inputs;
  array_view<uint32_t> control_deps;
  std::weak_ptr<nnvm::Node> weak_ref;
};

其中NodeEntry如下:

struct NoodeEntry {
  uint32_t node_id;
  uint32_t index;
  uint32_t version;
};

成員變量:

  • input_nodes_:輸入node的索引
  • mutable_input_nodes_
  • outputs:輸出節(jié)點(diǎn)
  • node2index:node到索引的映射
  • entry_rptr_:
  • input_entries_
  • control_deps_
    方法:
  • DFSVisit
  • PostOrderDFSVisti

7. pass

7.1 gradient.cc

  • Gradientgradient會(huì)根據(jù)屬于的graph,返回一個(gè)帶反向傳播圖的新圖。它主要由executor建立圖的時(shí)候調(diào)用,調(diào)用方式如下:
nnvm::Graph g_grad = nnvm::pass::Gradient(g, 
            symbol.outputs, xs, head_grad_entry_, ArggregateGradient,
            need_mirror, nullptr, zero_ops, "_copy");

調(diào)用該方法會(huì)調(diào)用文件pass_function.h下的Gradient函數(shù)。該函數(shù)將傳入的參數(shù)保存在graph下的attrs中。再通過(guò)applypass調(diào)用Gradient方法。也就是在該文件下定義的方法,簽名:Graph Gradient(Graph src)。

  1. 根據(jù)DFSVisit進(jìn)行拓?fù)渑判?,將序列存?chǔ)到topo_order
  2. 將輸出的梯度保存在output_grads
  3. 根據(jù)mirror_fun在適當(dāng)?shù)牡胤讲迦胄碌墓?jié)點(diǎn),來(lái)實(shí)現(xiàn)內(nèi)存的復(fù)用
  • DefaultAggregateGradient

7.2 plan_memory.cc

7.3 place_device.cc

7.4 correct_layout.cc

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容