欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法

程序员文章站 2024-02-15 15:07:10
...

python:实现基于信息熵进行划分选择的决策树算法

本文主要介绍本人用python基于信息熵进行划分选择的决策树代码实现,参考教材为西瓜书第四章——决策树。ps.本文只涉及决策树连续和离散两种情况,未考虑缺失值和剪枝。
首先摘取一些书上比较关键的理论知识:

1.决策树学习基本算法

【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法
显然,决策树是一种递归算法,递归最重要的一点是return条件的设置,这里主要有三种情况会产生return:

  1. 当前结点包含的样本全属于同一类别,无需划分。(即全为好瓜或全为坏瓜)
  2. 当前属性集为空,或是所有样本在所有属性上取值相同,无法划分。
  3. 当前节点包含的样本集合为空,不能划分。

2.信息熵

刚才的算法中,有一步很关键,那就是第8:从A中选择最优划分属性a*。依据什么来选择最优划分属性呢?这里我们用最基本的基于信息熵来进行划分。

1. 信息熵定义

【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法
y表示样本类别的数量,例如分为好瓜和坏瓜,y=2.
Pk表示第k类样本所占所有样本的比例,例如好瓜/(好瓜+坏瓜)。

2. 信息增益定义

(1) 离散情况:
【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法
信息增益表示:基于某个属性(如纹理)对样本集进行划分后得到的信息增益。信息增益越大,说明纯度提升越大。
选择信息增益最大的属性为最优划分属性,若计算值相同,则任意选择即可。
v表示在a属性下的子属性,例如:“色泽”属性又分为:乌黑,青绿,浅白三种子属性。在计算时需要分别统计子属性下对应的好瓜数量和坏瓜数量,再使用公式进行计算。
(2)连续情况:
当给出的属性有连续值(如“密度”)时,上述公式需稍加修改。我们采用二分法对连续值进行划分,选择出最优的二分界限。因此,首先我们需要对连续属性的取值由小到大排序,然后算出n-1个两两相邻的属性值的中间值,再依次代入求解信息增益,选择信息增益最大时的中界值。
【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法
【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法

3.代码实现

先放完整代码(python):

import math

class Attribute():
    def __init__(self,name,id,iscon=0):
        self.name = name
        self.kids = []
        self.id = id
        self.iscon = iscon #whether the attribute is continuous.1:algha,0:number.

# count per number of the kid of per attribute in the SampleArray
def count_sample(SampleArray,index,iscon,T=0):
    attribute = {}
    if len(SampleArray)==0:
        return -1 #Sample is NULL
    if iscon == 1:
        for sample in SampleArray:
            samples = sample.split(',')
            if samples[index] not in attribute:
                attribute[samples[index]]=1
            else:
                attribute[samples[index]]+=1
    else:
        for sample in SampleArray:
            samples = sample.split(',')
            if float(samples[index])<=T:
                if 'less' not in attribute:
                    attribute['less'] = 1
                else:
                    attribute['less'] += 1
            else:
                if 'more' not in attribute:
                    attribute['more'] = 1
                else:
                    attribute['more'] += 1
    return attribute
#count the number of the good and bad objects of each attribute
def count_attribute(SampleArray,index,iscon,T=0):
    attribute = {}
    if len(SampleArray) == 0:
        return -1  # Sample is NULL
    if iscon==1 :#depersed
        for sample in SampleArray:
            samples = sample.split(',')
            if str(samples[index]+samples[-1]) not in attribute:
                attribute[samples[index]+samples[-1]] = 1
            else:
                attribute[samples[index]+samples[-1]] += 1
    else:#continous
        for sample in SampleArray:
            samples = sample.split(',')
            if float(samples[index]) <= T:
                if str('less'+sample[-1]) not in attribute.keys():
                    attribute['less'+sample[-1]] = 1
                else:
                    attribute['less' + sample[-1]] += 1
            else:
                if str('more'+sample[-1]) not in attribute.keys():
                    attribute['more'+sample[-1]] = 1
                else:
                    attribute['more' + sample[-1]] += 1
    return attribute


