Alink教程(Java版)

第25.3节 深度回归算法

在Alink教材第16章,以葡萄酒的品质预测为例,演示了线性模型、随机森林、GBDT等算法的回归训练及预测。本节仍以葡萄酒的品质预测为例,重点演示如何使用深度学习进行回归训练及预测。

25.3.1 线性回归算法


首先,我们先使用线性回归算法,得到一个baseline,具体代码如下:

new LinearRegression()
    .setFeatureCols(Chap16.FEATURE_COL_NAMES)
    .setLabelCol("quality")
    .setPredictionCol("pred")
    .enableLazyPrintModelInfo()
    .fit(train_set)
    .transform(test_set)
    .lazyPrintStatistics()
    .link(
        new EvalRegressionBatchOp()
            .setLabelCol("quality")
            .setPredictionCol("pred")
            .lazyPrintMetrics()
    );
BatchOperator.execute();

其中使用了enableLazyPrintModelInfo方法,会打印出模型的信息,如下所示,“intercept”对应的是线性回归的常数项,每个特征列对应一个线性回归的参数。

----------------------------- model meta info -----------------------------
{hasInterception: true, model name: Linear Regression, num feature: 11}
---------------------------- model weight info ----------------------------
|  colName[0,9]| intercept|fixedAcidity|volatileAcidity| citricAcid|residualSugar|  chlorides|freeSulfurDioxide|totalSulfurDioxide|      density|        pH|
|   weight[0,9]|  147.2227|  0.05561785|    -1.88292677|-0.02573757|   0.08013539|-0.32426272|       0.00370713|       -0.00037885|-147.10297328|0.64837239|
|colName[10,11]| sulphates|     alcohol|               |           |             |           |                 |                  |             |          |
| weight[10,11]|0.64099725|  0.19730515|    


线性回归预测结果的统计如下,可以看到quality列的均值为5.8735,方差为0.7235;预测结果的均值为5.8767,方差为0.2083。

Summary: 
|           colName|count|missing|       sum|    mean| variance|   min|   max|
|------------------|-----|-------|----------|--------|---------|------|------|
|      fixedAcidity|  980|      0|    6722.4|  6.8596|     0.73|   4.7|    10|
|   volatileAcidity|  980|      0|    271.38|  0.2769|   0.0099|  0.08|   1.1|
|        citricAcid|  980|      0|    327.46|  0.3341|    0.014|     0|     1|
|     residualSugar|  980|      0|   6280.75|  6.4089|  26.1442|   0.7| 26.05|
|         chlorides|  980|      0|    44.462|  0.0454|   0.0004| 0.012| 0.239|
| freeSulfurDioxide|  980|      0|   34358.5| 35.0597| 286.5621|     2| 138.5|
|totalSulfurDioxide|  980|      0|  135518.5|138.2842|1948.7059|     9|   303|
|           density|  980|      0|  974.1799|  0.9941|        0|0.9872| 1.003|
|                pH|  980|      0|   3125.47|  3.1893|   0.0219|  2.85|  3.82|
|         sulphates|  980|      0|    483.09|  0.4929|   0.0123|  0.27|  0.98|
|           alcohol|  980|      0|10289.4933| 10.4995|   1.5211|   8.4| 14.05|
|           quality|  980|      0|      5756|  5.8735|   0.7235|     3|     9|
|              pred|  980|      0| 5759.1695|  5.8767|   0.2083|4.1547|7.2093|


使用EvalRegressionBatchOp组件,计算显示回归统计指标如下:

-------------------------------- Metrics: --------------------------------
MSE:0.5309	RMSE:0.7286	MAE:0.5748	MAPE:10.0995	R2:0.2655

25.3.2 深度回归算法

使用深度回归模型的代码如下:

new Pipeline()
	.add(
		new StandardScaler()
			.setSelectedCols(Chap16.FEATURE_COL_NAMES)
	)
	.add(
		new VectorAssembler()
			.setSelectedCols(Chap16.FEATURE_COL_NAMES)
			.setOutputCol("vec")
	)
	.add(
		new VectorToTensor()
			.setSelectedCol("vec")
			.setOutputCol("tensor")
			.setReservedCols("quality")
	)
	.add(
		new KerasSequentialRegressor()
			.setTensorCol("tensor")
			.setLabelCol("quality")
			.setPredictionCol("pred")
			.setLayers(
				"Dense(64, activation='relu')",
				"Dense(64, activation='relu')",
				"Dense(64, activation='relu')",
				"Dense(64, activation='relu')",
				"Dense(64, activation='relu')"
			)
			.setNumEpochs(20)
			.setNumWorkers(1)
			.setNumPSs(0)
	)
	.fit(train_set)
	.transform(test_set)
	.lazyPrintStatistics()
	.link(
		new EvalRegressionBatchOp()
			.setLabelCol("quality")
			.setPredictionCol("pred")
			.lazyPrintMetrics()
	);
BatchOperator.execute();

在Pipeline中使用了多个组件:

1、数据标准化组件StandardScaler,因为数据中各列的数值范围差异较大,标准化后,有助于提升深度模型的效果

2、拼接向量组件VectorAssembler,将多列数值数据转化为一列向量数据。

3、向量转化为张量的组件VectorToTensor,后面的Keras组件的输入格式为张量

4、Keras回归器组件KerasSequentialRegressor,定义了深度模型

Pipeline使用fit方法对训练集train_set进行训练,然后使用transform方法对测试集test_set进行预测。预测结果的统计如下

Summary: 
|colName|count|missing|      sum|  mean|variance|   min|   max|
|-------|-----|-------|---------|------|--------|------|------|
|quality|  980|      0|     5756|5.8735|  0.7235|     3|     9|
| tensor|  980|      0|      NaN|   NaN|     NaN|   NaN|   NaN|
|   pred|  980|      0|5766.3137| 5.884|  0.3881|3.8637|7.5729|

相应的评估指标如下,可以看到均方误差MSE和平均绝对误差MAE等指标,相对线性回归算法有明显改进。

-------------------------------- Metrics: --------------------------------
MSE:0.485	RMSE:0.6964	MAE:0.5323	MAPE:9.3756	R2:0.3289