tf.contrib.learn Quickstart
TensorFlow的機器學(xué)習(xí)高級API(tf.contrib.learn)使配置、訓(xùn)練、評估不同的學(xué)習(xí)模型變得更加容易。在這個教程里,你將使用tf.contrib.learn在Iris data set上構(gòu)建一個神經(jīng)網(wǎng)絡(luò)分類器。代碼有一下5個步驟:
- 在TensorFlow數(shù)據(jù)集上加載Iris
- 構(gòu)建神經(jīng)網(wǎng)絡(luò)
- 用訓(xùn)練數(shù)據(jù)擬合
- 評估模型的準(zhǔn)確性
- 在新樣本上分類
Complete Neural Network Source Code
這里是神經(jīng)網(wǎng)絡(luò)的源代碼:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import urllib
import numpy as np
import tensorflow as tf
# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"
IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
def main():
# If the training and test sets aren't stored locally, download them.
if not os.path.exists(IRIS_TRAINING):
raw = urllib.urlopen(IRIS_TRAINING_URL).read()
with open(IRIS_TRAINING, "w") as f:
f.write(raw)
if not os.path.exists(IRIS_TEST):
raw = urllib.urlopen(IRIS_TEST_URL).read()
with open(IRIS_TEST, "w") as f:
f.write(raw)
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TRAINING,
target_dtype=np.int,
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TEST,
target_dtype=np.int,
features_dtype=np.float32)
# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model")
# Define the training inputs
def get_train_inputs():
x = tf.constant(training_set.data)
y = tf.constant(training_set.target)
return x, y
# Fit model.
classifier.fit(input_fn=get_train_inputs, steps=2000)
# Define the test inputs
def get_test_inputs():
x = tf.constant(test_set.data)
y = tf.constant(test_set.target)
return x, y
# Evaluate accuracy.
accuracy_score = classifier.evaluate(input_fn=get_test_inputs,
steps=1)["accuracy"]
print("\nTest Accuracy: {0:f}\n".format(accuracy_score))
# Classify two new flower samples.
def new_samples():
return np.array(
[[6.4, 3.2, 4.5, 1.5],
[5.8, 3.1, 5.0, 1.7]], dtype=np.float32)
predictions = list(classifier.predict(input_fn=new_samples))
print(
"New Samples, Class Predictions: {}\n"
.format(predictions))
if __name__ == "__main__":
main()
Load the Iris CSV data to TensorFlow
Iris data set包含了150行數(shù)據(jù),3個種類:Iris setosa, Iris virginica, and Iris versicolor.
每一行包括了以下的數(shù)據(jù):花萼的寬度,長度,花瓣的寬度,花的種類?;ǖ姆N類有整數(shù)表示,0表示Iris setosa, 1表示Iris virginica, 2表示Iris versicolor.
| Sepal Length | Sepal Width | Petal Length | Petal Width | Species |
|---|---|---|---|---|
| 5.1 | 3.5 | 1.4 | 0.2 | 0 |
| 4.9 | 3.0 | 1.4 | 0.2 | 0 |
| 4.7 | 3.2 | 1.3 | 0.2 | 0 |
| … | … | … | … | … |
| 7.0 | 3.2 | 4.7 | 1.4 | 1 |
| 6.4 | 3.2 | 4.5 | 1.5 | 1 |
| 6.9 | 3.1 | 4.9 | 1.5 | 1 |
| … | … | … | … | … |
| 6.5 | 3.0 | 5.2 | 2.0 | 2 |
| 6.2 | 3.4 | 5.4 | 2.3 | 2 |
| 5.9 | 3.0 | 5.1 | 1.8 | 2 |
這里,Iris數(shù)據(jù)隨機分割成了兩組不同的CSV文件:
- 120個樣本的訓(xùn)練數(shù)據(jù)(iris_training.csv)
- 30個樣本的測試數(shù)據(jù)(iris_test.csv).
開始時,首先引進所有必要的模塊,然后定義下載存儲數(shù)據(jù)集的路徑:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import urllib
import tensorflow as tf
import numpy as np
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"
IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
然后,如果訓(xùn)練和測試集沒有在本地存儲,下載:
if not os.path.exists(IRIS_TRAINING):
raw = urllib.urlopen(IRIS_TRAINING_URL).read()
with open(IRIS_TRAINING,'w') as f:
f.write(raw)
if not os.path.exists(IRIS_TEST):
raw = urllib.urlopen(IRIS_TEST_URL).read()
with open(IRIS_TEST,'w') as f:
f.write(raw)
然后,用learn.datasets.base的load_csv_with_header()方法加載訓(xùn)練集和測試集成Dataset S,load_csv_with_header()包涵一下三個參數(shù):
- filename,CSV文件的路徑
- target_dtype,數(shù)據(jù)集目標(biāo)值的numpy數(shù)據(jù)類型
- features_dtype,數(shù)據(jù)集特征值的numpy數(shù)據(jù)類型
這里,目標(biāo)是花的種類,是0-2的整數(shù),所以數(shù)據(jù)類型是np.int:
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TRAINING,
target_dtype=np.int,
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TEST,
target_dtype=np.int,
features_dtype=np.float32)
tf.contrib.learn中的Dataset S是tuple,你可以通過data,target來訪問特征值和目標(biāo)值,比如,training_set.data,training_set.target
Construct a Deep Neural Network Classifier
tf.contrib.learn提供了多種預(yù)定義的模型,稱為 Estimator S,你可以用“黑盒子”在你的數(shù)據(jù)上來訓(xùn)練和評估節(jié)點。這里,你講配置深度神經(jīng)網(wǎng)絡(luò)分類器來擬合Iris數(shù)據(jù),你可以用tf.contrib.learn.DNNClassifier作為示例:
# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model")
首先定義特征所在的列,有4個特征,所以dimension設(shè)定為4.
然后,構(gòu)建了DNNClassifier,包含以下參數(shù):
- feature_columns=feature_columns.上面定義的特征的列
- hidden_units=[10, 20, 10]. 三個隱層,分別包含10,20,10個神經(jīng)元
- n_classes=3.三個目標(biāo)
- model_dir=/tmp/iris_model.訓(xùn)練模型時保存的斷點數(shù)據(jù)
Describe the training input pipeline
tf.contrib.learn API使用輸入函數(shù),創(chuàng)建TensorFlow節(jié)點來生成模型數(shù)據(jù)。這里,數(shù)據(jù)比較小,可以放在tf.constant。
# Define the test inputs
def get_train_inputs():
x = tf.constant(training_set.data)
y = tf.constant(training_set.target)
return x, y
Fit the DNNClassifier to the Iris Training Data
配置了DNN分類器,你可以用fit方法來擬合數(shù)據(jù),傳遞get_train_inputs到input_fn參數(shù)中,循環(huán)訓(xùn)練2000次:
# Fit model.
classifier.fit(input_fn=get_train_inputs, steps=2000)
等效于:
classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
如果你想追蹤訓(xùn)練模型,你可以用TensorFlow monitor來執(zhí)行節(jié)點的日志。
“Logging and Monitoring Basics with tf.contrib.learn”
Evaluate Model Accuracy
你已經(jīng)用訓(xùn)練數(shù)據(jù)擬合了模型,現(xiàn)在,你可以用evaluate方法在測試集上評估準(zhǔn)確性。像fit一樣,evaluate也需要一個輸入函數(shù)來構(gòu)建輸入的通道,并返回評估結(jié)果的字典。
# Define the test inputs
def get_test_inputs():
x = tf.constant(test_set.data)
y = tf.constant(test_set.target)
return x, y
# Evaluate accuracy.
accuracy_score = classifier.evaluate(input_fn=get_test_inputs,
steps=1)["accuracy"]
print("\nTest Accuracy: {0:f}\n".format(accuracy_score))
運行整個腳本,打印:
Test Accuracy: 0.966667
Classify New Samples
用predict()方法來分類新的樣本,比如,你有下面的兩個新樣本:
| Sepal Length | Sepal Width | Petal Length | Petal Width |
|---|---|---|---|
| 6.4 | 3.2 | 4.5 | 1.5 |
| 5.8 | 3.1 | 5.0 | 1.7 |
predict方法返回一個generator,可以轉(zhuǎn)換成list
# Classify two new flower samples.
def new_samples():
return np.array(
[[6.4, 3.2, 4.5, 1.5],
[5.8, 3.1, 5.0, 1.7]], dtype=np.float32)
predictions = list(classifier.predict(input_fn=new_samples))
print(
"New Samples, Class Predictions: {}\n"
.format(predictions))
結(jié)果大致如下:
New Samples, Class Predictions: [1 2]