Alink教程(Java版)

第9章 朴素贝叶斯模型与决策树模型

本章包括下面各节:
9.1 朴素贝叶斯模型
9.2 决策树模型
9.2.1 决策树的分裂指标定义
9.2.2 常用的决策树算法
9.2.3 指标计算示例
9.2.4 分类树与回归树
9.2.5 经典的决策树示例
9.3 数据探索
9.4 使用朴素贝叶斯方法
9.5 蘑菇分类的决策树

详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。

package com.alibaba.alink;

import org.apache.flink.types.Row;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.C45TrainBatchOp;
import com.alibaba.alink.operator.batch.classification.DecisionTreePredictBatchOp;
import com.alibaba.alink.operator.batch.classification.DecisionTreeTrainBatchOp;
import com.alibaba.alink.operator.batch.classification.NaiveBayesModelInfo;
import com.alibaba.alink.operator.batch.classification.NaiveBayesPredictBatchOp;
import com.alibaba.alink.operator.batch.classification.NaiveBayesTrainBatchOp;
import com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp;
import com.alibaba.alink.operator.batch.feature.ChiSqSelectorBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
import com.alibaba.alink.operator.common.tree.TreeModelInfo.DecisionTreeModelInfo;
import com.alibaba.alink.params.feature.BasedChisqSelectorParams.SelectorType;
import com.alibaba.alink.params.shared.tree.HasIndividualTreeType.TreeType;
import org.apache.commons.lang3.ArrayUtils;

import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map.Entry;
import java.util.function.Consumer;

public class Chap09 {

	static final String DATA_DIR = Utils.ROOT_DIR + "mushroom" + File.separator;

	static final String TEST_FILE = "test.ak";
	static final String TRAIN_FILE = "train.ak";
	static final String ORIGIN_FILE = "agaricus-lepiota.data";

	static final String[] COL_NAMES = new String[] {
		"class",
		"cap_shape", "cap_surface", "cap_color", "bruises", "odor",
		"gill_attachment", "gill_spacing", "gill_size", "gill_color",
		"stalk_shape", "stalk_root", "stalk_surface_above_ring", "stalk_surface_below_ring",
		"stalk_color_above_ring", "stalk_color_below_ring",
		"veil_type", "veil_color",
		"ring_number", "ring_type", "spore_print_color", "population", "habitat"
	};

	static final String[] COL_TYPES = new String[] {
		"string",
		"string", "string", "string", "string", "string",
		"string", "string", "string", "string", "string",
		"string", "string", "string", "string", "string",
		"string", "string", "string", "string", "string",
		"string", "string"
	};

	static final String LABEL_COL_NAME = "class";
	static final String[] FEATURE_COL_NAMES = ArrayUtils.removeElement(COL_NAMES, LABEL_COL_NAME);

	static final String PREDICTION_COL_NAME = "pred";
	static final String PRED_DETAIL_COL_NAME = "predInfo";

	public static void main(String[] args) throws Exception {

		BatchOperator.setParallelism(1);

		c_2_5();

		c_3();

		c_4_a();

		c_4_b();

		c_5();

	}

	static void c_2_5() throws Exception {
		MemSourceBatchOp source = new MemSourceBatchOp(
			new Row[] {
				Row.of("sunny", 85.0, 85.0, false, "no"),
				Row.of("sunny", 80.0, 90.0, true, "no"),
				Row.of("overcast", 83.0, 78.0, false, "yes"),
				Row.of("rainy", 70.0, 96.0, false, "yes"),
				Row.of("rainy", 68.0, 80.0, false, "yes"),
				Row.of("rainy", 65.0, 70.0, true, "no"),
				Row.of("overcast", 64.0, 65.0, true, "yes"),
				Row.of("sunny", 72.0, 95.0, false, "no"),
				Row.of("sunny", 69.0, 70.0, false, "yes"),
				Row.of("rainy", 75.0, 80.0, false, "yes"),
				Row.of("sunny", 75.0, 70.0, true, "yes"),
				Row.of("overcast", 72.0, 90.0, true, "yes"),
				Row.of("overcast", 81.0, 75.0, false, "yes"),
				Row.of("rainy", 71.0, 80.0, true, "no")
			},
			new String[] {"Outlook", "Temperature", "Humidity", "Windy", "Play"}
		);

		source.lazyPrint(-1);

		source
			.link(
				new C45TrainBatchOp()
					.setFeatureCols("Outlook", "Temperature", "Humidity", "Windy")
					.setCategoricalCols("Outlook", "Windy")
					.setLabelCol("Play")
					.lazyPrintModelInfo()
					.lazyCollectModelInfo(new Consumer <DecisionTreeModelInfo>() {
						@Override
						public void accept(DecisionTreeModelInfo decisionTreeModelInfo) {
							try {
								decisionTreeModelInfo.saveTreeAsImage(
									DATA_DIR + "weather_tree_model.png", true);
							} catch (IOException e) {
								e.printStackTrace();
							}
						}
					})
			);

		BatchOperator.execute();
	}

