Scikit-Learn自带了iris数据集合,iris集合具体介绍请参考第3章。完整演示代码请见本书GitHub上的6-1.py。
导入需要的函数库,加载iris数据集:
from sklearn.datasets import load_iris from sklearn import tree import pydotplus iris = load_iris()
使用决策树算法进行训练,并将训练得到的决策树保存成pdf文件:
clf = tree.DecisionTreeClassifier() clf = clf.fit(iris.data, iris.target) dot_data = tree.export_graphviz(clf, out_file=None) graph = pydotplus.graph_from_dot_data(dot_data) graph.write_pdf("../photo/6/iris.pdf")
训练得到的可视化决策树如图6-2所示。
图6-2 iris数据集得到的可视化决策树