Java 类名:com.alibaba.alink.operator.batch.onnx.OnnxModelPredictBatchOp
Python 类名:OnnxModelPredictBatchOp
加载 ONNX 模型进行预测。
模型路径modelPath需要是 ONNX 模型。
参与模型预测的数据通过参数 selectedCols 设置,需要注意以下几点:
模型输出信息通过参数 outputSchemaStr 指定,包括输出列名以及名称,需要注意以下几点:
组件使用的是 ONNX 1.11.0 版本,当有 GPU 时,自动使用 GPU 进行推理,否则使用 CPU 进行推理。
在 Windows 下运行时,如果遇到 UnsatisfiedLinkError,请下载 Visual C++ 2019 Redistributable Packages 并重启,然后重新运行。
| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 |
|---|---|---|---|---|---|---|
| modelPath | 模型的URL路径 | 模型的URL路径 | String | ✓ | ||
| outputSchemaStr | Schema | Schema。格式为“colname coltype[, colname2, coltype2[, …]]”,例如 “f0 string, f1 bigint, f2 double” | String | ✓ | ||
| inputNames | ONNX 模型输入名 | ONNX 模型输入名,用逗号分隔,需要与输入列一一对应,默认与选择列相同 | String[] | null | ||
| outputNames | ONNX 模型输出名 | ONNX 模型输出名,用逗号分隔,并且与输出 Schema 一一对应,默认与输出 Schema 中的列名相同 | String[] | null | ||
| reservedCols | 算法保留列名 | 算法保留列 | String[] | null | ||
| selectedCols | 选中的列名数组 | 计算列对应的列名列表 | String[] | null |
模型路径可以是以下形式:
** 以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!**
test = AkSourceBatchOp()\
.setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_test_vector.ak");
test = VectorToTensorBatchOp()\
.setTensorDataType("float")\
.setTensorShape([1, 1, 28, 28])\
.setSelectedCol("vec")\
.setOutputCol("tensor")\
.setReservedCols(["label"])\
.linkFrom(test)
predictor = OnnxModelPredictBatchOp() \
.setModelPath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/cnn_mnist_pytorch.onnx") \
.setSelectedCols(["tensor"]) \
.setInputNames(["0"]) \
.setOutputNames(["21"]) \
.setOutputSchemaStr("probabilities FLOAT_TENSOR")
test = predictor.linkFrom(test).select("label, probabilities")
test.print()
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.VectorToTensorBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import org.junit.Test;
public class OnnxModelPredictBatchOpTest {
@Test
public void testOnnxModelPredictBatchOp() throws Exception {
BatchOperator.setParallelism(1);
BatchOperator <?> test = new AkSourceBatchOp()
.setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_test_vector.ak");
test = new VectorToTensorBatchOp()
.setTensorDataType("float")
.setTensorShape(1, 1, 28, 28)
.setSelectedCol("vec")
.setOutputCol("tensor")
.setReservedCols("label")
.linkFrom(test);
BatchOperator <?> predictor = new OnnxModelPredictBatchOp()
.setModelPath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/cnn_mnist_pytorch.onnx")
.setSelectedCols("tensor")
.setInputNames("0")
.setOutputNames("21")
.setOutputSchemaStr("probabilities FLOAT_TENSOR");
test = predictor.linkFrom(test).select("label, probabilities");
test.print();
}
}