Alink教程(Java版)

第14章 在线学习

本章包括下面各节:
14.1 整体流程
14.2 数据准备
14.3 特征工程
14.4 特征工程处理数据
14.5 在线训练
14.6 模型过滤

详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。

package com.alibaba.alink;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;
import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.batch.source.TextSourceBatchOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp;
import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp;
import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlModelFilterStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlPredictStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp;
import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.dataproc.StandardScaler;
import com.alibaba.alink.pipeline.feature.FeatureHasher;
import org.apache.commons.lang3.ArrayUtils;

import java.io.File;

public class Chap14 {

	private static final String DATA_DIR = Utils.ROOT_DIR + "ctr_avazu" + File.separator;

	static final String SCHEMA_STRING
		= "id string, click string, dt string, C1 string, banner_pos int, site_id string, site_domain string, "
		+ "site_category string, app_id string, app_domain string, app_category string, device_id string, "
		+ "device_ip string, device_model string, device_type string, device_conn_type string, C14 int, C15 int, "
		+ "C16 int, C17 int, C18 int, C19 int, C20 int, C21 int";

	static final String[] CATEGORY_COL_NAMES = new String[] {
		"C1", "banner_pos", "site_category", "app_domain",
		"app_category", "device_type", "device_conn_type",
		"site_id", "site_domain", "device_id", "device_model"};

	static final String[] NUMERICAL_COL_NAMES = new String[] {
		"C14", "C15", "C16", "C17", "C18", "C19", "C20", "C21"};

	static final String FEATURE_MODEL_FILE = "feature_model.ak";
	static final String INIT_MODEL_FILE = "init_model.ak";

	static final String LABEL_COL_NAME = "click";
	static final String VEC_COL_NAME = "vec";
	static final String PREDICTION_COL_NAME = "pred";
	static final String PRED_DETAIL_COL_NAME = "pred_info";

	static final int NUM_HASH_FEATURES = 30000;

	public static void main(String[] args) throws Exception {

		c_2();

		c_3();

		c_4();

		c_5();

		c_6();

	}

	static void c_2() throws Exception {

		new TextSourceBatchOp()
			.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv")
			.firstN(10)
			.print();

		CsvSourceBatchOp trainBatchData = new CsvSourceBatchOp()
			.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv")
			.setSchemaStr(SCHEMA_STRING);

		trainBatchData.firstN(10).print();

	}

	static void c_3() throws Exception {
		CsvSourceBatchOp trainBatchData = new CsvSourceBatchOp()
			.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv")
			.setSchemaStr(SCHEMA_STRING);

		// setup feature enginerring pipeline
		Pipeline feature_pipeline = new Pipeline()
			.add(
				new StandardScaler()
					.setSelectedCols(NUMERICAL_COL_NAMES)
			)
			.add(
				new FeatureHasher()
					.setSelectedCols(ArrayUtils.addAll(CATEGORY_COL_NAMES, NUMERICAL_COL_NAMES))
					.setCategoricalCols(CATEGORY_COL_NAMES)
					.setOutputCol(VEC_COL_NAME)
					.setNumFeatures(NUM_HASH_FEATURES)
			);

		if (!new File(DATA_DIR + FEATURE_MODEL_FILE).exists()) {
			// fit and save feature pipeline model
			feature_pipeline
				.fit(trainBatchData)
				.save(DATA_DIR + FEATURE_MODEL_FILE);
			BatchOperator.execute();
		}
	}

	static void c_4() throws Exception {
		// load pipeline model
		PipelineModel feature_pipelineModel = PipelineModel.load(DATA_DIR + FEATURE_MODEL_FILE);

		// prepare stream train data
		CsvSourceStreamOp data = new CsvSourceStreamOp()
			.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-ctr-train-8M.csv")
			.setSchemaStr(SCHEMA_STRING);

		if (!new File(DATA_DIR + INIT_MODEL_FILE).exists()) {
			CsvSourceBatchOp trainBatchData = new CsvSourceBatchOp()
				.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv")
				.setSchemaStr(SCHEMA_STRING);

			// train initial batch model
			LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp()
				.setVectorCol(VEC_COL_NAME)
				.setLabelCol(LABEL_COL_NAME)
				.setWithIntercept(true)
				.setMaxIter(10);

			feature_pipelineModel
				.transform(trainBatchData)
				.link(lr)
				.link(
					new AkSinkBatchOp().setFilePath(DATA_DIR + INIT_MODEL_FILE)
				);
			BatchOperator.execute();
		}

	}

