tensorflow示例學(xué)習(xí)--貳 fully_connected_feed.py mnist.py

轉(zhuǎn)載請注明出處。

一、簡介:
1、相比于第一個(gè)例程,在程序上做了優(yōu)化,將特定功能以函數(shù)進(jìn)行封裝,獨(dú)立可能修改的變量,使程序架構(gòu)更清晰。加入了可視化、保存數(shù)據(jù)的功能。
2、模型

網(wǎng)絡(luò)結(jié)構(gòu).png

3、程序流程

程序流程.png

4、可視化

封裝視圖方法:'with tf.name_scope('name'):'
收集數(shù)據(jù):'tf.summary.scalar('name',name)','tf.summary.merge_all()'
在終端運(yùn)行:tensorboard --logdir=/direction/
其中direction是log保存位置,對于fully_connected_feed.py來說,完整命令應(yīng)是
tensorboard --logdir=/tmp/tensorflow/mnist/logs/fully_connected_feed/
完成后在瀏覽器輸入http://127.0.1.1:6006,即可查看網(wǎng)絡(luò)結(jié)構(gòu)圖,以及學(xué)習(xí)數(shù)據(jù)。
tensorboard.png

二、示例代碼

源:/tensorflow/tensorflow/examples/tutorials/mnist/mnist.py
1.導(dǎo)入模塊
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import tensorflow as tf
2.系統(tǒng)參數(shù),讓程序更直觀,方便修改
# The MNIST dataset has 10 classes, representing the digits 0 through 9.
NUM_CLASSES = 10

# The MNIST images are always 28x28 pixels.
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE

3.神經(jīng)網(wǎng)絡(luò)圖
   輸入:輸入圖片,隱藏層1神經(jīng)元個(gè)數(shù),隱藏層2神經(jīng)元個(gè)數(shù)
   輸出:神經(jīng)網(wǎng)絡(luò)輸出
def inference(images, hidden1_units, hidden2_units):
  # Hidden 1
  with tf.name_scope('hidden1'):
    weights = tf.Variable(
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden1_units]),
                         name='biases')
    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
  # Hidden 2
  with tf.name_scope('hidden2'):
    weights = tf.Variable(
        tf.truncated_normal([hidden1_units, hidden2_units],
                            stddev=1.0 / math.sqrt(float(hidden1_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden2_units]),
                         name='biases')
    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
  # Linear
  with tf.name_scope('softmax_linear'):
    weights = tf.Variable(
        tf.truncated_normal([hidden2_units, NUM_CLASSES],
                            stddev=1.0 / math.sqrt(float(hidden2_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([NUM_CLASSES]),
                         name='biases')
    logits = tf.matmul(hidden2, weights) + biases
  return logits

4.輸出損失計(jì)算方法
  輸入?yún)?shù) logits:網(wǎng)絡(luò)輸出,為float類型,[batch_size,NUM_CLASSES]
          labels:目標(biāo)標(biāo)簽,為int32類型,[batch_size]
  輸出:損失,float類型
def loss(logits, labels):
  labels = tf.to_int64(labels)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits, name='xentropy')
  return tf.reduce_mean(cross_entropy, name='xentropy_mean')

5.訓(xùn)練方法
  輸入:損失,學(xué)習(xí)速率
  輸出:訓(xùn)練op
  訓(xùn)練方法為梯度下降。
def training(loss, learning_rate):
  # Add a scalar summary for the snapshot loss.
  tf.summary.scalar('loss', loss)
  # Create the gradient descent optimizer with the given learning rate.
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  # Create a variable to track the global step.
  global_step = tf.Variable(0, name='global_step', trainable=False)
  # Use the optimizer to apply the gradients that minimize the loss
  # (and also increment the global step counter) as a single training step.
  train_op = optimizer.minimize(loss, global_step=global_step)
  return train_op

6.評估訓(xùn)練效果
  輸入:logits:網(wǎng)絡(luò)輸出,float32,[batch_size, NUM_CLASSES]
             labels:標(biāo)簽,int32,[batch_size]
  輸出:預(yù)測正確的數(shù)量
def evaluation(logits, labels):
  correct = tf.nn.in_top_k(logits, labels, 1)
  return tf.reduce_sum(tf.cast(correct, tf.int32))
源:/tensorflow/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
1.導(dǎo)入模塊
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=missing-docstring
import argparse
import os.path
import sys
import time

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist

# Basic model parameters as external flags.
FLAGS = None

2.占位符
目的:產(chǎn)生圖片及標(biāo)簽的占位符
輸入:batch_size
輸出:Images placehodler(float32),Labels placeholder(int32)
def placeholder_inputs(batch_size):
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                         mnist.IMAGE_PIXELS))
  labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
  return images_placeholder, labels_placeholder

