Alink教程(Java版)

第29.5节 线性模型的增量训练


Alink的线性模型(LR、SVM、Softmax、线性回归、LASSO回归等)都提供了增量训练的功能,在使用方式上,相当于提供了初始模型的训练,如下面伪代码所示,逻辑回归训练组件LogisticRegressionTrainBatchOp不仅要接入训练数据train_set,还要接入初始模型init_model。linkFrom方法的第二个参数是可选的,如果没有初始模型,则与一般的逻辑回归训练用法是一样的。

AkSourceBatchOp train_set = new AkSourceBatchOp().setFilePath("...");
AkSourceBatchOp init_model = new AkSourceBatchOp().setFilePath("...");

......

new LogisticRegressionTrainBatchOp()
	.setFeatureCols(NUMERICAL_COL_NAMES)
	.setLabelCol(LABEL_COL_NAME)
	.linkFrom(train_set, init_model)

......


我们使用增量训练可以显著减少训练时间,增加模型的实效性。增量训练出来的模型可以用于生成模型流,从而使用模型流的机制,流式预测和LocalPredictor嵌入式预测场景动态更新模型。


下面的实验会进行10次逻辑回归增量训练,第一次是使用原始的初始模型(模型文件路径为:DATA_DIR + INIT_NUMERIC_LR_MODEL_FILE),训练结果添加到模型流文件夹;第二次会从模型流文件夹中找到最新生成的子文件夹,从中读取逻辑回归模型作为第二次训练的初始模型,训练结果添加到模型流文件夹;后面各次的流程与第二次类似。

String current_base_model_path = DATA_DIR + INIT_NUMERIC_LR_MODEL_FILE;

for (int i = 0; i < 10; i++) {

	if (i > 0) {
		long latest_time = -1L;
		for (File subdir : new File(DATA_DIR + FTRL_MODEL_STREAM_DIR).listFiles()) {
			if (!subdir.getName().equals("conf") && subdir.lastModified() > latest_time) {
				latest_time = subdir.lastModified();
				current_base_model_path = subdir.getCanonicalPath();
			}
		}
	}

	System.out.println(current_base_model_path);

	BatchOperator <?> train_set = new CsvSourceBatchOp()
		.setFilePath(DATA_DIR + "avazu-small.csv")
		.setSchemaStr(SCHEMA_STRING);

	AkSourceBatchOp init_model = new AkSourceBatchOp().setFilePath(current_base_model_path);

	new LogisticRegressionTrainBatchOp()
		.setFeatureCols(NUMERICAL_COL_NAMES)
		.setLabelCol(LABEL_COL_NAME)
		.setMaxIter(5)
		.linkFrom(train_set, init_model)
		.link(
			new AppendModelStreamFileSinkBatchOp()
				.setFilePath(DATA_DIR + FTRL_MODEL_STREAM_DIR)
				.setNumKeepModel(10)
		);
	BatchOperator.execute();

	Thread.sleep(2000);
}


运行输出如下,显示的是每次训练所用的初始模型路径,第一次用的是init_numeric_lr_model.ak,后面各次都用的是模型流文件夹中最新的模型路径。

/Users/yangxu/alink/data/ctr_avazu/init_numeric_lr_model.ak
/Users/yangxu/alink/data/ctr_avazu/ftrl_model_stream/20211215113020767
/Users/yangxu/alink/data/ctr_avazu/ftrl_model_stream/20211215113045333
/Users/yangxu/alink/data/ctr_avazu/ftrl_model_stream/20211215113101764
/Users/yangxu/alink/data/ctr_avazu/ftrl_model_stream/20211215113116845
/Users/yangxu/alink/data/ctr_avazu/ftrl_model_stream/20211215113131302
/Users/yangxu/alink/data/ctr_avazu/ftrl_model_stream/20211215113145603
/Users/yangxu/alink/data/ctr_avazu/ftrl_model_stream/20211215113159833
/Users/yangxu/alink/data/ctr_avazu/ftrl_model_stream/2021121511321398
/Users/yangxu/alink/data/ctr_avazu/ftrl_model_stream/20211215113228048

上面的代码为函数Chap29.c_5()的内容,同时运行Chap29.c_5()和第29.3.3节介绍的LocalPredictor使用模型流预测的程序Chap29.c_3_3(),即,一边产生新的逻辑回归模型,一边不断更新localPredictor的预测模型。Chap29.c_3_3()的运行的部分结果如下所示,因为每次预测的都是同一个数据,如果模型没有变化,结果也是相同的,我们从运行结果上,可以看出预测结果在发生变化,说明了localPredictor在根据模型流更新预测模型。

18	1,0,{"0":"0.8059634797835652","1":"0.1940365202164348"}
19	1,0,{"0":"0.8059634797835652","1":"0.1940365202164348"}
20	1,0,{"0":"0.8059634797835652","1":"0.1940365202164348"}
21	1,0,{"0":"0.8060413109308956","1":"0.1939586890691044"}
22	1,0,{"0":"0.8060413109308956","1":"0.1939586890691044"}
23	1,0,{"0":"0.8060413109308956","1":"0.1939586890691044"}
24	1,0,{"0":"0.8060413109308956","1":"0.1939586890691044"}
25	1,0,{"0":"0.8060413109308956","1":"0.1939586890691044"}
26	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
27	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
28	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
29	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
30	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
31	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
32	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
33	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
34	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
35	1,0,{"0":"0.8065659127557904","1":"0.19343408724420963"}
36	1,0,{"0":"0.8065599998161146","1":"0.1934400001838854"}
37	1,0,{"0":"0.8065599998161146","1":"0.1934400001838854"}
38	1,0,{"0":"0.8065599998161146","1":"0.1934400001838854"}
39	1,0,{"0":"0.8065599998161146","1":"0.1934400001838854"}
40	1,0,{"0":"0.8065599998161146","1":"0.1934400001838854"}
41	1,0,{"0":"0.8069020847473007","1":"0.19309791525269926"}
42	1,0,{"0":"0.8069020847473007","1":"0.19309791525269926"}
43	1,0,{"0":"0.8069020847473007","1":"0.19309791525269926"}
44	1,0,{"0":"0.8069020847473007","1":"0.19309791525269926"}