本章包括下面各节:
23.1 使用提供的特征
23.1.1 使用朴素贝叶斯方法
23.1.2 使用逻辑回归算法
23.2 如何提取特征
23.3 构造更多特征
23.4 模型保存与预测
23.4.1 批式/流式预测任务
23.4.2 嵌入式预测
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink;
import org.apache.flink.types.Row;
import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.evaluation.EvalBinaryClassBatchOp;
import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.LibSvmSourceBatchOp;
import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
import com.alibaba.alink.operator.common.evaluation.TuningBinaryClassMetric;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.source.AkSourceStreamOp;
import com.alibaba.alink.pipeline.LocalPredictor;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.classification.NaiveBayesTextClassifier;
import com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler;
import com.alibaba.alink.pipeline.feature.Binarizer;
import com.alibaba.alink.pipeline.nlp.DocCountVectorizer;
import com.alibaba.alink.pipeline.nlp.DocHashCountVectorizer;
import com.alibaba.alink.pipeline.nlp.NGram;
import com.alibaba.alink.pipeline.nlp.RegexTokenizer;
import com.alibaba.alink.pipeline.tuning.BinaryClassificationTuningEvaluator;
import com.alibaba.alink.pipeline.tuning.GridSearchCV;
import com.alibaba.alink.pipeline.tuning.GridSearchCVModel;
import com.alibaba.alink.pipeline.tuning.ParamGrid;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
public class Chap23 {
static String DATA_DIR = Utils.ROOT_DIR + "sentiment_imdb" + File.separator;
static String ORIGIN_DATA_DIR = DATA_DIR + "aclImdb" + File.separator;
static final String TRAIN_FILE = "train.ak";
static final String TEST_FILE = "test.ak";
static String PIPELINE_MODEL = "pipeline_model.ak";
private static final String TXT_COL_NAME = "review";
private static final String LABEL_COL_NAME = "label";
private static final String VECTOR_COL_NAME = "vec";
private static final String PREDICTION_COL_NAME = "pred";
private static final String PRED_DETAIL_COL_NAME = "predinfo";
static String[] COL_NAMES = new String[] {LABEL_COL_NAME, TXT_COL_NAME};
public static void main(String[] args) throws Exception {
BatchOperator.setParallelism(1);
c_1();
c_2();
BatchOperator.setParallelism(4);
c_3();
c_4();
}
static void c_1() throws Exception {
BatchOperator <?> train_set = new LibSvmSourceBatchOp()
.setFilePath(ORIGIN_DATA_DIR + "train" + File.separator + "labeledBow.feat")
.setStartIndex(0);
train_set.lazyPrint(1, "train_set");
train_set
.groupBy("label", "label, COUNT(label) AS cnt")
.orderBy("label", 100)
.lazyPrint(-1, "labels of train_set");
BatchOperator <?> test_set = new LibSvmSourceBatchOp()
.setFilePath(ORIGIN_DATA_DIR + "test" + File.separator + "labeledBow.feat")
.setStartIndex(0);
train_set = train_set.select("CASE WHEN label>5 THEN 'pos' ELSE 'neg' END AS label, "
+ "features AS " + VECTOR_COL_NAME);
test_set = test_set.select("CASE WHEN label>5 THEN 'pos' ELSE 'neg' END AS label, "
+ "features AS " + VECTOR_COL_NAME);
train_set.lazyPrint(1, "train_set");
new NaiveBayesTextClassifier()
.setModelType("Multinomial")
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.enableLazyPrintModelInfo()
.fit(train_set)
.transform(test_set)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("pos")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("NaiveBayesTextClassifier + Multinomial")
);
BatchOperator.execute();
new Pipeline()
.add(
new Binarizer()
.setSelectedCol(VECTOR_COL_NAME)
.enableLazyPrintTransformData(1, "After Binarizer")
)
.add(
new NaiveBayesTextClassifier()
.setModelType("Bernoulli")
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.enableLazyPrintModelInfo()
)
.fit(train_set)
.transform(test_set)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("pos")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("Binarizer + NaiveBayesTextClassifier + Bernoulli")
);
BatchOperator.execute();
new LogisticRegression()
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.enableLazyPrintTrainInfo("< LR train info >")
.enableLazyPrintModelInfo("< LR model info >")
.fit(train_set)
.transform(test_set)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("pos")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("LogisticRegression")
);
BatchOperator.execute();
AlinkGlobalConfiguration.setPrintProcessInfo(true);
LogisticRegression lr = new LogisticRegression()
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME);
GridSearchCV gridSearch = new GridSearchCV()
.setEstimator(
new Pipeline().add(lr)
)
.setParamGrid(
new ParamGrid()
.addGrid(lr, LogisticRegression.MAX_ITER,
new Integer[] {10, 20, 30, 40, 50, 60, 80, 100})
)
.setTuningEvaluator(
new BinaryClassificationTuningEvaluator()
.setLabelCol(LABEL_COL_NAME)
.setPositiveLabelValueString("pos")
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.setTuningBinaryClassMetric(TuningBinaryClassMetric.AUC)
)
.setNumFolds(6)
.enableLazyPrintTrainInfo();
GridSearchCVModel bestModel = gridSearch.fit(train_set);
bestModel
.transform(test_set)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("pos")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("LogisticRegression")
);
BatchOperator.execute();
}
private static String readFileContent(File f) throws IOException {
BufferedReader reader = new BufferedReader(new FileReader(f));
StringBuilder sbd = new StringBuilder();
String t = null;
while ((t = reader.readLine()) != null) {
sbd.append(t);
}
reader.close();
return sbd.toString();
}
static void c_2() throws Exception {
if (!new File(DATA_DIR + TRAIN_FILE).exists()) {
ArrayList <Row> trainRows = new ArrayList <>();
ArrayList <Row> testRows = new ArrayList <>();
for (String label : new String[] {"pos", "neg"}) {
File subfolder = new File(ORIGIN_DATA_DIR + "train" + File.separator + label);
for (File f : subfolder.listFiles()) {
trainRows.add(Row.of(label, readFileContent(f)));
}
}
for (String label : new String[] {"pos", "neg"}) {
File subfolder = new File(ORIGIN_DATA_DIR + "test" + File.separator + label);
for (File f : subfolder.listFiles()) {
testRows.add(Row.of(label, readFileContent(f)));
}
}
new MemSourceBatchOp(trainRows, COL_NAMES)
.link(
new AkSinkBatchOp()
.setFilePath(DATA_DIR + TRAIN_FILE)
);
new MemSourceBatchOp(testRows, COL_NAMES)
.link(
new AkSinkBatchOp()
.setFilePath(DATA_DIR + TEST_FILE)
);
BatchOperator.execute();
}
AkSourceBatchOp train_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
AkSourceBatchOp test_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
train_set.lazyPrint(2);
new Pipeline()
.add(
new RegexTokenizer()
.setPattern("\\W+")
.setSelectedCol(TXT_COL_NAME)
)
.add(
new DocCountVectorizer()
.setFeatureType("WORD_COUNT")
.setSelectedCol(TXT_COL_NAME)
.setOutputCol(VECTOR_COL_NAME)
.enableLazyPrintTransformData(1)
)
.add(
new LogisticRegression()
.setMaxIter(30)
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
)
.fit(train_set)
.transform(test_set)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("pos")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("DocCountVectorizer")
);
BatchOperator.execute();
new Pipeline()
.add(
new RegexTokenizer()
.setPattern("\\W+")
.setSelectedCol(TXT_COL_NAME)
)
.add(
new DocHashCountVectorizer()
.setFeatureType("WORD_COUNT")
.setSelectedCol(TXT_COL_NAME)
.setOutputCol(VECTOR_COL_NAME)
.enableLazyPrintTransformData(1)
)
.add(
new LogisticRegression()
.setMaxIter(30)
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
)
.fit(train_set)
.transform(test_set)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("pos")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("DocHashCountVectorizer")
);
BatchOperator.execute();
}
static void c_3() throws Exception {
AkSourceBatchOp train_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
AkSourceBatchOp test_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
new Pipeline()
.add(
new RegexTokenizer()
.setPattern("\\W+")
.setSelectedCol(TXT_COL_NAME)
)
.add(
new DocCountVectorizer()
.setFeatureType("WORD_COUNT")
.setSelectedCol(TXT_COL_NAME)
.setOutputCol(VECTOR_COL_NAME)
)
.add(
new NGram()
.setN(2)
.setSelectedCol(TXT_COL_NAME)
.setOutputCol("v_2")
.enableLazyPrintTransformData(1, "2-gram")
)
.add(
new DocCountVectorizer()
.setFeatureType("WORD_COUNT")
.setSelectedCol("v_2")
.setOutputCol("v_2")
)
.add(
new VectorAssembler()
.setSelectedCols(VECTOR_COL_NAME, "v_2")
.setOutputCol(VECTOR_COL_NAME)
)
.add(
new LogisticRegression()
.setMaxIter(30)
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
)
.fit(train_set)
.transform(test_set)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("pos")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("NGram 2")
);
BatchOperator.execute();
new Pipeline()
.add(
new RegexTokenizer()
.setPattern("\\W+")
.setSelectedCol(TXT_COL_NAME)
)
.add(
new DocCountVectorizer()
.setFeatureType("WORD_COUNT")
.setSelectedCol(TXT_COL_NAME)
.setOutputCol(VECTOR_COL_NAME)
)
.add(
new NGram()
.setN(2)
.setSelectedCol(TXT_COL_NAME)
.setOutputCol("v_2")
)
.add(
new DocCountVectorizer()
.setFeatureType("WORD_COUNT")
.setSelectedCol("v_2")
.setOutputCol("v_2")
)
.add(
new NGram()
.setN(3)
.setSelectedCol(TXT_COL_NAME)
.setOutputCol("v_3")
)
.add(
new DocCountVectorizer()
.setFeatureType("WORD_COUNT")
.setVocabSize(10000)
.setSelectedCol("v_3")
.setOutputCol("v_3")
)
.add(
new VectorAssembler()
.setSelectedCols(VECTOR_COL_NAME, "v_2", "v_3")
.setOutputCol(VECTOR_COL_NAME)
)
.add(
new LogisticRegression()
.setMaxIter(30)
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
)
.fit(train_set)
.transform(test_set)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("pos")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("NGram 2 and 3")
);
BatchOperator.execute();
}
static void c_4() throws Exception {
AkSourceBatchOp train_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
if (!new File(DATA_DIR + PIPELINE_MODEL).exists()) {
new Pipeline()
.add(
new RegexTokenizer()
.setPattern("\\W+")
.setSelectedCol(TXT_COL_NAME)
)
.add(
new DocCountVectorizer()
.setFeatureType("WORD_COUNT")
.setSelectedCol(TXT_COL_NAME)
.setOutputCol(VECTOR_COL_NAME)
)
.add(
new NGram()
.setN(2)
.setSelectedCol(TXT_COL_NAME)
.setOutputCol("v_2")
)
.add(
new DocCountVectorizer()
.setFeatureType("WORD_COUNT")
.setVocabSize(50000)
.setSelectedCol("v_2")
.setOutputCol("v_2")
)
.add(
new NGram()
.setN(3)
.setSelectedCol(TXT_COL_NAME)
.setOutputCol("v_3")
)
.add(
new DocCountVectorizer()
.setFeatureType("WORD_COUNT")
.setVocabSize(10000)
.setSelectedCol("v_3")
.setOutputCol("v_3")
)
.add(
new VectorAssembler()
.setSelectedCols(VECTOR_COL_NAME, "v_2", "v_3")
.setOutputCol(VECTOR_COL_NAME)
)
.add(
new LogisticRegression()
.setMaxIter(30)
.setVectorCol(VECTOR_COL_NAME)
.setLabelCol(LABEL_COL_NAME)
.setPredictionCol(PREDICTION_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
)
.fit(train_set)
.save(DATA_DIR + PIPELINE_MODEL);
BatchOperator.execute();
}
PipelineModel pipeline_model = PipelineModel.load(DATA_DIR + PIPELINE_MODEL);
AkSourceBatchOp test_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
pipeline_model
.transform(test_set)
.link(
new EvalBinaryClassBatchOp()
.setPositiveLabelValueString("pos")
.setLabelCol(LABEL_COL_NAME)
.setPredictionDetailCol(PRED_DETAIL_COL_NAME)
.lazyPrintMetrics("NGram 2 and 3")
);
BatchOperator.execute();
AkSourceStreamOp test_stream = new AkSourceStreamOp().setFilePath(DATA_DIR + TEST_FILE);
pipeline_model
.transform(test_stream)
.sample(0.001)
.select(PREDICTION_COL_NAME + ", " + LABEL_COL_NAME + ", " + TXT_COL_NAME)
.print();
StreamOperator.execute();
String str
= "Oh dear. good cast, but to write and direct is an art and to write wit and direct wit is a bit of a "
+ "task. Even doing good comedy you have to get the timing and moment right. Im not putting it all down "
+ "there were parts where i laughed loud but that was at very few times. The main focus to me was on the "
+ "fast free flowing dialogue, that made some people in the film annoying. It may sound great while "
+ "reading the script in your head but getting that out and to the camera is a different task. And the "
+ "hand held camera work does give energy to few parts of the film. Overall direction was good but the "
+ "script was not all that to me, but I'm sure you was reading the script in your head it would sound good"
+ ". Sorry.";
Row pred_row;
LocalPredictor local_predictor = pipeline_model.collectLocalPredictor("review string");
System.out.println(local_predictor.getOutputSchema());
pred_row = local_predictor.map(Row.of(str));
System.out.println(pred_row.getField(4));
LocalPredictor local_predictor_2
= new LocalPredictor(DATA_DIR + PIPELINE_MODEL, "review string");
System.out.println(local_predictor_2.getOutputSchema());
pred_row = local_predictor_2.map(Row.of(str));
System.out.println(pred_row.getField(4));
}
}