本章包括下面各节:
14.1 整体流程
14.2 数据准备
14.3 特征工程
14.4 特征工程处理数据
14.5 在线训练
14.6 模型过滤
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp; import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp; import com.alibaba.alink.operator.batch.source.AkSourceBatchOp; import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; import com.alibaba.alink.operator.batch.source.TextSourceBatchOp; import com.alibaba.alink.operator.stream.StreamOperator; import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp; import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp; import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp; import com.alibaba.alink.operator.stream.onlinelearning.FtrlModelFilterStreamOp; import com.alibaba.alink.operator.stream.onlinelearning.FtrlPredictStreamOp; import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp; import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp; import com.alibaba.alink.pipeline.Pipeline; import com.alibaba.alink.pipeline.PipelineModel; import com.alibaba.alink.pipeline.dataproc.StandardScaler; import com.alibaba.alink.pipeline.feature.FeatureHasher; import org.apache.commons.lang3.ArrayUtils; import java.io.File; public class Chap14 { private static final String DATA_DIR = Utils.ROOT_DIR + "ctr_avazu" + File.separator; static final String SCHEMA_STRING = "id string, click string, dt string, C1 string, banner_pos int, site_id string, site_domain string, " + "site_category string, app_id string, app_domain string, app_category string, device_id string, " + "device_ip string, device_model string, device_type string, device_conn_type string, C14 int, C15 int, " + "C16 int, C17 int, C18 int, C19 int, C20 int, C21 int"; static final String[] CATEGORY_COL_NAMES = new String[] { "C1", "banner_pos", "site_category", "app_domain", "app_category", "device_type", "device_conn_type", "site_id", "site_domain", "device_id", "device_model"}; static final String[] NUMERICAL_COL_NAMES = new String[] { "C14", "C15", "C16", "C17", "C18", "C19", "C20", "C21"}; static final String FEATURE_MODEL_FILE = "feature_model.ak"; static final String INIT_MODEL_FILE = "init_model.ak"; static final String LABEL_COL_NAME = "click"; static final String VEC_COL_NAME = "vec"; static final String PREDICTION_COL_NAME = "pred"; static final String PRED_DETAIL_COL_NAME = "pred_info"; static final int NUM_HASH_FEATURES = 30000; public static void main(String[] args) throws Exception { c_2(); c_3(); c_4(); c_5(); c_6(); } static void c_2() throws Exception { new TextSourceBatchOp() .setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv") .firstN(10) .print(); CsvSourceBatchOp trainBatchData = new CsvSourceBatchOp() .setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv") .setSchemaStr(SCHEMA_STRING); trainBatchData.firstN(10).print(); } static void c_3() throws Exception { CsvSourceBatchOp trainBatchData = new CsvSourceBatchOp() .setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv") .setSchemaStr(SCHEMA_STRING); // setup feature enginerring pipeline Pipeline feature_pipeline = new Pipeline() .add( new StandardScaler() .setSelectedCols(NUMERICAL_COL_NAMES) ) .add( new FeatureHasher() .setSelectedCols(ArrayUtils.addAll(CATEGORY_COL_NAMES, NUMERICAL_COL_NAMES)) .setCategoricalCols(CATEGORY_COL_NAMES) .setOutputCol(VEC_COL_NAME) .setNumFeatures(NUM_HASH_FEATURES) ); if (!new File(DATA_DIR + FEATURE_MODEL_FILE).exists()) { // fit and save feature pipeline model feature_pipeline .fit(trainBatchData) .save(DATA_DIR + FEATURE_MODEL_FILE); BatchOperator.execute(); } } static void c_4() throws Exception { // load pipeline model PipelineModel feature_pipelineModel = PipelineModel.load(DATA_DIR + FEATURE_MODEL_FILE); // prepare stream train data CsvSourceStreamOp data = new CsvSourceStreamOp() .setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-ctr-train-8M.csv") .setSchemaStr(SCHEMA_STRING); if (!new File(DATA_DIR + INIT_MODEL_FILE).exists()) { CsvSourceBatchOp trainBatchData = new CsvSourceBatchOp() .setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv") .setSchemaStr(SCHEMA_STRING); // train initial batch model LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp() .setVectorCol(VEC_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setWithIntercept(true) .setMaxIter(10); feature_pipelineModel .transform(trainBatchData) .link(lr) .link( new AkSinkBatchOp().setFilePath(DATA_DIR + INIT_MODEL_FILE) ); BatchOperator.execute(); } } static void c_5() throws Exception { // load pipeline model PipelineModel feature_pipelineModel = PipelineModel.load(DATA_DIR + FEATURE_MODEL_FILE); BatchOperator initModel = new AkSourceBatchOp().setFilePath(DATA_DIR + INIT_MODEL_FILE); // prepare stream train data CsvSourceStreamOp data = new CsvSourceStreamOp() .setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-ctr-train-8M.csv") .setSchemaStr(SCHEMA_STRING) .setIgnoreFirstLine(true); // split stream to train and eval data SplitStreamOp spliter = new SplitStreamOp().setFraction(0.5).linkFrom(data); StreamOperator train_stream_data = feature_pipelineModel.transform(spliter); StreamOperator test_stream_data = feature_pipelineModel.transform(spliter.getSideOutput(0)); // ftrl train FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel) .setVectorCol(VEC_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setWithIntercept(true) .setAlpha(0.1) .setBeta(0.1) .setL1(0.01) .setL2(0.01) .setTimeInterval(10) .setVectorSize(NUM_HASH_FEATURES) .linkFrom(train_stream_data); // ftrl predict FtrlPredictStreamOp predResult = new FtrlPredictStreamOp(initModel) .setVectorCol(VEC_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .setReservedCols(new String[] {LABEL_COL_NAME}) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) .linkFrom(model, test_stream_data); predResult .sample(0.0001) .select("'Pred Sample' AS out_type, *") .print(); // ftrl eval predResult .link( new EvalBinaryClassStreamOp() .setLabelCol(LABEL_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) .setTimeInterval(10) ) .link( new JsonValueStreamOp() .setSelectedCol("Data") .setReservedCols(new String[] {"Statistics"}) .setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"}) .setJsonPath(new String[] {"$.Accuracy", "$.AUC", "$.ConfusionMatrix"}) ) .select("'Eval Metric' AS out_type, *") .print(); StreamOperator.execute(); } static void c_6() throws Exception { // prepare stream train data CsvSourceStreamOp data = new CsvSourceStreamOp() .setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-ctr-train-8M.csv") .setSchemaStr(SCHEMA_STRING) .setIgnoreFirstLine(true); // load pipeline model PipelineModel feature_pipelineModel = PipelineModel.load(DATA_DIR + FEATURE_MODEL_FILE); // split stream to train and eval data SplitStreamOp spliter = new SplitStreamOp().setFraction(0.5).linkFrom(data); StreamOperator <?> train_stream_data = feature_pipelineModel.transform(spliter); StreamOperator <?> test_stream_data = feature_pipelineModel.transform(spliter.getSideOutput(0)); AkSourceBatchOp initModel = new AkSourceBatchOp().setFilePath(DATA_DIR + INIT_MODEL_FILE); // ftrl train FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel) .setVectorCol(VEC_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setWithIntercept(true) .setAlpha(0.1) .setBeta(0.1) .setL1(0.01) .setL2(0.01) .setTimeInterval(10) .setVectorSize(NUM_HASH_FEATURES) .linkFrom(train_stream_data); // model filter FtrlModelFilterStreamOp model_filter = new FtrlModelFilterStreamOp() .setPositiveLabelValueString("1") .setVectorCol(VEC_COL_NAME) .setLabelCol(LABEL_COL_NAME) .setAccuracyThreshold(0.83) .setAucThreshold(0.71) .linkFrom(model, train_stream_data); model_filter .select("'Model' AS out_type, *") .print(); // ftrl predict FtrlPredictStreamOp predResult = new FtrlPredictStreamOp(initModel) .setVectorCol(VEC_COL_NAME) .setPredictionCol(PREDICTION_COL_NAME) .setReservedCols(new String[] {LABEL_COL_NAME}) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) .linkFrom(model_filter, test_stream_data); predResult .sample(0.0001) .select("'Pred Sample' AS out_type, *") .print(); // ftrl eval predResult .link( new EvalBinaryClassStreamOp() .setPositiveLabelValueString("1") .setLabelCol(LABEL_COL_NAME) .setPredictionDetailCol(PRED_DETAIL_COL_NAME) .setTimeInterval(10) ) .link( new JsonValueStreamOp() .setSelectedCol("Data") .setReservedCols(new String[] {"Statistics"}) .setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"}) .setJsonPath(new String[] {"$.Accuracy", "$.AUC", "$.ConfusionMatrix"}) ) .select("'Eval Metric' AS out_type, *") .print(); StreamOperator.execute(); } }