Files
MLPproject/.venv/lib/python3.12/site-packages/xgboost/testing/plotting.py
2025-10-23 15:44:32 +02:00

30 lines
957 B
Python

"""Test plotting functions for XGBoost."""
import json
from graphviz import Source
from matplotlib.axes import Axes
from ..plotting import plot_tree, to_graphviz
from ..sklearn import XGBRegressor
from .data import make_categorical
from .utils import Device
def run_categorical(tree_method: str, device: Device) -> None:
"""Tests plotting functions for categorical features."""
X, y = make_categorical(1000, 31, 19, onehot=False)
reg = XGBRegressor(
enable_categorical=True, n_estimators=10, tree_method=tree_method, device=device
)
reg.fit(X, y)
trees = reg.get_booster().get_dump(dump_format="json")
for tree in trees:
j_tree = json.loads(tree)
assert "leaf" in j_tree.keys() or isinstance(j_tree["split_condition"], list)
graph = to_graphviz(reg, tree_idx=len(j_tree) - 1)
assert isinstance(graph, Source)
ax = plot_tree(reg, tree_idx=len(j_tree) - 1)
assert isinstance(ax, Axes)