Alink教程(Java版)

第12章 从二分类到多分类

本章包括下面各节:
12.1 多分类模型评估方法
12.1.1 综合指标
12.1.2 关于每个标签值的二分类指标
12.1.3 Micro、Macro、Weighted计算的指标
12.2 数据探索
12.3 使用朴素贝叶斯进行多分类
12.4 二分类器组合
12.5 Softmax算法
12.6 多层感知器分类器

详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。

package com.alibaba.alink;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.NaiveBayesPredictBatchOp;
import com.alibaba.alink.operator.batch.classification.NaiveBayesTrainBatchOp;
import com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.batch.statistics.CorrelationBatchOp;
import com.alibaba.alink.pipeline.classification.GbdtClassifier;
import com.alibaba.alink.pipeline.classification.LinearSvm;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.classification.MultilayerPerceptronClassifier;
import com.alibaba.alink.pipeline.classification.OneVsRest;
import com.alibaba.alink.pipeline.classification.Softmax;

import java.io.File;

public class Chap12 {

	static final String DATA_DIR = Utils.ROOT_DIR + "iris" + File.separator;

	static final String ORIGIN_FILE = "iris.data";

	static final String TRAIN_FILE = "train.ak";
	static final String TEST_FILE = "test.ak";

	static final String SCHEMA_STRING
		= "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";

	static final String[] FEATURE_COL_NAMES
		= new String[] {"sepal_length", "sepal_width", "petal_length", "petal_width"};

	static final String LABEL_COL_NAME = "category";

	static final String PREDICTION_COL_NAME = "pred";
	static final String PRED_DETAIL_COL_NAME = "pred_info";

	public static void main(String[] args) throws Exception {

		BatchOperator.setParallelism(1);

		c_2();

		c_3();

		c_4();

		c_5();

		c_6();

	}

	static void c_2() throws Exception {
		CsvSourceBatchOp source =
			new CsvSourceBatchOp()
				.setFilePath(DATA_DIR + ORIGIN_FILE)
				.setSchemaStr(SCHEMA_STRING);

		source
			.lazyPrint(5, "origin file")
			.lazyPrintStatistics("stat of origin file")
			.link(
				new CorrelationBatchOp()
					.setSelectedCols(FEATURE_COL_NAMES)
					.lazyPrintCorrelation()
			);

		source.groupBy(LABEL_COL_NAME, LABEL_COL_NAME + ", COUNT(*) AS cnt").lazyPrint(-1);

		BatchOperator.execute();

		Utils.splitTrainTestIfNotExist(source, DATA_DIR + TRAIN_FILE, DATA_DIR + TEST_FILE, 0.9);

	}

	static void c_3() throws Exception {

		AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
		AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

		NaiveBayesTrainBatchOp trainer =
			new NaiveBayesTrainBatchOp()
				.setFeatureCols(FEATURE_COL_NAMES)
				.setLabelCol(LABEL_COL_NAME);

		NaiveBayesPredictBatchOp predictor =
			new NaiveBayesPredictBatchOp()
				.setPredictionCol(PREDICTION_COL_NAME)
				.setPredictionDetailCol(PRED_DETAIL_COL_NAME);

		train_data.link(trainer);

		predictor.linkFrom(trainer, test_data);

		trainer.lazyPrintModelInfo();

		predictor.lazyPrint(1, "< Prediction >");

		predictor
			.link(
				new EvalMultiClassBatchOp()
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
					.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
					.lazyPrintMetrics("NaiveBayes")
			);

		BatchOperator.execute();

	}

	static void c_4() throws Exception {

		AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
		AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

		new OneVsRest()
			.setClassifier(
				new LogisticRegression()
					.setFeatureCols(FEATURE_COL_NAMES)
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
			)
			.setNumClass(3)
			.fit(train_data)
			.transform(test_data)
			.link(
				new EvalMultiClassBatchOp()
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
					.lazyPrintMetrics("OneVsRest_LogisticRegression")
			);

		new OneVsRest()
			.setClassifier(
				new GbdtClassifier()
					.setFeatureCols(FEATURE_COL_NAMES)
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
			)
			.setNumClass(3)
			.fit(train_data)
			.transform(test_data)
			.link(
				new EvalMultiClassBatchOp()
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
					.lazyPrintMetrics("OneVsRest_GBDT")
			);

		new OneVsRest()
			.setClassifier(
				new LinearSvm()
					.setFeatureCols(FEATURE_COL_NAMES)
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
			)
			.setNumClass(3)
			.fit(train_data)
			.transform(test_data)
			.link(
				new EvalMultiClassBatchOp()
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
					.lazyPrintMetrics("OneVsRest_LinearSvm")
			);

		BatchOperator.execute();

	}

	static void c_5() throws Exception {

		AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
		AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

		new Softmax()
			.setFeatureCols(FEATURE_COL_NAMES)
			.setLabelCol(LABEL_COL_NAME)
			.setPredictionCol(PREDICTION_COL_NAME)
			.enableLazyPrintTrainInfo()
			.enableLazyPrintModelInfo()
			.fit(train_data)
			.transform(test_data)
			.link(
				new EvalMultiClassBatchOp()
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
					.lazyPrintMetrics("Softmax")
			);

		BatchOperator.execute();

	}

	static void c_6() throws Exception {

		AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
		AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

		new MultilayerPerceptronClassifier()
			.setLayers(new int[] {4, 12, 3})
			.setFeatureCols(FEATURE_COL_NAMES)
			.setLabelCol(LABEL_COL_NAME)
			.setPredictionCol(PREDICTION_COL_NAME)
			.fit(train_data)
			.transform(test_data)
			.link(
				new EvalMultiClassBatchOp()
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
					.lazyPrintMetrics("MultilayerPerceptronClassifier [4, 12, 3]")
			);

		new MultilayerPerceptronClassifier()
			.setLayers(new int[] {4, 3})
			.setFeatureCols(FEATURE_COL_NAMES)
			.setLabelCol(LABEL_COL_NAME)
			.setPredictionCol(PREDICTION_COL_NAME)
			.fit(train_data)
			.transform(test_data)
			.link(
				new EvalMultiClassBatchOp()
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionCol(PREDICTION_COL_NAME)
					.lazyPrintMetrics("MultilayerPerceptronClassifier [4, 3]")
			);

		BatchOperator.execute();

	}

}