本章包括下面各节:
10.1 整体流程
10.1.1 特征哑元化
10.1.2 特征的重要性
10.2 减少模型特征的个数
10.3 离散特征转化
10.3.1 独热编码
10.3.2 特征哈希
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink;
import org.apache.flink.api.java.tuple.Tuple2;
import com.alibaba.alink.common.utils.Stopwatch;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.LogisticRegressionPredictBatchOp;
import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;
import com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.common.linear.LinearModelTrainInfo;
import com.alibaba.alink.params.feature.HasEncodeWithoutWoe.Encode;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler;
import com.alibaba.alink.pipeline.feature.FeatureHasher;
import com.alibaba.alink.pipeline.feature.OneHotEncoder;
import org.apache.commons.lang3.ArrayUtils;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.function.Consumer;
public class Chap10 {
static final String DATA_DIR = Utils.ROOT_DIR + "german_credit" + File.separator;
static final String ORIGIN_FILE = "german.data";
static final String TRAIN_FILE = "train.ak";
static final String TEST_FILE = "test.ak";
private static final String[] COL_NAMES = new String[] {
"status", "duration", "credit_history", "purpose", "credit_amount",
"savings", "employment", "installment_rate", "marriage_sex", "debtors",
"residence", "property", "age", "other_plan", "housing",
"number_credits", "job", "maintenance_num", "telephone", "foreign_worker",
"class"
};
private static final String[] COL_TYPES = new String[] {
"string", "int", "string", "string", "int",
"string", "string", "int", "string", "string",
"int", "string", "int", "string", "string",
"int", "string", "int", "string", "string",
"int"
};
static final String CLAUSE_CREATE_FEATURES
= "(case status when 'A11' then 1 else 0 end) as status_A11,"
+ "(case status when 'A12' then 1 else 0 end) as status_A12,"
+ "(case status when 'A13' then 1 else 0 end) as status_A13,"
+ "(case status when 'A14' then 1 else 0 end) as status_A14,"
+ "duration,"
+ "(case credit_history when 'A30' then 1 else 0 end) as credit_history_A30,"
+ "(case credit_history when 'A31' then 1 else 0 end) as credit_history_A31,"
+ "(case credit_history when 'A32' then 1 else 0 end) as credit_history_A32,"
+ "(case credit_history when 'A33' then 1 else 0 end) as credit_history_A33,"
+ "(case credit_history when 'A34' then 1 else 0 end) as credit_history_A34,"
+ "(case purpose when 'A40' then 1 else 0 end) as purpose_A40,"
+ "(case purpose when 'A41' then 1 else 0 end) as purpose_A41,"
+ "(case purpose when 'A42' then 1 else 0 end) as purpose_A42,"
+ "(case purpose when 'A43' then 1 else 0 end) as purpose_A43,"
+ "(case purpose when 'A44' then 1 else 0 end) as purpose_A44,"
+ "(case purpose when 'A45' then 1 else 0 end) as purpose_A45,"
+ "(case purpose when 'A46' then 1 else 0 end) as purpose_A46,"
+ "(case purpose when 'A47' then 1 else 0 end) as purpose_A47,"
+ "(case purpose when 'A48' then 1 else 0 end) as purpose_A48,"
+ "(case purpose when 'A49' then 1 else 0 end) as purpose_A49,"
+ "(case purpose when 'A410' then 1 else 0 end) as purpose_A410,"
+ "credit_amount,"
+ "(case savings when 'A61' then 1 else 0 end) as savings_A61,"
+ "(case savings when 'A62' then 1 else 0 end) as savings_A62,"
+ "(case savings when 'A63' then 1 else 0 end) as savings_A63,"
+ "(case savings when 'A64' then 1 else 0 end) as savings_A64,"
+ "(case savings when 'A65' then 1 else 0 end) as savings_A65,"
+ "(case employment when 'A71' then 1 else 0 end) as employment_A71,"
+ "(case employment when 'A72' then 1 else 0 end) as employment_A72,"
+ "(case employment when 'A73' then 1 else 0 end) as employment_A73,"
+ "(case employment when 'A74' then 1 else 0 end) as employment_A74,"
+ "(case employment when 'A75' then 1 else 0 end) as employment_A75,"
+ "installment_rate,"
+ "(case marriage_sex when 'A91' then 1 else 0 end) as marriage_sex_A91,"
+ "(case marriage_sex when 'A92' then 1 else 0 end) as marriage_sex_A92,"
+ "(case marriage_sex when 'A93' then 1 else 0 end) as marriage_sex_A93,"
+ "(case marriage_sex when 'A94' then 1 else 0 end) as marriage_sex_A94,"
+ "(case marriage_sex when 'A95' then 1 else 0 end) as marriage_sex_A95,"
+ "(case debtors when 'A101' then 1 else 0 end) as debtors_A101,"
+ "(case debtors when 'A102' then 1 else 0 end) as debtors_A102,"
+ "(case debtors when 'A103' then 1 else 0 end) as debtors_A103,"
+ "residence,"
+ "(case property when 'A121' then 1 else 0 end) as property_A121,"
+ "(case property when 'A122' then 1 else 0 end) as property_A122,"
+ "(case property when 'A123' then 1 else 0 end) as property_A123,"
+ "(case property when 'A124' then 1 else 0 end) as property_A124,"
+ "age,"
+ "(case other_plan when 'A141' then 1 else 0 end) as other_plan_A141,"
+ "(case other_plan when 'A142' then 1 else 0 end) as other_plan_A142,"
+ "(case other_plan when 'A143' then 1 else 0 end) as other_plan_A143,"
+ "(case housing when 'A151' then 1 else 0 end) as housing_A151,"
+ "(case housing when 'A152' then 1 else 0 end) as housing_A152,"
+ "(case housing when 'A153' then 1 else 0 end) as housing_A153,"
+ "number_credits,"
+ "(case job when 'A171' then 1 else 0 end) as job_A171,"
+ "(case job when 'A172' then 1 else 0 end) as job_A172,"
+ "(case job when 'A173' then 1 else 0 end) as job_A173,"
+ "(case job when 'A174' then 1 else 0 end) as job_A174,"
+ "maintenance_num,"
+ "(case telephone when 'A192' then 1 else 0 end) as telephone,"
+ "(case foreign_worker when 'A201' then 1 else 0 end) as foreign_worker,"
+ "class ";
static String LABEL_COL_NAME = "class";
static String[] FEATURE_COL_NAMES = ArrayUtils.removeElements(COL_NAMES, new String[] {LABEL_COL_NAME});
static final String[] NUMERIC_FEATURE_COL_NAMES = new String[] {
"duration", "credit_amount", "installment_rate", "residence", "age", "number_credits", "maintenance_num"
};
static final String[] CATEGORY_FEATURE_COL_NAMES =
ArrayUtils.removeElements(FEATURE_COL_NAMES, NUMERIC_FEATURE_COL_NAMES);
static String VEC_COL_NAME = "vec";
static String PREDICTION_COL_NAME = "pred";
static String PRED_DETAIL_COL_NAME = "predinfo";
public static void main(String[] args) throws Exception {
BatchOperator.setParallelism(1);
c_0();
c_1();
c_2();
c_3_1();
c_3_2();
}
static void c_0() throws Exception {
CsvSourceBatchOp source = new CsvSourceBatchOp()
.setFilePath(DATA_DIR + ORIGIN_FILE)
.setSchemaStr(Utils.generateSchemaString(COL_NAMES, COL_TYPES))
.setFieldDelimiter(" ");
source
.lazyPrint(5, "< origin data >")
.lazyPrintStatistics();
BatchOperator.execute();
Utils.splitTrainTestIfNotExist(source, DATA_DIR + TRAIN_FILE, DATA_DIR + TEST_FILE, 0.8);
}
static void c_1() throws Exception {
BatchOperator <?> train_data =
new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE).select(CLAUSE_CREATE_FEATURES);
BatchOperator <?> test_data =
new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE).select(CLAUSE_CREATE_FEATURES);
String[] new_features = ArrayUtils.removeElement(train_data.getColNames(), LABEL_COL_NAME);
train_data.lazyPrint(5, "< new features >");
LogisticRegressionTrainBatchOp trainer = new LogisticRegressionTrainBatchOp()
.setFeatureCols(new_features)
.setLabelCol(LABEL_COL_NAME);
LogisticRegressionPredictBatchOp predictor = new LogisticRegressionPredictBatchOp()
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME);
train_data.link(trainer);
predictor.linkFrom(trainer, test_data);
trainer
.lazyPrintTrainInfo()
.lazyCollectTrainInfo(new Consumer <LinearModelTrainInfo>() {
@Override
public void accept(LinearModelTrainInfo linearModelTrainInfo) {
printImportance(
linearModelTrainInfo.getColNames(),
linearModelTrainInfo.getImportance()
);
}
}
);
predictor.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("2")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics()
);
BatchOperator.execute();
}
public static void printImportance(String[] colNames, double[] importance) {
ArrayList <Tuple2 <String, Double>> list = new ArrayList <>();
for (int i = 0; i < colNames.length; i++) {
list.add(Tuple2.of(colNames[i], importance[i]));
}
Collections.sort(list, new Comparator <Tuple2 <String, Double>>() {
@Override
public int compare(Tuple2 <String, Double> o1, Tuple2 <String, Double> o2) {
return -(o1.f1).compareTo(o2.f1);
}
});
StringBuilder sbd = new StringBuilder();
for (int i = 0; i < list.size(); i++) {
sbd.append(i + 1).append(" \t")
.append(list.get(i).f0).append(" \t")
.append(list.get(i).f1).append("\n");
}
System.out.print(sbd.toString());
}
static void c_2() throws Exception {
BatchOperator <?> train_data =
new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE).select(CLAUSE_CREATE_FEATURES);
BatchOperator <?> test_data =
new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE).select(CLAUSE_CREATE_FEATURES);
String[] new_features = ArrayUtils.removeElement(train_data.getColNames(), LABEL_COL_NAME);
train_data.lazyPrint(5, "< new features >");
LogisticRegressionTrainBatchOp trainer = new LogisticRegressionTrainBatchOp()
.setFeatureCols(new_features)
.setLabelCol(LABEL_COL_NAME)
.setL1(0.01);
LogisticRegressionPredictBatchOp predictor = new LogisticRegressionPredictBatchOp()
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME);
train_data.link(trainer);
predictor.linkFrom(trainer, test_data);
trainer
.lazyPrintTrainInfo()
.lazyCollectTrainInfo(
new Consumer <LinearModelTrainInfo>() {
@Override
public void accept(LinearModelTrainInfo linearModelTrainInfo) {
printImportance(
linearModelTrainInfo.getColNames(),
linearModelTrainInfo.getImportance()
);
}
}
);
predictor.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("2")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics()
);
BatchOperator.execute();
}
static void c_3_1() throws Exception {
BatchOperator <?> train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
BatchOperator <?> test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
Pipeline pipeline = new Pipeline()
.add(
new OneHotEncoder()
.setSelectedCols(CATEGORY_FEATURE_COL_NAMES)
.setEncode(Encode.VECTOR)
)
.add(
new VectorAssembler()
.setSelectedCols(FEATURE_COL_NAMES)
.setOutputCol(VEC_COL_NAME)
)
.add(
new LogisticRegression()
.setVectorCol(VEC_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
);
pipeline
.fit(train_data)
.transform(test_data)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("2")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics()
);
BatchOperator.execute();
}
static void c_3_2() throws Exception {
BatchOperator <?> train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
BatchOperator <?> test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
Pipeline pipeline = new Pipeline()
.add(
new FeatureHasher()
.setSelectedCols(FEATURE_COL_NAMES)
.setCategoricalCols(CATEGORY_FEATURE_COL_NAMES)
.setOutputCol(VEC_COL_NAME)
)
.add(
new LogisticRegression()
.setVectorCol(VEC_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
);
pipeline
.fit(train_data)
.transform(test_data)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("2")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics()
);
BatchOperator.execute();
}
}