tensorflow--cifar10數(shù)據(jù)集

Cifar10數(shù)據(jù)集有6w張圖片,每張圖片有32行32列像素點的紅綠藍三通道數(shù)據(jù),其中5w張十分類彩色圖片用于訓練,1w張用于測試。
十分類分別是:


cifar10.png

導入數(shù)據(jù)集:

cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

搭建一個一層卷積、兩側(cè)全連接的網(wǎng)絡(luò)來訓練cifar10數(shù)據(jù)集:

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, 
    Activation, MaxPool2D, Dropout, Flatten, Dense
import matplotlib.pyplot as plt

import os


# 加載數(shù)據(jù)集
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

#  構(gòu)建神經(jīng)網(wǎng)路
class BaseLine(Model):
    def __init__(self):
        super(BaseLine, self).__init__()
        #  一層卷積(CBAPD)
        self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding="same")
        self.b1 = BatchNormalization()
        self.a1 = Activation("relu")
        self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
        self.d1 = Dropout(0.2)

        # 兩層全連接
        self.flatten = Flatten()
        self.f1 = Dense(128, activation="relu")
        self.d2 = Dropout(0.2)
        self.f2 = Dense(10, activation="softmax")

    # 完成神經(jīng)網(wǎng)路的前向傳播
    def call(self, x):
        x = self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)
        x = self.d1(x)
        x = self.flatten(x)
        x = self.f1(x)
        x = self.d2(x)
        y = self.f2(x)
        return y


model = BaseLine()

# 配置訓練方法
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=["sparse_categorical_accuracy"]
)

# 斷點續(xù)訓,讀取模型
checkpoint_save_path = "cifar10/BaseLine.ckpt"
if os.path.exists(checkpoint_save_path + ".index"):
    print("*******load the model******")
    model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_save_path,
    save_weights_only=True,
    save_best_only=True
)

# 訓練模型
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),
                    validation_freq=1, callbacks=[cp_callback])

# 打印網(wǎng)絡(luò)結(jié)構(gòu)和參數(shù)
model.summary()

# 寫入?yún)?shù)
with open("cifar10_weights.txt", "w") as f:
    for v in model.trainable_variables:
        f.write(str(v.name) + "\n")
        f.write(str(v.shape) + "\n")
        f.write(str(v.numpy()) + "\n")


# 顯示訓練和預(yù)測的acc、loss曲線
acc = history.history["sparse_categorical_accuracy"]
val_acc = history.history["val_sparse_categorical_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
plt.subplot(1, 2, 1)
plt.plot(acc, label="train acc")
plt.plot(val_acc, label="validation acc")
plt.title("train & validation acc")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label="train loss")
plt.plot(val_loss, label="validation loss")
plt.title("train & validation loss")
plt.legend()
plt.show()

打印結(jié)果:(有省略)

_________________________________________________________________
flatten (Flatten)            multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  196736    
_________________________________________________________________
dropout_1 (Dropout)          multiple                  0         
_________________________________________________________________
dense_1 (Dense)              multiple                  1290      
=================================================================
Total params: 198,506
Trainable params: 198,494
Non-trainable params: 12

繪圖結(jié)果:


myplot.png
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容