Alink教程(Python版)

第20章 超参数搜索

本章包括下面各节:
20.1 示例一:尝试正则系数
20.2 示例二:搜索GBDT超参数
20.3 示例三:最佳聚类个数

详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Python)》,这里为本章对应的示例代码。

from pyalink.alink import *
useLocalEnv(1)

from utils import *
import os
import pandas as pd

#c_1

DATA_DIR = ROOT_DIR + "german_credit" + os.sep

TRAIN_FILE = "train.ak";
TEST_FILE = "test.ak";

CLAUSE_CREATE_FEATURES = "(case status when 'A11' then 1 else 0 end) as status_A11,"\
+ "(case status when 'A12' then 1 else 0 end) as status_A12,"\
+ "(case status when 'A13' then 1 else 0 end) as status_A13,"\
+ "(case status when 'A14' then 1 else 0 end) as status_A14,"\
+ "duration,"\
+ "(case credit_history when 'A30' then 1 else 0 end) as credit_history_A30,"\
+ "(case credit_history when 'A31' then 1 else 0 end) as credit_history_A31,"\
+ "(case credit_history when 'A32' then 1 else 0 end) as credit_history_A32,"\
+ "(case credit_history when 'A33' then 1 else 0 end) as credit_history_A33,"\
+ "(case credit_history when 'A34' then 1 else 0 end) as credit_history_A34,"\
+ "(case purpose when 'A40' then 1 else 0 end) as purpose_A40,"\
+ "(case purpose when 'A41' then 1 else 0 end) as purpose_A41,"\
+ "(case purpose when 'A42' then 1 else 0 end) as purpose_A42,"\
+ "(case purpose when 'A43' then 1 else 0 end) as purpose_A43,"\
+ "(case purpose when 'A44' then 1 else 0 end) as purpose_A44,"\
+ "(case purpose when 'A45' then 1 else 0 end) as purpose_A45,"\
+ "(case purpose when 'A46' then 1 else 0 end) as purpose_A46,"\
+ "(case purpose when 'A47' then 1 else 0 end) as purpose_A47,"\
+ "(case purpose when 'A48' then 1 else 0 end) as purpose_A48,"\
+ "(case purpose when 'A49' then 1 else 0 end) as purpose_A49,"\
+ "(case purpose when 'A410' then 1 else 0 end) as purpose_A410,"\
+ "credit_amount,"\
+ "(case savings when 'A61' then 1 else 0 end) as savings_A61,"\
+ "(case savings when 'A62' then 1 else 0 end) as savings_A62,"\
+ "(case savings when 'A63' then 1 else 0 end) as savings_A63,"\
+ "(case savings when 'A64' then 1 else 0 end) as savings_A64,"\
+ "(case savings when 'A65' then 1 else 0 end) as savings_A65,"\
+ "(case employment when 'A71' then 1 else 0 end) as employment_A71,"\
+ "(case employment when 'A72' then 1 else 0 end) as employment_A72,"\
+ "(case employment when 'A73' then 1 else 0 end) as employment_A73,"\
+ "(case employment when 'A74' then 1 else 0 end) as employment_A74,"\
+ "(case employment when 'A75' then 1 else 0 end) as employment_A75,"\
+ "installment_rate,"\
+ "(case marriage_sex when 'A91' then 1 else 0 end) as marriage_sex_A91,"\
+ "(case marriage_sex when 'A92' then 1 else 0 end) as marriage_sex_A92,"\
+ "(case marriage_sex when 'A93' then 1 else 0 end) as marriage_sex_A93,"\
+ "(case marriage_sex when 'A94' then 1 else 0 end) as marriage_sex_A94,"\
+ "(case marriage_sex when 'A95' then 1 else 0 end) as marriage_sex_A95,"\
+ "(case debtors when 'A101' then 1 else 0 end) as debtors_A101,"\
+ "(case debtors when 'A102' then 1 else 0 end) as debtors_A102,"\
+ "(case debtors when 'A103' then 1 else 0 end) as debtors_A103,"\
+ "residence,"\
+ "(case property when 'A121' then 1 else 0 end) as property_A121,"\
+ "(case property when 'A122' then 1 else 0 end) as property_A122,"\
+ "(case property when 'A123' then 1 else 0 end) as property_A123,"\
+ "(case property when 'A124' then 1 else 0 end) as property_A124,"\
+ "age,"\
+ "(case other_plan when 'A141' then 1 else 0 end) as other_plan_A141,"\
+ "(case other_plan when 'A142' then 1 else 0 end) as other_plan_A142,"\
+ "(case other_plan when 'A143' then 1 else 0 end) as other_plan_A143,"\
+ "(case housing when 'A151' then 1 else 0 end) as housing_A151,"\
+ "(case housing when 'A152' then 1 else 0 end) as housing_A152,"\
+ "(case housing when 'A153' then 1 else 0 end) as housing_A153,"\
+ "number_credits,"\
+ "(case job when 'A171' then 1 else 0 end) as job_A171,"\
+ "(case job when 'A172' then 1 else 0 end) as job_A172,"\
+ "(case job when 'A173' then 1 else 0 end) as job_A173,"\
+ "(case job when 'A174' then 1 else 0 end) as job_A174,"\
+ "maintenance_num,"\
+ "(case telephone when 'A192' then 1 else 0 end) as telephone,"\
+ "(case foreign_worker when 'A201' then 1 else 0 end) as foreign_worker,"\
+ "class "

