本章包括下面各节:
9.1 朴素贝叶斯模型
9.2 决策树模型
9.2.1 决策树的分裂指标定义
9.2.2 常用的决策树算法
9.2.3 指标计算示例
9.2.4 分类树与回归树
9.2.5 经典的决策树示例
9.3 数据探索
9.4 使用朴素贝叶斯方法
9.5 蘑菇分类的决策树
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink;
import org.apache.flink.types.Row;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.C45TrainBatchOp;
import com.alibaba.alink.operator.batch.classification.DecisionTreePredictBatchOp;
import com.alibaba.alink.operator.batch.classification.DecisionTreeTrainBatchOp;
import com.alibaba.alink.operator.batch.classification.NaiveBayesModelInfo;
import com.alibaba.alink.operator.batch.classification.NaiveBayesPredictBatchOp;
import com.alibaba.alink.operator.batch.classification.NaiveBayesTrainBatchOp;
import com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp;
import com.alibaba.alink.operator.batch.feature.ChiSqSelectorBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
import com.alibaba.alink.operator.common.tree.TreeModelInfo.DecisionTreeModelInfo;
import com.alibaba.alink.params.feature.BasedChisqSelectorParams.SelectorType;
import com.alibaba.alink.params.shared.tree.HasIndividualTreeType.TreeType;
import org.apache.commons.lang3.ArrayUtils;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map.Entry;
import java.util.function.Consumer;
public class Chap09 {
static final String DATA_DIR = Utils.ROOT_DIR + "mushroom" + File.separator;
static final String TEST_FILE = "test.ak";
static final String TRAIN_FILE = "train.ak";
static final String ORIGIN_FILE = "agaricus-lepiota.data";
static final String[] COL_NAMES = new String[] {
"class",
"cap_shape", "cap_surface", "cap_color", "bruises", "odor",
"gill_attachment", "gill_spacing", "gill_size", "gill_color",
"stalk_shape", "stalk_root", "stalk_surface_above_ring", "stalk_surface_below_ring",
"stalk_color_above_ring", "stalk_color_below_ring",
"veil_type", "veil_color",
"ring_number", "ring_type", "spore_print_color", "population", "habitat"
};
static final String[] COL_TYPES = new String[] {
"string",
"string", "string", "string", "string", "string",
"string", "string", "string", "string", "string",
"string", "string", "string", "string", "string",
"string", "string", "string", "string", "string",
"string", "string"
};
static final String LABEL_COL_NAME = "class";
static final String[] FEATURE_COL_NAMES = ArrayUtils.removeElement(COL_NAMES, LABEL_COL_NAME);
static final String PREDICTION_COL_NAME = "pred";
static final String PRED_DETAIL_COL_NAME = "predInfo";
public static void main(String[] args) throws Exception {
BatchOperator.setParallelism(1);
c_2_5();
c_3();
c_4_a();
c_4_b();
c_5();
}
static void c_2_5() throws Exception {
MemSourceBatchOp source = new MemSourceBatchOp(
new Row[] {
Row.of("sunny", 85.0, 85.0, false, "no"),
Row.of("sunny", 80.0, 90.0, true, "no"),
Row.of("overcast", 83.0, 78.0, false, "yes"),
Row.of("rainy", 70.0, 96.0, false, "yes"),
Row.of("rainy", 68.0, 80.0, false, "yes"),
Row.of("rainy", 65.0, 70.0, true, "no"),
Row.of("overcast", 64.0, 65.0, true, "yes"),
Row.of("sunny", 72.0, 95.0, false, "no"),
Row.of("sunny", 69.0, 70.0, false, "yes"),
Row.of("rainy", 75.0, 80.0, false, "yes"),
Row.of("sunny", 75.0, 70.0, true, "yes"),
Row.of("overcast", 72.0, 90.0, true, "yes"),
Row.of("overcast", 81.0, 75.0, false, "yes"),
Row.of("rainy", 71.0, 80.0, true, "no")
},
new String[] {"Outlook", "Temperature", "Humidity", "Windy", "Play"}
);
source.lazyPrint(-1);
source
.link(
new C45TrainBatchOp()
.setFeatureCols("Outlook", "Temperature", "Humidity", "Windy")
.setCategoricalCols("Outlook", "Windy")
.setLabelCol("Play")
.lazyPrintModelInfo()
.lazyCollectModelInfo(new Consumer <DecisionTreeModelInfo>() {
@Override
public void accept(DecisionTreeModelInfo decisionTreeModelInfo) {
try {
decisionTreeModelInfo.saveTreeAsImage(
DATA_DIR + "weather_tree_model.png", true);
} catch (IOException e) {
e.printStackTrace();
}
}
})
);
BatchOperator.execute();
}
static void c_3() throws Exception {
CsvSourceBatchOp source = new CsvSourceBatchOp()
.setFilePath(DATA_DIR + ORIGIN_FILE)
.setSchemaStr(Utils.generateSchemaString(COL_NAMES, COL_TYPES));
source.lazyPrint(5, "< origin data >");
Utils.splitTrainTestIfNotExist(source, DATA_DIR + TRAIN_FILE, DATA_DIR + TEST_FILE, 0.9);
new AkSourceBatchOp()
.setFilePath(DATA_DIR + TRAIN_FILE)
.link(
new ChiSqSelectorBatchOp()
.setSelectorType(SelectorType.NumTopFeatures)
.setNumTopFeatures(3)
.setSelectedCols(FEATURE_COL_NAMES)
.setLabelCol(LABEL_COL_NAME)
.lazyPrintModelInfo("< Chi-Square Selector >")
);
new AkSourceBatchOp()
.setFilePath(DATA_DIR + TRAIN_FILE)
.select("veil_type")
.distinct()
.lazyPrint(100);
BatchOperator.execute();
}
static void c_4_a() throws Exception {
AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
NaiveBayesTrainBatchOp trainer =
new NaiveBayesTrainBatchOp()
.setFeatureCols(FEATURE_COL_NAMES)
.setCategoricalCols(FEATURE_COL_NAMES)
.setLabelCol(LABEL_COL_NAME);
NaiveBayesPredictBatchOp predictor = new NaiveBayesPredictBatchOp()
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME);
train_data.link(trainer);
predictor.linkFrom(trainer, test_data);
trainer.lazyPrintModelInfo();
trainer.lazyCollectModelInfo(new Consumer <NaiveBayesModelInfo>() {
@Override
public void accept(NaiveBayesModelInfo naiveBayesModelInfo) {
StringBuilder sbd = new StringBuilder();
for (String feature : new String[] {"odor", "spore_print_color", "gill_color"}) {
HashMap <Object, HashMap <Object, Double>> map2 =
naiveBayesModelInfo.getCategoryFeatureInfo().get(feature);
sbd.append("\nfeature:").append(feature);
for (Entry <Object, HashMap <Object, Double>> entry : map2.entrySet()) {
sbd.append("\n").append(entry.getKey()).append(" : ")
.append(entry.getValue().toString());
}
}
System.out.println(sbd.toString());
}
});
predictor.lazyPrint(10, "< Prediction >");
predictor
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("p")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics()
);
BatchOperator.execute();
}
static void c_4_b() throws Exception {
AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
NaiveBayesTrainBatchOp trainer =
new NaiveBayesTrainBatchOp()
.setFeatureCols("odor", "gill_color")
.setCategoricalCols("odor", "gill_color")
.setLabelCol(LABEL_COL_NAME);
NaiveBayesPredictBatchOp predictor = new NaiveBayesPredictBatchOp()
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME);
train_data.link(trainer);
predictor.linkFrom(trainer, test_data);
trainer.lazyCollectModelInfo(new Consumer <NaiveBayesModelInfo>() {
@Override
public void accept(NaiveBayesModelInfo naiveBayesModelInfo) {
StringBuilder sbd = new StringBuilder();
for (String feature : new String[] {"odor", "gill_color"}) {
HashMap <Object, HashMap <Object, Double>> map2 =
naiveBayesModelInfo.getCategoryFeatureInfo().get(feature);
sbd.append("\nfeature:").append(feature);
for (Entry <Object, HashMap <Object, Double>> entry : map2.entrySet()) {
sbd.append("\n").append(entry.getKey()).append(" : ")
.append(entry.getValue().toString());
}
}
System.out.println(sbd.toString());
}
});
predictor
.lazyPrint(10, "< Prediction >")
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("p")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics()
);
BatchOperator.execute();
}
static void c_5() throws Exception {
BatchOperator train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
BatchOperator test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
for (TreeType treeType : new TreeType[] {TreeType.GINI, TreeType.INFOGAIN, TreeType.INFOGAINRATIO}) {
BatchOperator <?> model = train_data
.link(
new DecisionTreeTrainBatchOp()
.setTreeType(treeType)
.setFeatureCols(FEATURE_COL_NAMES)
.setCategoricalCols(FEATURE_COL_NAMES)
.setLabelCol(LABEL_COL_NAME)
.lazyPrintModelInfo("< " + treeType.toString() + " >")
.lazyCollectModelInfo(new Consumer <DecisionTreeModelInfo>() {
@Override
public void accept(DecisionTreeModelInfo decisionTreeModelInfo) {
try {
decisionTreeModelInfo.saveTreeAsImage(
DATA_DIR + "tree_" + treeType.toString() + ".jpg", true);
} catch (IOException e) {
e.printStackTrace();
}
}
})
);
DecisionTreePredictBatchOp predictor = new DecisionTreePredictBatchOp()
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME);
predictor.linkFrom(model, test_data);
predictor.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("p")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("< " + treeType.toString() + " >")
);
}
BatchOperator.execute();
}
}