def read_file(file_name,SampleArray,AttributeArray):
    with open(file_name) as f:
        contents = f.readline()
        flag =0
        index = -1
        if "编号" in contents:
            flag = 1
            index = contents.find(',')
            attributes = contents[index+1:].split(',')
        else:
            attributes = contents.split(',')  # remove the last word in txt. '\n'
        id = 0
        for a in attributes:
            att = Attribute(a, id)
            id += 1
            AttributeArray.append(att)  # rocord the attribute
        per_att = []
        for contents in f:
            if flag == 1:
                index = contents.find(',')
                per_att = contents[index+1:-1].split(',')
            else:
                per_att = contents[:-1].split(',')
            for i in range(len(AttributeArray)):
                if per_att[i] not in AttributeArray[i].kids:
                    AttributeArray[i].kids.append(per_att[i])
                    if per_att[i].isalnum():#the kid is number
                        AttributeArray[i].iscon = 1
            SampleArray.append(contents[index+1:].replace('\n',''))
    del AttributeArray[-1].kids[-1] #delete the last '' in kids of attributes.
    max_mark = count_sample(SampleArray,-1,1)
    max_class = max(max_mark,key=max_mark.get)#find the max number of the classes
    return max_class

#find the best attribute for the node
def find_attribute(SampleArray,AttributeArray):
    entropy_D = 0
    entropy_Dv = 0
    entropy_Dv_total = 0
    max_index = 0
    max_gain = 0
    den = 0
    gains = []
    max_con_middle = 0  # find the max middle number
    max_con_gain = 0
    classes = count_sample(SampleArray, -1,1)
    total_nums = sum(classes.values())
    total_nums = sum(classes.values())
    for value in classes.values():
        p = value / total_nums
        entropy_D += p*math.log(p,2)
    entropy_D = -(entropy_D)

    for index in range(len(AttributeArray)-1):#from 1 begin: overlook the number of each sample
        if AttributeArray[index].iscon == 1:# dispersed
            total_kids = count_sample(SampleArray,index,1)
            per_kid = count_attribute(SampleArray,index,1)
            for kid in AttributeArray[index].kids:
                for j in AttributeArray[-1].kids:
                    if str(kid+j) not in per_kid.keys():
                        continue #avoid some kid has no good result
                    num = per_kid[str(kid+j)]
                    den = total_kids[kid]
                    p = num / den
                    entropy_Dv += p*math.log(p,2)
                entropy_Dv_total += (den/total_nums)*(entropy_Dv)
                entropy_Dv = 0
            gain = entropy_D + entropy_Dv_total
            entropy_Dv_total = 0
            gains.append(gain)

        elif AttributeArray[index].iscon == 0:#continous
            Ta = []
            AttributeArray[index].kids.sort()
            for i in range(len(AttributeArray[index].kids)-1):
                Ta.append((float(AttributeArray[index].kids[i])+float(AttributeArray[index].kids[i+1]))/2)
            for t in Ta:
                total_kids = count_sample(SampleArray, index, 0,t)
                per_kid = count_attribute(SampleArray, index, 0,t)

                for j in AttributeArray[-1].kids:
                    if str('less'+j) not in per_kid.keys():
                        continue
                    num = per_kid['less'+j]
                    den = total_kids['less']
                    p = num / den
                    entropy_Dv += p * math.log(p, 2)
                entropy_Dv_total += (den / total_nums) * (entropy_Dv)
                entropy_Dv = 0
                for j in AttributeArray[-1].kids:
                    if str('more'+j) not in per_kid.keys():
                        continue
                    num = per_kid['more'+j]
                    den = total_kids['more']
                    p = num / den
                    entropy_Dv += p * math.log(p, 2)
                entropy_Dv_total += (den / total_nums) * (entropy_Dv)
                entropy_Dv = 0
                con_gain = entropy_D + entropy_Dv_total
                entropy_Dv_total = 0
                if con_gain > max_con_gain:
                    max_con_gain = con_gain
                    max_con_middle = t
            gain = max_con_gain
            gains.append(gain)

        if gain > max_gain:
            max_gain = gain
            max_index = index
    return max_index,max_con_middle  #return the index of the best attribute

