Alink教程(Java版)

第25.6节 使用自定义 TensorFlow 脚本

在 Alink 提供的深度学习功能中,可定制程度最高的是自定义脚本类组件。在这类组件中,用户只需要进行很少的代码修改,就可以让已有的 TensorFlow 脚本在 Alink 系统中执行。

这类组件基于的是从 Alink 进程拉起 Python 进程、执行 Python 代码的能力。通过这个能力,Alink 可以将数据传递给 Python 进程,在 Python 进程中执行自定义代码,然后将处理的结果返回给 Alink。具体到 TensorFlow 自定义脚本组件,传递的数据是以 TensorFlow 的 tf.train.Example 的格式,自定义代码主要指 TensorFlow 脚本代码。


25.6.1 选择组件

下表列出了自定义脚本类组件包含的具体组件及其区别:

TF 版本

传入数据

传入参数类型

输出数据

TensorFlowBatchOp

1.15.2

批数据

BatchTaskConfig

自定义

TensorFlow2BatchOp

2.3.1

批数据

BatchTaskConfig

自定义

TensorFlowStreamOp

1.15.2

流数据

StreamTaskConfig

自定义

TensorFlow2StreamOp

2.3.1

流数据

StreamTaskConfig

自定义

TFTableModelTrainBatchOp

1.15.2

批数据

TrainTaskConfig

要求将训练模型保存到指定目录,无其他输出

TF2TableModelTrainBatchOp

2.3.1

批数据

TrainTaskConfig

要求将训练模型保存到指定目录,无其他输出

在实践中,可以根据以下几条规则选择组件:

  1. 是否需要使用自定义脚本组件?
    如果已经有训练好的模型,仅仅需要使用 Alink 做模型推理,那么可以使用 TFSavedModelPredictBatchOp 或者 TFSavedModelPredictStreamOp,具体使用见之前的文档。
  2. 输入数据是批数据还是流数据?
    如果输入数据是流数据,那么接下来根据 TensorFlow 版本选择 TensorFlowStreamOp 或者 TensorFlow2StreamOp 即可。如果是批数据,那么接着看下面的问题。
    这里批/流数据是 Alink/Flink 中的概念,对应 XxxBatchOpXxxStreamOp。在 TensorFlow 中,批数据意味着数据可以反复遍历多次,而流数据只能按顺序遍历一次(不额外缓存时)。
  3. 执行训练任务还是推理任务?
    如果是推理任务,请先检查第一个问题,如果不能满足需求,那么根据 TensorFlow 版本选择 TensorFlowBatchOp 或者 TensorFlow2BatchOp 。如果是训练任务,那么接着看下面的问题。
  4. 是否有额外存储模型的需求?还是希望像其他传统模型一样、看作一张 Table 来进行保存、处理?
    对于前者,根据 TensorFlow 版本选择 TensorFlowBatchOp 或者 TensorFlow2BatchOp :在代码中需要自行使用 python 代码,将模型保存到文件系统中。
    对于后者,根据 TensorFlow 版本选择 TFTableModelTrainBatchOp 或者 TF2TableModelTrainBatchOp:在代码中,必须将训练得到的模型以 SavedModel 格式保存到指定的目录下;之后可以接 TFTableModelPredictBatchOpTFTableModelPredictStreamOp进行预测。这样做的好处是,在使用上和其他传统模型基本一致,并且可以借助 LocalPredictor 部署服务。


25.6.2 代码示例

下面是以 TensorFlowBatchOp 为例,展示了在 Alink 中自定义脚本类组件的使用形式。其他组件的使用方式大同小异,可以参考具体组件的文档。


BatchOperator<?> source = new RandomTableSourceBatchOp()
    .setNumRows(100L)
    .setNumCols(10);

String[] colNames = source.getColNames();

Map<String, Object> userParams = new HashMap<>();
userParams.put("featureCols", JsonConverter.toJson(colNames));
userParams.put("labelCol", "label");
userParams.put("batch_size", 16);
userParams.put("num_epochs", 1);

TensorFlowBatchOp tensorFlowBatchOp = new TensorFlowBatchOp()
    .setUserFiles(new String[]{"https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/tf_dnn_batch.py"})
    .setMainScriptFile("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/tf_dnn_batch.py")
    .setUserParams(JsonConverter.toJson(userParams))
    .setOutputSchemaStr("model_id long, model_info string")
    .setNumWorkers(1)
    .setNumPSs(0);

source
    .select("*, case when RAND() > 0.5 then 1. else 0. end as label")
    .link(tensorFlowBatchOp)
    .print();


💡 各个组件参数的使用说明,可以参考对应组件的文档。


其中链接 https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/tf_dnn_batch.py 对应的python代码如下:

from akdl.models.tf.dnn import dnn_batch
from akdl.runner.config import BatchTaskConfig


def main(task_config: BatchTaskConfig):
    dnn_batch.main(task_config)


25.6.3 代码编写

