ONNX(Open Neural Network Exchange,开放神经网络交换)是为人工智能模型(深度学习和传统ML)提供的一种开放格式,可使模型在不同框架之间进行转移。详见 https://github.com/onnx/onnx
Alink提供了OnnxModelPredictBatchOp、OnnxModelPredictStreamOp和OnnxModelPredictor组件,分别处理批式场景、流式场景和进行Pipeline封装。
各组件都需要指定 ONNX 模型的modelPath
模型路径参数。模型路径可以是以下形式:
file://
加绝对路径,例如 file:///tmp/dnn.py
;res://
加路径,例如 res:///dnn.py
;http://
或 https://
路径;oss://
加路径和 Endpoint 和 access keyoss://bucket/xxx/xxx/xxx.py?host=xxx&access_key_id=xxx&access_key_secret=xxx
;hdfs://
加路径;参与模型预测的数据通过参数 selectedCols
设置,需要注意以下几点:
inputNames
,与 selectedCols
一一对应,表明某列对应某输入桩。inputNames
不填写时,默认与列名一致。Tensor
类型,不支持 Sequences
和 Maps
类型。float, double, int, long, byte, string
类型及其对应的 Alink Tensor
类型。模型输出信息通过参数 outputSchemaStr
指定,包括输出列名以及名称,需要注意以下几点:
outputNames
,与 outputSchemaStr
一一对应,表明某列对应某输入桩。outputNames
不填写时,默认与列名一致。Tensor
类型,不支持 Sequences
和 Maps
类型。outputSchemaStr
填写的输出类型需要是对应的输出桩类型,例如 输出桩类型 为 Float 类型的 Tensor 时,对应的 Alink 类型可以是 TENSOR
或者 FLOAT_TENSOR
,当输出仅包含一个元素时,还可以是 FLOAT
。组件使用的是 ONNX 1.11.0 版本,当有 GPU 时,自动使用 GPU 进行推理,否则使用 CPU 进行推理。
使用OnnxModelPredictBatchOp组件,可以加载ONNX模型进行批式预测。关于该组件的详细说明参见Alink文档 https://www.yuque.com/pinshu/alink_doc/onnxmodelpredictbatchop .
使用ONNX模型前,还需要将输入数据列的类型转换为Tensor格式,可以使用VectorToTensorBatchOp组件。具体代码如下所示:
new AkSourceBatchOp() .setFilePath(Chap13.DATA_DIR + Chap13.DENSE_TEST_FILE) .link( new VectorToTensorBatchOp() .setTensorDataType("float") .setTensorShape(1, 1, 28, 28) .setSelectedCol("vec") .setOutputCol("tensor") .setReservedCols("label") ) .link( new OnnxModelPredictBatchOp() .setModelPath( "https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/cnn_mnist_pytorch.onnx") .setSelectedCols("tensor") .setInputNames("0") .setOutputNames("21") .setOutputSchemaStr("probabilities FLOAT_TENSOR") ) .link( new UDFBatchOp() .setFunc(new GetMaxIndex()) .setSelectedCols("probabilities") .setOutputCol("pred") ) .lazyPrint(3) .link( new EvalMultiClassBatchOp() .setLabelCol("label") .setPredictionCol("pred") .lazyPrintMetrics() ); BatchOperator.execute();
这里用到了一个自定义函数,具体定义如下:
public static class GetMaxIndex extends ScalarFunction { public int eval(FloatTensor tensor) { int k = 0; float max = tensor.getFloat(0, 0); for (int i = 1; i < 10; i++) { if (tensor.getFloat(0, i) > max) { k = i; max = tensor.getFloat(0, i); } } return k; } }
批式任务的运行结果为:
label|tensor|probabilities|pred -----|------|-------------|---- 0|FloatTensor(1,1,28,28) |FloatTensor(1,10) |0 |[[[[0.0 0.0 0.0 ... 0.0 0.0 0.0]|[[0.0 -1434.9991 -1350.7235 ... -1287.3451 -1404.4778 -1434.9991]]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... | | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... ... | 9|FloatTensor(1,1,28,28) |FloatTensor(1,10) |9 |[[[[0.0 0.0 0.0 ... 0.0 0.0 0.0]|[[-1105.8378 -1105.8378 -1105.8378 ... -1105.8378 -732.2857 0.0]]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... | | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... ... | 1|FloatTensor(1,1,28,28) |FloatTensor(1,10) |1 |[[[[0.0 0.0 0.0 ... 0.0 0.0 0.0]|[[-1503.7417 0.0 -1402.91 ... -1180.878 -1370.2913 -1503.7417]]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... | | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... ... | -------------------------------- Metrics: -------------------------------- Accuracy:0.9904 Macro F1:0.9904 Micro F1:0.9904 Kappa:0.9893 |Pred\Real| 9| 8| 7|...| 2| 1| 0| |---------|---|---|----|---|----|----|---| | 9|988| 3| 3|...| 0| 0| 0| | 8| 1|962| 0|...| 1| 0| 1| | 7| 6| 2|1014|...| 3| 0| 1| | ...|...|...| ...|...| ...| ...|...| | 2| 0| 2| 3|...|1021| 0| 0| | 1| 3| 0| 6|...| 3|1134| 0| | 0| 2| 4| 0|...| 1| 0|978|
使用OnnxModelPredictStreamOp组件,可以加载ONNX模型进行批式预测。关于该组件的详细说明参见Alink文档 https://www.yuque.com/pinshu/alink_doc/onnxmodelpredictstreamop .
使用ONNX模型前,还需要将输入数据列的类型转换为Tensor格式,可以使用VectorToTensorStreamOp组件。具体代码如下所示:
new AkSourceStreamOp() .setFilePath(Chap13.DATA_DIR + Chap13.DENSE_TEST_FILE) .link( new VectorToTensorStreamOp() .setTensorDataType("float") .setTensorShape(1, 1, 28, 28) .setSelectedCol("vec") .setOutputCol("tensor") .setReservedCols("label") ) .link( new OnnxModelPredictStreamOp() .setModelPath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/cnn_mnist_pytorch" + ".onnx") .setSelectedCols("tensor") .setInputNames("0") .setOutputNames("21") .setOutputSchemaStr("probabilities FLOAT_TENSOR") ) .link( new UDFStreamOp() .setFunc(new GetMaxIndex()) .setSelectedCols("probabilities") .setOutputCol("pred") ) .sample(0.001) .print(); StreamOperator.execute();
运行结果为:
label|tensor|probabilities|pred -----|------|-------------|---- 5|FloatTensor(1,1,28,28) |FloatTensor(1,10) |5 |[[[[0.0 0.0 0.0 ... 0.0 0.0 0.0]|[[-1679.8259 -1679.8259 -1679.8259 ... -1679.8259 -1399.0499 -1008.8595]]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... | | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... ... | 8|FloatTensor(1,1,28,28) |FloatTensor(1,10) |8 |[[[[0.0 0.0 0.0 ... 0.0 0.0 0.0]|[[-565.4354 -879.83417 -691.5988 ... -898.3416 0.0 -784.97516]]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... | | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... ... | 7|FloatTensor(1,1,28,28) |FloatTensor(1,10) |7 |[[[[0.0 0.0 0.0 ... 0.0 0.0 0.0]|[[-1642.872 -1219.6741 -818.6717 ... 0.0 -1427.4896 -1179.5854]]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... | | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... ... | ......
学习了如何在批式任务和流式任务中使用ONNX模型,我们很容易在Pipeline中使用ONNX模型进行预测,只要将其中的批式/流式组件对应到Pipeline组件即可。具体代码如下:
new PipelineModel( new VectorToTensor() .setTensorDataType("float") .setTensorShape(1, 1, 28, 28) .setSelectedCol("vec") .setOutputCol("tensor") .setReservedCols("label"), new OnnxModelPredictor() .setModelPath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/cnn_mnist_pytorch.onnx") .setSelectedCols("tensor") .setInputNames("0") .setOutputNames("21") .setOutputSchemaStr("probabilities FLOAT_TENSOR") ).save(Chap13.DATA_DIR + PIPELINE_ONNX_MODEL, true); BatchOperator.execute(); PipelineModel .load(Chap13.DATA_DIR + PIPELINE_ONNX_MODEL) .transform( new AkSourceStreamOp() .setFilePath(Chap13.DATA_DIR + Chap13.DENSE_TEST_FILE) ) .link( new UDFStreamOp() .setFunc(new GetMaxIndex()) .setSelectedCols("probabilities") .setOutputCol("pred") ) .sample(0.001) .print(); StreamOperator.execute();
运行结果为:
label|tensor|probabilities|pred -----|------|-------------|---- 6|FloatTensor(1,1,28,28) |FloatTensor(1,10) |6 |[[[[0.0 0.0 0.0 ... 0.0 0.0 0.0]|[[-1234.3433 -1670.8674 -1670.8674 ... -1670.8674 -1384.003 -1670.8674]]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... | | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... ... | 7|FloatTensor(1,1,28,28) |FloatTensor(1,10) |7 |[[[[0.0 0.0 0.0 ... 0.0 0.0 0.0]|[[-615.27924 -152.31528 -290.78244 ... 0.0 -575.1346 -499.75998]]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... | | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... ... | 1|FloatTensor(1,1,28,28) |FloatTensor(1,10) |1 |[[[[0.0 0.0 0.0 ... 0.0 0.0 0.0]|[[-1557.5773 0.0 -1557.5773 ... -1363.7996 -1293.8707 -1557.5773]]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... | | [0.0 0.0 0.0 ... 0.0 0.0 0.0]| | ... ... | ......
除了通过Alink任务使用ONNX模型,也可以使用LocalPredictor进行嵌入式预测。示例代码如下,首先从数据集中抽取一行数据,输入数据的SchemaStr为“vec string, label int”;然后通过导入上一节保存的Pipeline模型,并设置输入数据的SchemaStr,得到LocalPredictor类型的实例localPredictor;如果不确定预测结果各列的含义,可以打印输出localPredictor的OutputSchema;使用localPredictor的map方法获得预测结果。
AkSourceBatchOp source = new AkSourceBatchOp() .setFilePath(Chap13.DATA_DIR + Chap13.DENSE_TEST_FILE); System.out.println(source.getSchema()); Row row = source.firstN(1).collect().get(0); LocalPredictor localPredictor = new LocalPredictor(Chap13.DATA_DIR + PIPELINE_ONNX_MODEL, "vec string, label int"); System.out.println(localPredictor.getOutputSchema()); Row r = localPredictor.map(row); System.out.println(r.getField(0).toString() + " | " + r.getField(2).toString()); localPredictor.close();
运行结果为:
root |-- vec: STRING |-- label: INT root |-- label: INT |-- tensor: LEGACY(GenericType<com.alibaba.alink.common.linalg.tensor.FloatTensor>) |-- probabilities: LEGACY(GenericType<com.alibaba.alink.common.linalg.tensor.FloatTensor>) 1 | FloatTensor(1,10) [[-1503.7417 0.0 -1402.91 ... -1180.878 -1370.2913 -1503.7417]]