本章包括下面各节:
12.1 多分类模型评估方法
12.1.1 综合指标
12.1.2 关于每个标签值的二分类指标
12.1.3 Micro、Macro、Weighted计算的指标
12.2 数据探索
12.3 使用朴素贝叶斯进行多分类
12.4 二分类器组合
12.5 Softmax算法
12.6 多层感知器分类器
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.classification.NaiveBayesPredictBatchOp; import com.alibaba.alink.operator.batch.classification.NaiveBayesTrainBatchOp; import com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp; import com.alibaba.alink.operator.batch.source.AkSourceBatchOp; import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; import com.alibaba.alink.operator.batch.statistics.CorrelationBatchOp; import com.alibaba.alink.pipeline.classification.GbdtClassifier; import com.alibaba.alink.pipeline.classification.LinearSvm; import com.alibaba.alink.pipeline.classification.LogisticRegression; import com.alibaba.alink.pipeline.classification.MultilayerPerceptronClassifier; import com.alibaba.alink.pipeline.classification.OneVsRest; import com.alibaba.alink.pipeline.classification.Softmax; import java.io.File; public class Chap12 { static final String DATA_DIR = Utils.ROOT_DIR + "iris" + File.separator; static final String ORIGIN_FILE = "iris.data"; static final String TRAIN_FILE = "train.ak"; static final String TEST_FILE = "test.ak"; static final String SCHEMA_STRING = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string"; static final String[] FEATURE_COL_NAMES = new String[] {"sepal_length", "sepal_width", "petal_length", "petal_width"}; static final String LABEL_COL_NAME = "category"; static final String PREDICTION_COL_NAME = "pred"; static final String PRED_DETAIL_COL_NAME = "pred_info"; public static void main(String[] args) throws Exception { BatchOperator.setParallelism(1); c_2(); c_3(); c_4(); c_5(); c_6(); } static void c_2() throws Exception { CsvSourceBatchOp source = new CsvSourceBatchOp() .setFilePath(DATA_DIR + ORIGIN_FILE) .setSchemaStr(SCHEMA_STRING); source .lazyPrint(5, "origin file") .lazyPrintStatistics("stat of origin file") .link( new CorrelationBatchOp() .setSelectedCols(FEATURE_COL_NAMES) .lazyPrintCorrelation() ); source.groupBy(LABEL_COL_NAME, LABEL_COL_NAME + ", COUNT(*) AS cnt").lazyPrint(-1); BatchOperator.execute(); Utils.splitTrainTestIfNotExist(source, DATA_DIR + TRAIN_FILE, DATA_DIR + TEST_FILE, 0.9); } static void c_3() throws Exception { AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); NaiveBayesTrainBatchOp trainer = new NaiveBayesTrainBatchOp() .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME); NaiveBayesPredictBatchOp predictor = new NaiveBayesPredictBatchOp() .setPredictionCol(PREDICTION_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME); train_data.link(trainer); predictor.linkFrom(trainer, test_data); trainer.lazyPrintModelInfo(); predictor.lazyPrint(1, "< Prediction >"); predictor .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) .lazyPrintMetrics("NaiveBayes") ); BatchOperator.execute(); } static void c_4() throws Exception { AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); new OneVsRest() .setClassifier( new LogisticRegression() .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) ) .setNumClass(3) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("OneVsRest_LogisticRegression") ); new OneVsRest() .setClassifier( new GbdtClassifier() .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) ) .setNumClass(3) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("OneVsRest_GBDT") ); new OneVsRest() .setClassifier( new LinearSvm() .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) ) .setNumClass(3) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("OneVsRest_LinearSvm") ); BatchOperator.execute(); } static void c_5() throws Exception { AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); new Softmax() .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .enableLazyPrintTrainInfo() .enableLazyPrintModelInfo() .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("Softmax") ); BatchOperator.execute(); } static void c_6() throws Exception { AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE); new MultilayerPerceptronClassifier() .setLayers(new int[] {4, 12, 3}) .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("MultilayerPerceptronClassifier [4, 12, 3]") ); new MultilayerPerceptronClassifier() .setLayers(new int[] {4, 3}) .setFeatureCols(FEATURE_COL_NAMES) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("MultilayerPerceptronClassifier [4, 3]") ); BatchOperator.execute(); } }