TensorFlow架構與設計:會話生命周期

TensorFlow的系統(tǒng)結構以C API為界,將整個系統(tǒng)分為「前端」和「后端」兩個子系統(tǒng):

  • 前端系統(tǒng):提供編程模型,負責構造計算圖;
  • 后端系統(tǒng):提供運行時環(huán)境,負責執(zhí)行計算圖。
系統(tǒng)架構

前端系統(tǒng)主要扮演Client的角色,主要負責計算圖的構造,并管理Session生命周期過程。

前端系統(tǒng)是一個支持多語言的編程環(huán)境,并提供統(tǒng)一的編程模型支撐用戶構造計算圖。Client通過Session,連接TensorFlow后端的「運行時」,啟動計算圖的執(zhí)行過程。

后端系統(tǒng)是TensorFlow的運行時系統(tǒng),主要負責計算圖的執(zhí)行過程,包括計算圖的剪枝,設備分配,子圖計算等過程。

本文首先以Session創(chuàng)建為例,揭示前端Python與后端C/C++系統(tǒng)實現(xiàn)的通道,闡述TensorFlow多語言編程的奧秘。隨后,以Python前端,C API橋梁,C++后端為生命線,闡述Session的生命周期過程。

Swig: 幕后英雄

前端多語言編程環(huán)境與后端C/C++實現(xiàn)系統(tǒng)的通道歸功于Swig的包裝器。TensorFlow使用Bazel的構建工具,在編譯之前啟動Swig的代碼生成過程,通過tf_session.i自動生成了兩個適配(Wrapper)文件:

  • pywrap_tensorflow.py: 負責對接上層Python調用;
  • pywrap_tensorflow.cpp: 負責對接下層C實現(xiàn)。

此外,pywrap_tensorflow.py模塊首次被加載時,自動地加載_pywrap_tensorflow.so的動態(tài)鏈接庫。從而實現(xiàn)了pywrap_tensorflow.pypywrap_tensorflow.cpp的函數(shù)調用關系。

pywrap_tensorflow.cpp的實現(xiàn)中,靜態(tài)注冊了一個函數(shù)符號表。在運行時,按照Python的函數(shù)名稱,匹配找到對應的C函數(shù)實現(xiàn),最終轉調到c_api.c的具體實現(xiàn)。

Swig代碼生成器

編程接口:Python

當Client要啟動計算圖的執(zhí)行過程時,先創(chuàng)建了一個Session實例,進而調用父類BaseSession的構造函數(shù)。

# tensorflow/python/client/session.py
class Session(BaseSession):
  def __init__(self, target='', graph=None, config=None):
    super(Session, self).__init__(target, graph, config=config)
    # ignoring others

BaseSession的構造函數(shù)中,將調用pywrap_tensorflow模塊中的函數(shù)。其中,pywrap_tensorflow模塊自動由Swig生成。

# tensorflow/python/client/session.py
from tensorflow.python import pywrap_tensorflow as tf_session

class BaseSession(SessionInterface):
  def __init__(self, target='', graph=None, config=None):
    self._session = None
    opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
    try:
      with errors.raise_exception_on_not_ok_status() as status:
        self._session = tf_session.TF_NewDeprecatedSession(opts, status)
    finally:
      tf_session.TF_DeleteSessionOptions(opts)
    # ignoring others

生成代碼:Swig

pywrap_tensorflow.py

pywrap_tensorflow模塊中,通過_pywrap_tensorflow將在_pywrap_tensorflow.so中調用對應的C++函數(shù)實現(xiàn)。

# tensorflow/bazel-bin/tensorflow/python/pywrap_tensorflow.py
def TF_NewDeprecatedSession(arg1, status):
    return _pywrap_tensorflow.TF_NewDeprecatedSession(arg1, status)

pywrap_tensorflow.cpp

pywrap_tensorflow.cpp的具體實現(xiàn)中,它靜態(tài)注冊了函數(shù)調用的符號表,實現(xiàn)Python的函數(shù)名稱到C++實現(xiàn)函數(shù)的具體映射。

# tensorflow/bazel-bin/tensorflow/python/pywrap_tensorflow.cpp
static PyMethodDef SwigMethods[] = {
    ...
     {"TF_NewDeprecatedSession", _wrap_TF_NewDeprecatedSession, METH_VARARGS, NULL},
}

PyObject *_wrap_TF_NewDeprecatedSession(
  PyObject *self, PyObject *args) {
  TF_SessionOptions* arg1 = ... 
  TF_Status* arg2 = ...
  
  TF_DeprecatedSession* result = TF_NewDeprecatedSession(arg1, arg2);
  // ignoring others implements
}

最終,自動生成的pywrap_tensorflow.cpp僅僅負責函數(shù)調用的轉發(fā),最終將調用底層C系統(tǒng)向上提供的API接口。

C API:橋梁

c_api.h是TensorFlow的后端執(zhí)行系統(tǒng)面向前端開放的公共API接口之一,自此將進入TensorFlow后端系統(tǒng)的浩瀚天空。

// tensorflow/c/c_api.c
TF_DeprecatedSession* TF_NewDeprecatedSession(
  const TF_SessionOptions*, TF_Status* status) {
  Session* session;
  status->status = NewSession(opt->options, &session);
  if (status->status.ok()) {
    return new TF_DeprecatedSession({session});
  } else {
    return NULL;
  }
}

后端系統(tǒng):C++

