Alink教程(Java版)

第29.6节 模型流的过滤

在实际应用中,并不是每个产生出来的模型都会进入预测组件,用于实际的预测服务。中间需要一个验证过程,通常是使用一个验证数据,评估模型的指标,当指标满足上线要求时,才会进入后续的部署上线预测流程。

从模型流的角度来看,在线学习组件产生原始的模型流,通过一个评估过滤组件,产生待部署上线的模型流,提供给预测组件。


29.6.1 流式组件过滤模型流


使用模型流过滤组件FtrlModelFilterStreamOp,可以对线性二分类模型(FTRL,LR,SVM等)构成的模型流进行过滤,

AlinkGlobalConfiguration.setPrintProcessInfo(true);

StreamOperator <?> source_model_stream = new ModelStreamFileSourceStreamOp()
	.setFilePath(DATA_DIR + FTRL_MODEL_STREAM_DIR)
	.setStartTime("2021-01-01 00:00:00");

CsvSourceStreamOp val_stream_data = new CsvSourceStreamOp()
	.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-ctr-train-8M.csv")
	.setSchemaStr(SCHEMA_STRING)
	.setIgnoreFirstLine(true);

FtrlModelFilterStreamOp model_filter = new FtrlModelFilterStreamOp()
	.setPositiveLabelValueString("1")
	.setLabelCol(LABEL_COL_NAME)
	.setAccuracyThreshold(0.8)
	.setAucThreshold(0.6);

model_filter.linkFrom(source_model_stream, val_stream_data);

model_filter
	.link(
		new ModelStreamFileSinkStreamOp()
			.setFilePath(DATA_DIR + FILTERED_MODEL_STREAM_DIR)
			.setNumKeepModel(10)
	);

StreamOperator.execute();


由于代码中设置了打印中间信息,即,下面这行代码:

AlinkGlobalConfiguration.setPrintProcessInfo(true);


我们可以看到如下的输出内容,其中模型2(auc : 0.6068055555555555 accuracy : 0.8520710059171598)、模型4(auc : 0.60706340378198 accuracy : 0.803921568627451)、模型5(auc : 0.6859070464767616 accuracy : 0.8630952380952381)达到了设定的阈值。

load model : 0
load model : 1
auc : 0.5295836616758605     accuracy : 0.8005952380952381
load model : 2
auc : 0.6068055555555555     accuracy : 0.8520710059171598
load model : 3
auc : 0.5661646046261433     accuracy : 0.8402366863905325
load model : 4
auc : 0.60706340378198     accuracy : 0.803921568627451
load model : 5
auc : 0.6859070464767616     accuracy : 0.8630952380952381
load model : 6
auc : 0.5155737704918034     accuracy : 0.8591549295774648
load model : 7
auc : 0.5860325974829791     accuracy : 0.7559523809523809
load model : 8
auc : 0.716765285996055     accuracy : 0.7692307692307693
load model : 9
auc : 0.48494525547445255     accuracy : 0.8165680473372781


再看结果模型流所在的文件夹(路径:DATA_DIR + FILTERED_MODEL_STREAM_DIR),如下图所示,有三个模型子文件夹,正是过滤出来的三个模型。注意:每个子文件夹的名称都是以模型写入模型流的时间戳命名的。



29.6.2 批式评估


List <File> model_dirs = Arrays.asList(
	new File(DATA_DIR + FTRL_MODEL_STREAM_DIR)
		.listFiles(
			new FilenameFilter() {
				@Override
				public boolean accept(File dir, String name) {
					return name.length() > 10;
				}
			}
		)
);

Collections.sort(model_dirs,
	new Comparator <File>() {
		@Override
		public int compare(File o1, File o2) {
			return o1.getName().compareTo(o2.getName());
		}
	}
);

for (File model_dir : model_dirs) {

	CsvSourceBatchOp validation_set = new CsvSourceBatchOp()
		.setFilePath(DATA_DIR + "avazu-small.csv")
		.setSchemaStr(SCHEMA_STRING);

	AkSourceBatchOp model = new AkSourceBatchOp().setFilePath(model_dir.getCanonicalPath());

	BinaryClassMetrics metrics =
		new LogisticRegressionPredictBatchOp()
			.setPredictionCol(PREDICTION_COL_NAME)
			.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
			.linkFrom(model, validation_set)
			.link(
				new EvalBinaryClassBatchOp()
					.setPositiveLabelValueString("1")
					.setLabelCol(LABEL_COL_NAME)
					.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
			)
			.collectMetrics();

	System.out.println(model_dir.getName());
	System.out.println("auc : " + metrics.getAuc() + ",\t accuracy : " + metrics.getAccuracy());

	if (metrics.getAuc() > 0.6 && metrics.getAccuracy() > 0.8) {
		model.link(
			new AppendModelStreamFileSinkBatchOp()
				.setFilePath(DATA_DIR + FILTERED_MODEL_STREAM_DIR)
				.setNumKeepModel(10)
		);
		BatchOperator.execute();
	}

}


下面是运行结果,分别记录了每一个评估模型所在的文件夹名称,以及该模型的两个评估指标:auc和accuracy。显然这10个模型都达到了代码中设置的阈值(auc的阈值为0.6,accuracy的阈值为0.8),这10个模型都会被追加到结果模型流。

20211202204259756
auc : 0.655222746031442,	 accuracy : 0.833274583186458
20211202204321978
auc : 0.6513136624275457,	 accuracy : 0.833274583186458
20211202204336074
auc : 0.6549124168247131,	 accuracy : 0.833274583186458
20211202204349758
auc : 0.651443537132315,	 accuracy : 0.8323495808739522
20211202204403225
auc : 0.6540554208075173,	 accuracy : 0.8323495808739522
20211202204416585
auc : 0.6516985004191098,	 accuracy : 0.8332645831614579
20211202204429884
auc : 0.6548096591610176,	 accuracy : 0.833274583186458
20211202204443127
auc : 0.651967765251886,	 accuracy : 0.8323495808739522
2021120220445639
auc : 0.6556805933418423,	 accuracy : 0.8323495808739522
20211202204509697
auc : 0.6558697976120662,	 accuracy : 0.833274583186458


查看模型流所在的文件夹(路径:DATA_DIR + FILTERED_MODEL_STREAM_DIR),如下图所示,模型所在的子文件夹已经被更新,注意:每个子文件夹的名称都是以模型写入模型流的时间戳命名的。