tensorflow 學(xué)習(xí)(3)-Lenet
Author:Joyner
學(xué)習(xí)mnist數(shù)據(jù)集訓(xùn)練
1.數(shù)據(jù)集
192.168.9.5:/DATACENTER1/zhiwen.wang/tensorflow-wzw/MNIST_data
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
2.代碼下載
https://github.com/sujaybabruwad/LeNet-in-Tensorflow
3.修改pre_data.py的路徑
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
def pre_data():
? ? mnist = input_data.read_data_sets("/DATACENTER1/zhiwen.wang/tensorflow-wzw/MNIST_data", reshape=False)
? ? X_train, y_train? ? ? ? ? = mnist.train.images, mnist.train.labels
? ? X_validation, y_validation = mnist.validation.images, mnist.validation.labels
? ? X_test, y_test? ? ? ? ? ? = mnist.test.images, mnist.test.labels
? ? assert(len(X_train) == len(y_train))
? ? assert(len(X_validation) == len(y_validation))
? ? assert(len(X_test) == len(y_test))
? ? print("Image Shape: {}".format(X_train[0].shape))
? ? print("Training Set:? {} samples".format(len(X_train)))
? ? print("Validation Set: {} samples".format(len(X_validation)))
? ? print("Test Set:? ? ? {} samples".format(len(X_test)))
? ? # Pad images with 0s
? ? X_train? ? ? = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
? ? X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')
? ? X_test? ? ? = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
? ? return X_train,y_train,X_validation,y_validation,X_test,y_test
4.訓(xùn)練(python3的環(huán)境下運(yùn)行)
cd /DATACENTER1/zhiwen.wang/tensorflow-wzw/Lenet-5-tensorflow/src
CUDA_VISIBLE_DEVICES=1 python3 main/train_and_evaluate.py
5.訓(xùn)練結(jié)果
