Alink教程(Java版)

第20章 超参数搜索

本章包括下面各节:
20.1 示例一:尝试正则系数
20.2 示例二:搜索GBDT超参数
20.3 示例三:最佳聚类个数

详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。

package com.alibaba.alink;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.utils.Stopwatch;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp;
import com.alibaba.alink.operator.batch.evaluation.EvalClusterBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.common.evaluation.TuningBinaryClassMetric;
import com.alibaba.alink.operator.common.evaluation.TuningClusterMetric;
import com.alibaba.alink.params.shared.clustering.HasKMeansDistanceType.DistanceType;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.classification.GbdtClassifier;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.clustering.KMeans;
import com.alibaba.alink.pipeline.tuning.BinaryClassificationTuningEvaluator;
import com.alibaba.alink.pipeline.tuning.ClusterTuningEvaluator;
import com.alibaba.alink.pipeline.tuning.GridSearchCV;
import com.alibaba.alink.pipeline.tuning.GridSearchCVModel;
import com.alibaba.alink.pipeline.tuning.ParamDist;
import com.alibaba.alink.pipeline.tuning.ParamGrid;
import com.alibaba.alink.pipeline.tuning.RandomSearchTVSplit;
import com.alibaba.alink.pipeline.tuning.RandomSearchTVSplitModel;
import com.alibaba.alink.pipeline.tuning.ValueDist;
import org.apache.commons.lang3.ArrayUtils;

public class Chap20 {

	public static void main(String[] args) throws Exception {

		BatchOperator.setParallelism(1);

		c_1();

		c_2();

		c_3();

	}

	static void c_1() throws Exception {
		BatchOperator <?> train_data =
			new AkSourceBatchOp()
				.setFilePath(Chap10.DATA_DIR + Chap10.TRAIN_FILE)
				.select(Chap10.CLAUSE_CREATE_FEATURES);

		BatchOperator <?> test_data =
			new AkSourceBatchOp()
				.setFilePath(Chap10.DATA_DIR + Chap10.TEST_FILE)
				.select(Chap10.CLAUSE_CREATE_FEATURES);

		final String[] new_features =
			ArrayUtils.removeElement(train_data.getColNames(), Chap10.LABEL_COL_NAME);

		LogisticRegression lr = new LogisticRegression()
			.setFeatureCols(new_features)
			.setLabelCol(Chap10.LABEL_COL_NAME)
			.setPredictionCol(Chap10.PREDICTION_COL_NAME)
			.setPredictionDetailCol(Chap10.PRED_DETAIL_COL_NAME);

		Pipeline pipeline = new Pipeline().add(lr);

		GridSearchCV gridSearch = new GridSearchCV()
			.setNumFolds(5)
			.setEstimator(pipeline)
			.setParamGrid(
				new ParamGrid()
					.addGrid(lr, LogisticRegression.L_1,
						new Double[] {0.0000001, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0, 10.0})
			)
			.setTuningEvaluator(
				new BinaryClassificationTuningEvaluator()
					.setLabelCol(Chap10.LABEL_COL_NAME)
					.setPredictionDetailCol(Chap10.PRED_DETAIL_COL_NAME)
					.setTuningBinaryClassMetric(TuningBinaryClassMetric.AUC)
			)
			.enableLazyPrintTrainInfo();

		GridSearchCVModel bestModel = gridSearch.fit(train_data);

		bestModel.transform(test_data)
			.link(
				new EvalBinaryClassBatchOp()
					.setPositiveLabelValueString("2")
					.setLabelCol(Chap10.LABEL_COL_NAME)
					.setPredictionDetailCol(Chap10.PRED_DETAIL_COL_NAME)
					.lazyPrintMetrics("GridSearchCV")
			);

		BatchOperator.execute();

	}