treenode = []
#per tree node: [father, father_index, num, judge, result,leaf]
def tree_generate(SampleArray, AttributeArray,father,father_index,pass_kid,max_class):
    treenode.append([])  # create a new tree node
    index = len(treenode) - 1
    treenode[index].append(father)  # record the father of the node
    treenode[index].append(father_index)
    treenode[index].append(index)
    treenode[index].append(pass_kid)
    '''case 1: judge whether there is only one class in SampleArray'''
    count = count_sample(SampleArray,-1,1)
    if len(count)==1:
        treenode[index].append(max_class)
        treenode[index].append(1)
        return

    '''case 2: AttributeArray is NULL or all the samples have the same attributes.'''
    i = 0
    for i in range(len(AttributeArray)-1):
        if len(count_sample(SampleArray,i,1))!=1:
            break
    if i==(len(AttributeArray)-1) or len(AttributeArray)==1:#class should not be included.
        treenode[index].append(max_class)
        treenode[index].append(1)  # leaf
        return
    treenode[index].append(0)#no result
    treenode[index].append(0)#not the leaf
    '''case 3: find the best attribute.'''
    best_index,best_middle = find_attribute(SampleArray,AttributeArray)
    kid_SampleArray = []
    kid_SampleArray.clear()
    new_index = 0
    #prepare to create the kid tree
    if AttributeArray[best_index].iscon == 1:
        for kid in AttributeArray[best_index].kids:
            kid_SampleArray.clear()
            for sample in SampleArray:
                samples = sample.split(',')
                if samples[best_index] == kid:
                    kid_SampleArray.append(sample.replace(kid+',',''))
            if len(kid_SampleArray)== 0:
                treenode.append([])  # create a new tree node
                new_index = len(treenode) - 1
                treenode[new_index].append(AttributeArray[best_index].name)  # record the father of the node
                treenode[new_index].append(index)
                treenode[new_index].append(new_index)
                treenode[new_index].append(kid)
                treenode[new_index].append(max_class)
                treenode[new_index].append(1)  # leaf
                return
            else:
                kid_AttributeArray = list(AttributeArray)
                del kid_AttributeArray[best_index]
                max_class = count_sample(kid_SampleArray,-1,1)
                max_class = max(max_class, key=max_class.get)
                tree_generate(kid_SampleArray,kid_AttributeArray,AttributeArray[best_index].name,index,kid,max_class)
    else:
        kid_less_SampleArray = []
        kid_less_SampleArray.clear()
        kid_more_SampleArray = []
        kid_more_SampleArray.clear()
        for sample in SampleArray:
            samples = sample.split(',')
            if float(samples[best_index]) <= best_middle:
                kid_less_SampleArray.append(sample.replace(samples[best_index]+',',''))
            else:
                kid_more_SampleArray.append(sample.replace(samples[best_index]+',',''))
        if len(kid_less_SampleArray)== 0:
            treenode.append([])  # create a new tree node
            new_index = len(treenode) - 1
            treenode[new_index].append(AttributeArray[best_index].name)  # record the father of the node
            treenode[new_index].append(index)
            treenode[new_index].append(new_index)
            treenode[new_index].append("<="+str(best_middle))
            treenode[new_index].append(max_class)
            treenode[new_index].append(1)  # leaf
            return
        else:
            kid_AttributeArray = list(AttributeArray)
            del kid_AttributeArray[best_index]
            max_less_class = count_sample(kid_less_SampleArray, -1, 1)
            max_less_class = max(max_less_class, key=max_less_class.get)
            tree_generate(kid_less_SampleArray, kid_AttributeArray, AttributeArray[best_index].name,index, "<="+str(best_middle),max_less_class)
        if len(kid_more_SampleArray)== 0:
            treenode.append([])  # create a new tree node
            new_index = len(treenode) - 1
            treenode[new_index].append(AttributeArray[best_index].name)  # record the father of the node
            treenode[new_index].append(index)
            treenode[new_index].append(new_index)
            treenode[new_index].append(">"+str(best_middle))
            treenode[new_index].append(max_class)
            treenode[new_index].append(1)  # leaf
            return
        else:
            kid_AttributeArray = list(AttributeArray)
            del kid_AttributeArray[best_index]
            max_more_class = count_sample(kid_more_SampleArray, -1, 1)
            max_more_class = max(max_more_class, key=max_more_class.get)
            tree_generate(kid_more_SampleArray, kid_AttributeArray, AttributeArray[best_index].name,index, ">"+str(best_middle),max_more_class)

