本章包括下面各节:
20.1 示例一:尝试正则系数
20.2 示例二:搜索GBDT超参数
20.3 示例三:最佳聚类个数
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Python)》,这里为本章对应的示例代码。
from pyalink.alink import * useLocalEnv(1) from utils import * import os import pandas as pd
#c_1
DATA_DIR = ROOT_DIR + "german_credit" + os.sep
TRAIN_FILE = "train.ak";
TEST_FILE = "test.ak";
CLAUSE_CREATE_FEATURES = "(case status when 'A11' then 1 else 0 end) as status_A11,"\
+ "(case status when 'A12' then 1 else 0 end) as status_A12,"\
+ "(case status when 'A13' then 1 else 0 end) as status_A13,"\
+ "(case status when 'A14' then 1 else 0 end) as status_A14,"\
+ "duration,"\
+ "(case credit_history when 'A30' then 1 else 0 end) as credit_history_A30,"\
+ "(case credit_history when 'A31' then 1 else 0 end) as credit_history_A31,"\
+ "(case credit_history when 'A32' then 1 else 0 end) as credit_history_A32,"\
+ "(case credit_history when 'A33' then 1 else 0 end) as credit_history_A33,"\
+ "(case credit_history when 'A34' then 1 else 0 end) as credit_history_A34,"\
+ "(case purpose when 'A40' then 1 else 0 end) as purpose_A40,"\
+ "(case purpose when 'A41' then 1 else 0 end) as purpose_A41,"\
+ "(case purpose when 'A42' then 1 else 0 end) as purpose_A42,"\
+ "(case purpose when 'A43' then 1 else 0 end) as purpose_A43,"\
+ "(case purpose when 'A44' then 1 else 0 end) as purpose_A44,"\
+ "(case purpose when 'A45' then 1 else 0 end) as purpose_A45,"\
+ "(case purpose when 'A46' then 1 else 0 end) as purpose_A46,"\
+ "(case purpose when 'A47' then 1 else 0 end) as purpose_A47,"\
+ "(case purpose when 'A48' then 1 else 0 end) as purpose_A48,"\
+ "(case purpose when 'A49' then 1 else 0 end) as purpose_A49,"\
+ "(case purpose when 'A410' then 1 else 0 end) as purpose_A410,"\
+ "credit_amount,"\
+ "(case savings when 'A61' then 1 else 0 end) as savings_A61,"\
+ "(case savings when 'A62' then 1 else 0 end) as savings_A62,"\
+ "(case savings when 'A63' then 1 else 0 end) as savings_A63,"\
+ "(case savings when 'A64' then 1 else 0 end) as savings_A64,"\
+ "(case savings when 'A65' then 1 else 0 end) as savings_A65,"\
+ "(case employment when 'A71' then 1 else 0 end) as employment_A71,"\
+ "(case employment when 'A72' then 1 else 0 end) as employment_A72,"\
+ "(case employment when 'A73' then 1 else 0 end) as employment_A73,"\
+ "(case employment when 'A74' then 1 else 0 end) as employment_A74,"\
+ "(case employment when 'A75' then 1 else 0 end) as employment_A75,"\
+ "installment_rate,"\
+ "(case marriage_sex when 'A91' then 1 else 0 end) as marriage_sex_A91,"\
+ "(case marriage_sex when 'A92' then 1 else 0 end) as marriage_sex_A92,"\
+ "(case marriage_sex when 'A93' then 1 else 0 end) as marriage_sex_A93,"\
+ "(case marriage_sex when 'A94' then 1 else 0 end) as marriage_sex_A94,"\
+ "(case marriage_sex when 'A95' then 1 else 0 end) as marriage_sex_A95,"\
+ "(case debtors when 'A101' then 1 else 0 end) as debtors_A101,"\
+ "(case debtors when 'A102' then 1 else 0 end) as debtors_A102,"\
+ "(case debtors when 'A103' then 1 else 0 end) as debtors_A103,"\
+ "residence,"\
+ "(case property when 'A121' then 1 else 0 end) as property_A121,"\
+ "(case property when 'A122' then 1 else 0 end) as property_A122,"\
+ "(case property when 'A123' then 1 else 0 end) as property_A123,"\
+ "(case property when 'A124' then 1 else 0 end) as property_A124,"\
+ "age,"\
+ "(case other_plan when 'A141' then 1 else 0 end) as other_plan_A141,"\
+ "(case other_plan when 'A142' then 1 else 0 end) as other_plan_A142,"\
+ "(case other_plan when 'A143' then 1 else 0 end) as other_plan_A143,"\
+ "(case housing when 'A151' then 1 else 0 end) as housing_A151,"\
+ "(case housing when 'A152' then 1 else 0 end) as housing_A152,"\
+ "(case housing when 'A153' then 1 else 0 end) as housing_A153,"\
+ "number_credits,"\
+ "(case job when 'A171' then 1 else 0 end) as job_A171,"\
+ "(case job when 'A172' then 1 else 0 end) as job_A172,"\
+ "(case job when 'A173' then 1 else 0 end) as job_A173,"\
+ "(case job when 'A174' then 1 else 0 end) as job_A174,"\
+ "maintenance_num,"\
+ "(case telephone when 'A192' then 1 else 0 end) as telephone,"\
+ "(case foreign_worker when 'A201' then 1 else 0 end) as foreign_worker,"\
+ "class "
LABEL_COL_NAME = "class";
VEC_COL_NAME = "vec";
PREDICTION_COL_NAME = "pred";
PRED_DETAIL_COL_NAME = "predinfo";
train_data = AkSourceBatchOp()\
.setFilePath(DATA_DIR + TRAIN_FILE)\
.select(CLAUSE_CREATE_FEATURES);
test_data = AkSourceBatchOp()\
.setFilePath(DATA_DIR + TEST_FILE)\
.select(CLAUSE_CREATE_FEATURES);
new_features = train_data.getColNames()
new_features.remove(LABEL_COL_NAME)
lr = LogisticRegression()\
.setFeatureCols(new_features)\
.setLabelCol(LABEL_COL_NAME)\
.setPredictionCol(PREDICTION_COL_NAME)\
.setPredictionDetailCol(PRED_DETAIL_COL_NAME);
pipeline = Pipeline().add(lr);
gridSearch = GridSearchCV()\
.setNumFolds(5)\
.setEstimator(pipeline)\
.setParamGrid(
ParamGrid()\
.addGrid(lr, 'L_1',
[0.0000001, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0, 10.0])
)\
.setTuningEvaluator(
BinaryClassificationTuningEvaluator()\
.setLabelCol(LABEL_COL_NAME)\
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
.setTuningBinaryClassMetric('AUC')
)\
.enableLazyPrintTrainInfo();
bestModel = gridSearch.fit(train_data);
bestModel\
.transform(test_data)\
.link(
EvalBinaryClassBatchOp()\
.setPositiveLabelValueString("2")\
.setLabelCol(LABEL_COL_NAME)\
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
.lazyPrintMetrics("GridSearchCV")
);
BatchOperator.execute()
#c_2
DATA_DIR = ROOT_DIR + "tmall" + os.sep
ORIGIN_FILE = "tmall.csv";
TRAIN_SAMPLE_FILE = "train_sample.ak";
LABEL_COL_NAME = "label";
PREDICTION_COL_NAME = "pred";
PRED_DETAIL_COL_NAME = "predInfo";
sw = Stopwatch();
sw.start();
train_sample = AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_SAMPLE_FILE);
test_data = AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
featureColNames = train_sample.getColNames()
featureColNames.remove(LABEL_COL_NAME)
gbdt = GbdtClassifier()\
.setFeatureCols(featureColNames)\
.setLabelCol(LABEL_COL_NAME)\
.setPredictionCol(PREDICTION_COL_NAME)\
.setPredictionDetailCol(PRED_DETAIL_COL_NAME);
randomSearch = RandomSearchTVSplit()\
.setNumIter(20)\
.setTrainRatio(0.8)\
.setEstimator(gbdt)\
.setParamDist(
ParamDist()\
.addDist(gbdt, 'NUM_TREES', ValueDist.randArray([50, 100]))\
.addDist(gbdt, 'MAX_DEPTH', ValueDist.randInteger(4, 10))\
.addDist(gbdt, 'MAX_BINS', ValueDist.randArray([64, 128, 256, 512]))\
.addDist(gbdt, 'LEARNING_RATE', ValueDist.randArray([0.3, 0.1, 0.01]))
)\
.setTuningEvaluator(
BinaryClassificationTuningEvaluator()\
.setLabelCol(LABEL_COL_NAME)\
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
.setTuningBinaryClassMetric('F1')
)\
.enableLazyPrintTrainInfo();
bestModel = randomSearch.fit(train_sample);
bestModel\
.transform(test_data)\
.link(
EvalBinaryClassBatchOp()\
.setPositiveLabelValueString("1")\
.setLabelCol(LABEL_COL_NAME)\
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
.lazyPrintMetrics()
);
BatchOperator.execute();
sw.stop();
print(sw.getElapsedTimeSpan());
#c_3
DATA_DIR = ROOT_DIR + "iris" + os.sep
VECTOR_FILE = "iris_vec.ak";
LABEL_COL_NAME = "category";
VECTOR_COL_NAME = "vec";
PREDICTION_COL_NAME = "cluster_id";
sw = Stopwatch();
sw.start();
source = AkSourceBatchOp().setFilePath(DATA_DIR + VECTOR_FILE);
kmeans = KMeans()\
.setVectorCol(VECTOR_COL_NAME)\
.setPredictionCol(PREDICTION_COL_NAME);
cv = GridSearchCV()\
.setNumFolds(4)\
.setEstimator(kmeans)\
.setParamGrid(
ParamGrid()\
.addGrid(kmeans, 'K', [2, 3, 4, 5, 6])\
.addGrid(kmeans, 'DISTANCE_TYPE', ['EUCLIDEAN', 'COSINE'])
)\
.setTuningEvaluator(
ClusterTuningEvaluator()\
.setVectorCol(VECTOR_COL_NAME)\
.setPredictionCol(PREDICTION_COL_NAME)\
.setLabelCol(LABEL_COL_NAME)\
.setTuningClusterMetric('RI')
)\
.enableLazyPrintTrainInfo();
bestModel = cv.fit(source);
bestModel\
.transform(source)\
.link(
EvalClusterBatchOp()\
.setLabelCol(LABEL_COL_NAME)\
.setVectorCol(VECTOR_COL_NAME)\
.setPredictionCol(PREDICTION_COL_NAME)\
.lazyPrintMetrics()
);
BatchOperator.execute();
sw.stop();
print(sw.getElapsedTimeSpan());
```python
```