Alink教程(Java版)

第28.3节 BERT文本分类器


BertTextClassifier为Alink BERT文本分类器的Pipeline组件为,其输入为原始的文本数据,不需要事先做分词、构造特征等操作。Pipeline的构建和训练代码如下:

        new Pipeline()
            .add(
                new Imputer()
                    .setSelectedCols("review")
                    .setStrategy("value")
                    .setFillValue("null")
            )
            .add(
                new BertTextClassifier()
                    .setTextCol("review")
                    .setLabelCol("label")
                    .setPredictionCol("pred")
                    .setPredictionDetailCol("pred_info")
                    .setBertModelName("Base-Chinese")
                    .setNumEpochs(1.0)
            )
            .fit(train_set)
            .save(DATA_DIR + "bert_pipeline_model.ak", true);

        BatchOperator.execute();

由于数据中有缺失值,在运行BERT之前,需要进行缺失值填充,填充为字符串值“null”。通过Pipeline的fit()方法,可以得到整个流程的模型(PipelineModel),并将模型保存到文件“bert_pipeline_model.ak”。

导入上面训练的PipelineModel模型后,可以调用transform方法对批式/流式数据进行预测,最后,使用 EvalBinaryClassBatchOp组件进行二分类模型评估。具体代码如下:

        PipelineModel.load(DATA_DIR + "bert_pipeline_model.ak")
            .transform(test_set)
            .lazyPrint(5)
            .link(
                new EvalBinaryClassBatchOp()
                    .setLabelCol("label")
                    .setPredictionDetailCol("pred_info")
                    .lazyPrintMetrics("BERT")
            );
        BatchOperator.execute();

运行结果为

label|review|pred|pred_info
-----|------|----|---------
0|1。房间很小,我住的是标准间A(2008年春节房价430元),根本不值,比上海市内的还要贵;2。空调的温度调节器很糟糕,搞懂该怎么调会费你许多时间,不信可以去试试,哈哈哈哈哈哈。。。不过房间还是很温暖,温度很高,建议开点窗户睡觉,以防被热醒;3。房间也不太干净,地毯脏兮兮的;4。送的早餐太简单,只能算充饥;5。下次不会再入住顺利大酒店,打算去住“莫泰崇明八一店”。|0|{"0":0.9687044806778431,"1":0.031295519322156906}
0|55555555,我住过最差的酒店之一,奉劝大家还是跑远一点吧。|0|{"0":0.9489788040518761,"1":0.05102119594812393}
0|6月16号就从网上订了房间,晚上到酒店时前台说房间了,也没找到我的订单.后来打电话给携程说已经酒店已经将订单回传过.酒店前台一直让我等待了五十分钟才给了房间.到房间后居然还有别人剩下的垃圾没有清理.而且环境还不好,外面就是马路,吵死了,反正这次弄得相当不愉快,相当差劲点|0|{"0":0.781480684876442,"1":0.21851931512355804}
0|8月4日入住了该酒店,没想到很多地方不如意,不得不缩短行程,提前返京。之前在网上查询,想住一个特色酒店,就选择此处的豪华标准间。但没想到:(1)酒店周围正在拆迁,非常脏和吵。出租车司机让我们三年以后看,说一定变样。分配给我们的房子临马路,面对一大施工场,极其嘈杂。不得以要求将其中的两间房调换。前台很不情愿给换了,告知换的房子会比这个小。当时为求安静,只好同意,但后来发现上大当了。(2)调换的两间房估计为普通标准间,屋子里墙壁和柜橱有斑驳的潮湿痕迹,其中一间屋子抽水马桶经常坏。房间很小,书桌和床之间仅容一人过。最可恶的是窗户没有纱窗,且朝过道,冒着进蚊子的威胁,开窗后发现无法通气,因过道窗户紧紧封闭,房间里一股霉味和油漆味散发不出去。更有甚者,因是围拢的房屋结构,早晚有人在过道说话很闹,且过道的灯笼整夜亮着。透过菲薄的窗帘,不得不忍受光污染。屋子的隔音效果很差,需要忍受隔壁看电视的声音。没有中央空调。(3)服务水平一般。在调换房间的第二天,讯问前台要求换到豪华标间,前台说因是自己要求调换的,不补差价,且说现房即为折后338元的价位。这可是在承德!前台服务很差,虽身着民族服装很漂亮,但态度冷冰冰。但对我之前的一老外可是鞠躬哈腰的。(4)早餐难以下咽。因周围是村庄,只能在酒店用早餐。四五个小咸菜;两种粥、豆包、包子以及奶粉冲得很稀的牛奶。即便这种早餐过了八点半还就没有了。开始打扫卫生,不管是否客人吃完。我非常非常后悔选择了这个酒店,性价比最差的。本想休息几天,但住宿实在太差,干脆一家人,勉强住了两夜就退房走人。|0|{"0":0.963587149977684,"1":0.03641285002231598}
0|三个字:脏、乱、差!房间里面看上去还可以,但是仔细看,很多地方极其脏,服务人员貌似没有培训上岗的,从前台到餐厅的服务人员都是这样,走廊上停着服务车,24小时没有换地方,四周环境乱得很,和照片完全不相符。饮食很差!|0|{"0":0.9859012337401509,"1":0.014098766259849072}
BERT
-------------------------------- Metrics: --------------------------------
Auc:0.9529	Accuracy:0.9009	Precision:0.9089	Recall:0.954	F1:0.9309	LogLoss:0.2638
|Pred\Real|  1|  0|
|---------|---|---|
|        1|519| 52|
|        0| 25|181|


使用LocalPredictor导入前面训练的PipelineModel,可以被集成到预测服务中。具体代码如下,首先是通过导入PipelineModel,构建LocalPredictor实例。

LocalPredictor localPredictor
    = new LocalPredictor(DATA_DIR + "bert_pipeline_model.ak", "review string");


使用 getOutputSchema 方法获得当前预测输出结果的 Schema,代码如下

System.out.println(localPredictor.getOutputSchema());


运行结果如下,第一列为review,预测的文本内容;第二列为pred,预测结果;第三列为pred_info,预测详情信息。

root
 |-- review: STRING
 |-- pred: BIGINT
 |-- pred_info: STRING


下面使用localPredictor进行预测,代码如下。构造了两条预测文本数据,使用localPredictor的map方法进行预测,输入数据的类型为Row,这里使用Row.of方法进行转换。

final int index_pred = TableUtil.findColIndex(localPredictor.getOutputSchema(), "pred");

String[] reviews = new String[] {
	"硬件不错,服务态度也不错,下次到附近的话还会选择住这里",
	"房间还比较干净,交通方便,离外滩很近.但外面声音太大,休息不好",
};

for (String review : reviews) {
	Object[] result = localPredictor.predict(review);
	System.out.println("Pred Result : " + result[index_pred] + " @ " + review);
}

注意:预测结果的类型为Row,即为一行数据,包含多个数据字段。可以通过TableUtil.findColIndex方法,确定预测结果“pred”列的位置索引值index_pred,从Row类型变量result中使用getField方法获取索引值为index_pred的数据字段,就是预测结果。

运行结果为

Pred Result : 1 @ 硬件不错,服务态度也不错,下次到附近的话还会选择住这里
Pred Result : 1 @ 房间还比较干净,交通方便,离外滩很近.但外面声音太大,休息不好