tensorflow深度學(xué)習(xí)之生成數(shù)據(jù)集(二)

import os

import numpy as np

import tensorflow as tf

import input_data

import model


N_CLASSES = 2 ?#數(shù)據(jù)集分兩類

IMG_W = 64? # 圖片的高

IMG_H = 64?# 圖片的寬

BATCH_SIZE = 16

CAPACITY = 1000

MAX_STEP = 10000 # 學(xué)習(xí)的步長

learning_rate = 0.0001 # 學(xué)習(xí)率


def run_training():


? ? # you need to change the directories to yours.

? ? train_dir = '/Users/Desktop/cd/cd/Far_1/' ?#主要說下這個(gè)文件夾里邊的圖片 分成兩類 一類是帶image的圖片名稱, 一類是不帶。。 ?圖片的名稱叫什么都行,學(xué)習(xí)特征兩類,多類,都可以,需要自行修改代碼。我是參考識(shí)別貓和狗的代碼。。

? ? logs_train_dir = '/Users/Desktop/cd/cd/logs' #生成的日志文件,數(shù)據(jù)集和tensorflow學(xué)習(xí)的效率,可以使用?tensorbord進(jìn)行查看


? ? train, train_label = input_data.get_files(train_dir)


? ? train_batch, train_label_batch = input_data.get_batch(train,

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? train_label,

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? IMG_W,

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? IMG_H,

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? BATCH_SIZE,

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? CAPACITY)? ? ?

? ? train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES)

? ? train_loss = model.losses(train_logits, train_label_batch)? ? ? ?

? ? train_op = model.trainning(train_loss, learning_rate)

? ? train__acc = model.evaluation(train_logits, train_label_batch)


? ? summary_op = tf.summary.merge_all()

? ? sess = tf.Session()

? ? train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)

? ? saver = tf.train.Saver()


? ? sess.run(tf.global_variables_initializer())

? ? coord = tf.train.Coordinator()

? ? threads = tf.train.start_queue_runners(sess=sess, coord=coord)


? ? try:

? ? ? ? for step in np.arange(MAX_STEP):

? ? ? ? ? ? if coord.should_stop():

? ? ? ? ? ? ? ? ? ? break

? ? ? ? ? ? _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc])


? ? ? ? ? ? if step % 50 == 0:

? ? ? ? ? ? ? ? print('Step %d, train loss = %.2f, train accuracy = %.2f%%' %(step, tra_loss, tra_acc*100.0))

? ? ? ? ? ? ? ? summary_str = sess.run(summary_op)

? ? ? ? ? ? ? ? train_writer.add_summary(summary_str, step)


? ? ? ? ? ? if step % 2000 == 0 or (step + 1) == MAX_STEP:

? ? ? ? ? ? ? ? checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')

? ? ? ? ? ? ? ? saver.save(sess, checkpoint_path, global_step=step)


? ? except tf.errors.OutOfRangeError:

? ? ? ? print('Done training -- epoch limit reached')

? ? finally:

? ? ? ? coord.request_stop()


? ? coord.join(threads)

? ? sess.close()

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容