導(dǎo)入庫(kù)
我們先從導(dǎo)入相關(guān)庫(kù)開(kāi)始。這里我們導(dǎo)入TensorFlow,并將其簡(jiǎn)稱為tf來(lái)方便使用。
然后,我們導(dǎo)入一個(gè)名為numpy的庫(kù),該庫(kù)可以幫助我們輕松地將數(shù)據(jù)表示為列表。
將神經(jīng)網(wǎng)絡(luò)定義為一組順序?qū)拥目蚣芊Q為keras,我們也將其導(dǎo)入。
import tensorflow as tf
import numpy as np
from tensorflow import keras
定義并編譯神經(jīng)網(wǎng)絡(luò)
下面,我們將創(chuàng)建最簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)。它具有1層,并且該層只有1個(gè)神經(jīng)元,其輸入值也只有一個(gè)。
model = tf.keras.Sequential([keras.layers.Dense(input_shape=[1], units=1)])
現(xiàn)在,我們開(kāi)始編譯神經(jīng)網(wǎng)絡(luò)。在做之前,我們必須指定兩個(gè)函數(shù),一個(gè)是損失函數(shù),一個(gè)是優(yōu)化器。
假設(shè)我們的函數(shù)中,數(shù)字之間的關(guān)系為。當(dāng)計(jì)算機(jī)嘗試“學(xué)習(xí)”時(shí),會(huì)做出猜測(cè),也許認(rèn)為數(shù)字之間滿足
。損失函數(shù)根據(jù)已知的正確答案來(lái)衡量猜測(cè)的答案,并衡量其執(zhí)行的好壞程度。
然后,它使用optimizer函數(shù)進(jìn)行另一個(gè)猜測(cè)。基于損失函數(shù)的運(yùn)行方式,它將嘗試使損失最小化。到那時(shí),也許會(huì)得到類似的結(jié)果,雖然此時(shí)的結(jié)果仍然很糟糕,但是更接近正確的結(jié)果。
它將重復(fù)此操作,最終重復(fù)的數(shù)量由epochs給出。但是首先,我們要告它所要使用的方法,即對(duì)損失使用均方誤差,對(duì)optimizer使用隨機(jī)梯度下降方法。
model.compile(optimizer='sgd', loss='mean_squared_error')
提供數(shù)據(jù)
接下來(lái),我們將輸入一些數(shù)據(jù),取6組輸入與輸出。
x = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
y = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)
訓(xùn)練神經(jīng)網(wǎng)絡(luò)
訓(xùn)練神經(jīng)網(wǎng)絡(luò)的過(guò)程,是調(diào)用model.fit來(lái)學(xué)習(xí)x與y之間的關(guān)系。
model.fit(x, y, epochs=100)
經(jīng)過(guò)上述步驟之后,我們得到了一個(gè)訓(xùn)練后的神經(jīng)網(wǎng)絡(luò),它已經(jīng)學(xué)習(xí)到了x與y之間隱藏的數(shù)學(xué)關(guān)系。我們使用model.predict方法來(lái)找出未知的x所對(duì)應(yīng)的y。例如,如果x=10,y將會(huì)是什么?根據(jù)我們知道y應(yīng)該是19,那么神經(jīng)網(wǎng)絡(luò)給我們的輸出是多少?
print(model.predict([10.0]))
我們可以看到,預(yù)測(cè)的結(jié)果與我們期望的19非常接近。
code
import tensorflow as tf
import numpy as np
from tensorflow import keras
model = tf.keras.Sequential([keras.layers.Dense(input_shape=[1], units=1)])
model.compile(optimizer='sge', loss='mean_squared_error')
x = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
y = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)
model.fit(x, y, epochs=100)
print(model.predict([10.0]))