	static void c_5() throws Exception {

		// load pipeline model
		PipelineModel feature_pipelineModel = PipelineModel.load(DATA_DIR + FEATURE_MODEL_FILE);

		BatchOperator initModel = new AkSourceBatchOp().setFilePath(DATA_DIR + INIT_MODEL_FILE);

		// prepare stream train data
		CsvSourceStreamOp data = new CsvSourceStreamOp()
			.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-ctr-train-8M.csv")
			.setSchemaStr(SCHEMA_STRING)
			.setIgnoreFirstLine(true);

		// split stream to train and eval data
		SplitStreamOp spliter = new SplitStreamOp().setFraction(0.5).linkFrom(data);
		StreamOperator train_stream_data = feature_pipelineModel.transform(spliter);
		StreamOperator test_stream_data = feature_pipelineModel.transform(spliter.getSideOutput(0));

		// ftrl train
		FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)
			.setVectorCol(VEC_COL_NAME)
			.setLabelCol(LABEL_COL_NAME)
			.setWithIntercept(true)
			.setAlpha(0.1)
			.setBeta(0.1)
			.setL1(0.01)
			.setL2(0.01)
			.setTimeInterval(10)
			.setVectorSize(NUM_HASH_FEATURES)
			.linkFrom(train_stream_data);

		// ftrl predict
		FtrlPredictStreamOp predResult = new FtrlPredictStreamOp(initModel)
			.setVectorCol(VEC_COL_NAME)
			.setPredictionCol(PREDICTION_COL_NAME)
			.setReservedCols(new String[] {LABEL_COL_NAME})
			.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
			.linkFrom(model, test_stream_data);

		predResult
			.sample(0.0001)
			.select("'Pred Sample' AS out_type, *")
			.print();

		// ftrl eval
		predResult
			.link(
				new EvalBinaryClassStreamOp()
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
					.setTimeInterval(10)
			)
			.link(
				new JsonValueStreamOp()
					.setSelectedCol("Data")
					.setReservedCols(new String[] {"Statistics"})
					.setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"})
					.setJsonPath(new String[] {"$.Accuracy", "$.AUC", "$.ConfusionMatrix"})
			)
			.select("'Eval Metric' AS out_type, *")
			.print();

		StreamOperator.execute();

	}

	static void c_6() throws Exception {

		// prepare stream train data
		CsvSourceStreamOp data = new CsvSourceStreamOp()
			.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-ctr-train-8M.csv")
			.setSchemaStr(SCHEMA_STRING)
			.setIgnoreFirstLine(true);

		// load pipeline model
		PipelineModel feature_pipelineModel = PipelineModel.load(DATA_DIR + FEATURE_MODEL_FILE);

		// split stream to train and eval data
		SplitStreamOp spliter = new SplitStreamOp().setFraction(0.5).linkFrom(data);
		StreamOperator <?> train_stream_data = feature_pipelineModel.transform(spliter);
		StreamOperator <?> test_stream_data = feature_pipelineModel.transform(spliter.getSideOutput(0));

		AkSourceBatchOp initModel = new AkSourceBatchOp().setFilePath(DATA_DIR + INIT_MODEL_FILE);

		// ftrl train
		FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)
			.setVectorCol(VEC_COL_NAME)
			.setLabelCol(LABEL_COL_NAME)
			.setWithIntercept(true)
			.setAlpha(0.1)
			.setBeta(0.1)
			.setL1(0.01)
			.setL2(0.01)
			.setTimeInterval(10)
			.setVectorSize(NUM_HASH_FEATURES)
			.linkFrom(train_stream_data);

		// model filter
		FtrlModelFilterStreamOp model_filter = new FtrlModelFilterStreamOp()
			.setPositiveLabelValueString("1")
			.setVectorCol(VEC_COL_NAME)
			.setLabelCol(LABEL_COL_NAME)
			.setAccuracyThreshold(0.83)
			.setAucThreshold(0.71)
			.linkFrom(model, train_stream_data);

		model_filter
			.select("'Model' AS out_type, *")
			.print();

		// ftrl predict
		FtrlPredictStreamOp predResult = new FtrlPredictStreamOp(initModel)
			.setVectorCol(VEC_COL_NAME)
			.setPredictionCol(PREDICTION_COL_NAME)
			.setReservedCols(new String[] {LABEL_COL_NAME})
			.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
			.linkFrom(model_filter, test_stream_data);

		predResult
			.sample(0.0001)
			.select("'Pred Sample' AS out_type, *")
			.print();

		// ftrl eval
		predResult
			.link(
				new EvalBinaryClassStreamOp()
					.setPositiveLabelValueString("1")
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
					.setTimeInterval(10)
			)
			.link(
				new JsonValueStreamOp()
					.setSelectedCol("Data")
					.setReservedCols(new String[] {"Statistics"})
					.setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"})
					.setJsonPath(new String[] {"$.Accuracy", "$.AUC", "$.ConfusionMatrix"})
			)
			.select("'Eval Metric' AS out_type, *")
			.print();

		StreamOperator.execute();

	}

}