導(dǎo)入Keras函數(shù)模型入門
假設(shè)你使用Keras的函數(shù)API開始定義一個簡單的MLP:
from keras.models import Model
from keras.layers import Dense, Input
inputs = Input(shape=(100,))
x = Dense(64, activation='relu')(inputs)
predictions = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=predictions)
model.compile(loss='categorical_crossentropy',optimizer='sgd', metrics=['accuracy'])

在Keras,有幾種保存模型的方法。你可以將整個模型(模型定義、權(quán)重和訓(xùn)練配置)存儲為HDF5文件,僅存儲模型配置(作為JSON或YAML文件)或僅存儲權(quán)重(作為HDF5文件)。以下是你如何做每一件事:
model.save('full_model.h5') # save everything in HDF5 format
model_json = model.to_json() # save just the config. replace with "to_yaml" for YAML serialization
with open("model_config.json", "w") as f:
f.write(model_json)
model.save_weights('model_weights.h5') # save just the weights.

如果你決定保存完整的模型,那么你將能夠訪問模型的訓(xùn)練配置,否則你將不訪問。因此,如果你想在導(dǎo)入之后在DL4J中進(jìn)一步訓(xùn)練模型,請記住這一點,并使用model.save(...)來持久化你的模型。
載加你的Keras模型
讓我們從推薦的方法開始,將完整模型加載回DL4J(我們假設(shè)它在類路徑上):
String fullModel = new ClassPathResource("full_model.h5").getFile().getPath();
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel);

萬一你沒有編譯你的Keras模型,它就不會有一個訓(xùn)練配置。在這種情況下,你需要顯式地告訴模型導(dǎo)入忽略訓(xùn)練配置,方法是將enforceTrainingConfig標(biāo)志設(shè)置為false,如下所示:
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel, false);

若要僅從JSON加載模型配置,請按如下使用KerasModelImport
String modelJson = new ClassPathResource("model_config.json").getFile().getPath();
ComputationGraphConfiguration modelConfig = KerasModelImport.importKerasModelConfiguration(modelJson)

如果另外你還想加載模型權(quán)重與配置,那么以下是你要做的:
String modelWeights = new ClassPathResource("model_weights.h5").getFile().getPath();
MultiLayerNetwork network = KerasModelImport.importKerasModelAndWeights(modelJson, modelWeights)

在后面兩種情況下,將不讀取訓(xùn)練配置。
KerasModel
從Keras(函數(shù)API)模型或序列模型配置構(gòu)建計算圖。
KerasModel
public KerasModel(KerasModelBuilder modelBuilder)
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException

(建議)(函數(shù)API)模型的構(gòu)建器模式構(gòu)造器。
- 參數(shù) modelBuilder 構(gòu)建器對象
- 拋出 IOException IO 異常
- 拋出 InvalidKerasConfigurationException 無效的 Keras 配置
- 拋出 UnsupportedKerasConfigurationException 不支持的 Keras 配置
getComputationGraphConfiguration
public ComputationGraphConfiguration getComputationGraphConfiguration()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException

(不推薦)來自模型配置(JSON或YAML)、訓(xùn)練配置(JSON)、權(quán)重和“訓(xùn)練模式”布爾指示符的(函數(shù) API)模型的構(gòu)造器。當(dāng)內(nèi)置在訓(xùn)練模式時,某些不支持的配置(例如,未知的正則化器)將拋出異常。當(dāng)強(qiáng)制TrainingConfig= false時,這些將生成警告,但將被忽略。
- 參數(shù) modelJson 模型配置JSON 字符串
- 參數(shù) modelYaml 模型配置 YAML 字符串
- 參數(shù) enforceTrainingConfig 是否實施訓(xùn)練相關(guān)配置
- 拋出 IOException IO 異常
- 拋出 InvalidKerasConfigurationException 無效的 Keras 配置
- 拋出 UnsupportedKerasConfigurationException 不支持的 Keras 配置
getComputationGraph
public ComputationGraph getComputationGraph()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException

從這個Keras模型配置構(gòu)建計算圖并導(dǎo)入權(quán)重。
- 返回 ComputationGraph
getComputationGraph
public ComputationGraph getComputationGraph(boolean importWeights)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException

從這個Keras模型配置構(gòu)建計算圖并(可選的)導(dǎo)入權(quán)重。
- 參數(shù) importWeights 是否導(dǎo)入權(quán)重
- 返回 ComputationGraph
翻譯:風(fēng)一樣的男子

如果您覺得我的文章給了您幫助,請為我買一杯飲料吧!以下是我的支付寶,意思一下我將非常感激!
