我们在训练过程中需要了解训练所得模型的信息;对于已经存储的模型数据或者PipelineModel数据,需要采用某种方法来了解其具体内容。Alink对各场景提供了相应的方案,如表2-1所示。
在训练过程中如何显示模型信息呢?可参阅本书的8.5节、8.6节、9.2.5节、9.4节等处的示例。本节着重介绍如何针对存储的模型文件,提取模型信息。
比如,已知某个模型文件为单棵决策树模型文件,我们可以使用文件数据源读取模型数据,再连接决策树模型信息组件(DecisionTreeModelInfoBatchOp),并使用lazyPrintModelInfo方法打印模型信息。使用lazyCollectModelInfo方法,自定义抽取模型信息,并将决策树可视化导出到图片。具体代码如下:
new AkSourceBatchOp() .setFilePath(DATA_DIR + TREE_MODEL_FILE) .link( new DecisionTreeModelInfoBatchOp() .lazyPrintModelInfo() .lazyCollectModelInfo(new Consumer <DecisionTreeModelInfo>() { @Override public void accept(DecisionTreeModelInfo decisionTreeModelInfo) { try { decisionTreeModelInfo.saveTreeAsImage( DATA_DIR + "tree_model.png", true); } catch (IOException e) { e.printStackTrace(); } } }) ); BatchOperator.execute();
运行结果如下:
Classification trees modelInfo: Number of trees: 1 Number of features: 4 Number of categorical features: 2 Labels: [no, yes] Categorical feature info: |feature|number of categorical value| |-------|---------------------------| |outlook| 3| | Windy| 2| Table of feature importance Top 4: | feature|importance| |-----------|----------| | Humidity| 0.4637| | Windy| 0.4637| | outlook| 0.0725| |Temperature| 0|
导出的决策树可视化结果如下图所示。
我们再看一下如何从PipelineModel中提取信息,显示模型。首先载入PipelineModel,然后获取Pipeline中各阶段(PipelineStage)的信息,相关代码如下:
PipelineModel pipelineModel = PipelineModel.load(DATA_DIR + PIPELINE_MODEL_FILE); TransformerBase <?>[] stages = pipelineModel.getTransformers(); for (int i = 0; i < stages.length; i++) { System.out.println(String.valueOf(i) + "\t" + stages[i]); }
运行结果如下:
0 com.alibaba.alink.pipeline.sql.Select@19c1f6f4 1 com.alibaba.alink.pipeline.regression.LinearRegressionModel@46fa2a7e
这里共有两个阶段,第一个阶段执行Select操作,第二个阶段为使用LinearRegressionModel进行预测操作。这样,我们就确定了可以使用线性回归(Linear Regression)对应的模型信息组件来查看索引号为1的PipelineStage的信息,具体代码如下:
((LinearRegressionModel) stages[1]).getModelData() .link( new LinearRegModelInfoBatchOp() .lazyPrintModelInfo() ); BatchOperator.execute();
运行结果如下。这里显示了模型的meta信息,在此还可以看到各个权重参数的取值:
----------------------------- model meta info ----------------------------- {hasInterception: true, model name: Linear Regression, num feature: 2} ---------------------------- model weight info ---------------------------- | intercept| x| x2| |--------------|----------------|-----------| |122194787.2123|-121612.28370990|30.25813853|
最后,我们将模型信息(ModelInfo)组件、批式训练组件和PipelineStage的对照关系整理为表格,详见表2-2。