package benchmark.online; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.table.functions.TableFunction; import org.apache.flink.types.Row; import com.alibaba.alink.common.AlinkGlobalConfiguration; import com.alibaba.alink.common.AlinkTypes; import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp; import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp; import com.alibaba.alink.operator.batch.sink.AppendModelStreamFileSinkBatchOp; import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; import com.alibaba.alink.operator.stream.StreamOperator; import com.alibaba.alink.operator.stream.classification.LogisticRegressionPredictStreamOp; import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp; import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp; import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp; import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp; import com.alibaba.alink.operator.stream.sink.ModelStreamFileSinkStreamOp; import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp; import com.alibaba.alink.pipeline.LocalPredictor; import com.alibaba.alink.pipeline.Pipeline; import com.alibaba.alink.pipeline.PipelineModel; import com.alibaba.alink.pipeline.classification.LogisticRegression; import com.alibaba.alink.pipeline.dataproc.StandardScaler; import com.alibaba.alink.pipeline.feature.FeatureHasher; import org.apache.commons.lang3.ArrayUtils; import org.junit.Test; /** * https://www.kaggle.com/c/avazu-ctr-prediction/data */ public class FtrlTest { private static final String[] ORIGIN_COL_NAMES = new String[] { "id", "click", "dt", "C1", "banner_pos", "site_id", "site_domain", "site_category", "app_id", "app_domain", "app_category", "device_id", "device_ip", "device_model", "device_type", "device_conn_type", "C14", "C15", "C16", "C17", "C18", "C19", "C20", "C21" }; private static final String[] ORIGIN_COL_TYPES = new String[] { "string", "string", "string", "string", "int", "string", "string", "string", "string", "string", "string", "string", "string", "string", "string", "string", "int", "int", "int", "int", "int", "int", "int", "int" }; private static final String[] COL_NAMES = new String[] { "id", "click", "dt_year", "dt_month", "dt_day", "dt_hour", "C1", "banner_pos", "site_id", "site_domain", "site_category", "app_id", "app_domain", "app_category", "device_id", "device_ip", "device_model", "device_type", "device_conn_type", "C14", "C15", "C16", "C17", "C18", "C19", "C20", "C21" }; private static final String DATA_DIR = "https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/"; private static final String SMALL_FILE = "avazu-small.csv"; private static final String LARGE_FILE = "avazu-ctr-train-8M.csv"; private static final String FEATURE_PIPELINE_MODEL_FILE = "/tmp/feature_model.csv"; private static final String labelColName = "click"; private static final String vecColName = "vec"; static final String[] FEATURE_COL_NAMES = ArrayUtils.removeElements(COL_NAMES, labelColName, "id", "dt_year", "dt_month", "site_id", "site_domain", "app_id", "device_id", "device_ip", "device_model"); static final String[] HIGH_FREQ_FEATURE_COL_NAMES = new String[] {"site_id", "site_domain", "device_id", "device_model"}; static final String[] CATEGORY_FEATURE_COL_NAMES = new String[] { "C1", "banner_pos", "site_category", "app_domain", "app_category", "device_type", "device_conn_type" }; static final String[] NUMERICAL_FEATURE_COL_NAMES = ArrayUtils.removeElements(FEATURE_COL_NAMES, CATEGORY_FEATURE_COL_NAMES); @Test public void trainFeatureModel() throws Exception { MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().setParallelism(4); MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(1); int numHashFeatures = 30000; Pipeline feature_pipeline = new Pipeline() .add( new StandardScaler() .setSelectedCols(NUMERICAL_FEATURE_COL_NAMES) ) .add( new FeatureHasher() .setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES)) .setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES)) .setOutputCol(vecColName) .setNumFeatures(numHashFeatures).setReservedCols("click") ); feature_pipeline.fit(getSmallBatchSet()).save(FEATURE_PIPELINE_MODEL_FILE, true); BatchOperator.execute(); } @Test public void onlineTrainAndEval() throws Exception { PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE); AlinkGlobalConfiguration.setPrintProcessInfo(true); Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData(); StreamOperator <?> trainStream = sources.f0; StreamOperator <?> testStream = sources.f1; trainStream = featurePipelineModel.transform(trainStream); testStream = featurePipelineModel.transform(testStream); BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet()); StreamOperator.setParallelism(2); BatchOperator <?> model = new LogisticRegressionTrainBatchOp() .setVectorCol(vecColName) .setLabelCol(labelColName) .setWithIntercept(true) .linkFrom(trainBatch); StreamOperator <?> models = new FtrlTrainStreamOp(model) .setVectorCol(vecColName) .setLabelCol(labelColName) .setMiniBatchSize(1024) .setTimeInterval(10) .setWithIntercept(true) .setModelStreamFilePath("/tmp/avazu_fm_models") .linkFrom(trainStream); StreamOperator <?> predictResults = new LogisticRegressionPredictStreamOp(model) .setPredictionCol("predict") .setReservedCols(labelColName) .setPredictionDetailCol("details") .linkFrom(testStream, models); new EvalBinaryClassStreamOp() .setPredictionDetailCol("details").setLabelCol(labelColName).setTimeInterval(10).linkFrom(predictResults) .link(new JsonValueStreamOp().setSelectedCol("Data") .setReservedCols(new String[] {"Statistics"}) .setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"}) .setJsonPath("$.Accuracy", "$.AUC", "ConfusionMatrix")).print(); StreamOperator.execute(); } @Test public void onlineTrainAndSave() throws Exception { PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE); AlinkGlobalConfiguration.setPrintProcessInfo(true); Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData(); StreamOperator <?> trainStream = sources.f0; trainStream = featurePipelineModel.transform(trainStream); BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet()); StreamOperator.setParallelism(2); BatchOperator <?> model = new LogisticRegressionTrainBatchOp() .setVectorCol(vecColName) .setLabelCol(labelColName) .setWithIntercept(true) .linkFrom(trainBatch); StreamOperator <?> models = new FtrlTrainStreamOp(model) .setVectorCol(vecColName) .setLabelCol(labelColName) .setMiniBatchSize(1024) .setTimeInterval(10) .setWithIntercept(true) .setModelStreamFilePath("/tmp/rebase_ftrl_models") .linkFrom(trainStream); models.link(new ModelStreamFileSinkStreamOp().setFilePath("/tmp/ftrl_models")); StreamOperator.execute(); } @Test public void BatchTrainAndSaveRebaseModel() throws Exception { PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE); BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet()); StreamOperator.setParallelism(2); BatchOperator <?> model1 = new LogisticRegressionTrainBatchOp() .setVectorCol(vecColName) .setLabelCol(labelColName) .setWithIntercept(true) .linkFrom(trainBatch); model1.link(new AppendModelStreamFileSinkBatchOp().setFilePath("/tmp/rebase_ftrl_models")); BatchOperator.execute(); } @Test public void savePipelineModel() throws Exception { BatchOperator <?> trainBatch = getSmallBatchSet(); int numHashFeatures = 30000; PipelineModel pipelineModel = new Pipeline() .add( new StandardScaler() .setSelectedCols(NUMERICAL_FEATURE_COL_NAMES) ) .add( new FeatureHasher() .setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES)) .setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES)) .setOutputCol(vecColName) .setNumFeatures(numHashFeatures).setReservedCols("click") ).add( new LogisticRegression() .setVectorCol("vec") .setLabelCol("click") .setPredictionCol("pred") .setModelStreamFilePath("/tmp/ftrl_models") .setPredictionDetailCol("detail") .setMaxIter(10)) .fit(trainBatch); pipelineModel.save().link(new AkSinkBatchOp().setOverwriteSink(true).setFilePath("/tmp/lr_pipeline.ak")); BatchOperator.execute(); } @Test public void localPredictor() throws Exception { LocalPredictor predictor = new LocalPredictor("/tmp/lr_pipeline.ak", TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema())); System.out.println(TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema())); for (int i = 0; i < Integer.MAX_VALUE; ++i) { System.out.println(predictor.map( Row.of("220869541682524752", "0", 14, 10, 21, 2, "1005", 0, "1fbe01fe", "f3845767", "28905ebd", "ecad2386", "7801e8d9", "07d7df22", "a99f214a", "af1c0727", "a0f5f879", "1", "0", 15703, 320, 50, 1722, 0, 35, -1, 79))); Thread.sleep(5000); } } public static class SplitDataTime extends TableFunction <Row> { private Integer parseInt(String s) { if ('0' == s.charAt(0)) { return Integer.parseInt(s.substring(1)); } else { return Integer.parseInt(s); } } public void eval(String str) { collect(Row.of( parseInt(str.substring(0, 2)), parseInt(str.substring(2, 4)), parseInt(str.substring(4, 6)), parseInt(str.substring(6, 8)) )); } @Override public TypeInformation <Row> getResultType() { return new RowTypeInfo(AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT); } } private static Tuple2 <StreamOperator, StreamOperator> getStreamTrainTestData() { StringBuilder sbd = new StringBuilder(); for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) { if (i > 0) { sbd.append(","); } sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]); } StreamOperator <?> source = new CsvSourceStreamOp() .setFilePath(DATA_DIR + FtrlRebaseTest.LARGE_FILE) .setSchemaStr(sbd.toString()) .setIgnoreFirstLine(true) .udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime()) .select(COL_NAMES); SplitStreamOp splitter = new SplitStreamOp().setFraction(0.5); source.link(splitter); return new Tuple2 <>(splitter, splitter.getSideOutput(0)); } private static BatchOperator <?> getSmallBatchSet() { StringBuilder sbd = new StringBuilder(); for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) { if (i > 0) { sbd.append(","); } sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]); } return new CsvSourceBatchOp() .setFilePath(DATA_DIR + FtrlRebaseTest.SMALL_FILE) .setSchemaStr(sbd.toString()) .setIgnoreFirstLine(true) .udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime()) .select(COL_NAMES); } }
该demo使用Ftrl 算法对Avazu 数据(https://www.kaggle.com/c/avazu-ctr-prediction/data)进行实时训练并生成模型流,并将模型流实时加载到推理服务中。另外我们还增加了模型rebase 的示例代码,能够很容易的完成用一个批模型定时重新拉回模型,防止模型跑偏。最后还提供了一个模型训练+预测+评估的示例代码。
函数 | 任务类型 | 说明 |
trainFeatureModel() | 批任务 | 训练特征工程模型,这个模型将对训练、预测、推理数据进行特征编码 |