NewSession將根據(jù)前端傳遞的Session.target,使用SessionFactory多態(tài)創(chuàng)建不同類型的Session(C++)對象。

Status NewSession(const SessionOptions& options, Session** out_session) {
  SessionFactory* factory;
  Status s = SessionFactory::GetFactory(options, &factory);
  if (!s.ok()) {
    *out_session = nullptr;
    LOG(ERROR) << s;
    return s;
  }
  *out_session = factory->NewSession(options);
  if (!*out_session) {
    return errors::Internal("Failed to create session.");
  }
  return Status::OK();
}

會話生命周期

下文以前端Python,橋梁C API,后端C++為生命線,理順三者之間的調用關系,闡述Session的生命周期過程。

在Python前端,Session的生命周期主要體現(xiàn)在:

  • 創(chuàng)建Session(target)
  • 迭代執(zhí)行Session.run(fetches, feed_dict)
    • Session._extend_graph(graph)
    • Session.TF_Run(feeds, fetches, targets)
  • 關閉Session
  • 銷毀Session
sess = Session(target)
for _ in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
sess.close()

相應地,C++后端,Session的生命周期主要體現(xiàn)在:

  • 根據(jù)target多態(tài)創(chuàng)建Session
  • Session.Create(graph):有且僅有一次
  • Session.Extend(graph):零次或多次
  • 迭代執(zhí)行Session.Run(inputs, outputs, targets)
  • 關閉Session.Close
  • 銷毀Session對象
// create/load graph ...
tensorflow::GraphDef graph;

// local runtime, target is ""
tensorflow::SessionOptions options;

// create Session
std::unique_ptr<tensorflow::Session> 
sess(tensorflow::NewSession(options));

// create graph at initialization.
tensorflow::Status s = sess->Create(graph);
if (!s.ok()) { ... }

// run step
std::vector<tensorflow::Tensor> outputs;
s = session->Run(
  {},               // inputs is empty 
  {"output:0"},     // outputs names
  {"update_state"}, // target names
  &outputs);        // output tensors
if (!s.ok()) { ... }

// close
session->Close();

創(chuàng)建會話

上文介紹了Session創(chuàng)建的詳細過程,從Python前端為起點,通過Swig自動生成的Python-C++的包裝器為媒介,實現(xiàn)了Python到TensorFlow的C API的調用。

其中,C API是前端系統(tǒng)與后端系統(tǒng)的分水嶺。后端C++系統(tǒng)根據(jù)前端傳遞的Session.target,使用SessionFactory多態(tài)創(chuàng)建Session(C++)對象。

創(chuàng)建會話

后端C++系統(tǒng)中,Session的創(chuàng)建使用了抽象工廠方法,DirectionSession將啟動本地運行模式,GrpcSession將啟動基于RPC的分布式運行模式。

從嚴格的角色意義上劃分,GrpcSession依然扮演了Client的角色。它使用target,通過RPC協(xié)議與Master建立通信連接,因此,GrpcSession同時扮演了RPC Client的角色。

Session多態(tài)創(chuàng)建

創(chuàng)建/擴展圖

隨后,Python前端將調用Session.run接口,將構造好的計算圖,以GraphDef的形式發(fā)送給C++后端。

其中,前端每次調用Session.run接口時,都會試圖將新增節(jié)點的計算圖發(fā)送給后端系統(tǒng),以便后端系統(tǒng)將新增節(jié)點的計算圖Extend到原來的計算圖中。特殊地,在首次調用Session.run時,將發(fā)送整個計算圖給后端系統(tǒng)。

后端系統(tǒng)首次調用Session.Extend時,轉調(或等價)Session.Create;以后,后端系統(tǒng)每次調用Session.Extend時將真正執(zhí)行Extend的語義,將新增的計算圖的節(jié)點追加至原來的計算圖中。

隨后,后端將啟動計算圖執(zhí)行的準備工作。

創(chuàng)建/擴展圖

迭代運行

接著,Python前端Session.run實現(xiàn)將Feed, Fetch列表準備好,傳遞給后端系統(tǒng)。后端系統(tǒng)調用Session.Run接口。

后端系統(tǒng)的一次Session.Run執(zhí)行常常被稱為一次Step,Step的執(zhí)行過程是TensorFlow運行時的核心。

每次Step,計算圖將正向計算網絡的輸出,反向傳遞梯度,并完成一次訓練參數(shù)的更新。首先,后端系統(tǒng)根據(jù)Feed, Fetch,對計算圖(常稱為Full Graph)進行剪枝,得到一個最小依賴的計算子圖(常稱為Client Graph)。

然后,運行時啟動設備分配算法,如果節(jié)點之間的邊橫跨設備,則將該邊分裂,插入相應的SendRecv節(jié)點,實現(xiàn)跨設備節(jié)點的通信機制。

隨后,將分裂出來的子圖片段(常稱為Partition Graph)注冊到相應的設備上,并在本地設備上啟動子圖片段的執(zhí)行過程。

Run Step

關閉會話

當計算圖執(zhí)行完畢后,需要關閉Session,以便釋放后端的系統(tǒng)資源,包括隊列,IO等。會話關閉流程較為簡單,如下圖所示。

關閉會話

銷毀會話

最后,會話關閉之后,Python前端系統(tǒng)啟動GC,當Session.__del__被調用后,啟動后臺C++的Session對象銷毀過程。

銷毀會話

開源技術書

https://github.com/horance-liu/tensorflow-internals
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容