189 lines
5.3 KiB
Python
189 lines
5.3 KiB
Python
# pylint: disable=too-many-locals
|
|
"""Tests for callback functions."""
|
|
|
|
import json
|
|
from itertools import product
|
|
from typing import Dict, List, Tuple
|
|
|
|
from ..callback import LearningRateScheduler
|
|
from ..core import Booster, DMatrix
|
|
from ..training import cv, train
|
|
from .utils import Device
|
|
|
|
|
|
def run_eta_decay(
|
|
tree_method: str, dtrain: DMatrix, dtest: DMatrix, device: Device
|
|
) -> None:
|
|
"""Test learning rate scheduler, used by both CPU and GPU tests."""
|
|
scheduler = LearningRateScheduler
|
|
|
|
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
|
num_round = 4
|
|
|
|
# learning_rates as a list
|
|
# init eta with 0 to check whether learning_rates work
|
|
param = {
|
|
"max_depth": 2,
|
|
"eta": 0,
|
|
"objective": "binary:logistic",
|
|
"eval_metric": "error",
|
|
"tree_method": tree_method,
|
|
"device": device,
|
|
}
|
|
evals_result: Dict[str, Dict] = {}
|
|
bst = train(
|
|
param,
|
|
dtrain,
|
|
num_round,
|
|
evals=watchlist,
|
|
callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
|
|
evals_result=evals_result,
|
|
)
|
|
eval_errors_0 = list(map(float, evals_result["eval"]["error"]))
|
|
assert isinstance(bst, Booster)
|
|
# validation error should decrease, if eta > 0
|
|
assert eval_errors_0[0] > eval_errors_0[-1]
|
|
|
|
# init learning_rate with 0 to check whether learning_rates work
|
|
param = {
|
|
"max_depth": 2,
|
|
"learning_rate": 0,
|
|
"objective": "binary:logistic",
|
|
"eval_metric": "error",
|
|
"tree_method": tree_method,
|
|
"device": device,
|
|
}
|
|
evals_result = {}
|
|
|
|
bst = train(
|
|
param,
|
|
dtrain,
|
|
num_round,
|
|
evals=watchlist,
|
|
callbacks=[scheduler([0.8, 0.7, 0.6, 0.5])],
|
|
evals_result=evals_result,
|
|
)
|
|
eval_errors_1 = list(map(float, evals_result["eval"]["error"]))
|
|
assert isinstance(bst, Booster)
|
|
# validation error should decrease, if learning_rate > 0
|
|
assert eval_errors_1[0] > eval_errors_1[-1]
|
|
|
|
# check if learning_rates override default value of eta/learning_rate
|
|
param = {
|
|
"max_depth": 2,
|
|
"objective": "binary:logistic",
|
|
"eval_metric": "error",
|
|
"tree_method": tree_method,
|
|
"device": device,
|
|
}
|
|
evals_result = {}
|
|
bst = train(
|
|
param,
|
|
dtrain,
|
|
num_round,
|
|
evals=watchlist,
|
|
callbacks=[scheduler([0, 0, 0, 0])],
|
|
evals_result=evals_result,
|
|
)
|
|
eval_errors_2 = list(map(float, evals_result["eval"]["error"]))
|
|
assert isinstance(bst, Booster)
|
|
# validation error should not decrease, if eta/learning_rate = 0
|
|
assert eval_errors_2[0] == eval_errors_2[-1]
|
|
|
|
# learning_rates as a customized decay function
|
|
def eta_decay(ithround: int, num_boost_round: int = num_round) -> float:
|
|
return num_boost_round / (ithround + 1)
|
|
|
|
evals_result = {}
|
|
bst = train(
|
|
param,
|
|
dtrain,
|
|
num_round,
|
|
evals=watchlist,
|
|
callbacks=[scheduler(eta_decay)],
|
|
evals_result=evals_result,
|
|
)
|
|
eval_errors_3 = list(map(float, evals_result["eval"]["error"]))
|
|
|
|
assert isinstance(bst, Booster)
|
|
|
|
assert eval_errors_3[0] == eval_errors_2[0]
|
|
|
|
for i in range(1, len(eval_errors_0)):
|
|
assert eval_errors_3[i] != eval_errors_2[i]
|
|
|
|
cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
|
|
|
|
|
|
def tree_methods_objs() -> List[Tuple[str, str]]:
|
|
"""Test parameters for the leaf output test."""
|
|
return list(
|
|
product(
|
|
["approx", "hist"],
|
|
[
|
|
"binary:logistic",
|
|
"reg:absoluteerror",
|
|
"reg:quantileerror",
|
|
],
|
|
)
|
|
)
|
|
|
|
|
|
def run_eta_decay_leaf_output(
|
|
tree_method: str, objective: str, dtrain: DMatrix, dtest: DMatrix, device: Device
|
|
) -> None:
|
|
"""check decay has effect on leaf output."""
|
|
num_round = 4
|
|
scheduler = LearningRateScheduler
|
|
|
|
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
|
|
|
param = {
|
|
"max_depth": 2,
|
|
"objective": objective,
|
|
"eval_metric": "error",
|
|
"tree_method": tree_method,
|
|
"device": device,
|
|
}
|
|
if objective == "reg:quantileerror":
|
|
param["quantile_alpha"] = 0.3
|
|
|
|
def eta_decay_0(i: int) -> float:
|
|
return num_round / (i + 1)
|
|
|
|
bst0 = train(
|
|
param,
|
|
dtrain,
|
|
num_round,
|
|
evals=watchlist,
|
|
callbacks=[scheduler(eta_decay_0)],
|
|
)
|
|
|
|
def eta_decay_1(i: int) -> float:
|
|
if i > 1:
|
|
return 5.0
|
|
return num_round / (i + 1)
|
|
|
|
bst1 = train(
|
|
param,
|
|
dtrain,
|
|
num_round,
|
|
evals=watchlist,
|
|
callbacks=[scheduler(eta_decay_1)],
|
|
)
|
|
bst_json0 = bst0.save_raw(raw_format="json")
|
|
bst_json1 = bst1.save_raw(raw_format="json")
|
|
|
|
j0 = json.loads(bst_json0)
|
|
j1 = json.loads(bst_json1)
|
|
|
|
tree_2th_0 = j0["learner"]["gradient_booster"]["model"]["trees"][2]
|
|
tree_2th_1 = j1["learner"]["gradient_booster"]["model"]["trees"][2]
|
|
assert tree_2th_0["base_weights"] == tree_2th_1["base_weights"]
|
|
assert tree_2th_0["split_conditions"] == tree_2th_1["split_conditions"]
|
|
|
|
tree_3th_0 = j0["learner"]["gradient_booster"]["model"]["trees"][3]
|
|
tree_3th_1 = j1["learner"]["gradient_booster"]["model"]["trees"][3]
|
|
assert tree_3th_0["base_weights"] != tree_3th_1["base_weights"]
|
|
assert tree_3th_0["split_conditions"] != tree_3th_1["split_conditions"]
|