Java 类名:com.alibaba.alink.operator.stream.onlinelearning.OnlineLearningStreamOp
Python 类名:OnlineLearningStreamOp
该组件是一个基于Pipeline模型的在线学习算法组件,支持多种在线学习模型,具体包括(LogisticRegression,SVM,LinearReg,FmClassifier,FmRegressor,Softmax),
并且支持多种在线优化算法,具体包括(Ftrl,SGD,ADAM,MOMENTUM,RMSProp,ADAGRAD)。另外该组件通过Pipeline模型Rebase的方式支持特征工程模型和在线学习模型同步更新。
该框架将特征工程模型和在线学习模型做成一个Pipeline模型,通过一个批式训练任务,训练这个离线Pipeline模型,然后以该模型为初始模型进行在线学习,
在线学习过程中,将使用最新的流式数据实时更新在线学习模型,并定时输出Pipeline模型。输出的Pipeline模型中特征工程模型保持不变,在线学学习模型实时更新。特征工程模型的更新通过整个Pipeline模型的rebase来完成,
使用者通过定时调度一个批式训练任务训练一个Pipeline模型的方式定时rebase模型,保证特征工程算法的时效性。
该算法的使用方式分如下几个步骤:
使用大规模离线数据,训练一个离线模型,作为在线学习的初始模型,示例代码如下:
pipelineModel = Pipeline() \
.add(StandardScaler() \
.setSelectedCols(FEATURES) \
.add(FeatureHasher() \
.setNumFeatures(numHashFeatures) \
.setSelectedCols(FEATURES) \
.setReservedCols(labelColName) \
.setOutputCol(vecColName)) \
.add(FmClassifier()\
.setVectorCol(vecColName)\
.setLabelCol(labelColName)\
.setWithIntercept(True)\
.setNumEpochs(2)\
.setPredictionCol("pred")\
.setPredictionDetailCol("details")).fit(batchData)
其中StandardScaler和FeatureHasher 是特征工程模型,这边可以是任何其他特征工程模型,类似GbdtEncoder、OneHot、MultiHot等。
FmClassifier 是在线学习模型,此处还可以是SVM,LogisticRegression,LinearReg,FmRegressor,Softmax 等。
这些特征工程和在线学习模型使用可以参考对应的批式算法的文档。
使用实时训练数据,实时训练在线模型,示例代码如下:
models = OnlineLearningStreamOp(pipelineModel)\
.setModelStreamFilePath(REBASE_PATH)\
.setTimeInterval(10)\
.setLearningRate(0.01)\
.setOptimMethod("ADAM")\
.linkFrom(trainStream)
其中,REBASE_PATH是用来做rebase的模型流目录,在第3步中将介绍怎么准备Rebase模型。
使用批式数据训练模型,并写出到rebase的模型流目录,代码示例如下:
pipelineModel = Pipeline() \
.add(StandardScaler() \
.setSelectedCols(FEATURES) \
.add(FeatureHasher() \
.setNumFeatures(numHashFeatures) \
.setSelectedCols(FEATURES) \
.setReservedCols(labelColName) \
.setOutputCol(vecColName)) \
.add(FmClassifier()\
.setVectorCol(vecColName)\
.setLabelCol(labelColName)\
.setWithIntercept(True)\
.setNumEpochs(2)\
.setPredictionCol("pred")\
.setPredictionDetailCol("details")).fit(batchData)
pipelineModel.save().link(AppendModelStreamFileSinkBatchOp().setFilePath(REBASE_PATH))
BatchOperator.execute()
[1] McMahan, H. Brendan, et al. “Ad click prediction: a view from the trenches.” Proceedings of the 19th ACM SIGKDD international conference on Knowledge discovery and data mining. 2013.
[2] Ruder, Sebastian. “An overview of gradient descent optimization algorithms.” arXiv preprint arXiv:1609.04747 (2016).
| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 |
|---|---|---|---|---|---|---|
| alpha | 希腊字母:阿尔法 | 经常用来表示算法特殊的参数 | Double | 0.1 | ||
| beta | 希腊字母:贝塔 | 经常用来表示算法特殊的参数 | Double | 1.0 | ||
| beta1 | beta1 | beta1: parameter for adam optimizer. | Double | 0.0 <= x <= 1.0 | 0.9 | |
| beta2 | beta2 | beta2: parameter for adam optimizer. | Double | 0.0 <= x <= 1.0 | 0.999 | |
| gamma | gamma | gamma: parameter for RMSProp or momentum optimizer. | Double | 0.0 <= x <= 1.0 | 0.9 | |
| l1 | L1 正则化系数 | L1 正则化系数,默认为0.1。 | Double | x >= 0.0 | 0.1 | |
| l2 | 正则化系数 | L2 正则化系数,默认为0.1。 | Double | x >= 0.0 | 0.1 | |
| learningRate | 学习率 | 优化算法的学习率,默认0.1。 | Double | null | ||
| optimMethod | 优化方法 | 在线学习问题求解时选择的优化方法 | String | “FTRL”, “ADAM”, “RMSprop”, “ADAGRAD”, “SGD”, “MOMENTUM” | “FTRL” | |
| timeInterval | 时间间隔 | 数据流流动过程中时间的间隔 | Integer | 1800 | ||
| modelStreamFilePath | 模型流的文件路径 | 模型流的文件路径 | String | null | ||
| modelStreamScanInterval | 扫描模型路径的时间间隔 | 描模型路径的时间间隔,单位秒 | Integer | 10 | ||
| modelStreamStartTime | 模型流的起始时间 | 模型流的起始时间。默认从当前时刻开始读。使用yyyy-mm-dd hh:mm:ss.fffffffff格式,详见Timestamp.valueOf(String s) | String | null |
** 以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!**
FEATURE_LABEL = ["c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "label"]
FEATURES = ["c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8"]
labelColName = "label"
vecColName ="vec"
REBASE_PATH = "/tmp/rebase"
trainStream = RandomTableSourceStreamOp()\
.setNumCols(10)\
.setMaxRows(100000)\
.setOutputCols(FEATURE_LABEL)\
.setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0)")\
.link(SpeedControlStreamOp().setTimeInterval(0.01))
testStream = trainStream
batchData = RandomTableSourceBatchOp()\
.setNumCols(10)\
.setNumRows(1000)\
.setOutputCols(FEATURE_LABEL)\
.setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0)")
pipelineModel = Pipeline()\
.add(OneHotEncoder()\
.setSelectedCols(FEATURES)\
.setOutputCols([vecColName])\
.setReservedCols([labelColName]))\
.add(FmClassifier()\
.setVectorCol(vecColName)\
.setLabelCol(labelColName)\
.setWithIntercept(True)\
.setNumEpochs(2)\
.setPredictionCol("pred")\
.setPredictionDetailCol("details")).fit(batchData)
models = OnlineLearningStreamOp(pipelineModel)\
.setModelStreamFilePath(REBASE_PATH)\
.setTimeInterval(10)\
.setLearningRate(0.01)\
.setOptimMethod("ADAM")\
.linkFrom(trainStream)
predResults = PipelinePredictStreamOp(pipelineModel).linkFrom(testStream, models)
EvalBinaryClassStreamOp()\
.setPredictionDetailCol("details")\
.setLabelCol(labelColName)\
.setTimeInterval(10).linkFrom(predResults)\
.link(JsonValueStreamOp().setSelectedCol("Data")\
.setReservedCols(["Statistics"])\
.setOutputCols(["Accuracy", "AUC", "ConfusionMatrix"])\
.setJsonPath(["$.Accuracy", "$.AUC", "ConfusionMatrix"])).print()
StreamOperator.execute()
package benchmark.online;
import org.alinklab.operator.batch.AssemblePipelineModelBatchOp;
import org.apache.flink.api.java.tuple.Tuple2;
import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.exceptions.AkUnimplementedOperationException;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
import com.alibaba.alink.operator.batch.sink.AppendModelStreamFileSinkBatchOp;
import com.alibaba.alink.operator.batch.source.RandomTableSourceBatchOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp;
import com.alibaba.alink.operator.stream.dataproc.SpeedControlStreamOp;
import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp;
import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp;
import com.alibaba.alink.operator.stream.evaluation.EvalMultiClassStreamOp;
import com.alibaba.alink.operator.stream.evaluation.EvalRegressionStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.OnlineLearningStreamOp;
import com.alibaba.alink.operator.stream.PipelinePredictStreamOp;
import com.alibaba.alink.operator.stream.sink.ModelStreamFileSinkStreamOp;
import com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp;
import com.alibaba.alink.params.onlinelearning.OnlineLearningTrainParams.OptimMethod;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.PipelineStageBase;
import com.alibaba.alink.pipeline.classification.FmClassifier;
import com.alibaba.alink.pipeline.classification.LinearSvm;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.classification.Softmax;
import com.alibaba.alink.pipeline.dataproc.StandardScaler;
import com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler;
import com.alibaba.alink.pipeline.feature.FeatureHasher;
import com.alibaba.alink.pipeline.feature.GbdtEncoder;
import com.alibaba.alink.pipeline.feature.OneHotEncoder;
import com.alibaba.alink.pipeline.regression.FmRegressor;
import com.alibaba.alink.pipeline.regression.LinearRegression;
import org.junit.Test;
public class OnlineLearningTest {
private static final String[] FEATURES = new String[]{
"c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8"
};
private static final String[] FEATURE_LABEL = new String[]{
"c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "label"
};
private static final String labelColName = "label";
private static final String vecColName = "vec";
private static final String REBASE_PATH = "/tmp/rebase";
private static final String MODEL_PATH = "/tmp/encoder_web";
private static final String PIPELINE_PATH = "/tmp/pipeline_model.ak";
private static Tuple2<StreamOperator, StreamOperator> getStreamTrainData() {
StreamOperator source = new RandomTableSourceStreamOp()
.setNumCols(10)
.setMaxRows(100000L)
.setOutputCols(FEATURE_LABEL)
.setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0)")
.link(new SpeedControlStreamOp().setTimeInterval(0.01));
SplitStreamOp splitter = new SplitStreamOp().setFraction(0.5);
source.link(splitter);
return new Tuple2<>(splitter, splitter.getSideOutput(0));
}
private static BatchOperator getBatchSet() {
return new RandomTableSourceBatchOp()
.setNumCols(10)
.setNumRows(1000L)
.setOutputCols(FEATURE_LABEL)
.setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0)");
}
@Test
public void ClassificationOnlineTrainAndEval() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
Tuple2<StreamOperator, StreamOperator> sources = getStreamTrainData();
StreamOperator trainStream = sources.f0;
StreamOperator testStream = sources.f1;
PipelineModel pipelineModel = getFeaturePipeline(EncoderType.GBDT)
.add(getLastPipelineStage(StageType.LR))
.fit(getBatchSet());
BatchOperator pipelineModelData = new AssemblePipelineModelBatchOp().setPipelineModel(pipelineModel);
StreamOperator.setParallelism(8);
StreamOperator models = new OnlineLearningStreamOp(pipelineModelData)
.setOptimMethod(OptimMethod.ADAGRAD)
.setLearningRate(0.1)
.setTimeInterval(10)
.linkFrom(trainStream);
StreamOperator predResults = new PipelinePredictStreamOp(pipelineModel)
.linkFrom(testStream, models);
new EvalBinaryClassStreamOp()
.setPredictionDetailCol("details").setLabelCol(labelColName).setTimeInterval(10).linkFrom(predResults)
.link(new JsonValueStreamOp().setSelectedCol("Data")
.setReservedCols(new String[]{"Statistics"})
.setOutputCols(new String[]{"Accuracy", "AUC", "ConfusionMatrix"})
.setJsonPath("$.Accuracy", "$.AUC", "ConfusionMatrix")).print();
StreamOperator.execute();
}
@Test
public void regressionOnlineTrainAndEval() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
Tuple2<StreamOperator, StreamOperator> sources = getStreamTrainData();
StreamOperator trainStream = sources.f0;
StreamOperator testStream = sources.f1;
PipelineModel pipelineModel = getFeaturePipeline(EncoderType.FEATURE_HASH)
.add(getLastPipelineStage(StageType.FM_REG))
.fit(getBatchSet());
BatchOperator pipelineModelData = new AssemblePipelineModelBatchOp().setPipelineModel(pipelineModel);
StreamOperator.setParallelism(2);
StreamOperator models = new OnlineLearningStreamOp(pipelineModelData)
.setTimeInterval(10)
.linkFrom(trainStream);
StreamOperator predResults = new PipelinePredictStreamOp(pipelineModel)
.linkFrom(testStream, models);
new EvalRegressionStreamOp()
.setPredictionCol("pred").setLabelCol(labelColName).setTimeInterval(10).linkFrom(predResults)
.link(new JsonValueStreamOp().setSelectedCol("regression_eval_result")
.setReservedCols(new String[]{})
.setOutputCols(new String[]{"MAPE", "RMSE", "MAE"})
.setJsonPath("$.MAPE", "$.RMSE", "$.MAE"))
.print();
StreamOperator.execute();
}
@Test
public void saveOnlineModels() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
Tuple2<StreamOperator, StreamOperator> sources = getStreamTrainData();
StreamOperator trainStream = sources.f0;
PipelineModel pipelineModel = getFeaturePipeline(EncoderType.ONE_HOT)
.add(getLastPipelineStage(StageType.LR)).fit(getBatchSet());
BatchOperator pipelineModelData = new AssemblePipelineModelBatchOp().setPipelineModel(pipelineModel);
StreamOperator.setParallelism(2);
StreamOperator models = new OnlineLearningStreamOp(pipelineModelData)
.setTimeInterval(10)
.setModelStreamFilePath(REBASE_PATH)
.linkFrom(trainStream);
models.link(new ModelStreamFileSinkStreamOp().setNumKeepModel(10).setFilePath(MODEL_PATH));
StreamOperator.execute();
}
@Test
public void OnlineSoftmaxTrainAndEval() throws Exception {
BatchOperator batchData = new RandomTableSourceBatchOp()
.setNumCols(10)
.setNumRows(1000L)
.setOutputCols(FEATURE_LABEL)
.setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0,3.0,3.0)");
StreamOperator<?> streamData = new RandomTableSourceStreamOp()
.setNumCols(10)
.setMaxRows(100000L)
.setOutputCols(FEATURE_LABEL)
.setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0,3.0,3.0)")
.link(new SpeedControlStreamOp().setTimeInterval(0.01))
.link(new SplitStreamOp(0.5));
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PipelineModel pipelineModel
= getFeaturePipeline(EncoderType.ASSEMBLER)
.add(getLastPipelineStage(StageType.SOFTMAX))
.fit(batchData);
StreamOperator.setParallelism(2);
StreamOperator models = new OnlineLearningStreamOp(pipelineModel)
.setOptimMethod(OptimMethod.ADAM)
.setLearningRate(0.01)
.setTimeInterval(10)
.linkFrom(streamData);
StreamOperator predResults = new PipelinePredictStreamOp(pipelineModel)
.linkFrom(streamData.getSideOutput(0), models);
new EvalMultiClassStreamOp()
.setPredictionDetailCol("details").setLabelCol(labelColName).setTimeInterval(10).linkFrom(predResults)
.link(new JsonValueStreamOp().setSelectedCol("Data")
.setReservedCols(new String[]{"Statistics"})
.setOutputCols(new String[]{"Accuracy", "AUC", "ConfusionMatrix"})
.setJsonPath("$.Accuracy", "$.AUC", "ConfusionMatrix")).print();
StreamOperator.execute();
}
@Test
public void saveRebaseModel() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PipelineModel pipelineModel = getFeaturePipeline(EncoderType.GBDT)
.add(getLastPipelineStage(StageType.LR)).fit(getBatchSet());
BatchOperator pipelineModelData = new AssemblePipelineModelBatchOp().setPipelineModel(pipelineModel);
pipelineModelData.link(new AppendModelStreamFileSinkBatchOp().setFilePath(REBASE_PATH));
BatchOperator.execute();
}
@Test
public void savePipelineModel() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PipelineModel pipelineModel = getFeaturePipeline(EncoderType.ONE_HOT)
.add(getLastPipelineStage(StageType.LR)).fit(getBatchSet());
BatchOperator pipelineModelData = new AssemblePipelineModelBatchOp().setPipelineModel(pipelineModel);
pipelineModelData.link(new AkSinkBatchOp().setOverwriteSink(true).setFilePath(PIPELINE_PATH));
BatchOperator.execute();
}
enum StageType {
LR,
LINEAR_REG,
FM_CLASSIFIER,
FM_REG,
SOFTMAX,
SVM
}
private PipelineStageBase getLastPipelineStage(StageType type) {
switch (type) {
case LR:
return new LogisticRegression()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setMaxIter(1)
.setWithIntercept(true)
.setPredictionCol("pred")
.setPredictionDetailCol("details");
case SVM:
return new LinearSvm()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setMaxIter(1)
.setWithIntercept(true)
.setPredictionCol("pred")
.setPredictionDetailCol("details");
case FM_CLASSIFIER:
return new FmClassifier()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setNumEpochs(1)
.setWithIntercept(true)
.setWithIntercept(true)
.setNumFactor(10).setPredictionCol("pred")
.setPredictionDetailCol("details");
case FM_REG:
new FmRegressor()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setNumEpochs(1)
.setWithIntercept(true)
.setWithIntercept(true)
.setNumFactor(10).setPredictionCol("pred")
.setPredictionDetailCol("details");
case LINEAR_REG:
return new LinearRegression()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setMaxIter(1)
.setWithIntercept(true)
.setPredictionCol("pred");
case SOFTMAX:
return new Softmax()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setMaxIter(100)
.setWithIntercept(true)
.setPredictionCol("pred")
.setPredictionDetailCol("details");
default:
throw new AkUnimplementedOperationException("not support yet.");
}
}
enum EncoderType {
ONE_HOT,
FEATURE_HASH,
ASSEMBLER,
GBDT
}
private Pipeline getFeaturePipeline(EncoderType type) {
Pipeline pipeline = new Pipeline();
int numHashFeatures = 30000;
switch (type) {
case FEATURE_HASH:
pipeline
.add(
new StandardScaler()
.setSelectedCols(FEATURES))
.add(
new FeatureHasher()
.setNumFeatures(numHashFeatures)
.setSelectedCols(FEATURES)
.setReservedCols(labelColName)
.setOutputCol(vecColName));
break;
case ONE_HOT:
pipeline
.add(
new OneHotEncoder()
.setSelectedCols(FEATURES)
.setOutputCols(vecColName)
.setReservedCols(labelColName)
.setReservedCols(labelColName));
break;
case ASSEMBLER:
pipeline
.add(
new StandardScaler()
.setSelectedCols(FEATURES))
.add(
new VectorAssembler()
.setSelectedCols(FEATURES)
.setReservedCols(labelColName)
.setOutputCol(vecColName));
break;
case GBDT:
pipeline
.add(
new StandardScaler()
.setSelectedCols(FEATURES))
.add(
new GbdtEncoder()
.setLabelCol(labelColName)
.setFeatureCols(FEATURES)
.setReservedCols(labelColName)
.setPredictionCol(vecColName));
break;
default:
}
return pipeline;
}
}
| Statistics | Accuracy | AUC | ConfusionMatrix |
|---|---|---|---|
| all | 0.8404921700223713 | 0.5297035062366982 | 3757,713],[0,0 |
| window | 0.8404921700223713 | 0.5297035062366982 | 3757,713],[0,0 |
| window | 0.32543617998163454 | 0.5001458256936807 | 1076,204],[3469,696 |
| all | 0.5576399394856278 | 0.5128750935507584 | 4833,917],[3469,696 |
| window | 0.6837857666911226 | 0.4987630054677052 | 3547,711],[1013,181 |
| all | 0.6023947419795666 | 0.5015687641976192 | 8380,1628],[4482,877 |
| … | … | … | … |