本章包括下面各节:
10.1 整体流程
10.1.1 特征哑元化
10.1.2 特征的重要性
10.2 减少模型特征的个数
10.3 离散特征转化
10.3.1 独热编码
10.3.2 特征哈希
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink; import org.apache.flink.api.java.tuple.Tuple2; import com.alibaba.alink.common.utils.Stopwatch; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.classification.LogisticRegressionPredictBatchOp; import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp; import com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp; import com.alibaba.alink.operator.batch.source.AkSourceBatchOp; import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; import com.alibaba.alink.operator.common.linear.LinearModelTrainInfo; import com.alibaba.alink.params.feature.HasEncodeWithoutWoe.Encode; import com.alibaba.alink.pipeline.Pipeline; import com.alibaba.alink.pipeline.classification.LogisticRegression; import com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler; import com.alibaba.alink.pipeline.feature.FeatureHasher; import com.alibaba.alink.pipeline.feature.OneHotEncoder; import org.apache.commons.lang3.ArrayUtils; import java.io.File; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.function.Consumer; public class Chap10 { static final String DATA_DIR = Utils.ROOT_DIR + "german_credit" + File.separator; static final String ORIGIN_FILE = "german.data"; static final String TRAIN_FILE = "train.ak"; static final String TEST_FILE = "test.ak"; private static final String[] COL_NAMES = new String[] { "status", "duration", "credit_history", "purpose", "credit_amount", "savings", "employment", "installment_rate", "marriage_sex", "debtors", "residence", "property", "age", "other_plan", "housing", "number_credits", "job", "maintenance_num", "telephone", "foreign_worker", "class" }; private static final String[] COL_TYPES = new String[] { "string", "int", "string", "string", "int", "string", "string", "int", "string", "string", "int", "string", "int", "string", "string", "int", "string", "int", "string", "string", "int" }; static final String CLAUSE_CREATE_FEATURES = "(case status when 'A11' then 1 else 0 end) as status_A11," + "(case status when 'A12' then 1 else 0 end) as status_A12," + "(case status when 'A13' then 1 else 0 end) as status_A13," + "(case status when 'A14' then 1 else 0 end) as status_A14," + "duration," + "(case credit_history when 'A30' then 1 else 0 end) as credit_history_A30," + "(case credit_history when 'A31' then 1 else 0 end) as credit_history_A31," + "(case credit_history when 'A32' then 1 else 0 end) as credit_history_A32," + "(case credit_history when 'A33' then 1 else 0 end) as credit_history_A33," + "(case credit_history when 'A34' then 1 else 0 end) as credit_history_A34," + "(case purpose when 'A40' then 1 else 0 end) as purpose_A40," + "(case purpose when 'A41' then 1 else 0 end) as purpose_A41," + "(case purpose when 'A42' then 1 else 0 end) as purpose_A42," + "(case purpose when 'A43' then 1 else 0 end) as purpose_A43," + "(case purpose when 'A44' then 1 else 0 end) as purpose_A44," + "(case purpose when 'A45' then 1 else 0 end) as purpose_A45," + "(case purpose when 'A46' then 1 else 0 end) as purpose_A46," + "(case purpose when 'A47' then 1 else 0 end) as purpose_A47," + "(case purpose when 'A48' then 1 else 0 end) as purpose_A48," + "(case purpose when 'A49' then 1 else 0 end) as purpose_A49," + "(case purpose when 'A410' then 1 else 0 end) as purpose_A410," + "credit_amount," + "(case savings when 'A61' then 1 else 0 end) as savings_A61," + "(case savings when 'A62' then 1 else 0 end) as savings_A62," + "(case savings when 'A63' then 1 else 0 end) as savings_A63," + "(case savings when 'A64' then 1 else 0 end) as savings_A64," + "(case savings when 'A65' then 1 else 0 end) as savings_A65," + "(case employment when 'A71' then 1 else 0 end) as employment_A71," + "(case employment when 'A72' then 1 else 0 end) as employment_A72," + "(case employment when 'A73' then 1 else 0 end) as employment_A73," + "(case employment when 'A74' then 1 else 0 end) as employment_A74," + "(case employment when 'A75' then 1 else 0 end) as employment_A75," + "installment_rate," + "(case marriage_sex when 'A91' then 1 else 0 end) as marriage_sex_A91," + "(case marriage_sex when 'A92' then 1 else 0 end) as marriage_sex_A92," + "(case marriage_sex when 'A93' then 1 else 0 end) as marriage_sex_A93," + "(case marriage_sex when 'A94' then 1 else 0 end) as marriage_sex_A94," + "(case marriage_sex when 'A95' then 1 else 0 end) as marriage_sex_A95," + "(case debtors when 'A101' then 1 else 0 end) as debtors_A101," + "(case debtors when 'A102' then 1 else 0 end) as debtors_A102," + "(case debtors when 'A103' then 1 else 0 end) as debtors_A103," + "residence," + "(case property when 'A121' then 1 else 0 end) as property_A121," + "(case property when 'A122' then 1 else 0 end) as property_A122," + "(case property when 'A123' then 1 else 0 end) as property_A123," + "(case property when 'A124' then 1 else 0 end) as property_A124," + "age," + "(case other_plan when 'A141' then 1 else 0 end) as other_plan_A141," + "(case other_plan when 'A142' then 1 else 0 end) as other_plan_A142," + "(case other_plan when 'A143' then 1 else 0 end) as other_plan_A143," + "(case housing when 'A151' then 1 else 0 end) as housing_A151," + "(case housing when 'A152' then 1 else 0 end) as housing_A152," + "(case housing when 'A153' then 1 else 0 end) as housing_A153," + "number_credits," + "(case job when 'A171' then 1 else 0 end) as job_A171," + "(case job when 'A172' then 1 else 0 end) as job_A172," + "(case job when 'A173' then 1 else 0 end) as job_A173," + "(case job when 'A174' then 1 else 0 end) as job_A174," + "maintenance_num," + "(case telephone when 'A192' then 1 else 0 end) as telephone," + "(case foreign_worker when 'A201' then 1 else 0 end) as foreign_worker," + "class "; static String LABEL_COL_NAME = "class"; static String[] FEATURE_COL_NAMES = ArrayUtils.removeElements(COL_NAMES, new String[] {LABEL_COL_NAME}); static final String[] NUMERIC_FEATURE_COL_NAMES = new String[] { "duration", "credit_amount", "installment_rate", "residence", "age", "number_credits", "maintenance_num" }; static final String[] CATEGORY_FEATURE_COL_NAMES = ArrayUtils.removeElements(FEATURE_COL_NAMES, NUMERIC_FEATURE_COL_NAMES); static String VEC_COL_NAME = "vec"; static String PREDICTION_COL_NAME = "pred"; static String PRED_DETAIL_COL_NAME = "predinfo"; public static void main(String[] args) throws Exception { BatchOperator.setParallelism(1); c_0(); c_1(); c_2(); c_3_1(); c_3_2(); } static void c_0() throws Exception { CsvSourceBatchOp source = new CsvSourceBatchOp() .setFilePath(DATA_DIR + ORIGIN_FILE) .setSchemaStr(Utils.generateSchemaString(COL_NAMES, COL_TYPES)) .setFieldDelimiter(" "); source .lazyPrint(5, "< origin data >") .lazyPrintStatistics(); BatchOperator.execute(); Utils.splitTrainTestIfNotExist(source, DATA_DIR + TRAIN_FILE, DATA_DIR + TEST_FILE, 0.8); } static void c_1() throws Exception { BatchOperator <?> train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE).select(CLAUSE_CREATE_FEATURES); BatchOperator <?> test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE).select(CLAUSE_CREATE_FEATURES); String[] new_features = ArrayUtils.removeElement(train_data.getColNames(), LABEL_COL_NAME); train_data.lazyPrint(5, "< new features >"); LogisticRegressionTrainBatchOp trainer = new LogisticRegressionTrainBatchOp() .setFeatureCols(new_features) .setLabelCol(LABEL_COL_NAME); LogisticRegressionPredictBatchOp predictor = new LogisticRegressionPredictBatchOp() .setPredictionCol(PREDICTION_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME); train_data.link(trainer); predictor.linkFrom(trainer, test_data); trainer .lazyPrintTrainInfo() .lazyCollectTrainInfo(new Consumer <LinearModelTrainInfo>() { @Override public void accept(LinearModelTrainInfo linearModelTrainInfo) { printImportance( linearModelTrainInfo.getColNames(), linearModelTrainInfo.getImportance() ); } } ); predictor.link( new EvalBinaryClassBatchOp() .setPositiveLabelValueString("2") .setLabelCol(LABEL_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) .lazyPrintMetrics() ); BatchOperator.execute(); } public static void printImportance(String[] colNames, double[] importance) { ArrayList <Tuple2 <String, Double>> list = new ArrayList <>(); for (int i = 0; i < colNames.length; i++) { list.add(Tuple2.of(colNames[i], importance[i])); } Collections.sort(list, new Comparator <Tuple2 <String, Double>>() { @Override public int compare(Tuple2 <String, Double> o1, Tuple2 <String, Double> o2) { return -(o1.f1).compareTo(o2.f1); } }); StringBuilder sbd = new StringBuilder(); for (int i = 0; i < list.size(); i++) { sbd.append(i + 1).append(" \t") .append(list.get(i).f0).append(" \t") .append(list.get(i).f1).append("\n"); } System.out.print(sbd.toString()); } static void c_2() throws Exception { BatchOperator <?> train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE).select(CLAUSE_CREATE_FEATURES); BatchOperator <?> test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE).select(CLAUSE_CREATE_FEATURES); String[] new_features = ArrayUtils.removeElement(train_data.getColNames(), LABEL_COL_NAME); train_data.lazyPrint(5, "< new features >"); LogisticRegressionTrainBatchOp trainer = new LogisticRegressionTrainBatchOp() .setFeatureCols(new_features) .setLabelCol(LABEL_COL_NAME) .setL1(0.01); LogisticRegressionPredictBatchOp predictor = new LogisticRegressionPredictBatchOp() .setPredictionCol(PREDICTION_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME); train_data.link(trainer); predictor.linkFrom(trainer, test_data); trainer .lazyPrintTrainInfo() .lazyCollectTrainInfo( new Consumer <LinearModelTrainInfo>() { @Override public void accept(LinearModelTrainInfo linearModelTrainInfo) { printImportance( linearModelTrainInfo.getColNames(), linearModelTrainInfo.getImportance() ); } } ); predictor.link( new EvalBinaryClassBatchOp() .setPositiveLabelValueString("2") .setLabelCol(LABEL_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) .lazyPrintMetrics() ); BatchOperator.execute(); } static void c_3_1() throws Exception { BatchOperator <?> train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE); BatchOperator <?> test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); Pipeline pipeline = new Pipeline() .add( new OneHotEncoder() .setSelectedCols(CATEGORY_FEATURE_COL_NAMES) .setEncode(Encode.VECTOR) ) .add( new VectorAssembler() .setSelectedCols(FEATURE_COL_NAMES) .setOutputCol(VEC_COL_NAME) ) .add( new LogisticRegression() .setVectorCol(VEC_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) ); pipeline .fit(train_data) .transform(test_data) .link( new EvalBinaryClassBatchOp() .setPositiveLabelValueString("2") .setLabelCol(LABEL_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) .lazyPrintMetrics() ); BatchOperator.execute(); } static void c_3_2() throws Exception { BatchOperator <?> train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE); BatchOperator <?> test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); Pipeline pipeline = new Pipeline() .add( new FeatureHasher() .setSelectedCols(FEATURE_COL_NAMES) .setCategoricalCols(CATEGORY_FEATURE_COL_NAMES) .setOutputCol(VEC_COL_NAME) ) .add( new LogisticRegression() .setVectorCol(VEC_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) ); pipeline .fit(train_data) .transform(test_data) .link( new EvalBinaryClassBatchOp() .setPositiveLabelValueString("2") .setLabelCol(LABEL_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) .lazyPrintMetrics() ); BatchOperator.execute(); } }