使用tensorflow實現(xiàn)簡單的邏輯回歸

使用tensorflow提供的mnist數(shù)據(jù)集

```

import tensorflowas tf

import? numpyas np

import? matplotlib.pyplotas plt

import input_data

mnist = input_data.read_data_sets('./MNIST_data',one_hot=True)

trainning = mnist.train.images

train_labels = mnist.train.labels

testing = mnist.test.images

test_labels = mnist.test.labels

# print(trainning.shape)

# print(train_labels.shape)

# print(testing.shape)

# print(test_labels.shape)

'''

Extracting ./MNIST_data\train-images-idx3-ubyte.gz

Extracting ./MNIST_data\train-labels-idx1-ubyte.gz

Extracting ./MNIST_data\t10k-images-idx3-ubyte.gz

Extracting ./MNIST_data\t10k-labels-idx1-ubyte.gz

(55000, 784)

(55000, 10)

(10000, 784)

(10000, 10)

'''

#初始化變量x,y

x = tf.placeholder("float32",[None,784])

y = tf.placeholder("float32",[None,10])

#w,b都是為0的矩陣

w = tf.Variable(tf.zeros([784,10]))

b = tf.Variable(tf.zeros([10]))

actv = tf.nn.softmax(tf.matmul(x,w)+b)

cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))

learning_rate =0.01

optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# sess = tf.Session()

# init = tf.global_variables_initializer()

# sess.run(init)

# print(sess.run(cost))

#對比預(yù)測值和真實值的索引是否一樣

pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))

accr = tf.reduce_mean(tf.cast(pred,"float32"))

init = tf.global_variables_initializer()

#sess = tf.InteractiveSession()

#迭代次數(shù)

train_epochs =50

#m每次迭代的樣本

batch_size =100

display_step =5

sess= tf.Session()

sess.run(init)

for epochin range(train_epochs):

avg_cost =0

? ? num_batch =int(mnist.train.num_examples/batch_size)

for iin range(num_batch):

batch_x,batch_y =mnist.train.next_batch(batch_size)

sess.run(optm,feed_dict={x:batch_x,y:batch_y})

feeds = {x:batch_x,y:batch_y}

avg_cost +=sess.run(cost,feed_dict=feeds)/num_batch

if (epoch+1) % display_step ==0:

feed_train = {x:batch_x,y:batch_y}

feed_test = {x:mnist.train.images,y:mnist.train.labels}

train_acc = sess.run(accr,feed_dict=feed_train)

test_acc = sess.run(accr,feed_dict=feed_test)

print("Epoch:%03d/%03d Cost:%.9f train_acc:%.3f test_acc :%.3f" % (epoch,train_epochs,avg_cost,train_acc,test_acc))

```

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容