Alink教程(Java版)

第25.4节 运行TensorFlow模型

在本章的第2、3节介绍了使用Alink提供的深度学习组件KerasSequentialClassifier和KerasSequentialRegressor进行分类和回归模型的训练、预测。

实际应用中,经常需要使用TensorFlow或着PyTorch训练好的模型,对流式数据、批式数据进行预测。Alink提供了相应的流式、批式和Pipeline组件适配TensorFlow或着PyTorch模型。

本节重点介绍与TensorFlow模型相关的操作。


25.4.1 生成TensorFlow模型

本节所需的TensorFlow模型压缩文件mnist_model_tf.zip,已经被放到了OSS上,本节后面的实验会直接从网络读取该模型。https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_model_tf.zip

如果读者有兴趣,可以在TensorFlow环境,运行下面代码便可生成TensorFlow模型,从而被Alink相关组件使用。注意:TensorFlow模型执行完save操作会被保存到一个文件夹,需要将其压缩为zip文件,便于Alink相关组件导入模型。建议的压缩示例代码在下面代码的最后部分。

import tensorflow as tf
from tensorflow import keras

mnist = keras.datasets.mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

test_images,train_images = test_images.reshape((10000,28,28,1)),train_images.reshape(60000,28,28,1)
test_images,train_images = test_images/255.0,train_images/255.0
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(20,(5,5),padding="SAME",activation="relu"),
    tf.keras.layers.MaxPool2D(2,2,padding="SAME"),
    tf.keras.layers.Conv2D(40,(5,5),padding="SAME",activation="relu"),
    tf.keras.layers.MaxPool2D(2, 2,padding="SAME"),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512,activation="relu"),
    tf.keras.layers.Dense(10,activation="softmax")
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images,train_labels,epochs=5)
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(test_loss)
print(test_acc)

dir_name = "mnist_model_tf"
model.save(dir_name)

import shutil
shutil.make_archive(base_name=dir_name, format='zip', root_dir=dir_name)


该段脚本的执行输出如下,测试集上预测精确率为98.75%.

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 5s 75us/sample - loss: 0.1095 - accuracy: 0.9660
Epoch 2/5
60000/60000 [==============================] - 4s 70us/sample - loss: 0.0376 - accuracy: 0.9883
Epoch 3/5
60000/60000 [==============================] - 4s 70us/sample - loss: 0.0255 - accuracy: 0.9917
Epoch 4/5
60000/60000 [==============================] - 4s 70us/sample - loss: 0.0176 - accuracy: 0.9942
Epoch 5/5
60000/60000 [==============================] - 4s 70us/sample - loss: 0.0140 - accuracy: 0.9951
10000/10000 [==============================] - 1s 59us/sample - loss: 0.0473 - accuracy: 0.9875
0.0473310407480522
0.9875


25.4.2 批式任务中使用TensorFlow模型


使用TFSavedModelPredictBatchOp组件,可以加载TF模型进行批式预测。关于该组件的详细说明参见Alink文档 https://www.yuque.com/pinshu/alink_doc/tfsavedmodelpredictbatchop .

由于TensorFlow模型训练前对每个数据都除以255,所以批式任务也要执行此操作,可以使用VectorFunctionBatchOp组件,设置函数名称(FuncName)为"Scale",系数为1.0 / 255.0。另外,使用TensorFlow模型前,还需要将输入数据列的类型转换为Tensor格式,可以使用VectorToTensorBatchOp组件。具体代码如下所示:

new AkSourceBatchOp()
	.setFilePath(Chap13.DATA_DIR + Chap13.DENSE_TEST_FILE)
	.link(
		new VectorFunctionBatchOp()
			.setSelectedCol("vec")
			.setFuncName("Scale")
			.setWithVariable(1.0 / 255.0)
	)
	.link(
		new VectorToTensorBatchOp()
			.setTensorDataType("float")
			.setTensorShape(1, 28, 28, 1)
			.setSelectedCol("vec")
			.setOutputCol("input_1")
			.setReservedCols("label")
	)
	.link(
		new TFSavedModelPredictBatchOp()
			.setModelPath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_model_tf.zip")
			.setSelectedCols("input_1")
			.setOutputSchemaStr("output_1 FLOAT_TENSOR")
	)
	.lazyPrint(3)
	.link(
		new UDFBatchOp()
			.setFunc(new GetMaxIndex())
			.setSelectedCols("output_1")
			.setOutputCol("pred")
	)
	.lazyPrint(3)
	.link(
		new EvalMultiClassBatchOp()
			.setLabelCol("label")
			.setPredictionCol("pred")
			.lazyPrintMetrics()
	);

