Alink教程(Python版)

第9章 朴素贝叶斯模型与决策树模型

本章包括下面各节:
9.1 朴素贝叶斯模型
9.2 决策树模型
9.2.1 决策树的分裂指标定义
9.2.2 常用的决策树算法
9.2.3 指标计算示例
9.2.4 分类树与回归树
9.2.5 经典的决策树示例
9.3 数据探索
9.4 使用朴素贝叶斯方法
9.5 蘑菇分类的决策树

详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Python)》,这里为本章对应的示例代码。

from pyalink.alink import *
useLocalEnv(1)

from utils import *
import os
import pandas as pd

DATA_DIR = ROOT_DIR + "mushroom" + os.sep

ORIGIN_FILE = "agaricus-lepiota.data"
TRAIN_FILE = "train.ak"
TEST_FILE = "test.ak"

COL_NAMES = [
    "class",
    "cap_shape", "cap_surface", "cap_color", "bruises", "odor",
    "gill_attachment", "gill_spacing", "gill_size", "gill_color",
    "stalk_shape", "stalk_root", "stalk_surface_above_ring", "stalk_surface_below_ring",
    "stalk_color_above_ring", "stalk_color_below_ring",
    "veil_type", "veil_color",
    "ring_number", "ring_type", "spore_print_color", "population", "habitat"    
]

COL_TYPES = [
    "string",
    "string", "string", "string", "string", "string",
    "string", "string", "string", "string", "string",
    "string", "string", "string", "string", "string",
    "string", "string", "string", "string", "string",
    "string", "string"
]

LABEL_COL_NAME = "class"

FEATURE_COL_NAMES = COL_NAMES.copy()
FEATURE_COL_NAMES.remove(LABEL_COL_NAME)

PREDICTION_COL_NAME = "pred"
PRED_DETAIL_COL_NAME = "predInfo"
#c_1
source = CsvSourceBatchOp()\
    .setFilePath(DATA_DIR + ORIGIN_FILE)\
    .setSchemaStr(generateSchemaString(COL_NAMES, COL_TYPES))

source.lazyPrint(5, "< origin data >")

splitTrainTestIfNotExist(source, DATA_DIR + TRAIN_FILE, DATA_DIR + TEST_FILE, 0.9)

AkSourceBatchOp()\
    .setFilePath(DATA_DIR + TRAIN_FILE)\
    .link(
        ChiSqSelectorBatchOp()\
            .setSelectorType("NumTopFeatures")\
            .setNumTopFeatures(3)\
            .setSelectedCols(FEATURE_COL_NAMES)\
            .setLabelCol(LABEL_COL_NAME)\
            .lazyPrintModelInfo("< Chi-Square Selector >")
    )

AkSourceBatchOp()\
    .setFilePath(DATA_DIR + TRAIN_FILE)\
    .select("veil_type")\
    .distinct()\
    .lazyPrint(100)

BatchOperator.execute()
#c_2_1
train_data = AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
test_data = AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

trainer = NaiveBayesTrainBatchOp()\
    .setFeatureCols(FEATURE_COL_NAMES)\
    .setCategoricalCols(FEATURE_COL_NAMES)\
    .setLabelCol(LABEL_COL_NAME)

predictor = NaiveBayesPredictBatchOp()\
    .setPredictionCol(PREDICTION_COL_NAME)\
    .setPredictionDetailCol(PRED_DETAIL_COL_NAME)

train_data.link(trainer);
predictor.linkFrom(trainer, test_data);

trainer.lazyPrintModelInfo();

def print_model_info(naiveBayesModelInfo: NaiveBayesModelInfo):
    for feature in ["odor", "spore_print_color", "gill_color"]:
        print("feature: " + feature)
        print(naiveBayesModelInfo.getCategoryFeatureInfo().get(feature))

trainer.lazyCollectModelInfo(print_model_info)

predictor.lazyPrint(10, "< Prediction >");

predictor\
    .link(
        EvalBinaryClassBatchOp()\
            .setPositiveLabelValueString("p")\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
            .lazyPrintMetrics()
    )

