Alink教程(Python版)

第16章 常用的回归算法

本章包括下面各节:
16.1 回归模型的评估指标
16.2 数据探索
16.3 线性回归
16.4 决策树与随机森林
16.5 GBDT

详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Python)》,这里为本章对应的示例代码。

from pyalink.alink import *
useLocalEnv(1)

from utils import *
import os
import pandas as pd

DATA_DIR = ROOT_DIR + "wine" + os.sep

ORIGIN_FILE = "winequality-white.csv";

TRAIN_FILE = "train.ak";
TEST_FILE = "test.ak";

COL_NAMES = [
    "fixedAcidity", "volatileAcidity", "citricAcid", "residualSugar", "chlorides",
    "freeSulfurDioxide", "totalSulfurDioxide", "density", "pH", "sulphates",
    "alcohol", "quality"
]

COL_TYPES = [
    "double", "double", "double", "double", "double",
    "double", "double", "double", "double", "double",
    "double", "double"
]

FEATURE_COL_NAMES = COL_NAMES.copy()
FEATURE_COL_NAMES.remove("quality")

LABEL_COL_NAME = "quality";
PREDICTION_COL_NAME = "pred";

#c_1
source = CsvSourceBatchOp()\
    .setFilePath(DATA_DIR + ORIGIN_FILE)\
    .setSchemaStr(generateSchemaString(COL_NAMES, COL_TYPES))\
    .setFieldDelimiter(";")\
    .setIgnoreFirstLine(True);

source.lazyPrint(5);

source.link(CorrelationBatchOp().lazyPrintCorrelation());

import matplotlib.pyplot as plt
import seaborn as sns

corr = source.collectToDataframe().corr()  
plt.figure(figsize=(15, 5))
sns.heatmap(corr, annot = True, cmap="Greys") ;

source\
    .groupBy(LABEL_COL_NAME, LABEL_COL_NAME + ", COUNT(*) AS cnt")\
    .orderBy(LABEL_COL_NAME, 100)\
    .lazyPrint(-1);

BatchOperator.execute();

splitTrainTestIfNotExist(source, DATA_DIR + TRAIN_FILE, DATA_DIR + TEST_FILE, 0.8);

#c_2
train_data = AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
test_data = AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

LinearRegression()\
    .setFeatureCols(FEATURE_COL_NAMES)\
    .setLabelCol(LABEL_COL_NAME)\
    .setPredictionCol(PREDICTION_COL_NAME)\
    .enableLazyPrintTrainInfo()\
    .enableLazyPrintModelInfo()\
    .fit(train_data)\
    .transform(test_data)\
    .link(
        EvalRegressionBatchOp()\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionCol(PREDICTION_COL_NAME)\
            .lazyPrintMetrics("LinearRegression")
    );

LassoRegression()\
    .setLambda(0.05)\
    .setFeatureCols(FEATURE_COL_NAMES)\
    .setLabelCol(LABEL_COL_NAME)\
    .setPredictionCol(PREDICTION_COL_NAME)\
    .enableLazyPrintTrainInfo()\
    .enableLazyPrintModelInfo("< LASSO model >")\
    .fit(train_data)\
    .transform(test_data)\
    .link(
        EvalRegressionBatchOp()\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionCol(PREDICTION_COL_NAME)\
            .lazyPrintMetrics("LassoRegression")
    );

BatchOperator.execute();

#c_3
train_data = AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
test_data = AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

DecisionTreeRegressor()\
    .setFeatureCols(FEATURE_COL_NAMES)\
    .setLabelCol(LABEL_COL_NAME)\
    .setPredictionCol(PREDICTION_COL_NAME)\
    .fit(train_data)\
    .transform(test_data)\
    .link(
        EvalRegressionBatchOp()\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionCol(PREDICTION_COL_NAME)\
            .lazyPrintMetrics("DecisionTreeRegressor")
    );
BatchOperator.execute();

for numTrees in [2, 4, 8, 16, 32, 64, 128] :
    RandomForestRegressor()\
        .setNumTrees(numTrees)\
        .setFeatureCols(FEATURE_COL_NAMES)\
        .setLabelCol(LABEL_COL_NAME)\
        .setPredictionCol(PREDICTION_COL_NAME)\
        .fit(train_data)\
        .transform(test_data)\
        .link(
            EvalRegressionBatchOp()\
                .setLabelCol(LABEL_COL_NAME)\
                .setPredictionCol(PREDICTION_COL_NAME)\
                .lazyPrintMetrics("RandomForestRegressor - " + str(numTrees))
        )
    BatchOperator.execute();

#c_4
train_data = AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
test_data = AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

for numTrees in [16, 32, 64, 128, 256, 512] :
    GbdtRegressor()\
        .setLearningRate(0.05)\
        .setMaxLeaves(256)\
        .setFeatureSubsamplingRatio(0.3)\
        .setMinSamplesPerLeaf(2)\
        .setMaxDepth(100)\
        .setNumTrees(numTrees)\
        .setFeatureCols(FEATURE_COL_NAMES)\
        .setLabelCol(LABEL_COL_NAME)\
        .setPredictionCol(PREDICTION_COL_NAME)\
        .fit(train_data)\
        .transform(test_data)\
        .link(
            EvalRegressionBatchOp()\
                .setLabelCol(LABEL_COL_NAME)\
                .setPredictionCol(PREDICTION_COL_NAME)\
                .lazyPrintMetrics("GbdtRegressor - " + str(numTrees))
        );
    BatchOperator.execute()