Alink教程(Java版)
Alink教程(Python版)

第14章 在线学习 Ftrl Demo

Demo code

package benchmark.online;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.types.Row;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.AlinkTypes;
import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;
import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
import com.alibaba.alink.operator.batch.sink.AppendModelStreamFileSinkBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.classification.LogisticRegressionPredictStreamOp;
import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp;
import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp;
import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp;
import com.alibaba.alink.operator.stream.sink.ModelStreamFileSinkStreamOp;
import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;
import com.alibaba.alink.pipeline.LocalPredictor;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.dataproc.StandardScaler;
import com.alibaba.alink.pipeline.feature.FeatureHasher;
import org.apache.commons.lang3.ArrayUtils;
import org.junit.Test;

/**
 * https://www.kaggle.com/c/avazu-ctr-prediction/data
 */
public class FtrlTest {

	private static final String[] ORIGIN_COL_NAMES = new String[] {
		"id", "click", "dt", "C1", "banner_pos",
		"site_id", "site_domain", "site_category", "app_id", "app_domain",
		"app_category", "device_id", "device_ip", "device_model", "device_type",
		"device_conn_type", "C14", "C15", "C16", "C17",
		"C18", "C19", "C20", "C21"
	};

	private static final String[] ORIGIN_COL_TYPES = new String[] {
		"string", "string", "string", "string", "int",
		"string", "string", "string", "string", "string",
		"string", "string", "string", "string", "string",
		"string", "int", "int", "int", "int",
		"int", "int", "int", "int"
	};

	private static final String[] COL_NAMES = new String[] {
		"id", "click",
		"dt_year", "dt_month", "dt_day", "dt_hour",
		"C1", "banner_pos",
		"site_id", "site_domain", "site_category", "app_id", "app_domain",
		"app_category", "device_id", "device_ip", "device_model", "device_type",
		"device_conn_type", "C14", "C15", "C16", "C17",
		"C18", "C19", "C20", "C21"
	};

	private static final String DATA_DIR = "https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/";
	private static final String SMALL_FILE = "avazu-small.csv";
	private static final String LARGE_FILE = "avazu-ctr-train-8M.csv";

	private static final String FEATURE_PIPELINE_MODEL_FILE = "/tmp/feature_model.csv";

	private static final String labelColName = "click";
	private static final String vecColName = "vec";

	static final String[] FEATURE_COL_NAMES =
		ArrayUtils.removeElements(COL_NAMES, labelColName, "id", "dt_year", "dt_month",
			"site_id", "site_domain", "app_id", "device_id", "device_ip", "device_model");

	static final String[] HIGH_FREQ_FEATURE_COL_NAMES = new String[] {"site_id", "site_domain", "device_id",
		"device_model"};

	static final String[] CATEGORY_FEATURE_COL_NAMES = new String[] {
		"C1", "banner_pos",
		"site_category", "app_domain",
		"app_category", "device_type",
		"device_conn_type"
	};

	static final String[] NUMERICAL_FEATURE_COL_NAMES =
		ArrayUtils.removeElements(FEATURE_COL_NAMES, CATEGORY_FEATURE_COL_NAMES);

	@Test
	public void trainFeatureModel() throws Exception {
		MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().setParallelism(4);
		MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(1);

		int numHashFeatures = 30000;
		Pipeline feature_pipeline = new Pipeline()
			.add(
				new StandardScaler()
					.setSelectedCols(NUMERICAL_FEATURE_COL_NAMES)
			)
			.add(
				new FeatureHasher()
					.setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
					.setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
					.setOutputCol(vecColName)
					.setNumFeatures(numHashFeatures).setReservedCols("click")
			);

		feature_pipeline.fit(getSmallBatchSet()).save(FEATURE_PIPELINE_MODEL_FILE,
			true);
		BatchOperator.execute();
	}

	@Test
	public void onlineTrainAndEval() throws Exception {
		PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
		AlinkGlobalConfiguration.setPrintProcessInfo(true);
		Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData();
		StreamOperator <?> trainStream = sources.f0;
		StreamOperator <?> testStream = sources.f1;

		trainStream = featurePipelineModel.transform(trainStream);
		testStream = featurePipelineModel.transform(testStream);

		BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());

		StreamOperator.setParallelism(2);

		BatchOperator <?> model = new LogisticRegressionTrainBatchOp()
			.setVectorCol(vecColName)
			.setLabelCol(labelColName)
			.setWithIntercept(true)
			.linkFrom(trainBatch);

		StreamOperator <?> models = new FtrlTrainStreamOp(model)
			.setVectorCol(vecColName)
			.setLabelCol(labelColName)
			.setMiniBatchSize(1024)
			.setTimeInterval(10)
			.setWithIntercept(true)
			.setModelStreamFilePath("/tmp/avazu_fm_models")
			.linkFrom(trainStream);

