Java調(diào)用Keras、Tensorflow模型

實現(xiàn)python離線訓(xùn)練模型,Java在線預(yù)測部署。查看原文

目前深度學(xué)習(xí)主流使用python訓(xùn)練自己的模型,有非常多的框架提供了能快速搭建神經(jīng)網(wǎng)絡(luò)的功能,其中Keras提供了high-level的語法,底層可以使用tensorflow或者theano。

但是有很多公司后臺應(yīng)用是用Java開發(fā)的,如果用python提供HTTP接口,對業(yè)務(wù)延遲要求比較高的話,仍然會有一定得延遲,所以能不能使用Java調(diào)用模型,python可以離線的訓(xùn)練模型?(tensorflow也提供了成熟的部署方案TensorFlow Serving

手頭上有一個用Keras訓(xùn)練的模型,網(wǎng)上關(guān)于Java調(diào)用Keras模型的資料不是很多,而且大部分是重復(fù)的,并且也沒有講的很詳細(xì)。大致有兩種方案,一種是基于Java的深度學(xué)習(xí)庫導(dǎo)入Keras模型實現(xiàn),另外一種是用tensorflow提供的Java接口調(diào)用。

Deeplearning4J

Eclipse Deeplearning4j is the first commercial-grade, open-source, distributed deep-learning library written for Java and Scala. Integrated with Hadoop and Spark, DL4J brings AIAI to business environments for use on distributed GPUs and CPUs.

Deeplearning4j目前支持導(dǎo)入Keras訓(xùn)練的模型,并且提供了類似python中numpy的一些功能,更方便地處理結(jié)構(gòu)化的數(shù)據(jù)。遺憾的是,Deeplearning4j現(xiàn)在只覆蓋了Keras <2.0版本的大部分Layer,如果你是用Keras 2.0以上的版本,在導(dǎo)入模型的時候可能會報錯。

了解更多:
Keras Model Import: Supported Features
Importing Models From Keras to Deeplearning4j

Tensorflow

文檔,Java的文檔很少,不過調(diào)用模型的過程也很簡單。采用這種方式調(diào)用模型需要先將Keras導(dǎo)出的模型轉(zhuǎn)成tensorflow的protobuf協(xié)議的模型。

1、Keras的h5模型轉(zhuǎn)為pb模型

在Keras中使用model.save(model.h5)保存當(dāng)前模型為HDF5格式的文件中。
Keras的后端框架使用的是tensorflow,所以先把模型導(dǎo)出為pb模型。在Java中只需要調(diào)用模型進(jìn)行預(yù)測,所以將當(dāng)前的graph中的Variable全部變成Constant,并且使用訓(xùn)練后的weight。以下是freeze graph的代碼:

    def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
        """
        :param session: 需要轉(zhuǎn)換的tensorflow的session
        :param keep_var_names:需要保留的variable,默認(rèn)全部轉(zhuǎn)換constant
        :param output_names:output的名字
        :param clear_devices:是否移除設(shè)備指令以獲得更好的可移植性
        :return:
        """
        from tensorflow.python.framework.graph_util import convert_variables_to_constants
        graph = session.graph
        with graph.as_default():
            freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
            output_names = output_names or []
            # 如果指定了output名字,則復(fù)制一個新的Tensor,并且以指定的名字命名
            if len(output_names) > 0:
                for i in range(output_names):
                    # 當(dāng)前graph中復(fù)制一個新的Tensor,指定名字
                    tf.identity(model.model.outputs[i], name=output_names[i])
            output_names += [v.op.name for v in tf.global_variables()]
            input_graph_def = graph.as_graph_def()
            if clear_devices:
                for node in input_graph_def.node:
                    node.device = ""
            frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                          output_names, freeze_var_names)
            return frozen_graph

該方法可以將tensor為Variable的graph全部轉(zhuǎn)為constant并且使用訓(xùn)練后的weight。注意output_name比較重要,后面Java調(diào)用模型的時候會用到。

在Keras中,模型是這么定義的:

    def create_model(self):
        input_tensor = Input(shape=(self.maxlen,), name="input")
        x = Embedding(len(self.text2id) + 1, 200)(input_tensor)
        x = Bidirectional(LSTM(128))(x)
        x = Dense(256, activation="relu")(x)
        x = Dropout(self.dropout)(x)
        x = Dense(len(self.id2class), activation='softmax', name="output_softmax")(x)
        model = Model(inputs=input_tensor, outputs=x)
        model.compile(loss='categorical_crossentropy',
                      optimizer='adam',
                      metrics=['accuracy'])

下面的代碼可以查看定義好的Keras模型的輸入、輸出的name,這對之后Java調(diào)用有幫助。

print(model.input.op.name)
print(model.output.op.name)

訓(xùn)練好Keras模型后,轉(zhuǎn)換為pb模型:

from keras import backend as K
import tensorflow as tf

model.load_model("model.h5")
print(model.input.op.name)
print(model.output.op.name)
# 自定義output_names
frozen_graph = freeze_session(K.get_session(), output_names=["output"])
tf.train.write_graph(frozen_graph, "./", "model.pb", as_text=False)

### 輸出:
# input
# output_softmax/Softmax
# 如果不自定義output_name,則生成的pb模型的output_name為output_softmax/Softmax,如果自定義則以自定義名為output_name

運行之后會生成model.pb的模型,這將是之后調(diào)用的模型。

2、Java調(diào)用

新建一個maven項目,pom里面導(dǎo)入tensorflow包:

<dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.6.0</version>
</dependency>

核心代碼:

public void predict() throws Exception {
        try (Graph graph = new Graph()) {
            graph.importGraphDef(Files.readAllBytes(Paths.get(
                    "path/to/model.pb"
            )));
            try (Session sess = new Session(graph)) {
                // 自己構(gòu)造一個輸入
                float[][] input = {{56, 632, 675, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
                try (Tensor x = Tensor.create(input);
                    // input是輸入的name,output是輸出的name
                    Tensor y = sess.runner().feed("input", x).fetch("output").run().get(0)) {
                    float[][] result = new float[1][y.shape[1]];
                    y.copyTo(result);
                    System.out.println(Arrays.toString(y.shape()));
                    System.out.println(Arrays.toString(result[0]));
                }
            }
        }
    }

Graph和Tensor對象都是需要通過close()方法顯式地釋放占用的資源,代碼中使用了try-with-resources的方法實現(xiàn)的。

至此,已經(jīng)可以實現(xiàn)Keras離線訓(xùn)練,Java在線預(yù)測的功能。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容