	static void c_2() throws Exception {
		Stopwatch sw = new Stopwatch();
		sw.start();

		AlinkGlobalConfiguration.setPrintProcessInfo(true);

		BatchOperator train_sample =
			new AkSourceBatchOp()
				.setFilePath(Chap11.DATA_DIR + Chap11.TRAIN_SAMPLE_FILE);

		BatchOperator test_data =
			new AkSourceBatchOp()
				.setFilePath(Chap11.DATA_DIR + Chap11.TEST_FILE);

		final String[] featuresColNames =
			ArrayUtils.removeElement(train_sample.getColNames(), Chap11.LABEL_COL_NAME);

		GbdtClassifier gbdt = new GbdtClassifier()
			.setFeatureCols(featuresColNames)
			.setLabelCol(Chap11.LABEL_COL_NAME)
			.setPredictionCol(Chap11.PREDICTION_COL_NAME)
			.setPredictionDetailCol(Chap11.PRED_DETAIL_COL_NAME);

		RandomSearchTVSplit randomSearch = new RandomSearchTVSplit()
			.setNumIter(20)
			.setTrainRatio(0.8)
			.setEstimator(gbdt)
			.setParamDist(
				new ParamDist()
					.addDist(gbdt, GbdtClassifier.NUM_TREES, ValueDist.randArray(new Integer[] {50, 100}))
					.addDist(gbdt, GbdtClassifier.MAX_DEPTH, ValueDist.randInteger(4, 10))
					.addDist(gbdt, GbdtClassifier.MAX_BINS, ValueDist.randArray(new Integer[] {64, 128, 256, 512}))
					.addDist(gbdt, GbdtClassifier.LEARNING_RATE, ValueDist.randArray(new Double[] {0.3, 0.1, 0.01}))
			)
			.setTuningEvaluator(
				new BinaryClassificationTuningEvaluator()
					.setLabelCol(Chap11.LABEL_COL_NAME)
					.setPredictionDetailCol(Chap11.PRED_DETAIL_COL_NAME)
					.setTuningBinaryClassMetric(TuningBinaryClassMetric.F1)
			)
			.enableLazyPrintTrainInfo();

		RandomSearchTVSplitModel bestModel = randomSearch.fit(train_sample);

		bestModel.transform(test_data)
			.link(
				new EvalBinaryClassBatchOp()
					.setPositiveLabelValueString("1")
					.setLabelCol(Chap11.LABEL_COL_NAME)
					.setPredictionDetailCol(Chap11.PRED_DETAIL_COL_NAME)
					.lazyPrintMetrics()
			);

		BatchOperator.execute();

		sw.stop();
		System.out.println(sw.getElapsedTimeSpan());
	}

	static void c_3() throws Exception {
		Stopwatch sw = new Stopwatch();
		sw.start();

		AlinkGlobalConfiguration.setPrintProcessInfo(true);

		AkSourceBatchOp source =
			new AkSourceBatchOp()
				.setFilePath(Chap17.DATA_DIR + Chap17.VECTOR_FILE);

		KMeans kmeans = new KMeans()
			.setVectorCol(Chap17.VECTOR_COL_NAME)
			.setPredictionCol(Chap17.PREDICTION_COL_NAME);

		GridSearchCV cv = new GridSearchCV()
			.setNumFolds(4)
			.setEstimator(kmeans)
			.setParamGrid(
				new ParamGrid()
					.addGrid(kmeans, KMeans.K, new Integer[] {2, 3, 4, 5, 6})
					.addGrid(kmeans, KMeans.DISTANCE_TYPE,
						new DistanceType[] {DistanceType.EUCLIDEAN, DistanceType.COSINE})
			)
			.setTuningEvaluator(
				new ClusterTuningEvaluator()
					.setVectorCol(Chap17.VECTOR_COL_NAME)
					.setPredictionCol(Chap17.PREDICTION_COL_NAME)
					.setLabelCol(Chap17.LABEL_COL_NAME)
					.setTuningClusterMetric(TuningClusterMetric.RI)
			)
			.enableLazyPrintTrainInfo();

		GridSearchCVModel bestModel = cv.fit(source);

		bestModel
			.transform(source)
			.link(
				new EvalClusterBatchOp()
					.setLabelCol(Chap17.LABEL_COL_NAME)
					.setVectorCol(Chap17.VECTOR_COL_NAME)
					.setPredictionCol(Chap17.PREDICTION_COL_NAME)
					.lazyPrintMetrics()
			);

		BatchOperator.execute();

		sw.stop();
		System.out.println(sw.getElapsedTimeSpan());
	}

}