本章包括下面各节:
13.1 数据准备
13.1.1 读取MNIST数据文件
13.1.2 稠密向量与稀疏向量
13.1.3 标签值的统计信息
13.2 Softmax算法
13.3 二分类器组合
13.4 多层感知器分类器
13.5 决策树与随机森林
13.6K最近邻算法
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink; import com.alibaba.alink.common.AlinkGlobalConfiguration; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.utils.Stopwatch; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp; import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp; import com.alibaba.alink.operator.batch.source.AkSourceBatchOp; import com.alibaba.alink.operator.batch.source.BaseSourceBatchOp; import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; import com.alibaba.alink.operator.batch.statistics.VectorSummarizerBatchOp; import com.alibaba.alink.params.shared.clustering.HasKMeansDistanceType.DistanceType; import com.alibaba.alink.params.shared.tree.HasIndividualTreeType.TreeType; import com.alibaba.alink.pipeline.classification.*; import com.alibaba.alink.pipeline.dataproc.format.VectorToColumns; import org.apache.commons.lang3.ArrayUtils; import org.apache.flink.table.api.Table; import org.apache.flink.types.Row; import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.util.ArrayList; import java.util.TreeMap; import java.util.zip.GZIPInputStream; public class Chap13 { static final String DATA_DIR = Utils.ROOT_DIR + "mnist" + File.separator; static final String DENSE_TRAIN_FILE = "dense_train.ak"; static final String DENSE_TEST_FILE = "dense_test.ak"; static final String SPARSE_TRAIN_FILE = "sparse_train.ak"; static final String SPARSE_TEST_FILE = "sparse_test.ak"; static final String TABLE_TRAIN_FILE = "table_train.ak"; static final String TABLE_TEST_FILE = "table_test.ak"; static final String VECTOR_COL_NAME = "vec"; static final String LABEL_COL_NAME = "label"; static final String PREDICTION_COL_NAME = "id_cluster"; public static void main(String[] args) throws Exception { AlinkGlobalConfiguration.setPrintProcessInfo(true); BatchOperator.setParallelism(4); c_1(); c_2(); c_3(); c_4(); c_5(); c_6(); } static void c_1() throws Exception { if (!new File(DATA_DIR + SPARSE_TRAIN_FILE).exists()) { new MnistGzFileSourceBatchOp ( DATA_DIR + "train-images-idx3-ubyte.gz", DATA_DIR + "train-labels-idx1-ubyte.gz", true ) .link( new AkSinkBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE) ); BatchOperator.execute(); new MnistGzFileSourceBatchOp ( DATA_DIR + "t10k-images-idx3-ubyte.gz", DATA_DIR + "t10k-labels-idx1-ubyte.gz", true ) .link( new AkSinkBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE) ); BatchOperator.execute(); new MnistGzFileSourceBatchOp ( DATA_DIR + "train-images-idx3-ubyte.gz", DATA_DIR + "train-labels-idx1-ubyte.gz", false ) .link( new AkSinkBatchOp().setFilePath(DATA_DIR + DENSE_TRAIN_FILE) ); BatchOperator.execute(); new MnistGzFileSourceBatchOp ( DATA_DIR + "t10k-images-idx3-ubyte.gz", DATA_DIR + "t10k-labels-idx1-ubyte.gz", false ) .link( new AkSinkBatchOp().setFilePath(DATA_DIR + DENSE_TEST_FILE) ); BatchOperator.execute(); } new AkSourceBatchOp() .setFilePath(DATA_DIR + DENSE_TRAIN_FILE) .lazyPrint(1, "MNIST data") .link( new VectorSummarizerBatchOp() .setSelectedCol(VECTOR_COL_NAME) .lazyPrintVectorSummary() ); new AkSourceBatchOp() .setFilePath(DATA_DIR + SPARSE_TRAIN_FILE) .lazyPrint(1, "MNIST data") .link( new VectorSummarizerBatchOp() .setSelectedCol(VECTOR_COL_NAME) .lazyPrintVectorSummary() ); new AkSourceBatchOp() .setFilePath(DATA_DIR + SPARSE_TRAIN_FILE) .lazyPrintStatistics() .groupBy(LABEL_COL_NAME, LABEL_COL_NAME + ", COUNT(*) AS cnt") .orderBy("cnt", 100) .lazyPrint(-1); BatchOperator.execute(); } static void c_2() throws Exception { AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE); new Softmax() .setVectorCol(VECTOR_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .enableLazyPrintTrainInfo() .enableLazyPrintModelInfo() .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("Softmax") ); BatchOperator.execute(); } static void c_3() throws Exception { AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE); BatchOperator.setParallelism(1); new OneVsRest() .setClassifier( new LogisticRegression() .setVectorCol(VECTOR_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) ) .setNumClass(10) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("OneVsRest - LogisticRegression") ); new OneVsRest() .setClassifier( new LinearSvm() .setVectorCol(VECTOR_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) ) .setNumClass(10) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("OneVsRest - LinearSvm") ); BatchOperator.execute(); } static void c_4() throws Exception { BatchOperator.setParallelism(4); AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE); new MultilayerPerceptronClassifier() .setLayers(new int[]{784, 10}) .setVectorCol(VECTOR_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("MultilayerPerceptronClassifier {784, 10}") ); BatchOperator.execute(); new MultilayerPerceptronClassifier() .setLayers(new int[]{784, 256, 128, 10}) .setVectorCol(VECTOR_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("MultilayerPerceptronClassifier {784, 256, 128, 10}") ); BatchOperator.execute(); } static void c_5() throws Exception { BatchOperator.setParallelism(4); if (!new File(DATA_DIR + TABLE_TRAIN_FILE).exists()) { AkSourceBatchOp train_sparse = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE); AkSourceBatchOp test_sparse = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE); StringBuilder sbd = new StringBuilder(); sbd.append("c_0 double"); for (int i = 1; i < 784; i++) { sbd.append(", c_").append(i).append(" double"); } new VectorToColumns() .setVectorCol(VECTOR_COL_NAME) .setSchemaStr(sbd.toString()) .setReservedCols(LABEL_COL_NAME) .transform(train_sparse) .link( new AkSinkBatchOp().setFilePath(DATA_DIR + TABLE_TRAIN_FILE) ); new VectorToColumns() .setVectorCol(VECTOR_COL_NAME) .setSchemaStr(sbd.toString()) .setReservedCols(LABEL_COL_NAME) .transform(test_sparse) .link( new AkSinkBatchOp().setFilePath(DATA_DIR + TABLE_TEST_FILE) ); BatchOperator.execute(); } AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TABLE_TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TABLE_TEST_FILE); final String[] featureColNames = ArrayUtils.removeElement(train_data.getColNames(), LABEL_COL_NAME); train_data.lazyPrint(5); Stopwatch sw = new Stopwatch(); for (TreeType treeType : new TreeType[]{TreeType.GINI, TreeType.INFOGAIN, TreeType.INFOGAINRATIO}) { sw.reset(); sw.start(); new DecisionTreeClassifier() .setTreeType(treeType) .setFeatureCols(featureColNames) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .enableLazyPrintModelInfo() .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("DecisionTreeClassifier " + treeType.toString()) ); BatchOperator.execute(); sw.stop(); System.out.println(sw.getElapsedTimeSpan()); } for (int numTrees : new int[]{2, 4, 8, 16, 32, 64, 128}) { sw.reset(); sw.start(); new RandomForestClassifier() .setSubsamplingRatio(0.6) .setNumTreesOfInfoGain(numTrees) .setFeatureCols(featureColNames) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .enableLazyPrintModelInfo() .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("RandomForestClassifier : " + numTrees) ); BatchOperator.execute(); sw.stop(); System.out.println(sw.getElapsedTimeSpan()); } } static void c_6() throws Exception { BatchOperator.setParallelism(4); AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE); AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE); new KnnClassifier() .setK(3) .setVectorCol(VECTOR_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("KnnClassifier - 3 - EUCLIDEAN") ); BatchOperator.execute(); new KnnClassifier() .setDistanceType(DistanceType.COSINE) .setK(3) .setVectorCol(VECTOR_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("KnnClassifier - 3 - COSINE") ); BatchOperator.execute(); new KnnClassifier() .setK(7) .setVectorCol(VECTOR_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .fit(train_data) .transform(test_data) .link( new EvalMultiClassBatchOp() .setLabelCol(LABEL_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .lazyPrintMetrics("KnnClassifier - 7 - EUCLIDEAN") ); BatchOperator.execute(); } public static class MnistGzFileSourceBatchOp extends BaseSourceBatchOp<MnistGzFileSourceBatchOp> { private final String imageGzFile; private final String labelGzFile; private final boolean isSparse; public MnistGzFileSourceBatchOp(String imageGzFile, String labelGzFile, boolean isSparse) { super(null, null); this.imageGzFile = imageGzFile; this.labelGzFile = labelGzFile; this.isSparse = isSparse; } @Override protected Table initializeDataSource() { try { ArrayList<Row> rows = new ArrayList<>(); String[] images = getImages(); Integer[] labels = getLabels(); int n = images.length; if (labels.length != n) { throw new RuntimeException("The size of images IS NOT EQUAL WITH the size of labels."); } for (int i = 0; i < n; i++) { rows.add(Row.of(images[i], labels[i])); } return new MemSourceBatchOp(rows, new String[]{"vec", "label"}).getOutputTable(); } catch (Exception ex) { ex.printStackTrace(); throw new RuntimeException(ex.getMessage()); } } private int getInteger(byte[] bytes) { return ((bytes[0] & 0xFF) << 24) + ((bytes[1] & 0xFF) << 16) + ((bytes[2] & 0xFF) << 8) + (bytes[3] & 0xFF); } private Integer[] getLabels() throws IOException { BufferedInputStream bis = new BufferedInputStream( new GZIPInputStream(new FileInputStream(this.labelGzFile))); byte[] bytes = new byte[4]; bis.read(bytes, 0, 4); int magic_number = getInteger(bytes); bis.read(bytes, 0, 4); int record_number = getInteger(bytes); Integer[] labels = new Integer[record_number]; for (int i = 0; i < record_number; i++) { labels[i] = bis.read(); } bis.close(); return labels; } private String[] getImages() throws IOException { BufferedInputStream bis = new BufferedInputStream( new GZIPInputStream(new FileInputStream(this.imageGzFile))); byte[] bytes = new byte[4]; bis.read(bytes, 0, 4); int magic_number = getInteger(bytes); bis.read(bytes, 0, 4); int record_number = getInteger(bytes); bis.read(bytes, 0, 4); int xPixels = getInteger(bytes); bis.read(bytes, 0, 4); int yPixels = getInteger(bytes); int nPixels = xPixels * yPixels; String[] images = new String[record_number]; if (isSparse) { TreeMap<Integer, Double> pixels = new TreeMap<>(); int val; for (int i = 0; i < record_number; i++) { pixels.clear(); for (int j = 0; j < nPixels; j++) { val = bis.read(); if (0 != val) { pixels.put(j, (double) val); } } images[i] = VectorUtil.serialize(new SparseVector(nPixels, pixels)); } } else { double[] image = new double[nPixels]; for (int i = 0; i < record_number; i++) { for (int j = 0; j < nPixels; j++) { image[j] = bis.read(); } images[i] = VectorUtil.serialize(new DenseVector(image)); } } bis.close(); return images; } } }