随机森林训练 (RandomForestTrainBatchOp)

Java 类名:com.alibaba.alink.operator.batch.classification.RandomForestTrainBatchOp

Python 类名:RandomForestTrainBatchOp

功能介绍

随机森林一种经典的有监督学习非线性决策树模型,可以解决分类,回归和其他的一些决策树模型可以解决的问题,通常可以拿到比单决策树更好的效果。

算法原理

通过 Bagging 的方法组合多棵决策树,生成最终的模型。

参数说明

名称 中文名称 描述 类型 是否必须? 取值范围 默认值
labelCol 标签列名 输入表中的标签列名 String
categoricalCols 离散特征列名 离散特征列名 String[] 所选列类型为 [BOOLEAN, DATE, DOUBLE, FLOAT, INTEGER, LONG, SHORT, STRING, TIME, TIMESTAMP]
createTreeMode 创建树的模式。 series表示每个单机创建单颗树,parallel表示并行创建单颗树。 String “series”
featureCols 特征列名数组 特征列名数组,默认全选 String[] 所选列类型为 [BOOLEAN, DATE, DOUBLE, FLOAT, INTEGER, LONG, SHORT, STRING, TIME, TIMESTAMP] null
featureSubsamplingRatio 每棵树特征采样的比例 每棵树特征采样的比例,范围为(0, 1]。 Double 0.2
maxBins 连续特征进行分箱的最大个数 连续特征进行分箱的最大个数。 Integer 128
maxDepth 树的深度限制 树的深度限制 Integer 2147483647
maxLeaves 叶节点的最多个数 叶节点的最多个数 Integer 2147483647
maxMemoryInMB 树模型中用来加和统计量的最大内存使用数 树模型中用来加和统计量的最大内存使用数 Integer 64
minInfoGain 分裂的最小增益 分裂的最小增益 Double 0.0
minSampleRatioPerChild 子节点占父节点的最小样本比例 子节点占父节点的最小样本比例 Double 0.0
minSamplesPerLeaf 叶节点的最小样本个数 叶节点的最小样本个数 Integer 2
numSubsetFeatures 每棵树的特征采样数目 每棵树的特征采样数目 Integer 2147483647
numTrees 模型中树的棵数 模型中树的棵数 Integer x >= 1 10
numTreesOfGini 模型中Cart树的棵数 模型中Cart树的棵数 Integer null
numTreesOfInfoGain 模型中Id3树的棵数 模型中Id3树的棵数 Integer null
numTreesOfInfoGainRatio 模型中C4.5树的棵数 模型中C4.5树的棵数 Integer null
subsamplingRatio 每棵树的样本采样比例或采样行数 每棵树的样本采样比例或采样行数,行数上限100w行 Double 100000.0
weightCol 权重列名 权重列对应的列名 String 所选列类型为 [BIGDECIMAL, BIGINTEGER, BYTE, DOUBLE, FLOAT, INTEGER, LONG, SHORT] null

代码示例

Python 代码

from pyalink.alink import *

import pandas as pd

useLocalEnv(1)

df = pd.DataFrame([
    [1.0, "A", 0, 0, 0],
    [2.0, "B", 1, 1, 0],
    [3.0, "C", 2, 2, 1],
    [4.0, "D", 3, 3, 1]
])
batchSource = BatchOperator.fromDataframe(
    df, schemaStr=' f0 double, f1 string, f2 int, f3 int, label int')
streamSource = StreamOperator.fromDataframe(
    df, schemaStr=' f0 double, f1 string, f2 int, f3 int, label int')

trainOp = RandomForestTrainBatchOp()\
    .setLabelCol('label')\
    .setFeatureCols(['f0', 'f1', 'f2', 'f3'])\
    .linkFrom(batchSource)
predictBatchOp = RandomForestPredictBatchOp()\
    .setPredictionDetailCol('pred_detail')\
    .setPredictionCol('pred')
predictStreamOp = RandomForestPredictStreamOp(trainOp)\
    .setPredictionDetailCol('pred_detail')\
    .setPredictionCol('pred')

predictBatchOp.linkFrom(trainOp, batchSource).print()
predictStreamOp.linkFrom(streamSource).print()

StreamOperator.execute()

Java 代码

import org.apache.flink.types.Row;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.RandomForestPredictBatchOp;
import com.alibaba.alink.operator.batch.classification.RandomForestTrainBatchOp;
import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.classification.RandomForestPredictStreamOp;
import com.alibaba.alink.operator.stream.source.MemSourceStreamOp;
import org.junit.Test;

import java.util.Arrays;
import java.util.List;

