Files
MLPproject/.venv/lib/python3.12/site-packages/catboost/widget/callbacks.py
2025-10-23 15:44:32 +02:00

108 lines
3.3 KiB
Python

try:
from xgboost.callback import TrainingCallback as XGBTrainingCallback
except Exception:
class XGBTrainingCallback:
pass
from IPython.display import display
from .metrics_plotter import MetricsPlotter
class XGBPlottingCallback(XGBTrainingCallback):
'''XGBoost callback with metrics plotting widget from CatBoost
'''
def __init__(self, total_iterations: int):
self.plotter = None
self.total_iterations = total_iterations
def after_iteration(self, model, epoch, evals_log):
data_names = evals_log.keys()
# if more than one sample is passed, consider first as train sample
first_train = (
all(['valid' in data_name.lower() for data_name in data_names])
and len(data_names) > 1
)
for data_name, metrics_info in evals_log.items():
if "train" in data_name.lower() or first_train:
train = True
first_train = False
elif "valid" in data_name.lower() or "test" in data_name.lower():
train = False
else:
raise Exception("Unexpected sample name during evaluation")
metrics = {name: values[-1] for name, values in metrics_info.items()}
if self.plotter is None:
names = list(metrics.keys())
self.plotter = MetricsPlotter(names, names, self.total_iterations)
display(self.plotter._widget)
self.plotter.log(epoch, train, metrics)
# False to indicate training should not stop.
return False
def lgbm_plotting_callback():
"""LightGBM callback with metrics plotting widget from CatBoost
"""
plotter = None
def _init(env):
train_metrics = []
test_metrics = []
for item in env.evaluation_result_list:
assert len(item) == 4, "Plotting was run in not suppored mode"
data_name, eval_name = item[:2]
if "train" in data_name.lower():
train_metrics.append(eval_name)
elif "valid" in data_name.lower() or "test" in data_name.lower():
test_metrics.append(eval_name)
else:
raise Exception("Unexpected sample name during evaluation")
nonlocal plotter
plotter = MetricsPlotter(
train_metrics, test_metrics, env.end_iteration - env.begin_iteration)
display(plotter._widget)
def _callback(env):
if plotter is None:
_init(env)
metrics = {"train": {}, "test": {}}
for item in env.evaluation_result_list:
data_name, eval_name, result = item[:3]
if "train" in data_name.lower():
metrics["train"][eval_name] = result
elif "valid" in data_name.lower() or "test" in data_name.lower():
metrics["test"][eval_name] = result
else:
raise Exception("Unexpected sample name during evaluation")
plotter.log(
env.iteration - env.begin_iteration,
train=True,
metrics=metrics["train"]
)
plotter.log(
env.iteration - env.begin_iteration,
train=False,
metrics=metrics["test"]
)
_callback.order = 20
return _callback