本章包括下面各节:
20.1 示例一:尝试正则系数
20.2 示例二:搜索GBDT超参数
20.3 示例三:最佳聚类个数
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Python)》,这里为本章对应的示例代码。
from pyalink.alink import * useLocalEnv(1) from utils import * import os import pandas as pd
#c_1 DATA_DIR = ROOT_DIR + "german_credit" + os.sep TRAIN_FILE = "train.ak"; TEST_FILE = "test.ak"; CLAUSE_CREATE_FEATURES = "(case status when 'A11' then 1 else 0 end) as status_A11,"\ + "(case status when 'A12' then 1 else 0 end) as status_A12,"\ + "(case status when 'A13' then 1 else 0 end) as status_A13,"\ + "(case status when 'A14' then 1 else 0 end) as status_A14,"\ + "duration,"\ + "(case credit_history when 'A30' then 1 else 0 end) as credit_history_A30,"\ + "(case credit_history when 'A31' then 1 else 0 end) as credit_history_A31,"\ + "(case credit_history when 'A32' then 1 else 0 end) as credit_history_A32,"\ + "(case credit_history when 'A33' then 1 else 0 end) as credit_history_A33,"\ + "(case credit_history when 'A34' then 1 else 0 end) as credit_history_A34,"\ + "(case purpose when 'A40' then 1 else 0 end) as purpose_A40,"\ + "(case purpose when 'A41' then 1 else 0 end) as purpose_A41,"\ + "(case purpose when 'A42' then 1 else 0 end) as purpose_A42,"\ + "(case purpose when 'A43' then 1 else 0 end) as purpose_A43,"\ + "(case purpose when 'A44' then 1 else 0 end) as purpose_A44,"\ + "(case purpose when 'A45' then 1 else 0 end) as purpose_A45,"\ + "(case purpose when 'A46' then 1 else 0 end) as purpose_A46,"\ + "(case purpose when 'A47' then 1 else 0 end) as purpose_A47,"\ + "(case purpose when 'A48' then 1 else 0 end) as purpose_A48,"\ + "(case purpose when 'A49' then 1 else 0 end) as purpose_A49,"\ + "(case purpose when 'A410' then 1 else 0 end) as purpose_A410,"\ + "credit_amount,"\ + "(case savings when 'A61' then 1 else 0 end) as savings_A61,"\ + "(case savings when 'A62' then 1 else 0 end) as savings_A62,"\ + "(case savings when 'A63' then 1 else 0 end) as savings_A63,"\ + "(case savings when 'A64' then 1 else 0 end) as savings_A64,"\ + "(case savings when 'A65' then 1 else 0 end) as savings_A65,"\ + "(case employment when 'A71' then 1 else 0 end) as employment_A71,"\ + "(case employment when 'A72' then 1 else 0 end) as employment_A72,"\ + "(case employment when 'A73' then 1 else 0 end) as employment_A73,"\ + "(case employment when 'A74' then 1 else 0 end) as employment_A74,"\ + "(case employment when 'A75' then 1 else 0 end) as employment_A75,"\ + "installment_rate,"\ + "(case marriage_sex when 'A91' then 1 else 0 end) as marriage_sex_A91,"\ + "(case marriage_sex when 'A92' then 1 else 0 end) as marriage_sex_A92,"\ + "(case marriage_sex when 'A93' then 1 else 0 end) as marriage_sex_A93,"\ + "(case marriage_sex when 'A94' then 1 else 0 end) as marriage_sex_A94,"\ + "(case marriage_sex when 'A95' then 1 else 0 end) as marriage_sex_A95,"\ + "(case debtors when 'A101' then 1 else 0 end) as debtors_A101,"\ + "(case debtors when 'A102' then 1 else 0 end) as debtors_A102,"\ + "(case debtors when 'A103' then 1 else 0 end) as debtors_A103,"\ + "residence,"\ + "(case property when 'A121' then 1 else 0 end) as property_A121,"\ + "(case property when 'A122' then 1 else 0 end) as property_A122,"\ + "(case property when 'A123' then 1 else 0 end) as property_A123,"\ + "(case property when 'A124' then 1 else 0 end) as property_A124,"\ + "age,"\ + "(case other_plan when 'A141' then 1 else 0 end) as other_plan_A141,"\ + "(case other_plan when 'A142' then 1 else 0 end) as other_plan_A142,"\ + "(case other_plan when 'A143' then 1 else 0 end) as other_plan_A143,"\ + "(case housing when 'A151' then 1 else 0 end) as housing_A151,"\ + "(case housing when 'A152' then 1 else 0 end) as housing_A152,"\ + "(case housing when 'A153' then 1 else 0 end) as housing_A153,"\ + "number_credits,"\ + "(case job when 'A171' then 1 else 0 end) as job_A171,"\ + "(case job when 'A172' then 1 else 0 end) as job_A172,"\ + "(case job when 'A173' then 1 else 0 end) as job_A173,"\ + "(case job when 'A174' then 1 else 0 end) as job_A174,"\ + "maintenance_num,"\ + "(case telephone when 'A192' then 1 else 0 end) as telephone,"\ + "(case foreign_worker when 'A201' then 1 else 0 end) as foreign_worker,"\ + "class " LABEL_COL_NAME = "class"; VEC_COL_NAME = "vec"; PREDICTION_COL_NAME = "pred"; PRED_DETAIL_COL_NAME = "predinfo"; train_data = AkSourceBatchOp()\ .setFilePath(DATA_DIR + TRAIN_FILE)\ .select(CLAUSE_CREATE_FEATURES); test_data = AkSourceBatchOp()\ .setFilePath(DATA_DIR + TEST_FILE)\ .select(CLAUSE_CREATE_FEATURES); new_features = train_data.getColNames() new_features.remove(LABEL_COL_NAME) lr = LogisticRegression()\ .setFeatureCols(new_features)\ .setLabelCol(LABEL_COL_NAME)\ .setPredictionCol(PREDICTION_COL_NAME)\ .setPredictionDetailCol(PRED_DETAIL_COL_NAME); pipeline = Pipeline().add(lr); gridSearch = GridSearchCV()\ .setNumFolds(5)\ .setEstimator(pipeline)\ .setParamGrid( ParamGrid()\ .addGrid(lr, 'L_1', [0.0000001, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0, 10.0]) )\ .setTuningEvaluator( BinaryClassificationTuningEvaluator()\ .setLabelCol(LABEL_COL_NAME)\ .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\ .setTuningBinaryClassMetric('AUC') )\ .enableLazyPrintTrainInfo(); bestModel = gridSearch.fit(train_data); bestModel\ .transform(test_data)\ .link( EvalBinaryClassBatchOp()\ .setPositiveLabelValueString("2")\ .setLabelCol(LABEL_COL_NAME)\ .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\ .lazyPrintMetrics("GridSearchCV") ); BatchOperator.execute()
#c_2 DATA_DIR = ROOT_DIR + "tmall" + os.sep ORIGIN_FILE = "tmall.csv"; TRAIN_SAMPLE_FILE = "train_sample.ak"; LABEL_COL_NAME = "label"; PREDICTION_COL_NAME = "pred"; PRED_DETAIL_COL_NAME = "predInfo"; sw = Stopwatch(); sw.start(); train_sample = AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_SAMPLE_FILE); test_data = AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); featureColNames = train_sample.getColNames() featureColNames.remove(LABEL_COL_NAME) gbdt = GbdtClassifier()\ .setFeatureCols(featureColNames)\ .setLabelCol(LABEL_COL_NAME)\ .setPredictionCol(PREDICTION_COL_NAME)\ .setPredictionDetailCol(PRED_DETAIL_COL_NAME); randomSearch = RandomSearchTVSplit()\ .setNumIter(20)\ .setTrainRatio(0.8)\ .setEstimator(gbdt)\ .setParamDist( ParamDist()\ .addDist(gbdt, 'NUM_TREES', ValueDist.randArray([50, 100]))\ .addDist(gbdt, 'MAX_DEPTH', ValueDist.randInteger(4, 10))\ .addDist(gbdt, 'MAX_BINS', ValueDist.randArray([64, 128, 256, 512]))\ .addDist(gbdt, 'LEARNING_RATE', ValueDist.randArray([0.3, 0.1, 0.01])) )\ .setTuningEvaluator( BinaryClassificationTuningEvaluator()\ .setLabelCol(LABEL_COL_NAME)\ .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\ .setTuningBinaryClassMetric('F1') )\ .enableLazyPrintTrainInfo(); bestModel = randomSearch.fit(train_sample); bestModel\ .transform(test_data)\ .link( EvalBinaryClassBatchOp()\ .setPositiveLabelValueString("1")\ .setLabelCol(LABEL_COL_NAME)\ .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\ .lazyPrintMetrics() ); BatchOperator.execute(); sw.stop(); print(sw.getElapsedTimeSpan());
#c_3 DATA_DIR = ROOT_DIR + "iris" + os.sep VECTOR_FILE = "iris_vec.ak"; LABEL_COL_NAME = "category"; VECTOR_COL_NAME = "vec"; PREDICTION_COL_NAME = "cluster_id"; sw = Stopwatch(); sw.start(); source = AkSourceBatchOp().setFilePath(DATA_DIR + VECTOR_FILE); kmeans = KMeans()\ .setVectorCol(VECTOR_COL_NAME)\ .setPredictionCol(PREDICTION_COL_NAME); cv = GridSearchCV()\ .setNumFolds(4)\ .setEstimator(kmeans)\ .setParamGrid( ParamGrid()\ .addGrid(kmeans, 'K', [2, 3, 4, 5, 6])\ .addGrid(kmeans, 'DISTANCE_TYPE', ['EUCLIDEAN', 'COSINE']) )\ .setTuningEvaluator( ClusterTuningEvaluator()\ .setVectorCol(VECTOR_COL_NAME)\ .setPredictionCol(PREDICTION_COL_NAME)\ .setLabelCol(LABEL_COL_NAME)\ .setTuningClusterMetric('RI') )\ .enableLazyPrintTrainInfo(); bestModel = cv.fit(source); bestModel\ .transform(source)\ .link( EvalClusterBatchOp()\ .setLabelCol(LABEL_COL_NAME)\ .setVectorCol(VECTOR_COL_NAME)\ .setPredictionCol(PREDICTION_COL_NAME)\ .lazyPrintMetrics() ); BatchOperator.execute(); sw.stop(); print(sw.getElapsedTimeSpan());