本章包括下面各节:
13.1 数据准备
13.1.1 读取MNIST数据文件
13.1.2 稠密向量与稀疏向量
13.1.3 标签值的统计信息
13.2 Softmax算法
13.3 二分类器组合
13.4 多层感知器分类器
13.5 决策树与随机森林
13.6K最近邻算法
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink;
import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.Stopwatch;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp;
import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.BaseSourceBatchOp;
import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
import com.alibaba.alink.operator.batch.statistics.VectorSummarizerBatchOp;
import com.alibaba.alink.params.shared.clustering.HasKMeansDistanceType.DistanceType;
import com.alibaba.alink.params.shared.tree.HasIndividualTreeType.TreeType;
import com.alibaba.alink.pipeline.classification.*;
import com.alibaba.alink.pipeline.dataproc.format.VectorToColumns;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.TreeMap;
import java.util.zip.GZIPInputStream;
public class Chap13 {
static final String DATA_DIR = Utils.ROOT_DIR + "mnist" + File.separator;
static final String DENSE_TRAIN_FILE = "dense_train.ak";
static final String DENSE_TEST_FILE = "dense_test.ak";
static final String SPARSE_TRAIN_FILE = "sparse_train.ak";
static final String SPARSE_TEST_FILE = "sparse_test.ak";
static final String TABLE_TRAIN_FILE = "table_train.ak";
static final String TABLE_TEST_FILE = "table_test.ak";
static final String VECTOR_COL_NAME = "vec";
static final String LABEL_COL_NAME = "label";
static final String PREDICTION_COL_NAME = "id_cluster";
public static void main(String[] args) throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
BatchOperator.setParallelism(4);
c_1();
c_2();
c_3();
c_4();
c_5();
c_6();
}
static void c_1() throws Exception {
if (!new File(DATA_DIR + SPARSE_TRAIN_FILE).exists()) {
new MnistGzFileSourceBatchOp
(
DATA_DIR + "train-images-idx3-ubyte.gz",
DATA_DIR + "train-labels-idx1-ubyte.gz",
true
)
.link(
new AkSinkBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE)
);
BatchOperator.execute();
new MnistGzFileSourceBatchOp
(
DATA_DIR + "t10k-images-idx3-ubyte.gz",
DATA_DIR + "t10k-labels-idx1-ubyte.gz",
true
)
.link(
new AkSinkBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE)
);
BatchOperator.execute();
new MnistGzFileSourceBatchOp
(
DATA_DIR + "train-images-idx3-ubyte.gz",
DATA_DIR + "train-labels-idx1-ubyte.gz",
false
)
.link(
new AkSinkBatchOp().setFilePath(DATA_DIR + DENSE_TRAIN_FILE)
);
BatchOperator.execute();
new MnistGzFileSourceBatchOp
(
DATA_DIR + "t10k-images-idx3-ubyte.gz",
DATA_DIR + "t10k-labels-idx1-ubyte.gz",
false
)
.link(
new AkSinkBatchOp().setFilePath(DATA_DIR + DENSE_TEST_FILE)
);
BatchOperator.execute();
}
new AkSourceBatchOp()
.setFilePath(DATA_DIR + DENSE_TRAIN_FILE)
.lazyPrint(1, "MNIST data")
.link(
new VectorSummarizerBatchOp()
.setSelectedCol(VECTOR_COL_NAME)
.lazyPrintVectorSummary()
);
new AkSourceBatchOp()
.setFilePath(DATA_DIR + SPARSE_TRAIN_FILE)
.lazyPrint(1, "MNIST data")
.link(
new VectorSummarizerBatchOp()
.setSelectedCol(VECTOR_COL_NAME)
.lazyPrintVectorSummary()
);
new AkSourceBatchOp()
.setFilePath(DATA_DIR + SPARSE_TRAIN_FILE)
.lazyPrintStatistics()
.groupBy(LABEL_COL_NAME, LABEL_COL_NAME + ", COUNT(*) AS cnt")
.orderBy("cnt", 100)
.lazyPrint(-1);
BatchOperator.execute();
}
static void c_2() throws Exception {
AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE);
AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE);
new Softmax()
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.enableLazyPrintTrainInfo()
.enableLazyPrintModelInfo()
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("Softmax")
);
BatchOperator.execute();
}
static void c_3() throws Exception {
AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE);
AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE);
BatchOperator.setParallelism(1);
new OneVsRest()
.setClassifier(
new LogisticRegression()
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
)
.setNumClass(10)
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("OneVsRest - LogisticRegression")
);
new OneVsRest()
.setClassifier(
new LinearSvm()
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
)
.setNumClass(10)
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("OneVsRest - LinearSvm")
);
BatchOperator.execute();
}
static void c_4() throws Exception {
BatchOperator.setParallelism(4);
AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE);
AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE);
new MultilayerPerceptronClassifier()
.setLayers(new int[]{784, 10})
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("MultilayerPerceptronClassifier {784, 10}")
);
BatchOperator.execute();
new MultilayerPerceptronClassifier()
.setLayers(new int[]{784, 256, 128, 10})
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("MultilayerPerceptronClassifier {784, 256, 128, 10}")
);
BatchOperator.execute();
}
static void c_5() throws Exception {
BatchOperator.setParallelism(4);
if (!new File(DATA_DIR + TABLE_TRAIN_FILE).exists()) {
AkSourceBatchOp train_sparse = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE);
AkSourceBatchOp test_sparse = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE);
StringBuilder sbd = new StringBuilder();
sbd.append("c_0 double");
for (int i = 1; i < 784; i++) {
sbd.append(", c_").append(i).append(" double");
}
new VectorToColumns()
.setVectorCol(VECTOR_COL_NAME)
.setSchemaStr(sbd.toString())
.setReservedCols(LABEL_COL_NAME)
.transform(train_sparse)
.link(
new AkSinkBatchOp().setFilePath(DATA_DIR + TABLE_TRAIN_FILE)
);
new VectorToColumns()
.setVectorCol(VECTOR_COL_NAME)
.setSchemaStr(sbd.toString())
.setReservedCols(LABEL_COL_NAME)
.transform(test_sparse)
.link(
new AkSinkBatchOp().setFilePath(DATA_DIR + TABLE_TEST_FILE)
);
BatchOperator.execute();
}
AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TABLE_TRAIN_FILE);
AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TABLE_TEST_FILE);
final String[] featureColNames = ArrayUtils.removeElement(train_data.getColNames(), LABEL_COL_NAME);
train_data.lazyPrint(5);
Stopwatch sw = new Stopwatch();
for (TreeType treeType : new TreeType[]{TreeType.GINI, TreeType.INFOGAIN, TreeType.INFOGAINRATIO}) {
sw.reset();
sw.start();
new DecisionTreeClassifier()
.setTreeType(treeType)
.setFeatureCols(featureColNames)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.enableLazyPrintModelInfo()
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("DecisionTreeClassifier " + treeType.toString())
);
BatchOperator.execute();
sw.stop();
System.out.println(sw.getElapsedTimeSpan());
}
for (int numTrees : new int[]{2, 4, 8, 16, 32, 64, 128}) {
sw.reset();
sw.start();
new RandomForestClassifier()
.setSubsamplingRatio(0.6)
.setNumTreesOfInfoGain(numTrees)
.setFeatureCols(featureColNames)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.enableLazyPrintModelInfo()
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("RandomForestClassifier : " + numTrees)
);
BatchOperator.execute();
sw.stop();
System.out.println(sw.getElapsedTimeSpan());
}
}
static void c_6() throws Exception {
BatchOperator.setParallelism(4);
AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE);
AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE);
new KnnClassifier()
.setK(3)
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("KnnClassifier - 3 - EUCLIDEAN")
);
BatchOperator.execute();
new KnnClassifier()
.setDistanceType(DistanceType.COSINE)
.setK(3)
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("KnnClassifier - 3 - COSINE")
);
BatchOperator.execute();
new KnnClassifier()
.setK(7)
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.fit(train_data)
.transform(test_data)
.link(
new EvalMultiClassBatchOp()
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.lazyPrintMetrics("KnnClassifier - 7 - EUCLIDEAN")
);
BatchOperator.execute();
}
public static class MnistGzFileSourceBatchOp extends BaseSourceBatchOp<MnistGzFileSourceBatchOp> {
private final String imageGzFile;
private final String labelGzFile;
private final boolean isSparse;
public MnistGzFileSourceBatchOp(String imageGzFile, String labelGzFile, boolean isSparse) {
super(null, null);
this.imageGzFile = imageGzFile;
this.labelGzFile = labelGzFile;
this.isSparse = isSparse;
}
@Override
protected Table initializeDataSource() {
try {
ArrayList<Row> rows = new ArrayList<>();
String[] images = getImages();
Integer[] labels = getLabels();
int n = images.length;
if (labels.length != n) {
throw new RuntimeException("The size of images IS NOT EQUAL WITH the size of labels.");
}
for (int i = 0; i < n; i++) {
rows.add(Row.of(images[i], labels[i]));
}
return new MemSourceBatchOp(rows, new String[]{"vec", "label"}).getOutputTable();
} catch (Exception ex) {
ex.printStackTrace();
throw new RuntimeException(ex.getMessage());
}
}
private int getInteger(byte[] bytes) {
return ((bytes[0] & 0xFF) << 24) + ((bytes[1] & 0xFF) << 16) +
((bytes[2] & 0xFF) << 8) + (bytes[3] & 0xFF);
}
private Integer[] getLabels() throws IOException {
BufferedInputStream bis = new BufferedInputStream(
new GZIPInputStream(new FileInputStream(this.labelGzFile)));
byte[] bytes = new byte[4];
bis.read(bytes, 0, 4);
int magic_number = getInteger(bytes);
bis.read(bytes, 0, 4);
int record_number = getInteger(bytes);
Integer[] labels = new Integer[record_number];
for (int i = 0; i < record_number; i++) {
labels[i] = bis.read();
}
bis.close();
return labels;
}
private String[] getImages() throws IOException {
BufferedInputStream bis = new BufferedInputStream(
new GZIPInputStream(new FileInputStream(this.imageGzFile)));
byte[] bytes = new byte[4];
bis.read(bytes, 0, 4);
int magic_number = getInteger(bytes);
bis.read(bytes, 0, 4);
int record_number = getInteger(bytes);
bis.read(bytes, 0, 4);
int xPixels = getInteger(bytes);
bis.read(bytes, 0, 4);
int yPixels = getInteger(bytes);
int nPixels = xPixels * yPixels;
String[] images = new String[record_number];
if (isSparse) {
TreeMap<Integer, Double> pixels = new TreeMap<>();
int val;
for (int i = 0; i < record_number; i++) {
pixels.clear();
for (int j = 0; j < nPixels; j++) {
val = bis.read();
if (0 != val) {
pixels.put(j, (double) val);
}
}
images[i] = VectorUtil.serialize(new SparseVector(nPixels, pixels));
}
} else {
double[] image = new double[nPixels];
for (int i = 0; i < record_number; i++) {
for (int j = 0; j < nPixels; j++) {
image[j] = bis.read();
}
images[i] = VectorUtil.serialize(new DenseVector(image));
}
}
bis.close();
return images;
}
}
}