public class RandomForestTrainBatchOpTest {
	@Test
	public void testRandomForestTrainBatchOp() throws Exception {
		List <Row> df = Arrays.asList(
			Row.of(1.0, "A", 0, 0, 0),
			Row.of(2.0, "B", 1, 1, 0),
			Row.of(3.0, "C", 2, 2, 1),
			Row.of(4.0, "D", 3, 3, 1)
		);
		BatchOperator <?> batchSource = new MemSourceBatchOp(
			df, " f0 double, f1 string, f2 int, f3 int, label int");
		StreamOperator <?> streamSource = new MemSourceStreamOp(
			df, " f0 double, f1 string, f2 int, f3 int, label int");
		BatchOperator <?> trainOp = new RandomForestTrainBatchOp()
			.setLabelCol("label")
			.setFeatureCols("f0", "f1", "f2", "f3")
			.linkFrom(batchSource);
		BatchOperator <?> predictBatchOp = new RandomForestPredictBatchOp()
			.setPredictionDetailCol("pred_detail")
			.setPredictionCol("pred");
		StreamOperator <?> predictStreamOp = new RandomForestPredictStreamOp(trainOp)
			.setPredictionDetailCol("pred_detail")
			.setPredictionCol("pred");
		predictBatchOp.linkFrom(trainOp, batchSource).print();
		predictStreamOp.linkFrom(streamSource).print();
		StreamOperator.execute();
	}
}

运行结果

f0 f1 f2 f3 label pred pred_detail
1.0000 A 0 0 0 0 {“0”:1.0,“1”:0.0}
2.0000 B 1 1 0 0 {“0”:1.0,“1”:0.0}
3.0000 C 2 2 1 1 {“0”:0.0,“1”:1.0}
4.0000 D 3 3 1 1 {“0”:0.0,“1”:1.0}

算法使用

我们给定 Adult 数据集,在这个场景下介绍随机森林的使用步骤

数据集

Adult

训练集

训练数据集的基本统计结果为

Adult train
Summary:

colName count missing sum mean variance min max
age 32560 0 1256214 38.5815 186.0665 17 90
workclass 32560 1836 NaN NaN NaN NaN NaN
fnlwgt 32560 0 6179243539 189780.207 11141029667.4508 12285 1484705
education 32560 0 NaN NaN NaN NaN NaN
education_num 32560 0 328231 10.0808 6.6186 1 16
marital_status 32560 0 NaN NaN NaN NaN NaN
occupation 32560 1843 NaN NaN NaN NaN NaN
relationship 32560 0 NaN NaN NaN NaN NaN
race 32560 0 NaN NaN NaN NaN NaN
sex 32560 0 NaN NaN NaN NaN NaN
capital_gain 32560 0 35089324 1077.6819 54544178.6998 0 99999
capital_loss 32560 0 2842700 87.3065 162381.6909 0 4356
hours_per_week 32560 0 1316644 40.4375 152.4637 1 99
native_country 32560 583 NaN NaN NaN NaN NaN
label 32560 0 NaN NaN NaN NaN NaN

读取数据可以使用如下方法进行:

CsvSourceBatchOp trainData = new CsvSourceBatchOp()
	.setFilePath("https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv")
	.setIgnoreFirstLine(true)
	.setSchemaStr(schemaStr)
	.lazyPrintStatistics("Adult train");

上述代码中可以使用

lazyPrintStatistics("Adult train");

即可拿到数据的统计结果

测试集

测试数据集的基本统计结果为

Adult test
Summary:

colName count missing sum mean variance min max
age 16280 0 631146 38.7682 191.8033 17 90
workclass 16280 963 NaN NaN NaN NaN NaN
fnlwgt 16280 0 3083900756 189428.7934 11175556521.7039 13492 1490400
education 16280 0 NaN NaN NaN NaN NaN
education_num 16280 0 163987 10.0729 6.5927 1 16
marital_status 16280 0 NaN NaN NaN NaN NaN
occupation 16280 966 NaN NaN NaN NaN NaN
relationship 16280 0 NaN NaN NaN NaN NaN
race 16280 0 NaN NaN NaN NaN NaN
sex 16280 0 NaN NaN NaN NaN NaN
capital_gain 16280 0 17614497 1081.9716 57519546.0031 0 99999
capital_loss 16280 0 1431088 87.9047 162503.3785 0 3770
hours_per_week 16280 0 657586 40.3923 155.7433 1 99
native_country 16280 274 NaN NaN NaN NaN NaN
label 16280 0 NaN NaN NaN NaN NaN

读取数据可以使用如下方法进行:

