108 lines
3.3 KiB
Python
108 lines
3.3 KiB
Python
try:
|
|
from xgboost.callback import TrainingCallback as XGBTrainingCallback
|
|
except Exception:
|
|
class XGBTrainingCallback:
|
|
pass
|
|
|
|
|
|
from IPython.display import display
|
|
|
|
from .metrics_plotter import MetricsPlotter
|
|
|
|
|
|
class XGBPlottingCallback(XGBTrainingCallback):
|
|
'''XGBoost callback with metrics plotting widget from CatBoost
|
|
'''
|
|
|
|
def __init__(self, total_iterations: int):
|
|
self.plotter = None
|
|
self.total_iterations = total_iterations
|
|
|
|
def after_iteration(self, model, epoch, evals_log):
|
|
|
|
data_names = evals_log.keys()
|
|
# if more than one sample is passed, consider first as train sample
|
|
first_train = (
|
|
all(['valid' in data_name.lower() for data_name in data_names])
|
|
and len(data_names) > 1
|
|
)
|
|
|
|
for data_name, metrics_info in evals_log.items():
|
|
if "train" in data_name.lower() or first_train:
|
|
train = True
|
|
first_train = False
|
|
elif "valid" in data_name.lower() or "test" in data_name.lower():
|
|
train = False
|
|
else:
|
|
raise Exception("Unexpected sample name during evaluation")
|
|
|
|
metrics = {name: values[-1] for name, values in metrics_info.items()}
|
|
|
|
if self.plotter is None:
|
|
names = list(metrics.keys())
|
|
self.plotter = MetricsPlotter(names, names, self.total_iterations)
|
|
display(self.plotter._widget)
|
|
|
|
self.plotter.log(epoch, train, metrics)
|
|
|
|
# False to indicate training should not stop.
|
|
return False
|
|
|
|
|
|
def lgbm_plotting_callback():
|
|
"""LightGBM callback with metrics plotting widget from CatBoost
|
|
"""
|
|
|
|
plotter = None
|
|
|
|
def _init(env):
|
|
train_metrics = []
|
|
test_metrics = []
|
|
|
|
for item in env.evaluation_result_list:
|
|
assert len(item) == 4, "Plotting was run in not suppored mode"
|
|
data_name, eval_name = item[:2]
|
|
|
|
if "train" in data_name.lower():
|
|
train_metrics.append(eval_name)
|
|
elif "valid" in data_name.lower() or "test" in data_name.lower():
|
|
test_metrics.append(eval_name)
|
|
else:
|
|
raise Exception("Unexpected sample name during evaluation")
|
|
|
|
nonlocal plotter
|
|
plotter = MetricsPlotter(
|
|
train_metrics, test_metrics, env.end_iteration - env.begin_iteration)
|
|
|
|
display(plotter._widget)
|
|
|
|
def _callback(env):
|
|
if plotter is None:
|
|
_init(env)
|
|
|
|
metrics = {"train": {}, "test": {}}
|
|
|
|
for item in env.evaluation_result_list:
|
|
data_name, eval_name, result = item[:3]
|
|
|
|
if "train" in data_name.lower():
|
|
metrics["train"][eval_name] = result
|
|
elif "valid" in data_name.lower() or "test" in data_name.lower():
|
|
metrics["test"][eval_name] = result
|
|
else:
|
|
raise Exception("Unexpected sample name during evaluation")
|
|
|
|
plotter.log(
|
|
env.iteration - env.begin_iteration,
|
|
train=True,
|
|
metrics=metrics["train"]
|
|
)
|
|
plotter.log(
|
|
env.iteration - env.begin_iteration,
|
|
train=False,
|
|
metrics=metrics["test"]
|
|
)
|
|
|
|
_callback.order = 20
|
|
return _callback
|