sklearn2pmml安裝使用

公司代碼是Java,但是算法部分使用了Python的sklearn,考慮用sklearn2pmml生成pmml文件,再由java調(diào)用,實(shí)現(xiàn)跨平臺(tái)使用。

  1. 安裝sklearn2pmml
pip install sklearn2pmml

需要注意的是,

  • scikit-learn的版本號(hào)需<=0.20.4,使用0.20.4之后的版本會(huì)報(bào)錯(cuò),
AttributeError: module 'sklearn.externals.joblib' has no attribute '__version__'

因?yàn)閟klearn.externals.joblib在0.21中棄用,將在0.23中刪除。

DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+
  • java版本號(hào)需>=1.7

我的配置是,

python: 3.6.8
sklearn: 0.20.4
sklearn.externals.joblib: 0.13.2
pandas: 0.24.1
sklearn_pandas: 1.8.0
sklearn2pmml: 0.48.0
java: 1.8.0_144
  1. 測(cè)試Python代碼
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn2pmml import PMMLPipeline, sklearn2pmml

iris = load_iris()

train, test, train_labels, test_labels = train_test_split(iris.data, iris.target, test_size=0.2, random_state=0)

pipeline = PMMLPipeline([
    ("classifier", tree.DecisionTreeClassifier(random_state=9))
])

pipeline.fit(train, train_labels)

sklearn2pmml(pipeline, 'result.pmml', with_repr=True, debug=True)

生成的pmml文件如下圖所示,


image.png

運(yùn)行自己的代碼時(shí)可能會(huì)出現(xiàn)以下錯(cuò)誤,

RuntimeError: The JPMML-SkLearn conversion application has failed. The Java executable should have printed more information about the failure into its standard output and/or standard error streams

出現(xiàn)此錯(cuò)誤時(shí)需要查看train和train_labels的列名,要求沒(méi)有重復(fù)并且格式正確

  1. 測(cè)試Java代碼
    下載jpmml-sklearn-executable-1.5.7.jarpmml-evaluator-1.4.3.jar,并引用jar包創(chuàng)建新工程。
    經(jīng)驗(yàn)證,引用上述jar包不會(huì)報(bào)錯(cuò),不同的版本可能會(huì)報(bào)錯(cuò),
    error.png

    以下為Java代碼,
package javaTopython;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
 
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;

public class PmmlFile {
    public static void main(String[] args) throws Exception {
        String  pathxml="tree.pmml";
        Map<String, Double>  map=new HashMap<String, Double>();
        map.put("x1", 5.1);
        map.put("x2", 3.5);
        map.put("x3", 1.4);
        map.put("x4", 0.2);    
        predictLrHeart(map, pathxml);
    }
    
    public static void predictLrHeart(Map<String, Double> irismap,String  pathxml)throws Exception {
 
        PMML pmml;
        // 模型導(dǎo)入
        File file = new File(pathxml);
        InputStream inputStream = new FileInputStream(file);
        try (InputStream is = inputStream) {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
 
            ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory
                    .newInstance();
            ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory
                    .newModelEvaluator(pmml);
            Evaluator evaluator = (Evaluator) modelEvaluator;
 
            List<InputField> inputFields = evaluator.getInputFields();
            // 過(guò)模型的原始特征,從畫(huà)像中獲取數(shù)據(jù),作為模型輸入
            Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
            for (InputField inputField : inputFields) {
                FieldName inputFieldName = inputField.getName();
                Object rawValue = irismap
                        .get(inputFieldName.getValue());
                FieldValue inputFieldValue = inputField.prepare(rawValue);
                arguments.put(inputFieldName, inputFieldValue);
            }
 
            Map<FieldName, ?> results = evaluator.evaluate(arguments);
            List<TargetField> targetFields = evaluator.getTargetFields();
            //對(duì)于分類(lèi)問(wèn)題等有多個(gè)輸出。
            for (TargetField targetField : targetFields) {
                FieldName targetFieldName = targetField.getName();
                Object targetFieldValue = results.get(targetFieldName);
                System.err.println("target: " + targetFieldName.getValue()
                        + " value: " + targetFieldValue);
            }
        }
    }
}

運(yùn)行結(jié)果如下,

target y value: ProbabilityDistribution{result=0, probability_entries=[0=0.8876504283659372, 1=0.11232695495162393, 2=2.2616682438804697E-5]}

需要注意模型簡(jiǎn)化處理的情況,此時(shí)pmml文件中的<DataField>可能會(huì)省略掉系數(shù)為零的列,所以最好有一個(gè)檢驗(yàn)。

參考:
sklearn2pmml安裝使用

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容