这小节演示LocalPredicator如何通过多线程加速。构造一条测试数据,在本节的实验中,我们会对此数据进行重复计算,重复次数N为10000次,具体代码如下:
String recStr = "$784$129:57.0 130:201.0 131:229.0 132:31.0 157:100.0 158:252.0 159:252.0 160:55.0 185:100.0 " + "186:252.0 187:252.0 188:55.0 212:6.0 213:209.0 214:252.0 215:247.0 216:50.0 240:138.0 241:252.0 " + "242:252.0 243:173.0 267:65.0 268:236.0 269:252.0 270:235.0 271:19.0 295:244.0 296:252.0 297:252.0 " + "298:77.0 322:20.0 323:253.0 324:252.0 325:192.0 326:4.0 350:111.0 351:253.0 352:252.0 353:120.0 " + "377:34.0 378:220.0 379:253.0 380:223.0 381:25.0 405:93.0 406:253.0 407:255.0 408:125.0 432:41.0 " + "433:204.0 434:252.0 435:230.0 436:23.0 460:154.0 461:252.0 462:252.0 463:177.0 487:127.0 488:248.0 " + "489:252.0 490:243.0 491:5.0 514:20.0 515:236.0 516:252.0 517:235.0 518:64.0 541:20.0 542:193.0 " + "543:252.0 544:252.0 545:89.0 569:56.0 570:252.0 571:252.0 572:252.0 573:70.0 597:123.0 598:252.0 " + "599:252.0 600:245.0 601:97.0 625:165.0 626:252.0 627:252.0 628:127.0 653:70.0 654:252.0 655:146.0 " + "656:13.0"; final int N = 10000;
构建localPredictor,具体代码如下,从文件(路径为DATA_DIR + PIPELINE_MODEL_FILE)中载入PipelineModel,并设置输入数据的Schema String参数,输入的数据只有1列(列名为VECTOR_COL_NAME, string类型),并打印输出LocalPredictor实例的输出Schema。
LocalPredictor localPredictor = new LocalPredictor(DATA_DIR + PIPELINE_MODEL_FILE, VECTOR_COL_NAME + " string"); System.out.println(localPredictor.getOutputSchema());
运行结果为如下,共有2个输出列,最后一列是预测的分类结果,后面可以localPredictor输出的数据行中,选择索引号为1的数据项,即为预测的分类结果。
root |-- vec: STRING |-- pred: INT
下面,先做一个基准的计算时间,看看单线程进行预测,多久可以进行N次预测,并把每次的预测结果累加起来,具体代码如下:
sw.reset(); sw.start(); sum = 0; for (int i = 0; i < N; i++) { sum += (Integer) localPredictor.predict(recStr)[index_pred]; } sw.stop(); System.out.println(sum); System.out.println(sw.getElapsedTimeSpan());
运行结果如下,计算求和结果为10000,总体计算时间为102.325秒。
10000 1 minutes 42 seconds 325.0 milliseconds.
定义线程池如下,最大线程个数为4。localPredictor不能直接放入线程池,还需要对其进行封装,常用的方式有两个,后面会分别进行介绍演示。
ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor( 4, 4, 0, TimeUnit.SECONDS, new ArrayBlockingQueue <Runnable>(50), new ThreadPoolExecutor.CallerRunsPolicy() );
第一种封装方式,定义类MyRunnableTask实现Runnable 接口,具体代码如下,分为2部分:
public static class MyRunnableTask implements Runnable { private LocalPredictor localPredictor; private String taskData; public MyRunnableTask(LocalPredictor localPredictor, String taskData) { this.localPredictor = localPredictor; this.taskData = taskData; } @Override public void run() { try { localPredictor.predict(taskData); } catch (Exception e) { e.printStackTrace(); } } }
在执行的时候,每次使用localPredictor及当前要计算的数据构建MyRunnableTask的对象,通过threadPoolExecutor的submit方法,让线程池为其绑定池中的线程来执行。
sw.reset(); sw.start(); for (int i = 0; i < N; i++) { threadPoolExecutor.submit(new MyRunnableTask(localPredictor, recStr)); } sw.stop(); System.out.println(sw.getElapsedTimeSpan());
运行结果如下,总体运行时间为23.670秒,约为单线程的运行时间的1/4。
23 seconds 670.0 milliseconds.
再看第二种封装方式,定义类MyCallableTask实现Callable <Row>接口,具体代码如下,分为2部分:
public static class MyCallableTask implements Callable <Object[]> { private LocalPredictor localPredictor; private String taskData; public MyCallableTask(LocalPredictor localPredictor, String taskData) { this.localPredictor = localPredictor; this.taskData = taskData; } @Override public Object[] call() throws Exception { return localPredictor.predict(taskData); } }
在执行的时候,每次使用localPredictor及当前要计算的数据构建MyCallableTask对象的列表,通过threadPoolExecutor的invokeAll方法,让线程池为其绑定池中的线程来执行列表中的各个MyCallableTask对象,返回List <Future <Row>>,最后,从返回结果中抽取出预测分类值进行求和。具体代码如下:
sw.reset(); sw.start(); sum = 0; int K = 1000; ArrayList <MyCallableTask> tasks = new ArrayList <>(K); for (int i = 0; i < N / K; i++) { tasks.clear(); for (int k = 0; k < K; k++) { tasks.add(new MyCallableTask(localPredictor, recStr)); } List <Future <Object[]>> futures = threadPoolExecutor.invokeAll(tasks); for (Future <Object[]> future : futures) { sum += (Integer) future.get()[index_pred]; } } System.out.println(sum); sw.stop(); System.out.println(sw.getElapsedTimeSpan());
运行结果如下,计算求和结果为10000,与单线程实验的求和结果相同;总体运行时间为23.591秒,与第一种封装方式的执行时间大致相同,约为单线程的运行时间的1/4。
10000 23 seconds 591.0 milliseconds.
最后,关闭线程池。
threadPoolExecutor.shutdown();