def main():
    AttributeArray = []  # record attributes
    SampleArray = []  # record samples
    max_class = read_file('data.txt',SampleArray,AttributeArray)
    tree_generate(SampleArray,AttributeArray,-1,-1,-1,max_class)
    print(treenode[1:])

if __name__ =='__main__':
    main()

输入示例:
【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法
左边为含连续型的情况,右边为不含连续型的情况。
输出示例:
【机器学习】西瓜书_周志华,python实现基于信息熵进行划分选择的决策树算法
分别为:[father, father_index, index, kid, result, isleaf]
我没有实现决策树的可视化,以后若有时间可以实现一下。不过现在也能根据跑出来的结果手动画出决策树了hhh。

3.1 具体说明:

1. 定义属性类

class Attribute():
    def __init__(self,name,id,iscon=0):
        self.name = name
        self.kids = []
        self.id = id
        self.iscon = iscon #whether the attribute is continuous.1:algha,0:number.

name用于记录属性的名称,kids用于记录属性下包括的子属性,id没啥用可以不加,iscon代表该属性是否为连续型,若是则为0.(逻辑有点反哈。。)

2. count_sample函数

def count_sample(SampleArray,index,iscon,T=0):
    attribute = {}
    if len(SampleArray)==0:
        return -1 #Sample is NULL
    if iscon == 1:
        for sample in SampleArray:
            samples = sample.split(',')
            if samples[index] not in attribute:
                attribute[samples[index]]=1
            else:
                attribute[samples[index]]+=1
    else:
        for sample in SampleArray:
            samples = sample.split(',')
            if float(samples[index])<=T:
                if 'less' not in attribute:
                    attribute['less'] = 1
                else:
                    attribute['less'] += 1
            else:
                if 'more' not in attribute:
                    attribute['more'] = 1
                else:
                    attribute['more'] += 1
    return attribute

该函数主要作用是计算样本中指定属性的各子属性的数目。若为连续型,则计算大于某个阈值T的数目和小于等于T的数目。返回一个字典。

  1. count_attribute函数
def count_attribute(SampleArray,index,iscon,T=0):
    attribute = {}
    if len(SampleArray) == 0:
        return -1  # Sample is NULL
    if iscon==1 :#depersed
        for sample in SampleArray:
            samples = sample.split(',')
            if str(samples[index]+samples[-1]) not in attribute:
                attribute[samples[index]+samples[-1]] = 1
            else:
                attribute[samples[index]+samples[-1]] += 1
    else:#continous
        for sample in SampleArray:
            samples = sample.split(',')
            if float(samples[index]) <= T:
                if str('less'+sample[-1]) not in attribute.keys():
                    attribute['less'+sample[-1]] = 1
                else:
                    attribute['less' + sample[-1]] += 1
            else:
                if str('more'+sample[-1]) not in attribute.keys():
                    attribute['more'+sample[-1]] = 1
                else:
                    attribute['more' + sample[-1]] += 1
    return attribute

这个函数实现计算指定属性下每个子属性对应的各类别的数目。例如:指定属性为“色泽”时,样本中“青绿”、“乌黑”、“浅白”分别对应的“好瓜”的数量和“坏瓜”的数目。返回值也为字典。

  1. read_file函数:
