209 lines
7.4 KiB
Python
209 lines
7.4 KiB
Python
import time
|
|
import logging
|
|
|
|
import typing as tp
|
|
|
|
from IPython.display import display
|
|
from copy import deepcopy
|
|
from typing import List, Optional, Any, Union
|
|
|
|
from .ipythonwidget import MetricWidget
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
class MetricsWidget(MetricWidget):
|
|
def __init__(self):
|
|
super(self.__class__, self).__init__()
|
|
|
|
def update_data(self, data: tp.Dict) -> None:
|
|
# deepcopy is crucial here
|
|
self.data = deepcopy(data)
|
|
|
|
|
|
class MetricsPlotter:
|
|
"""
|
|
Context manager that enables widget with learning curves in
|
|
JupyterLab / Jupyter Notebook
|
|
"""
|
|
|
|
def __init__(self, train_metrics: List[Union[str, tp.Dict[str, str]]],
|
|
test_metrics: Optional[List[Union[str, tp.Dict[str, str]]]] = None,
|
|
total_iterations: Optional[int] = None) -> None:
|
|
"""
|
|
Constructor that defines metrics to be plotted and total iterations count.
|
|
|
|
Parameters
|
|
----------
|
|
train_metrics : list of str or list of dict
|
|
List of train metrics to be tracked.
|
|
Each item in list can be either string with metric name or dict
|
|
with the following format:
|
|
{
|
|
"name": "{metric_name}",
|
|
"best_value": "Max|Min|Undefined",
|
|
}
|
|
|
|
test_metrics : list of str or list of dict, optional (default=None)
|
|
List of test metrics to be tracked.
|
|
Has the same format as train_metrics. Equals to train_metrics, if not defined
|
|
|
|
total_iterations: int, optional (default=None)
|
|
Total number of iterations, allows for remaining time estimation.
|
|
"""
|
|
|
|
self._widget = MetricsWidget()
|
|
|
|
self._values = {
|
|
"iterations": [],
|
|
"meta": {
|
|
"launch_mode": "Train",
|
|
"parameters": "",
|
|
"name": "experiment",
|
|
"iteration_count": None, # set later
|
|
"learn_sets": ["learn"],
|
|
"learn_metrics": None, # set later
|
|
"test_sets": ["test"],
|
|
"test_metrics": None, # set later
|
|
}
|
|
}
|
|
|
|
self._content = {
|
|
"passed_iterations": 0,
|
|
"total_iterations": None, # set later
|
|
"data": self._values,
|
|
}
|
|
|
|
# data propagated to widget class
|
|
self._data = {
|
|
"test_path": {
|
|
"path": "test_path",
|
|
"name": "experiment",
|
|
"content": self._content,
|
|
}
|
|
}
|
|
|
|
test_metrics = test_metrics or train_metrics
|
|
train_metrics_meta: List[tp.Dict[str, str]] = self.construct_metrics_meta(train_metrics)
|
|
test_metrics_meta: List[tp.Dict[str, str]] = self.construct_metrics_meta(test_metrics)
|
|
|
|
self._train_metrics_positions = {
|
|
meta["name"]: pos for pos, meta in enumerate(train_metrics_meta)}
|
|
self._test_metrics_positions = {
|
|
meta["name"]: pos for pos, meta in enumerate(test_metrics_meta)}
|
|
|
|
self._values["meta"].update({
|
|
"learn_metrics": train_metrics_meta,
|
|
"test_metrics": test_metrics_meta,
|
|
})
|
|
|
|
self.passed_iterations = 0
|
|
self.total_iterations = total_iterations
|
|
|
|
if total_iterations is not None:
|
|
self._values["meta"]["iteration_count"] = total_iterations
|
|
|
|
self._content["total_iterations"] = total_iterations or 0
|
|
|
|
self._start_time = time.time()
|
|
|
|
def __enter__(self) -> 'MetricsPlotter':
|
|
display(self._widget)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback) -> Any:
|
|
if exc_type == KeyboardInterrupt:
|
|
logger.info(
|
|
f"Learning was stopped manually after {self.passed_iterations} epochs")
|
|
return True
|
|
|
|
@staticmethod
|
|
def construct_metrics_meta(metrics: List[Union[str, tp.Dict[str, str]]]) -> List[tp.Dict[str, str]]:
|
|
meta: List[tp.Dict[str, str]] = []
|
|
for item in metrics:
|
|
if isinstance(item, str):
|
|
name, best_value = item, "Undefined"
|
|
elif isinstance(item, dict):
|
|
assert "name" in item and "best_value" in item, \
|
|
"Wrong metrics definition format: should have " \
|
|
"`name` and `best_value` fields"
|
|
name, best_value = item["name"], item["best_value"]
|
|
else:
|
|
assert False, "Each metric should be defined as str or as" \
|
|
"dict with `name` and `best_value` fields"
|
|
meta.append({"best_value": best_value, "name": name})
|
|
return meta
|
|
|
|
@staticmethod
|
|
def construct_metrics_array(metrics_positions: tp.Dict[str, int],
|
|
metrics: tp.Dict[str, float]) -> List[float]:
|
|
array: List[float] = [0.] * len(metrics_positions)
|
|
|
|
# data validation
|
|
assert set(metrics.keys()) == set(metrics_positions.keys()), \
|
|
f"Not all metrics were passed while logging, expected " \
|
|
f"following: {', '.join(list(metrics_positions.keys()))}"
|
|
|
|
for metric, value in metrics.items():
|
|
assert isinstance(value, float), "Type of metric {metric} should be float"
|
|
array[metrics_positions[metric]] = value
|
|
return array
|
|
|
|
def estimate_remaining_time(self, time_from_start: float) -> Optional[float]:
|
|
if self.total_iterations is None:
|
|
return None
|
|
remaining_iterations: int = self.total_iterations - self.passed_iterations
|
|
return time_from_start / self.passed_iterations * remaining_iterations
|
|
|
|
def log(self, epoch: int, train: bool, metrics: tp.Dict[str, float]) -> None:
|
|
"""
|
|
Save metrics at specific training epoch.
|
|
|
|
Parameters
|
|
----------
|
|
epoch : int
|
|
Current epoch
|
|
|
|
train : bool
|
|
Flag that indicates whether metrics are calculated on train or test data
|
|
|
|
metrics: dict
|
|
Values for each of metrics defined in `__init__` method of this class
|
|
"""
|
|
|
|
self.passed_iterations = epoch + 1
|
|
self._content["passed_iterations"] = self.passed_iterations
|
|
total_iterations = max(self._content["total_iterations"], self.passed_iterations)
|
|
|
|
self._content["total_iterations"] = total_iterations
|
|
self._values["meta"]["iteration_count"] = total_iterations
|
|
|
|
assert len(self._values["iterations"]) in (epoch, epoch + 1), \
|
|
"Data for epochs should be passed successively (wrong epoch number)"
|
|
|
|
time_from_start: float = time.time() - self._start_time
|
|
|
|
should_redraw: bool = len(self._values["iterations"]) == epoch + 1
|
|
|
|
if len(self._values["iterations"]) == epoch:
|
|
self._values["iterations"].append({
|
|
"learn": [],
|
|
"test": [],
|
|
"iteration": epoch,
|
|
"passed_time": time_from_start,
|
|
})
|
|
|
|
remaining_time = self.estimate_remaining_time(time_from_start)
|
|
if remaining_time is not None:
|
|
self._values["iterations"][-1]["remaining_time"] = remaining_time
|
|
|
|
key: str = "learn" if train else "test"
|
|
value: List[float] = self.construct_metrics_array(
|
|
self._train_metrics_positions if train else self._test_metrics_positions, metrics)
|
|
|
|
self._values["iterations"][-1].update({key: value})
|
|
|
|
if should_redraw:
|
|
self._widget.update_data(self._data)
|