Keras多輸出模型構(gòu)建

1、多輸出模型

使用keras函數(shù)式API構(gòu)建網(wǎng)絡(luò):

# 輸入層
inputs = tf.keras.layers.Input(shape=(64,64,3))

# 卷積層及全連接層等相關(guān)層
x = tf.keras.layers.Dense(256, activation=tf.nn.relu)(inputs)

# 多輸出,定義兩個輸出,指定名字標識
fc_a=tf.keras.layers.Dense(name='fc_a',units=CLASS_NUM,activation=tf.nn.softmax)(x)
fc_b=tf.keras.layers.Dense(name='fc_b',units=CLASS_NUM,activation=tf.nn.softmax)(x)
# 單輸入多輸出
model = tf.keras.Model(inputs=inputs, outputs=[fc_a, fc_b])

# 目標函數(shù)定義,需與輸出層名字對應(yīng)
losses = {'fc_a': 'categorical_crossentropy',
          'fc_b': 'categorical_crossentropy'}

model.compile(optimizer=tf.train.AdamOptimizer(),
                loss=losses,
                metrics=['accuracy'])

2、自定義loss函數(shù)

def loss_a(y_true, y_pred):
    return tf.keras.losses.categorical_crossentropy(y_true, y_pred)

def loss_b(y_true, y_pred):
    return tf.keras.losses.meas_squared_error(y_true, y_pred)

losses = {'fc_a': loss_a,
          'fc_b': loss_b}

model.compile(optimizer=tf.train.AdamOptimizer(),
                loss=losses,
                metrics=['accuracy'])

3、批量訓(xùn)練

# data_generator返回的標簽形式要是與多輸出的數(shù)量對應(yīng)的數(shù)組
def data_generator(sample_num, batch_size):
    while True:
        max_num = sample_num - (sample_num % batch_size)
        for i in range(0, max_num, batch_size):
            ...
            yield (batch_x, [batch_a, batch_b])

model.fit_generator(generator=data_generator(sample_num, batch_size),
                    steps_per_epoch=sample_num//batch_size,
                    epoches=EPOCHES,
                    verbose=1)

4、調(diào)試

在自定義的loss函數(shù)中,是以Sequence的方式來輸入的,如果想調(diào)試查看loss的計算過程中的輸出,直接print是無法打印值的,這是因為tensorflow的每次op都要以sess為基礎(chǔ)來啟動,如果想調(diào)試,可以用eager_execution模式:

import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
np.set_printoptions(threshold=np.nan) # 輸出所有元素
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容