Java 类名:com.alibaba.alink.operator.batch.classification.RandomForestTrainBatchOp
Python 类名:RandomForestTrainBatchOp
随机森林一种经典的有监督学习非线性决策树模型,可以解决分类,回归和其他的一些决策树模型可以解决的问题,通常可以拿到比单决策树更好的效果。
通过 Bagging 的方法组合多棵决策树,生成最终的模型。
我们给定 Adult 数据集,在这个场景下介绍随机森林的使用步骤
训练数据集的基本统计结果为
Adult train
Summary:
| colName|count|missing| sum| mean| variance| min| max|
|————–|—–|——-|———-|———-|—————-|—–|——-|
| age|32560| 0| 1256214| 38.5815| 186.0665| 17| 90|
| workclass|32560| 1836| NaN| NaN| NaN| NaN| NaN|
| fnlwgt|32560| 0|6179243539|189780.207|11141029667.4508|12285|1484705|
| education|32560| 0| NaN| NaN| NaN| NaN| NaN|
| education_num|32560| 0| 328231| 10.0808| 6.6186| 1| 16|
|marital_status|32560| 0| NaN| NaN| NaN| NaN| NaN|
| occupation|32560| 1843| NaN| NaN| NaN| NaN| NaN|
| relationship|32560| 0| NaN| NaN| NaN| NaN| NaN|
| race|32560| 0| NaN| NaN| NaN| NaN| NaN|
| sex|32560| 0| NaN| NaN| NaN| NaN| NaN|
| capital_gain|32560| 0| 35089324| 1077.6819| 54544178.6998| 0| 99999|
| capital_loss|32560| 0| 2842700| 87.3065| 162381.6909| 0| 4356|
|hours_per_week|32560| 0| 1316644| 40.4375| 152.4637| 1| 99|
|native_country|32560| 583| NaN| NaN| NaN| NaN| NaN|
| label|32560| 0| NaN| NaN| NaN| NaN| NaN|
读取数据可以使用如下方法进行:
CsvSourceBatchOp trainData = new CsvSourceBatchOp() .setFilePath("https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv") .setIgnoreFirstLine(true) .setSchemaStr(schemaStr) .lazyPrintStatistics("Adult train"); ```` 上述代码中可以使用
lazyPrintStatistics(“Adult train”);
即可拿到数据的统计结果
##### 测试集
测试数据集的基本统计结果为
Adult test
Summary:
| colName|count|missing| sum| mean| variance| min| max|
|--------------|-----|-------|----------|-----------|----------------|-----|-------|
| age|16280| 0| 631146| 38.7682| 191.8033| 17| 90|
| workclass|16280| 963| NaN| NaN| NaN| NaN| NaN|
| fnlwgt|16280| 0|3083900756|189428.7934|11175556521.7039|13492|1490400|
| education|16280| 0| NaN| NaN| NaN| NaN| NaN|
| education_num|16280| 0| 163987| 10.0729| 6.5927| 1| 16|
|marital_status|16280| 0| NaN| NaN| NaN| NaN| NaN|
| occupation|16280| 966| NaN| NaN| NaN| NaN| NaN|
| relationship|16280| 0| NaN| NaN| NaN| NaN| NaN|
| race|16280| 0| NaN| NaN| NaN| NaN| NaN|
| sex|16280| 0| NaN| NaN| NaN| NaN| NaN|
| capital_gain|16280| 0| 17614497| 1081.9716| 57519546.0031| 0| 99999|
| capital_loss|16280| 0| 1431088| 87.9047| 162503.3785| 0| 3770|
|hours_per_week|16280| 0| 657586| 40.3923| 155.7433| 1| 99|
|native_country|16280| 274| NaN| NaN| NaN| NaN| NaN|
| label|16280| 0| NaN| NaN| NaN| NaN| NaN|
读取数据可以使用如下方法进行:
CsvSourceBatchOp testData = new CsvSourceBatchOp()
.setFilePath(“https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_test.csv")
.setIgnoreFirstLine(true)
.setSchemaStr(schemaStr)
.lazyPrintStatistics(”Adult test");
#### 训练
训练模型可以使用 RandomForestTrainBatchOp , 其中支持一些常用的决策树剪枝参数,可以通过调整这些参数来拿到一些更好的模型,详细可以参考参数说明部分。
String[] numericalFeatureColNames = new String[] {“age”, “fnlwgt”, “education_num”, “capital_gain”,
“capital_loss”, “hours_per_week”};
String[] categoryFeatureColNames = new String[] {“workclass”, “education”, “marital_status”, “occupation”,
“relationship”, “race”, “sex”, “native_country”};
RandomForestTrainBatchOp randomForestBatchOp = new RandomForestTrainBatchOp()
.setFeatureCols(ArrayUtils.addAll(numericalFeatureColNames, categoryFeatureColNames))
.setCategoricalCols(categoryFeatureColNames)
.setSubsamplingRatio(0.6)
.setMaxLeaves(32)
.setLabelCol(“label”);
#### 预测
RandomForestPredictBatchOp prediction = new RandomForestPredictBatchOp()
.setPredictionCol(“prediction”)
.setPredictionDetailCol(“prediction_detail”);
#### 评估
EvalBinaryClassBatchOp eval = new EvalBinaryClassBatchOp()
.setLabelCol(“prediction”)
.setPredictionDetailCol(“prediction_detail”);
#### 训练预测流程构建
prediction
.linkFrom(
randomForestBatchOp
.linkFrom(trainData)
.lazyPrintModelInfo(“Adult random forest model”)
.lazyCollectModelInfo(new Consumer () {
@Override
public void accept(RandomForestModelInfo randomForestModelInfo) {
try {
randomForestModelInfo
.saveTreeAsImage(“/tmp/rf_adult_model.png”, 0, true);
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
}),
testData
)
.link(eval)
.lazyPrintMetrics(“Adult random forest evaluation”);
#### 执行
BatchOperator.execute();
#### 运行结果
##### 模型信息
Adult random forest model
Classification trees modelInfo:
Number of trees: 10
Number of features: 14
Number of categorical features: 8
Labels: [<=50K, >50K]
Categorical feature info:
| feature|number of categorical value|
|--------------|---------------------------|
| workclass| 8|
| education| 16|
|marital_status| 7|
| ...| ...|
| race| 5|
| sex| 2|
|native_country| 41|
Table of feature importance Top 14:
| feature|importance|
|--------------|----------|
| age| 0.1997|
| fnlwgt| 0.1992|
| capital_gain| 0.1447|
|hours_per_week| 0.1091|
| education_num| 0.0889|
| occupation| 0.0553|
| relationship| 0.0423|
| capital_loss| 0.0336|
| workclass| 0.0306|
| sex| 0.0299|
| race| 0.0188|
|marital_status| 0.0176|
|native_country| 0.0158|
| education| 0.0144|
Classification trees modelInfo:
Number of trees: 10
Number of features: 14
Number of categorical features: 8
Labels: [<=50K, >50K]
Categorical feature info:
| feature|number of categorical value|
|--------------|---------------------------|
| workclass| 8|
| education| 16|
|marital_status| 7|
| ...| ...|
| race| 5|
| sex| 2|
|native_country| 41|
Table of feature importance Top 14:
| feature|importance|
|--------------|----------|
| fnlwgt| 0.2318|
| age| 0.2286|
|hours_per_week| 0.1382|
| education_num| 0.0706|
| occupation| 0.0645|
| capital_gain| 0.0568|
| workclass| 0.0516|
| sex| 0.033|
| relationship| 0.0299|
| capital_loss| 0.0222|
| education| 0.0218|
|native_country| 0.0199|
| race| 0.0175|
|marital_status| 0.0136|
模型信息中包含一些常用的训练输入数据的基本信息,特征的基本信息,模型的基本信息。
离散特征的一些统计信息,可以通过 Categorical feature info 部分查看。
特征重要性是一类更常用的筛选特征的指标,可以通过 Table of feature importance Top 14 部分查看。
##### 模型可视化
我们也输出了随进森林中第 0 号树的模型结果可视化结果,通过代码中 lazyCollectModelInfo 收集到模型信息之后,通过模型中提供的 saveTreeAsImage ,可以输出模型的图片结果到指定路径。
.lazyCollectModelInfo(new Consumer () {
@Override
public void accept(RandomForestModelInfo randomForestModelInfo) {
try {
randomForestModelInfo
.saveTreeAsImage(“/tmp/rf_adult_model.png”, 0, true);
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
})

##### 评估结果
Adult random forest evaluation
-------------------------------- Metrics: --------------------------------
Auc:1 Accuracy:0.9995 Precision:0.9965 Recall:1 F1:0.9982 LogLoss:0.2584
|Pred\Real|>50K|<=50K|
|---------|----|-----|
| >50K|2273| 8|
| <=50K| 0|13999|
评估结果中包含一些常用
### 文献或出处
1. [RandomForest](https://link.springer.com/content/pdf/10.1023/A:1010933404324.pdf)
2. [weka](https://www.cs.waikato.ac.nz/ml/weka/)
## 参数说明
| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 |
| --- | --- | --- | --- | --- | --- | --- |
| featureCols | 特征列名 | 特征列名,必选 | String[] | ✓ | 所选列类型为 [BOOLEAN, DATE, DOUBLE, FLOAT, INTEGER, LONG, SHORT, STRING, TIME, TIMESTAMP] | |
| labelCol | 标签列名 | 输入表中的标签列名 | String | ✓ | | |
| categoricalCols | 离散特征列名 | 离散特征列名 | String[] | | 所选列类型为 [BOOLEAN, DATE, DOUBLE, FLOAT, INTEGER, LONG, SHORT, STRING, TIME, TIMESTAMP] | |
| createTreeMode | 创建树的模式。 | series表示每个单机创建单颗树,parallel表示并行创建单颗树。 | String | | | "series" |
| featureSubsamplingRatio | 每棵树特征采样的比例 | 每棵树特征采样的比例,范围为(0, 1]。 | Double | | | 0.2 |
| maxBins | 连续特征进行分箱的最大个数 | 连续特征进行分箱的最大个数。 | Integer | | | 128 |
| maxDepth | 树的深度限制 | 树的深度限制 | Integer | | | 2147483647 |
| maxLeaves | 叶节点的最多个数 | 叶节点的最多个数 | Integer | | | 2147483647 |
| maxMemoryInMB | 树模型中用来加和统计量的最大内存使用数 | 树模型中用来加和统计量的最大内存使用数 | Integer | | | 64 |
| minInfoGain | 分裂的最小增益 | 分裂的最小增益 | Double | | | 0.0 |
| minSampleRatioPerChild | 子节点占父节点的最小样本比例 | 子节点占父节点的最小样本比例 | Double | | | 0.0 |
| minSamplesPerLeaf | 叶节点的最小样本个数 | 叶节点的最小样本个数 | Integer | | | 2 |
| numSubsetFeatures | 每棵树的特征采样数目 | 每棵树的特征采样数目 | Integer | | | 2147483647 |
| numTrees | 模型中树的棵数 | 模型中树的棵数 | Integer | | [1, +inf) | 10 |
| numTreesOfGini | 模型中Cart树的棵数 | 模型中Cart树的棵数 | Integer | | | null |
| numTreesOfInfoGain | 模型中Id3树的棵数 | 模型中Id3树的棵数 | Integer | | | null |
| numTreesOfInfoGainRatio | 模型中C4.5树的棵数 | 模型中C4.5树的棵数 | Integer | | | null |
| subsamplingRatio | 每棵树的样本采样比例或采样行数 | 每棵树的样本采样比例或采样行数,行数上限100w行 | Double | | | 100000.0 |
| weightCol | 权重列名 | 权重列对应的列名 | String | | 所选列类型为 [BIGDECIMAL, BIGINTEGER, BYTE, DOUBLE, FLOAT, INTEGER, LONG, SHORT] | null |
## 代码示例
### Python 代码
from pyalink.alink import *
import pandas as pd
useLocalEnv(1)
df = pd.DataFrame([
[1.0, “A”, 0, 0, 0],
[2.0, “B”, 1, 1, 0],
[3.0, “C”, 2, 2, 1],
[4.0, “D”, 3, 3, 1]
])
batchSource = BatchOperator.fromDataframe(
df, schemaStr=‘ f0 double, f1 string, f2 int, f3 int, label int’)
streamSource = StreamOperator.fromDataframe(
df, schemaStr=‘ f0 double, f1 string, f2 int, f3 int, label int’)
trainOp = RandomForestTrainBatchOp()\
.setLabelCol(‘label’)\
.setFeatureCols([‘f0’, ‘f1’, ‘f2’, ‘f3’])\
.linkFrom(batchSource)
predictBatchOp = RandomForestPredictBatchOp()\
.setPredictionDetailCol(‘pred_detail’)\
.setPredictionCol(‘pred’)
predictStreamOp = RandomForestPredictStreamOp(trainOp)\
.setPredictionDetailCol(‘pred_detail’)\
.setPredictionCol(‘pred’)
predictBatchOp.linkFrom(trainOp, batchSource).print()
predictStreamOp.linkFrom(streamSource).print()
StreamOperator.execute()
### Java 代码
import org.apache.flink.types.Row;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.RandomForestPredictBatchOp;
import com.alibaba.alink.operator.batch.classification.RandomForestTrainBatchOp;
import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.classification.RandomForestPredictStreamOp;
import com.alibaba.alink.operator.stream.source.MemSourceStreamOp;
import org.junit.Test;
import java.util.Arrays;
import java.util.List;
public class RandomForestTrainBatchOpTest {
@Test
public void testRandomForestTrainBatchOp() throws Exception {
List df = Arrays.asList(
Row.of(1.0, “A”, 0, 0, 0),
Row.of(2.0, “B”, 1, 1, 0),
Row.of(3.0, “C”, 2, 2, 1),
Row.of(4.0, “D”, 3, 3, 1)
);
BatchOperator <?> batchSource = new MemSourceBatchOp(
df, “ f0 double, f1 string, f2 int, f3 int, label int”);
StreamOperator <?> streamSource = new MemSourceStreamOp(
df, “ f0 double, f1 string, f2 int, f3 int, label int”);
BatchOperator <?> trainOp = new RandomForestTrainBatchOp()
.setLabelCol(“label”)
.setFeatureCols(“f0”, “f1”, “f2”, “f3”)
.linkFrom(batchSource);
BatchOperator <?> predictBatchOp = new RandomForestPredictBatchOp()
.setPredictionDetailCol(“pred_detail”)
.setPredictionCol(“pred”);
StreamOperator <?> predictStreamOp = new RandomForestPredictStreamOp(trainOp)
.setPredictionDetailCol(“pred_detail”)
.setPredictionCol(“pred”);
predictBatchOp.linkFrom(trainOp, batchSource).print();
predictStreamOp.linkFrom(streamSource).print();
StreamOperator.execute();
}
}
```
f0 | f1 | f2 | f3 | label | pred | pred_detail |
---|---|---|---|---|---|---|
1.0000 | A | 0 | 0 | 0 | 0 | {“0”:1.0,“1”:0.0} |
2.0000 | B | 1 | 1 | 0 | 0 | {“0”:1.0,“1”:0.0} |
3.0000 | C | 2 | 2 | 1 | 1 | {“0”:0.0,“1”:1.0} |
4.0000 | D | 3 | 3 | 1 | 1 | {“0”:0.0,“1”:1.0} |