Java 类名:com.alibaba.alink.operator.batch.evaluation.EvalClusterBatchOp
Python 类名:EvalClusterBatchOp
对聚类算法的预测结果进行效果评估。
在下面的指标中,用 $C_i$ 表示第 $i$ 个簇, $u_i$ 表示$C_i$ 的中心,$k$ 表示簇的总数。
$\overline{CP_i}=\dfrac{1}{|C_i|}\sum_{x \in C_i}|x_i-u_i|$
$\overline{CP}=\dfrac{1}{k}\sum_{i=1}^{k}\overline{CP_k}$
CP越低意味着类内聚类距离越近。
$SP=\dfrac{2}{k^2-k}\sum_{i=1}^{k}\sum_{j=i+1}^{k}|u_i-u_j|$
SP越高意味类间聚类距离越远。
$DB=\dfrac{1}{k}\sum_{i=1}^{k}\max_{i \not=j}(\dfrac{\overline{CP_i}+\overline{CP_j}}{|u_i-u_j|})$
DB越小意味着类内距离越小,同时类间距离越大。
用 $u$ 表示数据集整体的中心点,
$SSB=\sum_{i=1}^{k}n_i|u_i-u|^2,$
$SSW=\sum_{i=1}^{k}\sum_{x \in C_i}|x_i-u_i|,$
$\mathrm{VRC}=\dfrac{SSB}{SSW}*\dfrac{N-k}{k-1}$
VRC越大意味着聚类质量越好。
另有一部分聚类评价指标称作外部指标(external criterion)。
这些指标在评估时除了有每个样本点所属的簇之外,还假设每个样本点有类别标签。
在下面的指标中,$N$表示簇的总数,用 $\omega_k$ 表示簇$k$所包含的样本点集合,$c_j$ 表示类别标签为 $j$ 的样本点集合。
$Purity(\Omega, C)=\dfrac{1}{N}\sum_{k}\max_{j}|\omega_k \cap c_j|$
取值在 $[0,1]$ 区间内,越接近1表示同一个簇内相同类别的数据点越多,聚类结果越好。
$H(\Omega)=-\sum_{k}\dfrac{\omega_k}{N}log\dfrac{\omega_k}{N},$
$H(C)=-\sum_{j}\dfrac{c_j}{N}log\dfrac{c_j}{N},$
$I(\Omega, C)=\sum_k\sum_j\dfrac{|\omega_k \cap c_j|}{N}\log \dfrac{N|\omega_k \cap c_j|}{|\omega_k||c_j|},$
$\mathrm{NMI}=\dfrac{2 * I(\Omega, C)}{H(\Omega) + H(C)}$
取值在 $[0,1]$ 区间内, 越接近1表示聚类结果越好。
对于任意一对样本点:
- 如果标签相同并且属于相同的簇,则认为是 TP (True Positive);
- 如果标签不同并且属于不同的簇,则认为是 TN (True Negative);
- 如果标签相同并且属于不同的簇,则认为是 FN (False Negative);
- 如果标签不同并且属于相同的簇,则认为是 FP (False Positive)。
用 $TP, TN, FN, FP$ 分别表示属于各自类别的样本点对的个数,$N(k,j)$ 表示簇 $k$ 内类别为 $j$ 的样本点个数,那么有:
$TP+FP=\sum_{j}\binom{c_j}{2},$
$TP+FN=\sum_{k}\binom{\omega_k}{2},$
$TP=\sum_{k}\sum_{j}\binom{N(k,j)}{2},$
$TP+TN+FP+FN= \binom{N}{2},$
$RI=\dfrac{TP+TN}{TP+TN+FP+FN}$
取值在 $[0,1]$ 区间内,越接近1表示聚类结果越好。
$\mathrm{Index}=TP,$
$\mathrm{ExpectedIndex}=\dfrac{(TP+FP)(TP+FN)}{TP+TN+FP+FN},$
$\mathrm{MaxIndex}=\dfrac{TP+FP+TP+FN}{2},$
$ARI=\dfrac{\mathrm{Index} - \mathrm{ExpectedIndex}}{\mathrm{MaxIndex} - \mathrm{ExpectedIndex}}$
E
取值在 $[-1,1]$ 区间内,越接近1表示聚类结果越好。
该组件通常接聚类算法的输出端。
使用时,需要通过 predictionCol
指定预测结果类。 通常还需要通过 vectorCol
指定样本点的坐标,这样才能计算评估指标。否则,只能输出样本点所属簇等基本信息。
另外,可以根据需要指定标签列 labelCol
,这样可以计算外部指标。
名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 |
---|---|---|---|---|---|---|
predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | ||
distanceType | 距离度量方式 | 距离类型 | String | “EUCLIDEAN”, “COSINE”, “CITYBLOCK” | “EUCLIDEAN” | |
labelCol | 标签列名 | 输入表中的标签列名 | String | null | ||
vectorCol | 向量列名 | 输入表中的向量列名 | String | 所选列类型为 [DENSE_VECTOR, SPARSE_VECTOR, STRING, VECTOR] | null |
from pyalink.alink import * import pandas as pd useLocalEnv(1) df = pd.DataFrame([ [0, "0 0 0"], [0, "0.1,0.1,0.1"], [0, "0.2,0.2,0.2"], [1, "9 9 9"], [1, "9.1 9.1 9.1"], [1, "9.2 9.2 9.2"] ]) inOp = BatchOperator.fromDataframe(df, schemaStr='id int, vec string') metrics = EvalClusterBatchOp().setVectorCol("vec").setPredictionCol("id").linkFrom(inOp).collectMetrics() print("Total Samples Number:", metrics.getCount()) print("Cluster Number:", metrics.getK()) print("Cluster Array:", metrics.getClusterArray()) print("Cluster Count Array:", metrics.getCountArray()) print("CP:", metrics.getCp()) print("DB:", metrics.getDb()) print("SP:", metrics.getSp()) print("SSB:", metrics.getSsb()) print("SSW:", metrics.getSsw()) print("CH:", metrics.getVrc())
import org.apache.flink.types.Row; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.evaluation.EvalClusterBatchOp; import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; import com.alibaba.alink.operator.common.evaluation.ClusterMetrics; import org.junit.Test; import java.util.Arrays; import java.util.List; public class EvalClusterBatchOpTest { @Test public void testEvalClusterBatchOp() throws Exception { List <Row> df = Arrays.asList( Row.of(0, "0 0 0"), Row.of(0, "0.1,0.1,0.1"), Row.of(0, "0.2,0.2,0.2"), Row.of(1, "9 9 9"), Row.of(1, "9.1 9.1 9.1"), Row.of(1, "9.2 9.2 9.2") ); BatchOperator <?> inOp = new MemSourceBatchOp(df, "id int, vec string"); ClusterMetrics metrics = new EvalClusterBatchOp().setVectorCol("vec").setPredictionCol("id").linkFrom(inOp) .collectMetrics(); System.out.println("Total Samples Number:" + metrics.getCount()); System.out.println("Cluster Number:" + metrics.getK()); System.out.println("Cluster Array:" + Arrays.toString(metrics.getClusterArray())); System.out.println("Cluster Count Array:" + Arrays.toString(metrics.getCountArray())); System.out.println("CP:" + metrics.getCp()); System.out.println("DB:" + metrics.getDb()); System.out.println("SP:" + metrics.getSp()); System.out.println("SSB:" + metrics.getSsb()); System.out.println("SSW:" + metrics.getSsw()); System.out.println("CH:" + metrics.getVrc()); } }
Total Samples Number: 6
Cluster Number: 2
Cluster Array: ['0', '1']
Cluster Count Array: [3.0, 3.0]
CP: 0.11547005383792497
DB: 0.014814814814814791
SP: 15.588457268119896
SSB: 364.5
SSW: 0.1199999999999996
CH: 12150.000000000042