LABEL_COL_NAME = "class";

VEC_COL_NAME = "vec";

PREDICTION_COL_NAME = "pred";

PRED_DETAIL_COL_NAME = "predinfo";


train_data = AkSourceBatchOp()\
    .setFilePath(DATA_DIR + TRAIN_FILE)\
    .select(CLAUSE_CREATE_FEATURES);

test_data = AkSourceBatchOp()\
    .setFilePath(DATA_DIR + TEST_FILE)\
    .select(CLAUSE_CREATE_FEATURES);

new_features = train_data.getColNames()
new_features.remove(LABEL_COL_NAME)

lr = LogisticRegression()\
    .setFeatureCols(new_features)\
    .setLabelCol(LABEL_COL_NAME)\
    .setPredictionCol(PREDICTION_COL_NAME)\
    .setPredictionDetailCol(PRED_DETAIL_COL_NAME);

pipeline = Pipeline().add(lr);

gridSearch = GridSearchCV()\
    .setNumFolds(5)\
    .setEstimator(pipeline)\
    .setParamGrid(
        ParamGrid()\
            .addGrid(lr, 'L_1',
                     [0.0000001, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0, 10.0])
    )\
    .setTuningEvaluator(
        BinaryClassificationTuningEvaluator()\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
            .setTuningBinaryClassMetric('AUC')
   )\
   .enableLazyPrintTrainInfo();

bestModel = gridSearch.fit(train_data);

bestModel\
    .transform(test_data)\
    .link(
        EvalBinaryClassBatchOp()\
            .setPositiveLabelValueString("2")\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
            .lazyPrintMetrics("GridSearchCV")
    );

BatchOperator.execute()
#c_2
DATA_DIR = ROOT_DIR + "tmall" + os.sep

ORIGIN_FILE = "tmall.csv";

TRAIN_SAMPLE_FILE = "train_sample.ak";

LABEL_COL_NAME = "label";
PREDICTION_COL_NAME = "pred";
PRED_DETAIL_COL_NAME = "predInfo";

sw = Stopwatch();
sw.start();

train_sample = AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_SAMPLE_FILE);

test_data = AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

featureColNames = train_sample.getColNames()
featureColNames.remove(LABEL_COL_NAME)

gbdt = GbdtClassifier()\
    .setFeatureCols(featureColNames)\
    .setLabelCol(LABEL_COL_NAME)\
    .setPredictionCol(PREDICTION_COL_NAME)\
    .setPredictionDetailCol(PRED_DETAIL_COL_NAME);

randomSearch = RandomSearchTVSplit()\
    .setNumIter(20)\
    .setTrainRatio(0.8)\
    .setEstimator(gbdt)\
    .setParamDist(
        ParamDist()\
            .addDist(gbdt, 'NUM_TREES', ValueDist.randArray([50, 100]))\
            .addDist(gbdt, 'MAX_DEPTH', ValueDist.randInteger(4, 10))\
            .addDist(gbdt, 'MAX_BINS', ValueDist.randArray([64, 128, 256, 512]))\
            .addDist(gbdt, 'LEARNING_RATE', ValueDist.randArray([0.3, 0.1, 0.01]))
    )\
    .setTuningEvaluator(
        BinaryClassificationTuningEvaluator()\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
            .setTuningBinaryClassMetric('F1')
    )\
    .enableLazyPrintTrainInfo();

bestModel = randomSearch.fit(train_sample);

bestModel\
    .transform(test_data)\
    .link(
        EvalBinaryClassBatchOp()\
            .setPositiveLabelValueString("1")\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
            .lazyPrintMetrics()
    );

BatchOperator.execute();

sw.stop();
print(sw.getElapsedTimeSpan());
#c_3

DATA_DIR = ROOT_DIR + "iris" + os.sep

VECTOR_FILE = "iris_vec.ak";

LABEL_COL_NAME = "category";
VECTOR_COL_NAME = "vec";
PREDICTION_COL_NAME = "cluster_id";

sw = Stopwatch();
sw.start();

source = AkSourceBatchOp().setFilePath(DATA_DIR + VECTOR_FILE);

kmeans = KMeans()\
    .setVectorCol(VECTOR_COL_NAME)\
    .setPredictionCol(PREDICTION_COL_NAME);

cv = GridSearchCV()\
    .setNumFolds(4)\
    .setEstimator(kmeans)\
    .setParamGrid(
        ParamGrid()\
            .addGrid(kmeans, 'K', [2, 3, 4, 5, 6])\
            .addGrid(kmeans, 'DISTANCE_TYPE', ['EUCLIDEAN', 'COSINE'])
    )\
    .setTuningEvaluator(
        ClusterTuningEvaluator()\
            .setVectorCol(VECTOR_COL_NAME)\
            .setPredictionCol(PREDICTION_COL_NAME)\
            .setLabelCol(LABEL_COL_NAME)\
            .setTuningClusterMetric('RI')
    )\
    .enableLazyPrintTrainInfo();

bestModel = cv.fit(source);

bestModel\
    .transform(source)\
    .link(
        EvalClusterBatchOp()\
            .setLabelCol(LABEL_COL_NAME)\
            .setVectorCol(VECTOR_COL_NAME)\
            .setPredictionCol(PREDICTION_COL_NAME)\
            .lazyPrintMetrics()
    );

BatchOperator.execute();

sw.stop();
print(sw.getElapsedTimeSpan());