BatchOperator.execute();

这里用到了一个自定义函数,具体定义如下:

public static class GetMaxIndex extends ScalarFunction {

	public int eval(FloatTensor tensor) {
		int k = 0;
		float max = tensor.getFloat(0, 0);
		for (int i = 1; i < 10; i++) {
			if (tensor.getFloat(0, i) > max) {
				k = i;
				max = tensor.getFloat(0, i);
			}
		}
		return k;
	}
}


批式任务的运行结果为:

label|input_1|output_1
-----|-------|--------
7|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                              
 |[[[[0.0]              |[[3.1598278E-13 6.958706E-10 1.0994857E-12 ... 1.0 1.0060469E-12 1.5447695E-9]]
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
4|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                              
 |[[[[0.0]              |[[3.6378616E-9 6.095424E-8 6.86549E-8 ... 4.792359E-10 2.9463915E-6 4.5094E-4]]
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
1|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                                 
 |[[[[0.0]              |[[2.4944006E-6 0.99974304 2.2457668E-6 ... 3.907643E-6 1.1800173E-5 3.2095505E-7]]
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
label|input_1|output_1|pred
-----|-------|--------|----
0|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                               |0
 |[[[[0.0]              |[[0.9999175 6.8047594E-9 3.209264E-8 ... 2.1794841E-8 4.711486E-6 3.0862586E-7]]|
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
9|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                                  |9
 |[[[[0.0]              |[[7.526831E-13 6.5608413E-12 2.2300215E-9 ... 5.055498E-11 1.2727068E-5 0.9999871]]|
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
6|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                                     |6
 |[[[[0.0]              |[[1.1784781E-9 9.737324E-12 7.8516065E-12 ... 9.064798E-16 2.4528541E-9 3.852846E-15]]|
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
-------------------------------- Metrics: --------------------------------
Accuracy:0.9917	Macro F1:0.9917	Micro F1:0.9917	Kappa:0.9908	
|Pred\Real|  9|  8|   7|...|   2|   1|  0|
|---------|---|---|----|---|----|----|---|
|        9|995|  2|   1|...|   0|   0|  0|
|        8|  4|965|   1|...|   1|   2|  0|
|        7|  2|  0|1019|...|   8|   1|  1|
|      ...|...|...| ...|...| ...| ...|...|
|        2|  0|  2|   3|...|1022|   1|  0|
|        1|  0|  0|   2|...|   1|1127|  0|
|        0|  0|  1|   1|...|   0|   0|976|

25.4.3 流式任务中使用TensorFlow模型


使用TFSavedModelPredictStreamOp组件,可以加载TF模型进行批式预测。关于该组件的详细说明参见Alink文档 https://www.yuque.com/pinshu/alink_doc/tfsavedmodelpredictstreamop .

由于TensorFlow模型训练前对每个数据都除以255,所以流式任务也要执行此操作,可以使用VectorFunctionStreamOp组件,设置函数名称(FuncName)为"Scale",系数为1.0 / 255.0。另外,使用TensorFlow模型前,还需要将输入数据列的类型转换为Tensor格式,可以使用VectorToTensorStreamOp组件。具体代码如下所示:

new AkSourceStreamOp()
	.setFilePath(Chap13.DATA_DIR + Chap13.DENSE_TEST_FILE)
	.link(
		new VectorFunctionStreamOp()
			.setSelectedCol("vec")
			.setFuncName("Scale")
			.setWithVariable(1.0 / 255.0)
	)
	.link(
		new VectorToTensorStreamOp()
			.setTensorDataType("float")
			.setTensorShape(1, 28, 28, 1)
			.setSelectedCol("vec")
			.setOutputCol("input_1")
			.setReservedCols("label")
	)
	.link(
		new TFSavedModelPredictStreamOp()
			.setModelPath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_model_tf.zip")
			.setSelectedCols("input_1")
			.setOutputSchemaStr("output_1 FLOAT_TENSOR")
	)
	.link(
		new UDFStreamOp()
			.setFunc(new GetMaxIndex())
			.setSelectedCols("output_1")
			.setOutputCol("pred")
	)
	.sample(0.001)
	.print();

StreamOperator.execute();

运行结果为:

label|input_1|output_1|pred
-----|-------|--------|----
5|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                                   |5
 |[[[[0.0]              |[[6.933754E-8 6.9330003E-6 6.1611705E-10 ... 3.8823796E-6 2.8930677E-5 3.047829E-6]]|
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
8|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                             |8
 |[[[[0.0]              |[[4.6705283E-13 1.194319E-11 6.325393E-11 ... 5.9661846E-12 1.0 1.941551E-10]]|
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
2|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                              |2
 |[[[[0.0]              |[[3.792658E-11 6.399531E-10 1.0 ... 1.9501381E-11 2.2754231E-12 3.9148443E-17]]|
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |

......

25.4.4 Pipeline中使用TensorFlow模型


学习了如何在批式任务和流式任务中使用TensorFlow模型,我们很容易在Pipeline中使用TensorFlow模型进行预测,只要将其中的批式/流式组件对应到Pipeline组件即可。具体代码如下:

new PipelineModel(
	new VectorFunction()
		.setSelectedCol("vec")
		.setFuncName("Scale")
		.setWithVariable(1.0 / 255.0),
	new VectorToTensor()
		.setTensorDataType("float")
		.setTensorShape(1, 28, 28, 1)
		.setSelectedCol("vec")
		.setOutputCol("input_1")
		.setReservedCols("label"),
	new TFSavedModelPredictor()
		.setModelPath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_model_tf.zip")
		.setSelectedCols("input_1")
		.setOutputSchemaStr("output_1 FLOAT_TENSOR")
).save(Chap13.DATA_DIR + PIPELINE_TF_MODEL, true);
BatchOperator.execute();

PipelineModel
	.load(Chap13.DATA_DIR + PIPELINE_TF_MODEL)
	.transform(
		new AkSourceStreamOp()
			.setFilePath(Chap13.DATA_DIR + Chap13.DENSE_TEST_FILE)
	)
	.link(
		new UDFStreamOp()
			.setFunc(new GetMaxIndex())
			.setSelectedCols("output_1")
			.setOutputCol("pred")
	)
	.sample(0.001)
	.print();
StreamOperator.execute();

运行结果为:

label|input_1|output_1|pred
-----|-------|--------|----
8|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                                |8
 |[[[[0.0]              |[[4.595701E-8 8.691159E-10 2.010363E-6 ... 3.4370315E-10 0.999998 1.46698165E-8]]|
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
6|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                                      |6
 |[[[[0.0]              |[[1.1165078E-9 1.0032316E-9 5.1055404E-9 ... 1.7537704E-14 1.8105054E-9 1.2814901E-12]]|
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |
4|FloatTensor(1,28,28,1)|FloatTensor(1,10)                                                                   |4
 |[[[[0.0]              |[[2.89732E-13 5.1105764E-9 3.7904546E-9 ... 9.956103E-10 5.4752927E-9 3.5678326E-7]]|
 |   [0.0]              |
 |   [0.0]              |
 |   ...                |
 |   [0.0]              |
 | ... ...              |

......

25.4.5 LocalPredictor中使用TensorFlow模型


除了通过Alink任务使用TensorFlow模型,也可以使用LocalPredictor进行嵌入式预测。示例代码如下,首先从数据集中抽取一行数据,输入数据的SchemaStr为“vec string, label int”;然后通过导入上一节保存的Pipeline模型,并设置输入数据的SchemaStr,得到LocalPredictor类型的实例localPredictor;如果不确定预测结果各列的含义,可以打印输出localPredictor的OutputSchema;使用localPredictor的map方法获得预测结果。

AkSourceBatchOp source = new AkSourceBatchOp()
	.setFilePath(Chap13.DATA_DIR + Chap13.DENSE_TEST_FILE);

System.out.println(source.getSchema());

Row row = source.firstN(1).collect().get(0);

LocalPredictor localPredictor
	= new LocalPredictor(Chap13.DATA_DIR + PIPELINE_TF_MODEL, "vec string, label int");

System.out.println(localPredictor.getOutputSchema());

Row r = localPredictor.map(row);
System.out.println(r.getField(0).toString() + " | " + r.getField(2).toString());

运行结果为:

root
 |-- vec: STRING
 |-- label: INT

root
 |-- label: INT
 |-- input_1: LEGACY(GenericType<com.alibaba.alink.common.linalg.tensor.FloatTensor>)
 |-- output_1: LEGACY(GenericType<com.alibaba.alink.common.linalg.tensor.FloatTensor>)

7 | [[3.1598278E-13 6.958706E-10 1.0994857E-12 ... 1.0 1.0060469E-12 1.5447695E-9]]