机器学习实战之决策树算法学习心得(下)
本文是接着机器学习实战之决策树算法学习心得(上)写得,所以读者要是偶然读到了这篇文章,可从上篇文章开始读起更容易理解
3. CART分类树和回归树
分类与回归树(classification and regression tree, CART)模型同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归。 CART算法由以下两步组成
(1)决策树生成:基于训练数据集生成决策树,牛成的决策树要尽量大;
(2)决策树剪枝:用验证数据集对己生成的树进行剪枝并选择最优子树,这时用损失函数最小作为剪枝的标准。
分类树的生成
分类树用基尼指数选择最优特征,同时决定该特征的最优二值切分点.
基尼指数:分类问题中,假设有K个类,样本点属于第k类的概率为pk,则概率分布的基尼指数定义为
对于给定的样本集合D,其基尼指数为
如果样本集合D根据特征A是否取某一可能值a被分割成D1和D2两部分,则在特征A的条件下,集合D的基尼指数定义为
下面贴出源码
import operator
import treePlotter
from itertools import combinations
# 创建本地数据集
def createDataSet():
dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calGini(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
gini = 1
for label in labelCounts.keys():
prop = float(labelCounts[label]) / numEntries
gini -= prop * prop
return gini
# 传入的是一个特征值的列表,返回特征值二分的结果
def featuresplit(features):
count = len(features) # 特征值的个数
if count < 2:
print("please check sample's features,only one feature value")
return -1
# 由于需要返回二分结果,所以每个分支至少需要一个特征值,所以要从所有的特征组合中选取1个以上的组合
# itertools的combinations 函数可以返回一个列表选多少个元素的组合结果,例如combinations(list,2)返回的列表元素选2个的组合
# 我们需要选择1-(count-1)的组合
combinationsList = []
resList = []
# 遍历所有的组合
for i in range(1, count):
temp_combination = list(combinations(features, len(features[0:i])))
combinationsList.extend(temp_combination)
print(combinationsList)
combiLen = len(combinationsList)
comb_mean = int(combiLen/2)
# 每次组合的顺序都是一致的,并且也是对称的,所以我们取首尾组合集合
# zip函数提供了两个列表对应位置组合的功能
resList = zip(combinationsList[0:comb_mean], combinationsList[combiLen - 1:comb_mean - 1:-1])
return resList
def splitDataSet(dataSet, axis, values):
retDataSet = []
for featVec in dataSet:
for value in values:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] # 剔除样本集
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 返回最好的特征以及二分特征值
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
bestGiniGain = 1.0
bestFeature = -1
bestBinarySplit = ()
for i in range(numFeatures): # 遍历特征
featList = [example[i] for example in dataSet] # 得到特征列
uniqueVals = list(set(featList)) # 从特征列获取该特征的特征值的set集合
# 三个特征值的二分结果:
# [(('young',), ('old', 'middle')), (('old',), ('young', 'middle')), (('middle',), ('young', 'old'))]
for split in featuresplit(uniqueVals):
GiniGain = 0.0
if len(split) == 1:
continue
(left, right) = split
# 对于每一个可能的二分结果计算gini增益
# 左增益
left_subDataSet = splitDataSet(dataSet, i, left)
left_prob = len(left_subDataSet) / float(len(dataSet))
GiniGain += left_prob * calGini(left_subDataSet)
# 右增益
right_subDataSet = splitDataSet(dataSet, i, right)
right_prob = len(right_subDataSet) / float(len(dataSet))
GiniGain += right_prob * calGini(right_subDataSet)
if GiniGain <= bestGiniGain: # 比较是否是最好的结果
bestGiniGain = GiniGain # 记录最好的结果和最好的特征
bestFeature = i
bestBinarySplit = (left, right)
return bestFeature, bestBinarySplit
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# print dataSet
if classList.count(classList[0]) == len(classList):
return classList[0] # 所有的类别都一样,就不用再划分了
if len(dataSet[0]) == 1: # 如果没有继续可以划分的特征,就多数表决决定分支的类别
# print "here"
return majorityCnt(classList)
bestFeat, bestBinarySplit = chooseBestFeatureToSplit(dataSet)
# print bestFeat,bestBinarySplit,labels
bestFeatLabel = labels[bestFeat]
if bestFeat == -1:
return majorityCnt(classList)
myTree = {bestFeatLabel: {}}
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = list(set(featValues))
for value in bestBinarySplit:
subLabels = labels[:] # #拷贝防止其他地方修改
min_value = value
if len(value) < 2:
del (subLabels[bestFeat])
min_value = value[0]
myTree[bestFeatLabel][min_value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
if __name__ == '__main__':
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)
下面是与之对应的图
不过对于CART分类树我现在还是有一个地方没弄出来,就是它的剪枝算法到底是该怎么计算。回归算法里是取两个左右叶子结点结果的平均值作为新的叶子结点值,可是对于分类树两个叶子结点均为离散值,怎么计算。如果有哪位大佬偶然看到了这篇文章,还请不吝赐教,感激不尽
CART生成:对回归树用平方误差最小化准则,对分类树用基尼指数(Gini index)最小化准则,进行特征选择。
回归树的生成:
下面就是回归树的源码了
from numpy import *
# 记载本地数据
import treePlotter
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = map(float, curLine)
dataMat.append(fltLine)
return dataMat
# 返回切分后的数据集
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
return mat0, mat1
# 叶节点生成函数
def regLeaf(dataSet):
return mean(dataSet[:, -1])
# 误差估计函数
def regErr(dataSet):
return var(dataSet[:, -1]) * shape(dataSet)[0]
# 找到数据的最佳二元切分方式
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
tolS = ops[0]
tolN = ops[1]
a = (dataSet[:, -1].T.tolist()[0])
if len(set(a)) == 1:
return None, leafType(dataSet)
m, n = shape(dataSet)
S = errType(dataSet)
bestS = inf
bestIndex = 0
bestValue = 0
for featIndex in range(n - 1):
for splitVal in (dataSet[:, featIndex]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex, bestValue
# 递归函数创建树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat is None:
return val
retTree = {'spInd': feat, 'spVal': val}
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
# 用于测试输入变量是否为一棵树
def isTree(obj):
return type(obj).__name__ == 'dict'
# 递归函数对树进行塌陷处理
def getMean(tree):
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['right'] + tree['left'])/2.0
# 回归树剪枝函数
def prune(tree, testData):
if shape(testData)[0] == 0:
return getMean(tree)
if (isTree(tree['right'])) or (isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:, -1]-tree['left'], 2)) + sum(power(rSet[:, -1]-tree['right'], 2))
treeMean = (tree['left'] + tree['right'])/2.0
errorMerge = sum(power(testData[:, -1] - treeMean, 2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else:
return tree
else:
return tree
if __name__ == '__main__':
myDat2 = loadDataSet('ex2.txt')
myMat2 = mat(myDat2)
regTree = createTree(myMat2, ops=(0, 1))
print(regTree)
myDataTest = loadDataSet('ex2test.txt')
myMat2Tst = mat(myDataTest)
tree = prune(regTree, myMat2Tst)
print(tree)
对了,这篇文章里所有代码用到的treePlotter类忘了给了,在这里补充出来
import matplotlib.pyplot as plt
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 预先储存树的相关信息
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
# 执行实际的绘图功能
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 在父子节点中间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.axl.text(xMid, yMid, txtString)
# 执行实际的绘图功能
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# 描述了树节点的常量
# def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# createPlot.axl = plt.subplot(111, frameon=False)
# plotNode('DecisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('LeafNode', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
# 描述了树节点的常量
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.axl = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# 获取叶节点的数目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 获取树的层数
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 使用决策树的分类函数
def classify(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
好了,这就是自己目前所整理的一些关于决策树的知识,感觉还只是一知半解。我将会在这两天好好消化,争取把问题都解决掉,给大家一篇完整的文章
大家好,抓耳挠腮了两三天之后终于是把CART分类树以及相关的剪枝操作的demo给弄出来了(两三天不眠不休刻苦钻研,真的要奖励自己一盘鸡腿。咳咳咳)
哈哈哈,闲话扯完了,直接进入正题了。上面我给出的CART分类树的demo是基于ID3算法实现的,其与CART回归树最大的区别主要体现在构建树的结构不同。ID3算法以及C4.5算法是通过构建字典来存储树的结构(上面我给出的CART分类树的demo也是如此) 而CART回归树则是通过构建树节点来存储树的结构。这两天我把机器学习实战和统计学习方法这两本书仔细看了之后比较发现:对于CART算法而言,无论是分类树还是回归树使用树节点来存储树的结构效果是最好的,因为CART算法最终构建的树都是二分树,所以用左右节点分别表示即可。具体代码可见下文(大家可要看仔细了,网上基本上很难找到相关的代码了,快点夸我,哈哈哈哈哈)
import operator
from numpy import shape
# 创建本地数据集
def createDataSet():
dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
# 切分数据
def splitDataSet(dataSet, axis, value):
retDataSet = []
subDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] # 剔除样本集
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
else:
reducedFeatVec = featVec[:axis] # 剔除样本集
reducedFeatVec.extend(featVec[axis + 1:])
subDataSet.append(reducedFeatVec)
return retDataSet, subDataSet
def calGini(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
gini = 1
for label in labelCounts.keys():
prop = float(labelCounts[label]) / numEntries
gini -= prop * prop
return gini
# 返回最好的特征以及二分特征值
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
bestGiniGain = 1.0
bestFeature = -1
bestValue = None
# 遍历特征
for i in range(numFeatures):
# 得到特征列
featList = [example[i] for example in dataSet]
# 从特征列获取该特征的特征值的set集合
uniqueVals = list(set(featList))
for value in uniqueVals:
GiniGain = 0.0
retDataSet, subDataSet = splitDataSet(dataSet, i, value)
left_prob = len(retDataSet) / float(len(dataSet))
GiniGain += left_prob * calGini(retDataSet)
right_prob = len(subDataSet) / float(len(dataSet))
GiniGain += right_prob * calGini(subDataSet)
# 比较是否是最好的结果
if GiniGain <= bestGiniGain:
# 记录最好的结果和最好的特征
bestGiniGain = GiniGain
bestFeature = i
bestValue = value
return bestFeature, bestValue
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
# 创建树的函数代码
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
bestFeature, bestValue = chooseBestFeatureToSplit(dataSet)
if classList.count(classList[0]) == len(classList):
# 所有的类别都一样,就不用再划分了
return classList[0]
# 如果没有继续可以划分的特征,就多数表决决定分支的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
reTree = {'spInd': bestFeature, 'spVal': bestValue}
retDataSet, subDataSet = splitDataSet(dataSet, bestFeature, bestValue)
subLabels = labels[:]
print("bestFeature:", bestFeature)
del(subLabels[bestFeature])
reTree['left'] = createTree(retDataSet, subLabels)
reTree['right'] = createTree(subDataSet, subLabels)
return reTree
# 用于测试输入变量是否为一棵树
def isTree(obj):
return type(obj).__name__ == 'dict'
# 回归树剪枝函数
def prune(tree, testData):
classList = [example[-1] for example in testData]
if shape(testData)[0] == 0:
print("数据集为空")
return majorityCnt(classList)
if (isTree(tree['right'])) or (isTree(tree['left'])):
lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
print("lSet[:, -1]:", lSet[0])
print("tree['left']:", tree['left'])
errorNoMerge = 0
if lSet[0] != tree['left']:
errorNoMerge += 1
if rSet[0] != tree['right']:
errorNoMerge += 1
treeMean = majorityCnt(classList)
errorMerge = 0
print("testData[:, -1]:", testData)
for item in testData:
if item[-1] != treeMean:
errorMerge += 1
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else:
return tree
else:
return tree
if __name__ == '__main__':
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print("lensesTree:", lensesTree)
fr = open('lenses2.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
cutTree = prune(lensesTree, lenses)
print("cutTree:", cutTree)
相关的测试集(lenses.txt lenses2.txt)我也给出来了,保证大家跑起来没问题
链接:CART决策树相关数据集 密码:y415
到了这儿算是大功告成了,明天终于松口气去看我的<复活>了