在使用组件时,用户提供一个或多个 Python 文件(setUserFiles), 其中一个为主文件(setMainScriptFile),作为自定义脚本的入口。

参数传递

在主文件中,必须包含一个名为 main 的函数,接受一个参数。参数的类型根据使用的组件不同可以为

BatchTaskConfigStreamTaskConfig 或者 TrainTaskConfig,具体见前文中的表格。

下面结合这三种 Config 的 源码 ,对这三种 Config 的字段进行说明:

  • 三者共有的字段:
    • tf_context: TFContext 类型,可以调用 flink_stream_dataset() 获取 一个TFRecordDataset,但这个数据集只能扫描一次;
    • num_workers:总的 worker 数;
    • clusterTF_CONFIG 中的 cluster 字段;
    • task_typeTF_CONFIG 中的 task.type 字段,取值有 'chief'、'worker' 或者 'ps'
    • task_indexTF_CONFIG 中的 task.index 字段;
    • work_dir:工作目录;
    • user_params:用户自定义参数,字典类型,对应为组件 setUserParams 的值。
  • BatchTaskConfig 有的字段:
    • dataset_file:将 tf_context.flink_stream_dataset() 得到的数据集写到本地文件中,从而可以读取多次;
    • dataset_length:数据条数;
    • output_writer:一个用于将数据写回 Alink 的工具,见下面说明。
  • StreamTaskConfig 有的字段:
    • dataset_fn:调用后返回的一个 DataSet
    • output_writer:一个用于将数据写回 Alink 的工具,见下面说明。
  • TrainTaskConfig 有的字段:
    • dataset_file:将 tf_context.flink_stream_dataset() 得到的数据集写到本地文件中,从而可以读取多次;
    • dataset_length:数据条数;
    • saved_model_dir:训练完成后,必须将模型以 SavedModel 的格式导出到这个目录下。


数据输入

💡 首先需要说明一下 TensorFlow 进程与输入数据集之间的关系。
当 Alink 作业本身的并发度大于 1 时,会有多个 Worker 同时执行任务,数据会根据任务的配置分布在各个 Worker 上,每个 Worker 只拥有完整数据的一个子集。在进入 TF 组件对应的任务时,各个 Worker 会启动一个 TF 进程,此时各个 Worker 会将其拥有的数据传递给 TF 进程。

每个 TF 进程只能访问到它所在 Worker 的数据,而访问不了其他 Worker 的数据。这一点与某些 TensorFlow 分布式训练的写法不同:在一些 TensorFlow 分布式训练的写法中,数据集中存储在某些共享文件系统(例如 HDFS)上,整体作为模型训练数据,各个 TF 进程通过 shard 的形式读取各自需要的数据。

从 Alink 进程传到 TensorFlow 进程的数据集为 TFRecordDataset 格式,每条数据是序列化后的 tf.train.Example 实例,可以通过 tf.parse_single_example 来进行解析。关于 TFRecordtf.train.Example 的教程可以参考 官方文档
其中,parse_single_examplefeatures 参数与原本数据集的列名和类型对应,由于 tf.train.Example 仅支持 tf.int64tf.float32tf.string,所以对数据类型会有一定的转换。


数据输出

从入口参数(BatchTaskConfig 或者 StreamTaskConfig)中获得output_writer,调用接口 write(self, example: tf.train.Example),向 Alink 进程写回数据。tf.train.Example 实例所含的 features需要与组件参数 OutputSchemaStr 中的列名和类型对应。


分布式训练

在脚本中可以获取环境变量 TF_CONFIG,从而可以写分布式训练的代码,包括 Estimator + PS 与 AllReduce 的模式。

另外需要注意,由于每个 Worker 只拥有部分数据,因此对数据集的处理可能会需要调整,例如不需要 shard 等。


关于 akdl 库

在 Alink 提供的 akdl 库 中,提供了一些便捷调用的函数,方便书写代码。具体写法可以以 alink_dl_predictors/predictor-tf/src/test/resources/tf_dnn_batch.py 为起始点,查看代码。

💡 需要注意的是:您完全可以不引入 akdl 库来实现自己的脚本。引入 akdl 库可能会导致某些代码执行报错,例如 TF2 动态图运行模式等等,这是因为 akdl 库内采用的是 TF1 或者 TF2 中 TF1 兼容模式的写法。(即使仅引入 akdl 包中的头文件,也可能导致运行不了一些纯 TF2 写法的代码。)


25.6.4 运行日志

为了便于查错,需要将一些日志开关打开。

Java 运行时,请设置 AlinkGlobalConfiguration.setPrintProcessInfo(true).

PyAlink 运行时,请设置 AlinkGlobalConfiguration.setPrintProcessInfo(True),并且如果使用 useLocalEnv 执行,还需要在参数中添加 config = {'debug_mode': True},例如:

from pyalink.alink import *
useLocalEnv(2, config={'debug_mode': True})

如果是本地运行,那么相关日志打印到控制台或者类似的地方。如果是集群运行,那么需要在集群的 TaskManager 的日志文件中找到相关日志。