經常在程序中看到有tf.app.flags和tf.app.run字樣的代碼,這兩段代碼究竟是什么作用,就讓我們從分析源碼的角度來加深理解!
tf.app.flags
tf.app.flags主要用于處理命令行參數的解析工作,其實可以理解為一個封裝好了的argparse包(argparse是一種結構化的數據存儲格式,類似于Json、XML)?,F在我們就從源碼來分析它究竟是怎么解析命令行參數的,應該怎么使用它!
源碼如下:
不出意外的首先導入了argparse包

使用argparse的第一步就是創(chuàng)建一個解析器對象,告訴它將會有些什么參數。當程序運行時,該解析器可以用于處理命令行參數(只能解析參數、獲取參數、設置已有參數的默認值等操作)。argparse中的解析器類是ArgumentParser。

定義了_FlagValues類,如我們前面所說,要處理命令行參數,就要用解析器類_global_parser里的方法來解析,這里使用了parse_known_args()這個函數,其實同parse_args函數差不多(注:這里說的parse_args()函數和此處_FlagValues類中定義的_parse_args()函數不一樣,前者也是argparse中一種解析參數的函數),只是這個函數在接受到多余的命令行參數時不會報錯,會原封不動的以一個list形式將其返回。所以此函數返回的”result“是參數解析完的數據,而”unparsed“是那些未被解析的參數list。將命令行傳入的命令和數據解析出來以字典的形式放到__dict__的['_flags']這個字典中,這么做也是為了方便我們后續(xù)直接訪問命令行輸入的命令,因為可以直接通過字典調用(在tensorflow中其實是通過tf.app.flag.Flags來實現實例化這個類,然后再調用里面解析得到的參數即可)。

初始化完了之后,可以看到源碼里是一些setattr/getattr的方法,也就是一些設置和獲得解析的命令行參數的方法。要注意的是,在獲得參數的時候(getattr),首先要通過解析字典中的’parsed’來檢驗參數是否已經被解析過,因為在_parse_flags方法中,只要解析過參數(也即是運行過該函數),那么self.__dict__[‘__parsed’]就會為True(表明解析過參數)。因為這里是獲取參數,所以出了要判斷參數是否在字典里的基本要求外,還要判斷有沒有解析過參數,沒有就運行_parse_flags解析參數。其它就比較簡單,這里就不介紹了!

后面將上述整個_FlagValues類的實例化,這樣就方便了我們的訪問操作。因為我們要訪問命令行輸入的命令時,就可以直接從這個實例里操作。

注意從這里開始都是在類外定義的方法,所以要調用就只能通過tf.app.flags.XXX來實現了。
下面的_define_helper函數中調用了_global_parser.add_argument完成對命令行參數的添加(傳入flag_name,default_value,docstring,flagtype參數),可以看到添加參數使用的是解析器類_global_parser的方法。仔細看這個函數的參數,第一個參數是‘--’+flag_name這個表示我們定義的命令行參數使用時必須以‘--’開頭,比如--flag_int9(具體看后面例子),而第二個參數default_value是參數的默認值,第三個參數docstring保存幫助信息(命令行中輸入 -h激活該參數),第四個參數表示限定了賦予命令行參數數據的類型。

上面我們已經看到了使用_define_helper參數即可以添加命令行參數,這里源碼中又將其封裝為針對string/int/float/bool類型參數的特定添加方法**。
看DEFINE_string(),這里則由于_define_helper()最后一個type參數是str,上面我們關于_define_helper參數的解釋,說明DEFINE_string()限定了可選參數輸入必須是string,這也就是為什么這個函數定義為DEFINE_string(),同理,DEFINE_interger()限定可選參數必須是int,DEFINE_float()限定可選參數必須是float,DEFINE_boolean()限定可選參數必須是bool。




源碼中最后介紹的方法是在程序運行前先將某些命令行參數加入到”必備參數“(__required_flags)的字典中,以判斷解析完的參數是否滿足這些必備要求!因為mark_flags_as_required方法會調用mark_flag_as_required方法,來將當前傳入的參數加入到__required_flags字典中(_add_required_flag方法),在最上面解析參數的方法_parse_flags中,解析完參數會通過_assert_all_required方法判斷解析到的參數是否都在_required_flags字典中。

講了這么多,具體在tensorflow中我們該怎么使用呢?
首先我們通過tf.app.flags來調用這個flags.py文件,這樣我們就可以用flags.DEFINE_interger/float()來添加命令行參數,而FLAGS=flags.FLAGS可以實例化這個解析參數的類從對應的命令行參數取出參數。
新建test.py文件,并輸入如下代碼,代碼的功能是創(chuàng)建幾個命令行參數,然后把命令行參數輸出顯示
import tensorflow as tf
flags = tf.app.flags
flags.DEFINE_string('data_dir', '/tmp/mnist', 'Directory with the MNIST data.')
flags.DEFINE_integer('batch_size', 5, 'Batch size.')
flags.DEFINE_integer('num_evals', 1000, 'Number of batches to evaluate.')
FLAGS = flags.FLAGS
print(FLAGS.data_dir, FLAGS.batch_size, FLAGS.num_evals)
- 在命令行中輸入
test.py -h就可以查看幫助信息,也就是Directory with the MNIST data.,Batch size和Number of batches to evaluate這樣的消息。 - 在命令行中輸入
test.py --batchsize 10就可以將batch_size的值修改為10!
tf.app.run()
該函數一般都是出現在這種代碼中:
if __name__ == '__main__':
tf.app.run()
上述第一行代碼表示如果當前是從其它模塊調用的該模塊程序,則不會運行main函數!而如果就是直接運行的該模塊程序,則會運行main函數。
具體第二行的功能從源碼開始分析,源碼如下:

flags_passthrough=f._parse_flags(args=args)這里的parse_flags就是我們tf.app.flags源碼中用來解析命令行參數的函數。所以這一行就是解析參數的功能;
下面兩行代碼也就是tf.app.run的核心意思:執(zhí)行程序中main函數,并解析命令行參數!