Alink教程(Python版)

第2.6节 模型信息显示

我们在训练过程中需要了解训练所得模型的信息;对于已经存储的模型数据或者PipelineModel数据,需要采用某种方法来了解其具体内容。Alink对各场景提供了相应的方案,如表2-1所示。


在训练过程中如何显示模型信息呢?可参阅本书的8.5节、8.6节、9.2.5节、9.4节等处的示例。本节着重介绍如何针对存储的模型文件,提取模型信息。

比如,已知某个模型文件为单棵决策树模型文件,我们可以使用文件数据源读取模型数据,再连接决策树模型信息组件(DecisionTreeModelInfoBatchOp),并使用lazyPrintModelInfo方法打印模型信息。使用lazyCollectModelInfo方法,自定义抽取模型信息,并将决策树可视化导出到图片。具体代码如下:

AkSourceBatchOp()\
    .setFilePath(DATA_DIR + TREE_MODEL_FILE)\
    .link(
        DecisionTreeModelInfoBatchOp()\
            .lazyPrintModelInfo()\
            .lazyCollectModelInfo(
                lambda decisionTreeModelInfo: 
                    decisionTreeModelInfo.saveTreeAsImage(
                        DATA_DIR + "tree_model.png", True)
            )
    )
BatchOperator.execute()


运行结果如下:

Classification trees modelInfo: 
Number of trees: 1
Number of features: 4
Number of categorical features: 2
Labels: [no, yes]

Categorical feature info:
|feature|number of categorical value|
|-------|---------------------------|
|outlook|                          3|
|  Windy|                          2|

Table of feature importance Top 4: 
|    feature|importance|
|-----------|----------|
|   Humidity|    0.4637|
|      Windy|    0.4637|
|    outlook|    0.0725|
|Temperature|         0|


导出的决策树可视化结果如下图所示。


我们再看一下如何从PipelineModel中提取信息,显示模型。首先载入PipelineModel,然后获取Pipeline中各阶段(PipelineStage)的信息,相关代码如下:

pipelineModel = PipelineModel.load(DATA_DIR + PIPELINE_MODEL_FILE);

stages = pipelineModel.getTransformers()

for i in range(2) :
    print(str(i) + "\t" + str(stages[i]));


运行结果如下:

0	com.alibaba.alink.pipeline.sql.Select@19c1f6f4
1	com.alibaba.alink.pipeline.regression.LinearRegressionModel@46fa2a7e


这里共有两个阶段,第一个阶段执行Select操作,第二个阶段为使用LinearRegressionModel进行预测操作。这样,我们就确定了可以使用线性回归(Linear Regression)对应的模型信息组件来查看索引号为1的PipelineStage的信息,具体代码如下:

stages[1].getModelData()\
    .link(
        LinearRegModelInfoBatchOp().lazyPrintModelInfo()
    )
BatchOperator.execute()


运行结果如下。这里显示了模型的meta信息,在此还可以看到各个权重参数的取值:

----------------------------- model meta info -----------------------------
{hasInterception: true, model name: Linear Regression, num feature: 2}
---------------------------- model weight info ----------------------------
|     intercept|               x|         x2|
|--------------|----------------|-----------|
|122194787.2123|-121612.28370990|30.25813853|


最后,我们将模型信息(ModelInfo)组件、批式训练组件和PipelineStage的对照关系整理为表格,详见表2-2。