Files
MLPproject/.venv/lib/python3.12/site-packages/xgboost/testing/callbacks.py
2025-10-23 15:44:32 +02:00

189 lines
5.3 KiB
Python

# pylint: disable=too-many-locals
"""Tests for callback functions."""
import json
from itertools import product
from typing import Dict, List, Tuple
from ..callback import LearningRateScheduler
from ..core import Booster, DMatrix
from ..training import cv, train
from .utils import Device
def run_eta_decay(
tree_method: str, dtrain: DMatrix, dtest: DMatrix, device: Device
) -> None:
"""Test learning rate scheduler, used by both CPU and GPU tests."""
scheduler = LearningRateScheduler
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 4
# learning_rates as a list
# init eta with 0 to check whether learning_rates work
param = {
"max_depth": 2,
"eta": 0,
"objective": "binary:logistic",
"eval_metric": "error",
"tree_method": tree_method,
"device": device,
}
evals_result: Dict[str, Dict] = {}
bst = train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
evals_result=evals_result,
)
eval_errors_0 = list(map(float, evals_result["eval"]["error"]))
assert isinstance(bst, Booster)
# validation error should decrease, if eta > 0
assert eval_errors_0[0] > eval_errors_0[-1]
# init learning_rate with 0 to check whether learning_rates work
param = {
"max_depth": 2,
"learning_rate": 0,
"objective": "binary:logistic",
"eval_metric": "error",
"tree_method": tree_method,
"device": device,
}
evals_result = {}
bst = train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
evals_result=evals_result,
)
eval_errors_1 = list(map(float, evals_result["eval"]["error"]))
assert isinstance(bst, Booster)
# validation error should decrease, if learning_rate > 0
assert eval_errors_1[0] > eval_errors_1[-1]
# check if learning_rates override default value of eta/learning_rate
param = {
"max_depth": 2,
"objective": "binary:logistic",
"eval_metric": "error",
"tree_method": tree_method,
"device": device,
}
evals_result = {}
bst = train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler([0, 0, 0, 0])],
evals_result=evals_result,
)
eval_errors_2 = list(map(float, evals_result["eval"]["error"]))
assert isinstance(bst, Booster)
# validation error should not decrease, if eta/learning_rate = 0
assert eval_errors_2[0] == eval_errors_2[-1]
# learning_rates as a customized decay function
def eta_decay(ithround: int, num_boost_round: int = num_round) -> float:
return num_boost_round / (ithround + 1)
evals_result = {}
bst = train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler(eta_decay)],
evals_result=evals_result,
)
eval_errors_3 = list(map(float, evals_result["eval"]["error"]))
assert isinstance(bst, Booster)
assert eval_errors_3[0] == eval_errors_2[0]
for i in range(1, len(eval_errors_0)):
assert eval_errors_3[i] != eval_errors_2[i]
cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
def tree_methods_objs() -> List[Tuple[str, str]]:
"""Test parameters for the leaf output test."""
return list(
product(
["approx", "hist"],
[
"binary:logistic",
"reg:absoluteerror",
"reg:quantileerror",
],
)
)
def run_eta_decay_leaf_output(
tree_method: str, objective: str, dtrain: DMatrix, dtest: DMatrix, device: Device
) -> None:
"""check decay has effect on leaf output."""
num_round = 4
scheduler = LearningRateScheduler
watchlist = [(dtest, "eval"), (dtrain, "train")]
param = {
"max_depth": 2,
"objective": objective,
"eval_metric": "error",
"tree_method": tree_method,
"device": device,
}
if objective == "reg:quantileerror":
param["quantile_alpha"] = 0.3
def eta_decay_0(i: int) -> float:
return num_round / (i + 1)
bst0 = train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler(eta_decay_0)],
)
def eta_decay_1(i: int) -> float:
if i > 1:
return 5.0
return num_round / (i + 1)
bst1 = train(
param,
dtrain,
num_round,
evals=watchlist,
callbacks=[scheduler(eta_decay_1)],
)
bst_json0 = bst0.save_raw(raw_format="json")
bst_json1 = bst1.save_raw(raw_format="json")
j0 = json.loads(bst_json0)
j1 = json.loads(bst_json1)
tree_2th_0 = j0["learner"]["gradient_booster"]["model"]["trees"][2]
tree_2th_1 = j1["learner"]["gradient_booster"]["model"]["trees"][2]
assert tree_2th_0["base_weights"] == tree_2th_1["base_weights"]
assert tree_2th_0["split_conditions"] == tree_2th_1["split_conditions"]
tree_3th_0 = j0["learner"]["gradient_booster"]["model"]["trees"][3]
tree_3th_1 = j1["learner"]["gradient_booster"]["model"]["trees"][3]
assert tree_3th_0["base_weights"] != tree_3th_1["base_weights"]
assert tree_3th_0["split_conditions"] != tree_3th_1["split_conditions"]