暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

Kaggle知识点:dtreeviz 可视化决策树

Coggle数据科学 2023-12-28
786

决策树可视化

决策树是梯度提升机和随机森林的基本构建块,可能是结构化数据中两个最流行的机器学习模型。可视化决策树在学习这些模型的工作原理和解释模型时是一种巨大的帮助。

不幸的是,当前的可视化包较为基础,对初学者来说并不直观。例如,我们找不到一个库能够展示决策节点如何分割特征空间。而且,支持以图形方式展示特定特征向量沿着树的决策节点的情况在库中也是不常见的;我们只能找到一张展示这一点的图片。

决策树回顾

决策树是一种基于二叉树(最多具有左右子节点的树)的机器学习模型。决策树通过检查和压缩训练数据形成一个由内部节点和叶节点组成的二叉树,从而学习训练集中观测之间的关系。

决策树中的每个叶子都负责做出具体的预测。对于回归树,预测是一个值,例如价格。对于分类器树,预测是一个目标类别,例如癌症或非癌症。决策树将特征空间划分为具有相似目标值的观测组,并且每个叶子代表其中的一组。对于回归,叶子中的相似性意味着目标值之间的方差较小,而对于分类,它意味着大多数或所有目标都属于单一类别。

从决策树的根到特定叶子预测的任何路径都通过一系列(内部)决策节点。每个决策节点将特征 x 中的单个特征值 xi 与在训练期间学到的特定分割点值进行比较。

决策树可视化的关键要素

  1. 决策节点特征与目标值分布(在本文中我们称之为特征-目标空间): 我们想要了解基于特征和分割点的目标值是如何可分的。
  2. 决策节点特征名称和特征分割值: 我们需要知道每个决策节点正在测试的特征以及在该空间中节点如何将观测分割。
  3. 叶子节点纯度: 这影响我们对预测的信心。在目标值之间方差较小(回归)或目标类别占绝大多数(分类)的叶子更可靠。
  4. 叶子节点预测值: 这个叶子实际上从目标值的集合中预测什么?
  5. 决策节点中的样本数量: 有时知道大多数样本是如何通过决策节点进行路由的是有用的。
  6. 叶子节点中的样本数量: 我们的目标是具有较少、较大和更纯净的叶子的决策树。样本太少的节点可能是过拟合的迹象。
  7. 特定特征向量如何沿着树运行到达叶子节点: 这有助于解释为什么特定特征向量会得到它所预测的结果。

dtreeviz介绍

dtreeviz
是一个用于决策树可视化和模型解释的 Python 库。决策树是梯度提升机和随机森林(tm)的基本构建块,这两者可能是结构化数据最流行的机器学习模型。在学习这些模型的工作原理和解释模型时,可视化决策树是一种巨大的帮助。

目前,dtreeviz 支持:scikit-learn、XGBoost、Spark MLlib、LightGBM 和 Tensorflow。

安装方法

pip install dtreeviz             # install dtreeviz for sklearn
pip install dtreeviz[xgboost]    # install XGBoost related dependency
pip install dtreeviz[pyspark]    # install pyspark related dependency
pip install dtreeviz[lightgbm]   # install LightGBM related dependency
pip install dtreeviz[tensorflow_decision_forests]   # install tensorflow_decision_forests related dependency
pip install dtreeviz[all]        # install all related dependencies

使用案例

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

import dtreeviz

iris = load_iris()
X = iris.data
y = iris.target

clf = DecisionTreeClassifier(max_depth=4)
clf.fit(X, y)

viz_model = dtreeviz.model(clf,
                           X_train=X, y_train=y,
                           feature_names=iris.feature_names,
                           target_name='iris',
                           class_names=iris.target_names)

v = viz_model.view()     # render as SVG into internal object 
v.show()                 # pop up window
v.save("/tmp/iris.svg")  # optionally save as svg

可视化结果

Tree visualizations

Prediction path explanations

Leaf information

Feature space exploration

学习大模型、推荐系统、算法竞赛
添加👇微信拉你进群
加入了之前的社群不需要重复添加~


文章转载自Coggle数据科学,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论