天天看点

Python编程实现后剪枝的CART决策树

前面实现了不进行剪枝的CART决策树和预剪枝的决策树,本文是对后剪枝的CART决策树的实现,这样关于CART决策树的东西就凑全了。

后剪枝的策略是一种“事后诸葛亮”的策略,因而效果往往要比预剪枝和不剪枝要好。主要的操作方式就是在生成一颗不剪枝的决策树之后,对每一个满足其所有子节点都为叶子结点的结点进行判断,计算如果将其子结点全部删除能不能带来决策树对测试数据分类正确率的提高,如果能则进行剪枝操作,否则,不进行剪枝。

下面是后剪枝决策树的Python代码实现,是基于不进行剪枝的CART决策树的。

from Ch04DecisionTree import TreeNode
from Ch04DecisionTree import cart
from Ch04DecisionTree import Dataset


def current_accuracy(root_node=TreeNode.TreeNode(), test_data=[], test_label=[]):
    """
    计算当前决策树在训练数据集上的正确率
    :param root_node: 决策树的根节点
    :param test_data: 测试数据集
    :param test_label: 测试数据集的label
    :return:
    """
    # root_node = tree_node
    # while not (root_node.parent is None):
    #     root_node = root_node.parent

    accuracy = 0
    for i in range(0, len(test_label)):
        this_label = cart.classify_data(root_node, test_data[i])
        if this_label == test_label[i]:
            accuracy += 1
    
    return accuracy / len(test_label)


def post_pruning(decision_tree=TreeNode.TreeNode(), test_data=[], test_label=[], train_label=[]):
    """
    对决策树进行后剪枝操作
    :param decision_tree: 决策树根节点
    :param test_data: 测试数据集
    :param test_label: 测试数据集的标签
    :param train_label: 训练数据集的标签
    :return:
    """
    leaf_father = []  # 所有的孩子都是叶结点的结点集合

    bianli_list = []
    bianli_list.append(decision_tree)
    while len(bianli_list) > 0:
        current_node = bianli_list[0]
        children = current_node.children
        wanted = True  # 判断当前结点是否满足所有的子结点都是叶子结点
        if not (children is None):
            for child in children:
                bianli_list.append(child)
                temp_bool = (child.children is None)
                wanted = (wanted and temp_bool)
        else:
            wanted = False

        if wanted:
            leaf_father.append(current_node)
        bianli_list.remove(current_node)

    while len(leaf_father) > 0:
        # 如果叶父结点为空,则剪枝完成。对于不需要进行剪枝操作的叶父结点,我们也之间将其从leaf_father中删除
        current_node = leaf_father.pop()
        # 不进行剪枝在测试集上的正确率
        before_accuracy = current_accuracy(root_node=decision_tree, test_data=test_data, test_label=test_label)

        data_index = current_node.data_index
        label_count = {}
        for index in data_index:
            if label_count.__contains__(index):
                label_count[train_label[index]] += 1
            else:
                label_count[train_label[index]] = 1
        current_node.judge = max(label_count, key=label_count.get)  # 如果进行剪枝当前结点应该做出的判断
        later_accuracy = current_accuracy(root_node=decision_tree, test_data=test_data, test_label=test_label)

        if before_accuracy > later_accuracy:  # 不进行剪枝
            current_node.judge = None
        else:  # 进行剪枝
            current_node.children = None
            # 还需要检查是否需要对它的父节点进行判断
            parent_node = current_node.parent
            if not (parent_node is None):
                children_list = parent_node.children
                temp_bool = True
                for child in children_list:
                    if not (child.children is None):
                        temp_bool = False
                        break
                if temp_bool:
                    leaf_father.append(parent_node)
    return decision_tree


def run_test():
    train_watermelon, test_watermelon, title = Dataset.watermelon2()

    # 先处理数据
    train_data = []
    test_data = []
    train_label = []
    test_label = []
    for melon in train_watermelon:
        a_dict = {}
        dim = len(melon) - 1
        for i in range(0, dim):
            a_dict[title[i]] = melon[i]
        train_data.append(a_dict)
        train_label.append(melon[dim])
    for melon in test_watermelon:
        a_dict = {}
        dim = len(melon) - 1
        for i in range(0, dim):
            a_dict[title[i]] = melon[i]
        test_data.append(a_dict)
        test_label.append(melon[dim])

    decision_tree = cart.cart_tree(train_data, title, train_label)
    decision_tree = post_pruning(decision_tree=decision_tree, test_data=test_data, test_label=test_label, train_label=train_label)

    print('剪枝之后的决策树是:')
    cart.print_tree(decision_tree)
    print('\n')

    test_judge = []
    for melon in test_data:
        test_judge.append(cart.classify_data(decision_tree, melon))
    print('决策树在测试数据集上的分类结果是:', test_judge)
    print('测试数据集的正确类别信息应该是:  ', test_label)

    accuracy = 0
    for i in range(0, len(test_label)):
        if test_label[i] == test_judge[i]:
            accuracy += 1
    accuracy /= len(test_label)
    print('决策树在测试数据集上的分类正确率为:'+str(accuracy*100)+"%")


if __name__ == '__main__':
    run_test()
           

以下是程序在西瓜数据集2.0上的运行结果。可以看出相比较于不剪枝的CART决策树,进行后剪枝之后决策树在测试集上的正确率有了明显提升。由于在根节点选择划分属性的时候,“色泽”和“脐部”的基尼指数相同,我的程序中选择了“色泽”作为划分属性,所以可能与有些版本的结果不太一样。也造成了这里后剪枝的效果相比预剪枝的效果并没有得到十分明显的提升。不过一般来讲后剪枝的效果是要比预剪枝的效果好的。另外结果显示决策树的根节点的索引号是5而不是1,这是Python的类计数导致的,在程序的执行过程中有一些临时的TreeNode变量,1~4号是给那些临时变量了,根节点是5号不是剪枝导致的,这里如果有初学者的话不要产生误解。如果发现其它问题欢迎通过QQ进行交流。

剪枝之后的决策树是:
--------------------------------------------
current index : 5;
data : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
select attribute is : 色泽;
children : [6, 7, 8]
--------------------------------------------
--------------------------------------------
current index : 6;
parent index : 5;
色泽 : 青绿;
data : [0, 3, 5, 9];
label : 是
--------------------------------------------
--------------------------------------------
current index : 7;
parent index : 5;
色泽 : 乌黑;
data : [1, 2, 4, 7];
label : 是
--------------------------------------------
--------------------------------------------
current index : 8;
parent index : 5;
色泽 : 浅白;
data : [6, 8];
label : 否
--------------------------------------------


决策树在测试数据集上的分类结果是: ['是', '否', '是', '是', '否', '否', '是']
测试数据集的正确类别信息应该是:   ['是', '是', '是', '否', '否', '否', '否']
决策树在测试数据集上的分类正确率为:57.14285714285714%