CsvSourceBatchOp testData = new CsvSourceBatchOp()
	.setFilePath("https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_test.csv")
	.setIgnoreFirstLine(true)
	.setSchemaStr(schemaStr)
	.lazyPrintStatistics("Adult test");

训练

训练模型可以使用 RandomForestTrainBatchOp , 其中支持一些常用的决策树剪枝参数,可以通过调整这些参数来拿到一些更好的模型,详细可以参考参数说明部分。

String[] numericalFeatureColNames = new String[] {"age", "fnlwgt", "education_num", "capital_gain",
	"capital_loss", "hours_per_week"};

String[] categoryFeatureColNames = new String[] {"workclass", "education", "marital_status", "occupation",
	"relationship", "race", "sex", "native_country"};

RandomForestTrainBatchOp randomForestBatchOp = new RandomForestTrainBatchOp()
	.setFeatureCols(ArrayUtils.addAll(numericalFeatureColNames, categoryFeatureColNames))
	.setCategoricalCols(categoryFeatureColNames)
	.setSubsamplingRatio(0.6)
	.setMaxLeaves(32)
	.setLabelCol("label");

预测

RandomForestPredictBatchOp prediction = new RandomForestPredictBatchOp()
	.setPredictionCol("prediction")
	.setPredictionDetailCol("prediction_detail");

评估

EvalBinaryClassBatchOp eval = new EvalBinaryClassBatchOp()
	.setLabelCol("prediction")
	.setPredictionDetailCol("prediction_detail");

训练预测流程构建

prediction
	.linkFrom(
		randomForestBatchOp
			.linkFrom(trainData)
			.lazyPrintModelInfo("Adult random forest model")
			.lazyCollectModelInfo(new Consumer <RandomForestModelInfo>() {
				@Override
				public void accept(RandomForestModelInfo randomForestModelInfo) {
					try {
						randomForestModelInfo
							.saveTreeAsImage("/tmp/rf_adult_model.png", 0, true);
					} catch (IOException e) {
						throw new IllegalStateException(e);
					}
				}
			}),
		testData
	)
	.link(eval)
	.lazyPrintMetrics("Adult random forest evaluation");

执行

BatchOperator.execute();

运行结果

模型信息

Adult random forest model
Classification trees modelInfo:
Number of trees: 10
Number of features: 14
Number of categorical features: 8
Labels: [<=50K, >50K]

Categorical feature info:

feature number of categorical value
workclass 8
education 16
marital_status 7
race 5
sex 2
native_country 41

Table of feature importance Top 14:

feature importance
age 0.1997
fnlwgt 0.1992
capital_gain 0.1447
hours_per_week 0.1091
education_num 0.0889
occupation 0.0553
relationship 0.0423
capital_loss 0.0336
workclass 0.0306
sex 0.0299
race 0.0188
marital_status 0.0176
native_country 0.0158
education 0.0144

Classification trees modelInfo:
Number of trees: 10
Number of features: 14
Number of categorical features: 8
Labels: [<=50K, >50K]

Categorical feature info:

feature number of categorical value
workclass 8
education 16
marital_status 7
race 5
sex 2
native_country 41

Table of feature importance Top 14:

feature importance
fnlwgt 0.2318
age 0.2286
hours_per_week 0.1382
education_num 0.0706
occupation 0.0645
capital_gain 0.0568
workclass 0.0516
sex 0.033
relationship 0.0299
capital_loss 0.0222
education 0.0218
native_country 0.0199
race 0.0175
marital_status 0.0136

模型信息中包含一些常用的训练输入数据的基本信息,特征的基本信息,模型的基本信息。

离散特征的一些统计信息,可以通过 Categorical feature info 部分查看。

特征重要性是一类更常用的筛选特征的指标,可以通过 Table of feature importance Top 14 部分查看。

模型可视化

我们也输出了随进森林中第 0 号树的模型结果可视化结果,通过代码中 lazyCollectModelInfo 收集到模型信息之后,通过模型中提供的 saveTreeAsImage ,可以输出模型的图片结果到指定路径。

.lazyCollectModelInfo(new Consumer <RandomForestModelInfo>() {
	@Override
	public void accept(RandomForestModelInfo randomForestModelInfo) {
		try {
			randomForestModelInfo
				.saveTreeAsImage("/tmp/rf_adult_model.png", 0, true);
		} catch (IOException e) {
			throw new IllegalStateException(e);
		}
	}
})
评估结果

Adult random forest evaluation
——————————– Metrics: ——————————–
Auc:1 Accuracy:0.9995 Precision:0.9965 Recall:1 F1:0.9982 LogLoss:0.2584

Pred\Real >50K <=50K
>50K 2273 8
<=50K 0 13999

文献或出处

  1. RandomForest
  2. weka