	static void c_3() throws Exception {

		CsvSourceBatchOp source = new CsvSourceBatchOp()
			.setFilePath(DATA_DIR + ORIGIN_FILE)
			.setSchemaStr(Utils.generateSchemaString(COL_NAMES, COL_TYPES));

		source.lazyPrint(5, "< origin data >");

		Utils.splitTrainTestIfNotExist(source, DATA_DIR + TRAIN_FILE, DATA_DIR + TEST_FILE, 0.9);

		new AkSourceBatchOp()
			.setFilePath(DATA_DIR + TRAIN_FILE)
			.link(
				new ChiSqSelectorBatchOp()
					.setSelectorType(SelectorType.NumTopFeatures)
					.setNumTopFeatures(3)
					.setSelectedCols(FEATURE_COL_NAMES)
					.setLabelCol(LABEL_COL_NAME)
					.lazyPrintModelInfo("< Chi-Square Selector >")
			);

		new AkSourceBatchOp()
			.setFilePath(DATA_DIR + TRAIN_FILE)
			.select("veil_type")
			.distinct()
			.lazyPrint(100);

		BatchOperator.execute();

	}

	static void c_4_a() throws Exception {

		AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
		AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

		NaiveBayesTrainBatchOp trainer =
			new NaiveBayesTrainBatchOp()
				.setFeatureCols(FEATURE_COL_NAMES)
				.setCategoricalCols(FEATURE_COL_NAMES)
				.setLabelCol(LABEL_COL_NAME);

		NaiveBayesPredictBatchOp predictor = new NaiveBayesPredictBatchOp()
			.setPredictionCol(PREDICTION_COL_NAME)
			.setPredictionDetailCol(PRED_DETAIL_COL_NAME);

		train_data.link(trainer);
		predictor.linkFrom(trainer, test_data);

		trainer.lazyPrintModelInfo();

		trainer.lazyCollectModelInfo(new Consumer <NaiveBayesModelInfo>() {
			@Override
			public void accept(NaiveBayesModelInfo naiveBayesModelInfo) {
				StringBuilder sbd = new StringBuilder();
				for (String feature : new String[] {"odor", "spore_print_color", "gill_color"}) {
					HashMap <Object, HashMap <Object, Double>> map2 =
						naiveBayesModelInfo.getCategoryFeatureInfo().get(feature);
					sbd.append("\nfeature:").append(feature);
					for (Entry <Object, HashMap <Object, Double>> entry : map2.entrySet()) {
						sbd.append("\n").append(entry.getKey()).append(" : ")
							.append(entry.getValue().toString());
					}
				}
				System.out.println(sbd.toString());
			}
		});

		predictor.lazyPrint(10, "< Prediction >");

		predictor
			.link(
				new EvalBinaryClassBatchOp()
					.setPositiveLabelValueString("p")
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
					.lazyPrintMetrics()
			);

		BatchOperator.execute();

	}

	static void c_4_b() throws Exception {

		AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
		AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

		NaiveBayesTrainBatchOp trainer =
			new NaiveBayesTrainBatchOp()
				.setFeatureCols("odor", "gill_color")
				.setCategoricalCols("odor", "gill_color")
				.setLabelCol(LABEL_COL_NAME);

		NaiveBayesPredictBatchOp predictor = new NaiveBayesPredictBatchOp()
			.setPredictionCol(PREDICTION_COL_NAME)
			.setPredictionDetailCol(PRED_DETAIL_COL_NAME);

		train_data.link(trainer);
		predictor.linkFrom(trainer, test_data);

		trainer.lazyCollectModelInfo(new Consumer <NaiveBayesModelInfo>() {
			@Override
			public void accept(NaiveBayesModelInfo naiveBayesModelInfo) {
				StringBuilder sbd = new StringBuilder();
				for (String feature : new String[] {"odor", "gill_color"}) {
					HashMap <Object, HashMap <Object, Double>> map2 =
						naiveBayesModelInfo.getCategoryFeatureInfo().get(feature);
					sbd.append("\nfeature:").append(feature);
					for (Entry <Object, HashMap <Object, Double>> entry : map2.entrySet()) {
						sbd.append("\n").append(entry.getKey()).append(" : ")
							.append(entry.getValue().toString());
					}
				}
				System.out.println(sbd.toString());
			}
		});

		predictor
			.lazyPrint(10, "< Prediction >")
			.link(
				new EvalBinaryClassBatchOp()
					.setPositiveLabelValueString("p")
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
					.lazyPrintMetrics()
			);

		BatchOperator.execute();
	}


	static void c_5() throws Exception {

		BatchOperator train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
		BatchOperator test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

		for (TreeType treeType : new TreeType[] {TreeType.GINI, TreeType.INFOGAIN, TreeType.INFOGAINRATIO}) {
			BatchOperator <?> model = train_data
				.link(
					new DecisionTreeTrainBatchOp()
						.setTreeType(treeType)
						.setFeatureCols(FEATURE_COL_NAMES)
						.setCategoricalCols(FEATURE_COL_NAMES)
						.setLabelCol(LABEL_COL_NAME)
						.lazyPrintModelInfo("< " + treeType.toString() + " >")
						.lazyCollectModelInfo(new Consumer <DecisionTreeModelInfo>() {
							@Override
							public void accept(DecisionTreeModelInfo decisionTreeModelInfo) {
								try {
									decisionTreeModelInfo.saveTreeAsImage(
										DATA_DIR + "tree_" + treeType.toString() + ".jpg", true);
								} catch (IOException e) {
									e.printStackTrace();
								}
							}
						})
				);

			DecisionTreePredictBatchOp predictor = new DecisionTreePredictBatchOp()
				.setPredictionCol(PREDICTION_COL_NAME)
				.setPredictionDetailCol(PRED_DETAIL_COL_NAME);

			predictor.linkFrom(model, test_data);

			predictor.link(
				new EvalBinaryClassBatchOp()
					.setPositiveLabelValueString("p")
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
					.lazyPrintMetrics("< " + treeType.toString() + " >")
			);
		}

		BatchOperator.execute();

	}

}