		StreamOperator <?> predictResults = new LogisticRegressionPredictStreamOp(model)
			.setPredictionCol("predict")
			.setReservedCols(labelColName)
			.setPredictionDetailCol("details")
			.linkFrom(testStream, models);

		new EvalBinaryClassStreamOp()
			.setPredictionDetailCol("details").setLabelCol(labelColName).setTimeInterval(10).linkFrom(predictResults)
			.link(new JsonValueStreamOp().setSelectedCol("Data")
				.setReservedCols(new String[] {"Statistics"})
				.setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"})
				.setJsonPath("$.Accuracy", "$.AUC", "ConfusionMatrix")).print();

		StreamOperator.execute();
	}

	@Test
	public void onlineTrainAndSave() throws Exception {
		PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
		AlinkGlobalConfiguration.setPrintProcessInfo(true);
		Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData();
		StreamOperator <?> trainStream = sources.f0;

		trainStream = featurePipelineModel.transform(trainStream);

		BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());

		StreamOperator.setParallelism(2);

		BatchOperator <?> model = new LogisticRegressionTrainBatchOp()
			.setVectorCol(vecColName)
			.setLabelCol(labelColName)
			.setWithIntercept(true)
			.linkFrom(trainBatch);

		StreamOperator <?> models = new FtrlTrainStreamOp(model)
			.setVectorCol(vecColName)
			.setLabelCol(labelColName)
			.setMiniBatchSize(1024)
			.setTimeInterval(10)
			.setWithIntercept(true)
			.setModelStreamFilePath("/tmp/rebase_ftrl_models")
			.linkFrom(trainStream);

		models.link(new ModelStreamFileSinkStreamOp().setFilePath("/tmp/ftrl_models"));
		StreamOperator.execute();
	}

	@Test
	public void BatchTrainAndSaveRebaseModel() throws Exception {
		PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
		BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());
		StreamOperator.setParallelism(2);
		BatchOperator <?> model1 = new LogisticRegressionTrainBatchOp()
			.setVectorCol(vecColName)
			.setLabelCol(labelColName)
			.setWithIntercept(true)
			.linkFrom(trainBatch);
		model1.link(new AppendModelStreamFileSinkBatchOp().setFilePath("/tmp/rebase_ftrl_models"));
		BatchOperator.execute();
	}

	@Test
	public void savePipelineModel() throws Exception {
		BatchOperator <?> trainBatch = getSmallBatchSet();
		int numHashFeatures = 30000;
		PipelineModel pipelineModel = new Pipeline()
			.add(
				new StandardScaler()
					.setSelectedCols(NUMERICAL_FEATURE_COL_NAMES)
			)
			.add(
				new FeatureHasher()
					.setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
					.setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
					.setOutputCol(vecColName)
					.setNumFeatures(numHashFeatures).setReservedCols("click")
			).add(
				new LogisticRegression()
					.setVectorCol("vec")
					.setLabelCol("click")
					.setPredictionCol("pred")
					.setModelStreamFilePath("/tmp/ftrl_models")
					.setPredictionDetailCol("detail")
					.setMaxIter(10))
			.fit(trainBatch);

		pipelineModel.save().link(new AkSinkBatchOp().setOverwriteSink(true).setFilePath("/tmp/lr_pipeline.ak"));
		BatchOperator.execute();
	}

	@Test
	public void localPredictor() throws Exception {
		LocalPredictor predictor = new LocalPredictor("/tmp/lr_pipeline.ak",
			TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema()));
		System.out.println(TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema()));
		for (int i = 0; i < Integer.MAX_VALUE; ++i) {
			System.out.println(predictor.map(
				Row.of("220869541682524752", "0", 14, 10, 21, 2, "1005", 0, "1fbe01fe", "f3845767",
					"28905ebd", "ecad2386", "7801e8d9", "07d7df22", "a99f214a", "af1c0727", "a0f5f879", "1", "0",
					15703, 320, 50, 1722, 0, 35, -1, 79)));

			Thread.sleep(5000);
		}
	}

	public static class SplitDataTime extends TableFunction <Row> {

		private Integer parseInt(String s) {
			if ('0' == s.charAt(0)) {
				return Integer.parseInt(s.substring(1));
			} else {
				return Integer.parseInt(s);
			}
		}

		public void eval(String str) {
			collect(Row.of(
				parseInt(str.substring(0, 2)),
				parseInt(str.substring(2, 4)),
				parseInt(str.substring(4, 6)),
				parseInt(str.substring(6, 8))
			));
		}

		@Override
		public TypeInformation <Row> getResultType() {
			return new RowTypeInfo(AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT);
		}

	}

	private static Tuple2 <StreamOperator, StreamOperator> getStreamTrainTestData() {
		StringBuilder sbd = new StringBuilder();
		for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) {
			if (i > 0) {
				sbd.append(",");
			}
			sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]);
		}
		StreamOperator <?> source = new CsvSourceStreamOp()
			.setFilePath(DATA_DIR + FtrlRebaseTest.LARGE_FILE)
			.setSchemaStr(sbd.toString())
			.setIgnoreFirstLine(true)
			.udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime())
			.select(COL_NAMES);

		SplitStreamOp splitter = new SplitStreamOp().setFraction(0.5);

		source.link(splitter);

		return new Tuple2 <>(splitter, splitter.getSideOutput(0));
	}

	private static BatchOperator <?> getSmallBatchSet() {
		StringBuilder sbd = new StringBuilder();
		for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) {
			if (i > 0) {
				sbd.append(",");
			}
			sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]);
		}

		return new CsvSourceBatchOp()
			.setFilePath(DATA_DIR + FtrlRebaseTest.SMALL_FILE)
			.setSchemaStr(sbd.toString())
			.setIgnoreFirstLine(true)
			.udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime())
			.select(COL_NAMES);
	}
}

