Files
MLPproject/.venv/lib/python3.12/site-packages/catboost/widget/metrics_plotter.py
2025-10-23 15:44:32 +02:00

209 lines
7.4 KiB
Python

import time
import logging
import typing as tp
from IPython.display import display
from copy import deepcopy
from typing import List, Optional, Any, Union
from .ipythonwidget import MetricWidget
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class MetricsWidget(MetricWidget):
def __init__(self):
super(self.__class__, self).__init__()
def update_data(self, data: tp.Dict) -> None:
# deepcopy is crucial here
self.data = deepcopy(data)
class MetricsPlotter:
"""
Context manager that enables widget with learning curves in
JupyterLab / Jupyter Notebook
"""
def __init__(self, train_metrics: List[Union[str, tp.Dict[str, str]]],
test_metrics: Optional[List[Union[str, tp.Dict[str, str]]]] = None,
total_iterations: Optional[int] = None) -> None:
"""
Constructor that defines metrics to be plotted and total iterations count.
Parameters
----------
train_metrics : list of str or list of dict
List of train metrics to be tracked.
Each item in list can be either string with metric name or dict
with the following format:
{
"name": "{metric_name}",
"best_value": "Max|Min|Undefined",
}
test_metrics : list of str or list of dict, optional (default=None)
List of test metrics to be tracked.
Has the same format as train_metrics. Equals to train_metrics, if not defined
total_iterations: int, optional (default=None)
Total number of iterations, allows for remaining time estimation.
"""
self._widget = MetricsWidget()
self._values = {
"iterations": [],
"meta": {
"launch_mode": "Train",
"parameters": "",
"name": "experiment",
"iteration_count": None, # set later
"learn_sets": ["learn"],
"learn_metrics": None, # set later
"test_sets": ["test"],
"test_metrics": None, # set later
}
}
self._content = {
"passed_iterations": 0,
"total_iterations": None, # set later
"data": self._values,
}
# data propagated to widget class
self._data = {
"test_path": {
"path": "test_path",
"name": "experiment",
"content": self._content,
}
}
test_metrics = test_metrics or train_metrics
train_metrics_meta: List[tp.Dict[str, str]] = self.construct_metrics_meta(train_metrics)
test_metrics_meta: List[tp.Dict[str, str]] = self.construct_metrics_meta(test_metrics)
self._train_metrics_positions = {
meta["name"]: pos for pos, meta in enumerate(train_metrics_meta)}
self._test_metrics_positions = {
meta["name"]: pos for pos, meta in enumerate(test_metrics_meta)}
self._values["meta"].update({
"learn_metrics": train_metrics_meta,
"test_metrics": test_metrics_meta,
})
self.passed_iterations = 0
self.total_iterations = total_iterations
if total_iterations is not None:
self._values["meta"]["iteration_count"] = total_iterations
self._content["total_iterations"] = total_iterations or 0
self._start_time = time.time()
def __enter__(self) -> 'MetricsPlotter':
display(self._widget)
return self
def __exit__(self, exc_type, exc_value, traceback) -> Any:
if exc_type == KeyboardInterrupt:
logger.info(
f"Learning was stopped manually after {self.passed_iterations} epochs")
return True
@staticmethod
def construct_metrics_meta(metrics: List[Union[str, tp.Dict[str, str]]]) -> List[tp.Dict[str, str]]:
meta: List[tp.Dict[str, str]] = []
for item in metrics:
if isinstance(item, str):
name, best_value = item, "Undefined"
elif isinstance(item, dict):
assert "name" in item and "best_value" in item, \
"Wrong metrics definition format: should have " \
"`name` and `best_value` fields"
name, best_value = item["name"], item["best_value"]
else:
assert False, "Each metric should be defined as str or as" \
"dict with `name` and `best_value` fields"
meta.append({"best_value": best_value, "name": name})
return meta
@staticmethod
def construct_metrics_array(metrics_positions: tp.Dict[str, int],
metrics: tp.Dict[str, float]) -> List[float]:
array: List[float] = [0.] * len(metrics_positions)
# data validation
assert set(metrics.keys()) == set(metrics_positions.keys()), \
f"Not all metrics were passed while logging, expected " \
f"following: {', '.join(list(metrics_positions.keys()))}"
for metric, value in metrics.items():
assert isinstance(value, float), "Type of metric {metric} should be float"
array[metrics_positions[metric]] = value
return array
def estimate_remaining_time(self, time_from_start: float) -> Optional[float]:
if self.total_iterations is None:
return None
remaining_iterations: int = self.total_iterations - self.passed_iterations
return time_from_start / self.passed_iterations * remaining_iterations
def log(self, epoch: int, train: bool, metrics: tp.Dict[str, float]) -> None:
"""
Save metrics at specific training epoch.
Parameters
----------
epoch : int
Current epoch
train : bool
Flag that indicates whether metrics are calculated on train or test data
metrics: dict
Values for each of metrics defined in `__init__` method of this class
"""
self.passed_iterations = epoch + 1
self._content["passed_iterations"] = self.passed_iterations
total_iterations = max(self._content["total_iterations"], self.passed_iterations)
self._content["total_iterations"] = total_iterations
self._values["meta"]["iteration_count"] = total_iterations
assert len(self._values["iterations"]) in (epoch, epoch + 1), \
"Data for epochs should be passed successively (wrong epoch number)"
time_from_start: float = time.time() - self._start_time
should_redraw: bool = len(self._values["iterations"]) == epoch + 1
if len(self._values["iterations"]) == epoch:
self._values["iterations"].append({
"learn": [],
"test": [],
"iteration": epoch,
"passed_time": time_from_start,
})
remaining_time = self.estimate_remaining_time(time_from_start)
if remaining_time is not None:
self._values["iterations"][-1]["remaining_time"] = remaining_time
key: str = "learn" if train else "test"
value: List[float] = self.construct_metrics_array(
self._train_metrics_positions if train else self._test_metrics_positions, metrics)
self._values["iterations"][-1].update({key: value})
if should_redraw:
self._widget.update_data(self._data)