30 lines
957 B
Python
30 lines
957 B
Python
"""Test plotting functions for XGBoost."""
|
|
|
|
import json
|
|
|
|
from graphviz import Source
|
|
from matplotlib.axes import Axes
|
|
|
|
from ..plotting import plot_tree, to_graphviz
|
|
from ..sklearn import XGBRegressor
|
|
from .data import make_categorical
|
|
from .utils import Device
|
|
|
|
|
|
def run_categorical(tree_method: str, device: Device) -> None:
|
|
"""Tests plotting functions for categorical features."""
|
|
X, y = make_categorical(1000, 31, 19, onehot=False)
|
|
reg = XGBRegressor(
|
|
enable_categorical=True, n_estimators=10, tree_method=tree_method, device=device
|
|
)
|
|
reg.fit(X, y)
|
|
trees = reg.get_booster().get_dump(dump_format="json")
|
|
for tree in trees:
|
|
j_tree = json.loads(tree)
|
|
assert "leaf" in j_tree.keys() or isinstance(j_tree["split_condition"], list)
|
|
|
|
graph = to_graphviz(reg, tree_idx=len(j_tree) - 1)
|
|
assert isinstance(graph, Source)
|
|
ax = plot_tree(reg, tree_idx=len(j_tree) - 1)
|
|
assert isinstance(ax, Axes)
|