Demo 功能介绍

该demo使用Ftrl 算法对Avazu 数据(https://www.kaggle.com/c/avazu-ctr-prediction/data)进行实时训练并生成模型流,并将模型流实时加载到推理服务中。另外我们还增加了模型rebase 的示例代码,能够很容易的完成用一个批模型定时重新拉回模型,防止模型跑偏。最后还提供了一个模型训练+预测+评估的示例代码。

函数说明

函数

任务类型

说明

trainFeatureModel()

批任务

训练特征工程模型,这个模型将对训练、预测、推理数据进行特征编码

savePipelineModel()

批任务

训练PipelineModel,该模型是部署到线上服务的模型

onlineTrainAndSave()

流任务

使用Ftrl实时训练模型,并定时将模型写出到指定目录

BatchTrainAndSaveRebaseModel()

批任务

训练用来重新初始化的模型,用来拉回模型,防止模型跑偏

localPredict()

本地推理

本地搭建一个服务,对同一条样本预测,用来验证模型更新

onlineTrainAndEval()

流任务

Ftrl 训练模型,并对模型进行预测评估

执行步骤

  • 首先我们需要执行 trainFeatureModel() 函数,生成特征工程模型,并存储到目录“/tmp/feature_model.csv”,后面的函数都需要该模型
  • 第二步,执行savePipelineModel() 函数,生成部署到线上服务的模型,该模型会通过 setModelStreamFilePath("/tmp/ftrl_models") 设置模型流的目录,设置完再部署这个模型时,在推理的同时会实时监控这个目录,当有新模型产生时会自动加载新模型,用最新模型进行推理。
  • 第三步,执行localPredict() 函数,将第二步产生的模型部署成本地服务。
  • 第四步,执行onlineTrainAndSave(),用Ftrl算法实时训练在线模型,并以固定频率输出到目录"/tmp/ftrl_models",这个目录与第二步的目录是同一个目录。另外这一步还要设置一个目录setModelStreamFilePath("/tmp/rebase_ftrl_models"),这个目录是用来做模型rebase的,在线学习过程中会实时监控这个目录,当有新的模型出现在这个目录中,会重新加载这个模型作为base模型继续进行训练。
  • 第五步,隔段时间(1小时 or 1天)执行函数BatchTrainAndSaveRebaseModel(),这个函数将训练一个新的批模型,并将其写入到rebase目录"/tmp/rebase_ftrl_models",这样,第四步中的函数就会监测到这个模型,并进行模型 rebase。
  • onlineTrainAndEval() 函数是一个独立的函数,用来评估在线学习算法生成的模型怎么样,并打印评估结果。

备注

  • 特征工程这里我们对数据做标准化和FeatureHash,其实这里可以使用任何其他Alink 的特征工程算法,类似OneHot编码,GBDT编码,多热编码,归一化、分桶算法等
  • 实时加载在线训练的模型是通过一个给定的文件系统的目录("/tmp/ftrl_models")完成的,这个目录可以是本地目录,也可以是网络文件系统的目录,类似OSS 和HDSF等。
  • 模型重新初始化和在线模型的加载是类似的,也是通过一个给定的文件系统的目录("/tmp/rebase_ftrl_models")完成的,同样,这个目录可以是本地目录,也可以是网络文件系统的目录,类似OSS 和HDSF等。