本章包括下面各节:
16.1 回归模型的评估指标
16.2 数据探索
16.3 线性回归
16.4 决策树与随机森林
16.5 GBDT回归
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.evaluation.EvalRegressionBatchOp; import com.alibaba.alink.operator.batch.source.AkSourceBatchOp; import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; import com.alibaba.alink.operator.batch.statistics.CorrelationBatchOp; import com.alibaba.alink.pipeline.regression.DecisionTreeRegressor; import com.alibaba.alink.pipeline.regression.GbdtRegressor; import com.alibaba.alink.pipeline.regression.LassoRegression; import com.alibaba.alink.pipeline.regression.LinearRegression; import com.alibaba.alink.pipeline.regression.RandomForestRegressor; import org.apache.commons.lang3.ArrayUtils; import java.io.File; public class Chap16 { static final String DATA_DIR = Utils.ROOT_DIR + "wine" + File.separator; private static final String ORIGIN_FILE = "winequality-white.csv"; static final String TRAIN_FILE = "train.ak"; static final String TEST_FILE = "test.ak"; private static final String[] COL_NAMES = new String[] { "fixedAcidity", "volatileAcidity", "citricAcid", "residualSugar", "chlorides", "freeSulfurDioxide", "totalSulfurDioxide", "density", "pH", "sulphates", "alcohol", "quality" }; private static final String[] COL_TYPES = new String[] { "double", "double", "double", "double", "double", "double", "double", "double", "double", "double", "double", "double" }; static final String[] FEATURE_COL_NAMES = ArrayUtils.removeElement(COL_NAMES, "quality"); static final String LABEL_COL_NAME = "quality"; private static final String PREDICTION_COL_NAME = "pred"; public static void main(String[] args) throws Exception { BatchOperator.setParallelism(1); c_2(); c_3(); c_4(); c_5(); } static void c_2() throws Exception { CsvSourceBatchOp source = new CsvSourceBatchOp() .setFilePath(DATA_DIR + ORIGIN_FILE) .setSchemaStr(Utils.generateSchemaString(COL_NAMES, COL_TYPES)) .setFieldDelimiter(";") .setIgnoreFirstLine(true); source.lazyPrint(5); source.link(new CorrelationBatchOp().lazyPrintCorrelation()); source .groupBy(LABEL_COL_NAME, LABEL_COL_NAME + ", COUNT(*) AS cnt") .orderBy(LABEL_COL_NAME, 100) .lazyPrint(-1); BatchOperator.execute(); Utils.splitTrainTestIfNotExist( source, DATA_DIR + TRAIN_FILE, DATA_DIR + TEST_FILE, 0.8 ); } static void c_3() throws Exception { AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); new LinearRegression() .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .enableLazyPrintTrainInfo() .enableLazyPrintModelInfo() .fit(train_data) .transform(test_data) .link( new EvalRegressionBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("LinearRegression") ); new LassoRegression() .setLambda(0.05) .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .enableLazyPrintTrainInfo() .enableLazyPrintModelInfo("< LASSO model >") .fit(train_data) .transform(test_data) .link( new EvalRegressionBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("LassoRegression") ); BatchOperator.execute(); } static void c_4() throws Exception { AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); new DecisionTreeRegressor() .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalRegressionBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("DecisionTreeRegressor") ); BatchOperator.execute(); for (int numTrees : new int[] {2, 4, 8, 16, 32, 64, 128}) { new RandomForestRegressor() .setNumTrees(numTrees) .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalRegressionBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("RandomForestRegressor - " + numTrees) ); BatchOperator.execute(); } } static void c_5() throws Exception { AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); for (int numTrees : new int[] {16, 32, 64, 128, 256, 512}) { new GbdtRegressor() .setLearningRate(0.05) .setMaxLeaves(256) .setFeatureSubsamplingRatio(0.3) .setMinSamplesPerLeaf(2) .setMaxDepth(100) .setNumTrees(numTrees) .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalRegressionBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("GbdtRegressor - " + numTrees) ); BatchOperator.execute(); } } }