def read_file(file_name,SampleArray,AttributeArray):
    with open(file_name) as f:
        contents = f.readline()
        flag =0
        index = -1
        if "编号" in contents:
            flag = 1
            index = contents.find(',')
            attributes = contents[index+1:].split(',')
        else:
            attributes = contents.split(',')  # remove the last word in txt. '\n'
        id = 0
        for a in attributes:
            att = Attribute(a, id)
            id += 1
            AttributeArray.append(att)  # rocord the attribute
        per_att = []
        for contents in f:
            if flag == 1:
                index = contents.find(',')
                per_att = contents[index+1:-1].split(',')
            else:
                per_att = contents[:-1].split(',')
            for i in range(len(AttributeArray)):
                if per_att[i] not in AttributeArray[i].kids:
                    AttributeArray[i].kids.append(per_att[i])
                    if per_att[i].isalnum():#the kid is number
                        AttributeArray[i].iscon = 1
            SampleArray.append(contents[index+1:].replace('\n',''))
    del AttributeArray[-1].kids[-1] #delete the last '' in kids of attributes.
    max_mark = count_sample(SampleArray,-1,1)
    max_class = max(max_mark,key=max_mark.get)#find the max number of the classes
    return max_class

本函数实现从指定文件中读入数据,需要注意以下两点:
(1) 如果文件中第一列为编号,需要无视这一列。
(2) 文件末尾可能存在"\n",需要删除。
这里我图方便就直接返回读入样本的最大类了,方便后面第一次使用。

  1. find_attribute函数
#find the best attribute for the node
def find_attribute(SampleArray,AttributeArray):
    entropy_D = 0
    entropy_Dv = 0
    entropy_Dv_total = 0
    max_index = 0
    max_gain = 0
    den = 0
    gains = []
    max_con_middle = 0  # find the max middle number
    max_con_gain = 0
    classes = count_sample(SampleArray, -1,1)
    total_nums = sum(classes.values())
    total_nums = sum(classes.values())
    for value in classes.values():
        p = value / total_nums
        entropy_D += p*math.log(p,2)
    entropy_D = -(entropy_D)

    for index in range(len(AttributeArray)-1):#from 1 begin: overlook the number of each sample
        if AttributeArray[index].iscon == 1:# dispersed
            total_kids = count_sample(SampleArray,index,1)
            per_kid = count_attribute(SampleArray,index,1)
            for kid in AttributeArray[index].kids:
                for j in AttributeArray[-1].kids:
                    if str(kid+j) not in per_kid.keys():
                        continue #avoid some kid has no good result
                    num = per_kid[str(kid+j)]
                    den = total_kids[kid]
                    p = num / den
                    entropy_Dv += p*math.log(p,2)
                entropy_Dv_total += (den/total_nums)*(entropy_Dv)
                entropy_Dv = 0
            gain = entropy_D + entropy_Dv_total
            entropy_Dv_total = 0
            gains.append(gain)

        elif AttributeArray[index].iscon == 0:#continous
            Ta = []
            AttributeArray[index].kids.sort()
            for i in range(len(AttributeArray[index].kids)-1):
                Ta.append((float(AttributeArray[index].kids[i])+float(AttributeArray[index].kids[i+1]))/2)
            for t in Ta:
                total_kids = count_sample(SampleArray, index, 0,t)
                per_kid = count_attribute(SampleArray, index, 0,t)

                for j in AttributeArray[-1].kids:
                    if str('less'+j) not in per_kid.keys():
                        continue
                    num = per_kid['less'+j]
                    den = total_kids['less']
                    p = num / den
                    entropy_Dv += p * math.log(p, 2)
                entropy_Dv_total += (den / total_nums) * (entropy_Dv)
                entropy_Dv = 0
                for j in AttributeArray[-1].kids:
                    if str('more'+j) not in per_kid.keys():
                        continue
                    num = per_kid['more'+j]
                    den = total_kids['more']
                    p = num / den
                    entropy_Dv += p * math.log(p, 2)
                entropy_Dv_total += (den / total_nums) * (entropy_Dv)
                entropy_Dv = 0
                con_gain = entropy_D + entropy_Dv_total
                entropy_Dv_total = 0
                if con_gain > max_con_gain:
                    max_con_gain = con_gain
                    max_con_middle = t
            gain = max_con_gain
            gains.append(gain)

        if gain > max_gain:
            max_gain = gain
            max_index = index
    return max_index,max_con_middle  #return the index of the best attribute

