Java 类名
Python 类名:RandomForestPredictBatchOp
通过 Bagging 的方法组合多棵决策树,生成最终的模型。
名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 |
predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | ||
modelFilePath | 模型的文件路径 | 模型的文件路径 | String | null | ||
predictionDetailCol | 预测详细信息列名 | 预测详细信息列名 | String | |||
reservedCols | 算法保留列名 | 算法保留列 | String[] | null | ||
numThreads | 组件多线程线程个数 | 组件多线程线程个数 | Integer | 1 |
from pyalink.alink import * import pandas as pd useLocalEnv(1) df = pd.DataFrame([ [1.0, "A", 0, 0, 0], [2.0, "B", 1, 1, 0], [3.0, "C", 2, 2, 1], [4.0, "D", 3, 3, 1] ]) batchSource = BatchOperator.fromDataframe( df, schemaStr=' f0 double, f1 string, f2 int, f3 int, label int') streamSource = StreamOperator.fromDataframe( df, schemaStr=' f0 double, f1 string, f2 int, f3 int, label int') trainOp = RandomForestTrainBatchOp()\ .setLabelCol('label')\ .setFeatureCols(['f0', 'f1', 'f2', 'f3'])\ .linkFrom(batchSource) predictBatchOp = RandomForestPredictBatchOp()\ .setPredictionDetailCol('pred_detail')\ .setPredictionCol('pred') predictStreamOp = RandomForestPredictStreamOp(trainOp)\ .setPredictionDetailCol('pred_detail')\ .setPredictionCol('pred') predictBatchOp.linkFrom(trainOp, batchSource).print() predictStreamOp.linkFrom(streamSource).print() StreamOperator.execute()
import org.apache.flink.types.Row; import; import; import; import; import; import; import; import org.junit.Test; import java.util.Arrays; import java.util.List; public class RandomForestPredictBatchOpTest { @Test public void testRandomForestPredictBatchOp() throws Exception { List <Row> df = Arrays.asList( Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1), Row.of(4.0, "D", 3, 3, 1) ); BatchOperator <?> batchSource = new MemSourceBatchOp( df, " f0 double, f1 string, f2 int, f3 int, label int"); StreamOperator <?> streamSource = new MemSourceStreamOp( df, " f0 double, f1 string, f2 int, f3 int, label int"); BatchOperator <?> trainOp = new RandomForestTrainBatchOp() .setLabelCol("label") .setFeatureCols("f0", "f1", "f2", "f3") .linkFrom(batchSource); BatchOperator <?> predictBatchOp = new RandomForestPredictBatchOp() .setPredictionDetailCol("pred_detail") .setPredictionCol("pred"); StreamOperator <?> predictStreamOp = new RandomForestPredictStreamOp(trainOp) .setPredictionDetailCol("pred_detail") .setPredictionCol("pred"); predictBatchOp.linkFrom(trainOp, batchSource).print(); predictStreamOp.linkFrom(streamSource).print(); StreamOperator.execute(); } }
f0 | f1 | f2 | f3 | label | pred | pred_detail |
1.0000 | A | 0 | 0 | 0 | 0 | {“0”:1.0,“1”:0.0} |
2.0000 | B | 1 | 1 | 0 | 0 | {“0”:1.0,“1”:0.0} |
3.0000 | C | 2 | 2 | 1 | 1 | {“0”:0.0,“1”:1.0} |
4.0000 | D | 3 | 3 | 1 | 1 | {“0”:0.0,“1”:1.0} |
我们给定 Adult 数据集,在这个场景下介绍随机森林的使用步骤
Adult train
colName | count | missing | sum | mean | variance | min | max |
age | 32560 | 0 | 1256214 | 38.5815 | 186.0665 | 17 | 90 |
workclass | 32560 | 1836 | NaN | NaN | NaN | NaN | NaN |
fnlwgt | 32560 | 0 | 6179243539 | 189780.207 | 11141029667.4508 | 12285 | 1484705 |
education | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |
education_num | 32560 | 0 | 328231 | 10.0808 | 6.6186 | 1 | 16 |
marital_status | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |
occupation | 32560 | 1843 | NaN | NaN | NaN | NaN | NaN |
relationship | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |
race | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |
sex | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |
capital_gain | 32560 | 0 | 35089324 | 1077.6819 | 54544178.6998 | 0 | 99999 |
capital_loss | 32560 | 0 | 2842700 | 87.3065 | 162381.6909 | 0 | 4356 |
hours_per_week | 32560 | 0 | 1316644 | 40.4375 | 152.4637 | 1 | 99 |
native_country | 32560 | 583 | NaN | NaN | NaN | NaN | NaN |
label | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |
CsvSourceBatchOp trainData = new CsvSourceBatchOp() .setFilePath("") .setIgnoreFirstLine(true) .setSchemaStr(schemaStr) .lazyPrintStatistics("Adult train");
lazyPrintStatistics("Adult train");
Adult test
colName | count | missing | sum | mean | variance | min | max |
age | 16280 | 0 | 631146 | 38.7682 | 191.8033 | 17 | 90 |
workclass | 16280 | 963 | NaN | NaN | NaN | NaN | NaN |
fnlwgt | 16280 | 0 | 3083900756 | 189428.7934 | 11175556521.7039 | 13492 | 1490400 |
education | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |
education_num | 16280 | 0 | 163987 | 10.0729 | 6.5927 | 1 | 16 |
marital_status | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |
occupation | 16280 | 966 | NaN | NaN | NaN | NaN | NaN |
relationship | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |
race | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |
sex | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |
capital_gain | 16280 | 0 | 17614497 | 1081.9716 | 57519546.0031 | 0 | 99999 |
capital_loss | 16280 | 0 | 1431088 | 87.9047 | 162503.3785 | 0 | 3770 |
hours_per_week | 16280 | 0 | 657586 | 40.3923 | 155.7433 | 1 | 99 |
native_country | 16280 | 274 | NaN | NaN | NaN | NaN | NaN |
label | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |
CsvSourceBatchOp testData = new CsvSourceBatchOp() .setFilePath("") .setIgnoreFirstLine(true) .setSchemaStr(schemaStr) .lazyPrintStatistics("Adult test");
训练模型可以使用 RandomForestTrainBatchOp , 其中支持一些常用的决策树剪枝参数,可以通过调整这些参数来拿到一些更好的模型,详细可以参考参数说明部分。
String[] numericalFeatureColNames = new String[] {"age", "fnlwgt", "education_num", "capital_gain", "capital_loss", "hours_per_week"}; String[] categoryFeatureColNames = new String[] {"workclass", "education", "marital_status", "occupation", "relationship", "race", "sex", "native_country"}; RandomForestTrainBatchOp randomForestBatchOp = new RandomForestTrainBatchOp() .setFeatureCols(ArrayUtils.addAll(numericalFeatureColNames, categoryFeatureColNames)) .setCategoricalCols(categoryFeatureColNames) .setSubsamplingRatio(0.6) .setMaxLeaves(32) .setLabelCol("label");
RandomForestPredictBatchOp prediction = new RandomForestPredictBatchOp() .setPredictionCol("prediction") .setPredictionDetailCol("prediction_detail");
EvalBinaryClassBatchOp eval = new EvalBinaryClassBatchOp() .setLabelCol("prediction") .setPredictionDetailCol("prediction_detail");
prediction .linkFrom( randomForestBatchOp .linkFrom(trainData) .lazyPrintModelInfo("Adult random forest model") .lazyCollectModelInfo(new Consumer <RandomForestModelInfo>() { @Override public void accept(RandomForestModelInfo randomForestModelInfo) { try { randomForestModelInfo .saveTreeAsImage("/tmp/rf_adult_model.png", 0, true); } catch (IOException e) { throw new IllegalStateException(e); } } }), testData ) .link(eval) .lazyPrintMetrics("Adult random forest evaluation");
Adult random forest model
Classification trees modelInfo:
Number of trees: 10
Number of features: 14
Number of categorical features: 8
Labels: [<=50K, >50K]
Categorical feature info:
feature | number of categorical value |
workclass | 8 |
education | 16 |
marital_status | 7 |
… | … |
race | 5 |
sex | 2 |
native_country | 41 |
Table of feature importance Top 14:
feature | importance |
age | 0.1997 |
fnlwgt | 0.1992 |
capital_gain | 0.1447 |
hours_per_week | 0.1091 |
education_num | 0.0889 |
occupation | 0.0553 |
relationship | 0.0423 |
capital_loss | 0.0336 |
workclass | 0.0306 |
sex | 0.0299 |
race | 0.0188 |
marital_status | 0.0176 |
native_country | 0.0158 |
education | 0.0144 |
Classification trees modelInfo:
Number of trees: 10
Number of features: 14
Number of categorical features: 8
Labels: [<=50K, >50K]
Categorical feature info:
feature | number of categorical value |
workclass | 8 |
education | 16 |
marital_status | 7 |
… | … |
race | 5 |
sex | 2 |
native_country | 41 |
Table of feature importance Top 14:
feature | importance |
fnlwgt | 0.2318 |
age | 0.2286 |
hours_per_week | 0.1382 |
education_num | 0.0706 |
occupation | 0.0645 |
capital_gain | 0.0568 |
workclass | 0.0516 |
sex | 0.033 |
relationship | 0.0299 |
capital_loss | 0.0222 |
education | 0.0218 |
native_country | 0.0199 |
race | 0.0175 |
marital_status | 0.0136 |
离散特征的一些统计信息,可以通过 Categorical feature info 部分查看。
特征重要性是一类更常用的筛选特征的指标,可以通过 Table of feature importance Top 14 部分查看。
我们也输出了随进森林中第 0 号树的模型结果可视化结果,通过代码中 lazyCollectModelInfo 收集到模型信息之后,通过模型中提供的 saveTreeAsImage ,可以输出模型的图片结果到指定路径。
.lazyCollectModelInfo(new Consumer <RandomForestModelInfo>() {
public void accept(RandomForestModelInfo randomForestModelInfo) {
try {
.saveTreeAsImage("/tmp/rf_adult_model.png", 0, true);
} catch (IOException e) {
throw new IllegalStateException(e);
Adult random forest evaluation
——————————– Metrics: ——————————–
Auc:1 Accuracy:0.9995 Precision:0.9965 Recall:1 F1:0.9982 LogLoss:0.2584
Pred\Real | >50K | <=50K |
>50K | 2273 | 8 |
<=50K | 0 | 13999 |