Alink教程(Java版)
该文档涉及的组件

第30.3节 LocalPredictor使用线程池

这小节演示LocalPredicator如何通过多线程加速。构造一条测试数据,在本节的实验中,我们会对此数据进行重复计算,重复次数N为10000次,具体代码如下:

String recStr = "$784$129:57.0 130:201.0 131:229.0 132:31.0 157:100.0 158:252.0 159:252.0 160:55.0 185:100.0 "
	+ "186:252.0 187:252.0 188:55.0 212:6.0 213:209.0 214:252.0 215:247.0 216:50.0 240:138.0 241:252.0 "
	+ "242:252.0 243:173.0 267:65.0 268:236.0 269:252.0 270:235.0 271:19.0 295:244.0 296:252.0 297:252.0 "
	+ "298:77.0 322:20.0 323:253.0 324:252.0 325:192.0 326:4.0 350:111.0 351:253.0 352:252.0 353:120.0 "
	+ "377:34.0 378:220.0 379:253.0 380:223.0 381:25.0 405:93.0 406:253.0 407:255.0 408:125.0 432:41.0 "
	+ "433:204.0 434:252.0 435:230.0 436:23.0 460:154.0 461:252.0 462:252.0 463:177.0 487:127.0 488:248.0 "
	+ "489:252.0 490:243.0 491:5.0 514:20.0 515:236.0 516:252.0 517:235.0 518:64.0 541:20.0 542:193.0 "
	+ "543:252.0 544:252.0 545:89.0 569:56.0 570:252.0 571:252.0 572:252.0 573:70.0 597:123.0 598:252.0 "
	+ "599:252.0 600:245.0 601:97.0 625:165.0 626:252.0 627:252.0 628:127.0 653:70.0 654:252.0 655:146.0 "
	+ "656:13.0";

final int N = 10000;


构建localPredictor,具体代码如下,从文件(路径为DATA_DIR + PIPELINE_MODEL_FILE)中载入PipelineModel,并设置输入数据的Schema String参数,输入的数据只有1列(列名为VECTOR_COL_NAME, string类型),并打印输出LocalPredictor实例的输出Schema。

LocalPredictor localPredictor =
	new LocalPredictor(DATA_DIR + PIPELINE_MODEL_FILE, VECTOR_COL_NAME + " string");

System.out.println(localPredictor.getOutputSchema());


运行结果为如下,共有2个输出列,最后一列是预测的分类结果,后面可以localPredictor输出的数据行中,选择索引号为1的数据项,即为预测的分类结果。

root
 |-- vec: STRING
 |-- pred: INT


下面,先做一个基准的计算时间,看看单线程进行预测,多久可以进行N次预测,并把每次的预测结果累加起来,具体代码如下:

sw.reset();
sw.start();
sum = 0;
for (int i = 0; i < N; i++) {
	sum += (Integer) localPredictor.predict(recStr)[index_pred];
}
sw.stop();
System.out.println(sum);
System.out.println(sw.getElapsedTimeSpan());


运行结果如下,计算求和结果为10000,总体计算时间为102.325秒。

10000
1 minutes  42 seconds  325.0 milliseconds.


定义线程池如下,最大线程个数为4。localPredictor不能直接放入线程池,还需要对其进行封装,常用的方式有两个,后面会分别进行介绍演示。

ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
	4, 4, 0, TimeUnit.SECONDS,
	new ArrayBlockingQueue <Runnable>(50),
	new ThreadPoolExecutor.CallerRunsPolicy()
);


第一种封装方式,定义类MyRunnableTask实现Runnable 接口,具体代码如下,分为2部分:

  1. 将localPredictor及计算数据作为构造参数。
  2. 覆盖Runnable 接口中的run方法,使用localPredictor对数据进行预测
public static class MyRunnableTask implements Runnable {
	private LocalPredictor localPredictor;
	private String taskData;

	public MyRunnableTask(LocalPredictor localPredictor, String taskData) {
		this.localPredictor = localPredictor;
		this.taskData = taskData;
	}

	@Override
	public void run() {
		try {
			localPredictor.predict(taskData);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
}

在执行的时候,每次使用localPredictor及当前要计算的数据构建MyRunnableTask的对象,通过threadPoolExecutor的submit方法,让线程池为其绑定池中的线程来执行。

sw.reset();
sw.start();
for (int i = 0; i < N; i++) {
	threadPoolExecutor.submit(new MyRunnableTask(localPredictor, recStr));
}
sw.stop();
System.out.println(sw.getElapsedTimeSpan());

运行结果如下,总体运行时间为23.670秒,约为单线程的运行时间的1/4。

23 seconds  670.0 milliseconds.


再看第二种封装方式,定义类MyCallableTask实现Callable <Row>接口,具体代码如下,分为2部分:

  1. 将localPredictor及计算数据作为构造参数。
  2. 覆盖Runnable 接口中的call方法,使用localPredictor对数据进行预测,返回Row类型结果
public static class MyCallableTask implements Callable <Object[]> {
	private LocalPredictor localPredictor;
	private String taskData;

	public MyCallableTask(LocalPredictor localPredictor, String taskData) {
		this.localPredictor = localPredictor;
		this.taskData = taskData;
	}

	@Override
	public Object[] call() throws Exception {
		return localPredictor.predict(taskData);
	}
}


在执行的时候,每次使用localPredictor及当前要计算的数据构建MyCallableTask对象的列表,通过threadPoolExecutorinvokeAll方法,让线程池为其绑定池中的线程来执行列表中的各个MyCallableTask对象,返回List <Future <Row>>,最后,从返回结果中抽取出预测分类值进行求和。具体代码如下:

sw.reset();
sw.start();
sum = 0;
int K = 1000;
ArrayList <MyCallableTask> tasks = new ArrayList <>(K);
for (int i = 0; i < N / K; i++) {
	tasks.clear();
	for (int k = 0; k < K; k++) {
		tasks.add(new MyCallableTask(localPredictor, recStr));
	}
	List <Future <Object[]>> futures = threadPoolExecutor.invokeAll(tasks);
	for (Future <Object[]> future : futures) {
		sum += (Integer) future.get()[index_pred];
	}
}
System.out.println(sum);
sw.stop();
System.out.println(sw.getElapsedTimeSpan());

运行结果如下,计算求和结果为10000,与单线程实验的求和结果相同;总体运行时间为23.591秒,与第一种封装方式的执行时间大致相同,约为单线程的运行时间的1/4。

10000
23 seconds  591.0 milliseconds.


最后,关闭线程池。

threadPoolExecutor.shutdown();