本函数实现算法中的第8步:即基于信息熵寻找最优划分属性。分为离散型和连续型两种情况,按公式实现即可。

  1. tree_generate函数
def tree_generate(SampleArray, AttributeArray,father,father_index,pass_kid,max_class):
    treenode.append([])  # create a new tree node
    index = len(treenode) - 1
    treenode[index].append(father)  # record the father of the node
    treenode[index].append(father_index)
    treenode[index].append(index)
    treenode[index].append(pass_kid)
    '''case 1: judge whether there is only one class in SampleArray'''
    count = count_sample(SampleArray,-1,1)
    if len(count)==1:
        treenode[index].append(max_class)
        treenode[index].append(1)
        return

    '''case 2: AttributeArray is NULL or all the samples have the same attributes.'''
    i = 0
    for i in range(len(AttributeArray)-1):
        if len(count_sample(SampleArray,i,1))!=1:
            break
    if i==(len(AttributeArray)-1) or len(AttributeArray)==1:#class should not be included.
        treenode[index].append(max_class)
        treenode[index].append(1)  # leaf
        return
    treenode[index].append(0)#no result
    treenode[index].append(0)#not the leaf
    '''case 3: find the best attribute.'''
    best_index,best_middle = find_attribute(SampleArray,AttributeArray)
    kid_SampleArray = []
    kid_SampleArray.clear()
    new_index = 0
    #prepare to create the kid tree
    if AttributeArray[best_index].iscon == 1:
        for kid in AttributeArray[best_index].kids:
            kid_SampleArray.clear()
            for sample in SampleArray:
                samples = sample.split(',')
                if samples[best_index] == kid:
                    kid_SampleArray.append(sample.replace(kid+',',''))
            if len(kid_SampleArray)== 0:
                treenode.append([])  # create a new tree node
                new_index = len(treenode) - 1
                treenode[new_index].append(AttributeArray[best_index].name)  # record the father of the node
                treenode[new_index].append(index)
                treenode[new_index].append(new_index)
                treenode[new_index].append(kid)
                treenode[new_index].append(max_class)
                treenode[new_index].append(1)  # leaf
                return
            else:
                kid_AttributeArray = list(AttributeArray)
                del kid_AttributeArray[best_index]
                max_class = count_sample(kid_SampleArray,-1,1)
                max_class = max(max_class, key=max_class.get)
                tree_generate(kid_SampleArray,kid_AttributeArray,AttributeArray[best_index].name,index,kid,max_class)
    else:
        kid_less_SampleArray = []
        kid_less_SampleArray.clear()
        kid_more_SampleArray = []
        kid_more_SampleArray.clear()
        for sample in SampleArray:
            samples = sample.split(',')
            if float(samples[best_index]) <= best_middle:
                kid_less_SampleArray.append(sample.replace(samples[best_index]+',',''))
            else:
                kid_more_SampleArray.append(sample.replace(samples[best_index]+',',''))
        if len(kid_less_SampleArray)== 0:
            treenode.append([])  # create a new tree node
            new_index = len(treenode) - 1
            treenode[new_index].append(AttributeArray[best_index].name)  # record the father of the node
            treenode[new_index].append(index)
            treenode[new_index].append(new_index)
            treenode[new_index].append("<="+str(best_middle))
            treenode[new_index].append(max_class)
            treenode[new_index].append(1)  # leaf
            return
        else:
            kid_AttributeArray = list(AttributeArray)
            del kid_AttributeArray[best_index]
            max_less_class = count_sample(kid_less_SampleArray, -1, 1)
            max_less_class = max(max_less_class, key=max_less_class.get)
            tree_generate(kid_less_SampleArray, kid_AttributeArray, AttributeArray[best_index].name,index, "<="+str(best_middle),max_less_class)
        if len(kid_more_SampleArray)== 0:
            treenode.append([])  # create a new tree node
            new_index = len(treenode) - 1
            treenode[new_index].append(AttributeArray[best_index].name)  # record the father of the node
            treenode[new_index].append(index)
            treenode[new_index].append(new_index)
            treenode[new_index].append(">"+str(best_middle))
            treenode[new_index].append(max_class)
            treenode[new_index].append(1)  # leaf
            return
        else:
            kid_AttributeArray = list(AttributeArray)
            del kid_AttributeArray[best_index]
            max_more_class = count_sample(kid_more_SampleArray, -1, 1)
            max_more_class = max(max_more_class, key=max_more_class.get)
            tree_generate(kid_more_SampleArray, kid_AttributeArray, AttributeArray[best_index].name,index, ">"+str(best_middle),max_more_class)