3.填充喂養(yǎng)字典
目的:在訓(xùn)練時(shí)對應(yīng)次數(shù)自動(dòng)填充字典
輸入:數(shù)據(jù)源data_set,來自input_data.read_data_sets
     圖片holder,images_pl,來自placeholder_inputs()
     標(biāo)簽holder,labels_pl,來自placeholder_inputs()
輸出:喂養(yǎng)字典feed_dict.
def fill_feed_dict(data_set, images_pl, labels_pl)
  images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict

4.評估
目的:每循環(huán)1000次或結(jié)束進(jìn)行一次評估。
輸入:sess: 模型訓(xùn)練所使用的Session
    eval_correct: 預(yù)測正確的樣本數(shù)量
    images_placeholder: images placeholder.
    labels_placeholder: labels placeholder.
    data_set: 圖片和標(biāo)簽數(shù)據(jù),來自input_data.read_data_sets().
輸出:打印測試結(jié)果。
def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
  true_count = 0  # 記錄預(yù)測正確的數(shù)目。
  steps_per_epoch = data_set.num_examples // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size
  for step in xrange(steps_per_epoch):
    feed_dict = fill_feed_dict(data_set,
                               images_placeholder,
                               labels_placeholder)
    true_count += sess.run(eval_correct, feed_dict=feed_dict)
  precision = float(true_count) / num_examples
  print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
        (num_examples, true_count, precision))

5.訓(xùn)練過程
def run_training():
  # 獲取數(shù)據(jù)
  data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)
  # 在默認(rèn)Graph下運(yùn)行.
  with tf.Graph().as_default():
    # 配置graph
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)
    loss = mnist.loss(logits, labels_placeholder)
    train_op = mnist.training(loss, FLAGS.learning_rate)
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # 匯聚tensor
    summary = tf.summary.merge_all()
    # 建立初始化機(jī)制
    init = tf.global_variables_initializer()
    # 建立保存機(jī)制
    saver = tf.train.Saver()
    # 建立Session
    sess = tf.Session()

    # 建立一個(gè)SummaryWriter輸出匯聚的tensor
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    # 開始執(zhí)行

    # 執(zhí)行變量
    sess.run(init)

    # 開始訓(xùn)練,2000次循環(huán)
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      #獲取當(dāng)次循環(huán)的數(shù)據(jù)
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # 丟棄了train數(shù)據(jù)
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

      duration = time.time() - start_time

      # 每訓(xùn)練100次輸出當(dāng)前損失,并記錄數(shù)據(jù)
      if step % 100 == 0:
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        summary_str = sess.run(summary, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

      # 每1000次測試模型
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # Evaluate against the validation set.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # Evaluate against the test set.
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)


def main(_):
  if tf.gfile.Exists(FLAGS.log_dir):
    tf.gfile.DeleteRecursively(FLAGS.log_dir)
  tf.gfile.MakeDirs(FLAGS.log_dir)
  run_training()


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='Initial learning rate.'
  )
  parser.add_argument(
      '--max_steps',
      type=int,
      default=2000,
      help='Number of steps to run trainer.'
  )
  parser.add_argument(
      '--hidden1',
      type=int,
      default=128,
      help='Number of units in hidden layer 1.'
  )
  parser.add_argument(
      '--hidden2',
      type=int,
      default=32,
      help='Number of units in hidden layer 2.'
  )
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Batch size.  Must divide evenly into the dataset sizes.'
  )
  parser.add_argument(
      '--input_data_dir',
      type=str,
      default='/tmp/tensorflow/mnist/input_data',
      help='Directory to put the input data.'
  )
  parser.add_argument(
      '--log_dir',
      type=str,
      default='/tmp/tensorflow/mnist/logs/fully_connected_feed',
      help='Directory to put the log data.'
  )
  parser.add_argument(
      '--fake_data',
      default=False,
      help='If true, uses fake data for unit testing.',
      action='store_true'
  )

  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
三、API
1、tf.truncated_normal(shape,mean=0.0,stddev=1.0,dtype=tf.float32,seed=None,name=None)
      說明:1、從截?cái)嗾龖B(tài)分布中獲取隨機(jī)數(shù),截?cái)喾秶鸀?個(gè)標(biāo)準(zhǔn)差。
          2、'mean'平均值,默認(rèn)為0.0。
          3、'stddev'標(biāo)準(zhǔn)差,默認(rèn)為1.0。
2、tf.nn.in_top_k(predictions,targets,k,name=None)
      說明:1、判斷'targets'是否在前'k'個(gè)預(yù)測中,如targets在一號樣本中為5,predictions預(yù)測一號樣本結(jié)果概率最高的為5,
那么targets就在predictions的前1個(gè)樣本中,返回True。
四、小結(jié)
      并沒有什么好說,進(jìn)一步熟悉了tensorflow的使用。
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容