BatchOperator.execute()
#c_2_2
train_data = AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE)
test_data = AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE)

trainer = NaiveBayesTrainBatchOp()\
    .setFeatureCols(["odor", "gill_color"])\
    .setCategoricalCols(["odor", "gill_color"])\
    .setLabelCol(LABEL_COL_NAME);

predictor = NaiveBayesPredictBatchOp()\
    .setPredictionCol(PREDICTION_COL_NAME)\
    .setPredictionDetailCol(PRED_DETAIL_COL_NAME);

train_data.link(trainer);
predictor.linkFrom(trainer, test_data);

def print_model_info(naiveBayesModelInfo: NaiveBayesModelInfo):
    for feature in ["odor", "gill_color"]:
        print("feature: " + feature)
        print(naiveBayesModelInfo.getCategoryFeatureInfo().get(feature))

trainer.lazyCollectModelInfo(print_model_info);

predictor\
    .lazyPrint(10, "< Prediction >")\
    .link(
        EvalBinaryClassBatchOp()\
            .setPositiveLabelValueString("p")\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
            .lazyPrintMetrics()
    );

BatchOperator.execute();
#c_3_1
df = pd.DataFrame(
    [
        ["sunny", 85.0, 85.0, False, "no"],
        ["sunny", 80.0, 90.0, True, "no"],
        ["overcast", 83.0, 78.0, False, "yes"],
        ["rainy", 70.0, 96.0, False, "yes"],
        ["rainy", 68.0, 80.0, False, "yes"],
        ["rainy", 65.0, 70.0, True, "no"],
        ["overcast", 64.0, 65.0, True, "yes"],
        ["sunny", 72.0, 95.0, False, "no"],
        ["sunny", 69.0, 70.0, False, "yes"],
        ["rainy", 75.0, 80.0, False, "yes"],
        ["sunny", 75.0, 70.0, True, "yes"],
        ["overcast", 72.0, 90.0, True, "yes"],
        ["overcast", 81.0, 75.0, False, "yes"],
        ["rainy", 71.0, 80.0, True, "no"]
    ]
)

source = BatchOperator.fromDataframe(df, schemaStr="Outlook string, Temperature double, Humidity double, Windy boolean, Play string")
 
source.lazyPrint(-1);

source\
    .link(
        C45TrainBatchOp()\
            .setFeatureCols(["Outlook", "Temperature", "Humidity", "Windy"])\
            .setCategoricalCols(["Outlook", "Windy"])\
            .setLabelCol("Play")\
            .lazyPrintModelInfo()\
            .lazyCollectModelInfo(
                lambda decisionTreeModelInfo: 
                    decisionTreeModelInfo.saveTreeAsImage(
                        DATA_DIR + "weather_tree_model.png", True)
        )
    );

BatchOperator.execute();
#c_3_2
train_data = AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
test_data = AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);

for treeType in ['GINI', 'INFOGAIN', 'INFOGAINRATIO'] :
    model = train_data.link(
        DecisionTreeTrainBatchOp()\
            .setTreeType(treeType)\
            .setFeatureCols(FEATURE_COL_NAMES)\
            .setCategoricalCols(FEATURE_COL_NAMES)\
            .setLabelCol(LABEL_COL_NAME)\
            .lazyPrintModelInfo("< " + treeType + " >")\
            .lazyCollectModelInfo(
                lambda decisionTreeModelInfo:
                    decisionTreeModelInfo.saveTreeAsImage(
                        DATA_DIR + "tree_" + treeType + ".jpg", True)
            )
    );

    predictor = DecisionTreePredictBatchOp()\
        .setPredictionCol(PREDICTION_COL_NAME)\
        .setPredictionDetailCol(PRED_DETAIL_COL_NAME);

    predictor.linkFrom(model, test_data);

    predictor.link(
        EvalBinaryClassBatchOp()\
            .setPositiveLabelValueString("p")\
            .setLabelCol(LABEL_COL_NAME)\
            .setPredictionDetailCol(PRED_DETAIL_COL_NAME)\
            .lazyPrintMetrics("< " + treeType + " >")
    )

BatchOperator.execute()