本函数就是按照决策树算法写的了。每一次调用时,先生成一个新的结点,然后判断是否满足条件return,若不return,则寻找本次递归的最优划分属性,然后再判断是否return。若也不return,则继续递归。


以上就是我写的python实现决策树的代码啦~。
附:
data.txt
编号,色泽,根蒂,敲声,纹理,脐部,触感,好瓜
1,青绿,蜷缩,浊响,清晰,凹陷,硬滑,是
2,乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,是
3,乌黑,蜷缩,浊响,清晰,凹陷,硬滑,是
4,青绿,蜷缩,沉闷,清晰,凹陷,硬滑,是
5,浅白,蜷缩,浊响,清晰,凹陷,硬滑,是
6,青绿,稍蜷,浊响,清晰,稍凹,软粘,是
7,乌黑,稍蜷,浊响,稍糊,稍凹,软粘,是
8,乌黑,稍蜷,浊响,清晰,稍凹,硬滑,是
9,乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,否
10,青绿,硬挺,清脆,清晰,平坦,软粘,否
11,浅白,硬挺,清脆,模糊,平坦,硬滑,否
12,浅白,蜷缩,浊响,模糊,平坦,软粘,否
13,青绿,稍蜷,浊响,稍糊,凹陷,硬滑,否
14,浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,否
15,乌黑,稍蜷,浊响,清晰,稍凹,软粘,否
16,浅白,蜷缩,浊响,模糊,平坦,硬滑,否
17,青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,否

data_con.txt
编号,色泽,根蒂,敲声,纹理,脐部,触感,密度,含糖率,好瓜
1,青绿,蜷缩,浊响,清晰,凹陷,硬滑,0.697,0.46,是
2,乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,0.744,0.376,是
3,乌黑,蜷缩,浊响,清晰,凹陷,硬滑,0.634,0.264,是
4,青绿,蜷缩,沉闷,清晰,凹陷,硬滑,0.608,0.318,是
5,浅白,蜷缩,浊响,清晰,凹陷,硬滑,0.556,0.215,是
6,青绿,稍蜷,浊响,清晰,稍凹,软粘,0.403,0.237,是
7,乌黑,稍蜷,浊响,稍糊,稍凹,软粘,0.481,0.149,是
8,乌黑,稍蜷,浊响,清晰,稍凹,硬滑,0.437,0.211,是
9,乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,0.666,0.091,否
10,青绿,硬挺,清脆,清晰,平坦,软粘,0.243,0.267,否
11,浅白,硬挺,清脆,模糊,平坦,硬滑,0.245,0.057,否
12,浅白,蜷缩,浊响,模糊,平坦,软粘,0.343,0.099,否
13,青绿,稍蜷,浊响,稍糊,凹陷,硬滑,0.639,0.161,否
14,浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,0.657,0.198,否
15,乌黑,稍蜷,浊响,清晰,稍凹,软粘,0.36,0.37,否
16,浅白,蜷缩,浊响,模糊,平坦,硬滑,0.593,0.042,否
17,青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,0.719,0.103,否