117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
import os
|
|
import json
|
|
from threading import Thread, Event
|
|
from traitlets import Unicode, Dict, default
|
|
from IPython.display import display
|
|
from ipywidgets import DOMWidget, Layout, widget_serialization
|
|
|
|
|
|
class MetricWidget(DOMWidget):
|
|
_view_name = Unicode('CatboostWidgetView').tag(sync=True)
|
|
_model_name = Unicode('CatboostWidgetModel').tag(sync=True)
|
|
_view_module = Unicode('catboost-widget').tag(sync=True)
|
|
_model_module = Unicode('catboost-widget').tag(sync=True)
|
|
_view_module_version = Unicode('^1.0.0').tag(sync=True)
|
|
_model_module_version = Unicode('^1.0.0').tag(sync=True)
|
|
|
|
data = Dict({}).tag(sync=True, **widget_serialization)
|
|
|
|
@default('layout')
|
|
def _default_layout(self):
|
|
return Layout(height='500px', align_self='stretch')
|
|
|
|
|
|
class MetricVisualizer(MetricWidget):
|
|
def __init__(self, train_dirs, subdirs=False):
|
|
super(self.__class__, self).__init__()
|
|
if isinstance(train_dirs, str):
|
|
train_dirs = [train_dirs]
|
|
if subdirs:
|
|
train_subdirs = []
|
|
for train_dir in train_dirs:
|
|
train_subdirs.extend(self._get_subdirectories(train_dir))
|
|
train_dirs = train_subdirs
|
|
self._train_dirs = train_dirs[:]
|
|
self._names = []
|
|
curdir = os.path.abspath(os.path.curdir)
|
|
for train_dir in train_dirs:
|
|
abspath = os.path.abspath(train_dir)
|
|
self._names.append(os.path.basename(abspath) if abspath != curdir else 'current')
|
|
self._need_to_stop = Event()
|
|
self._update_after_stop_signal = False
|
|
|
|
def start(self):
|
|
display(self)
|
|
self._update_data()
|
|
while not self._need_to_stop.wait(1.0):
|
|
self._update_data()
|
|
|
|
if self._update_after_stop_signal:
|
|
self._update_data()
|
|
|
|
def _run_update(self):
|
|
self.thread = Thread(target=self.start, args=())
|
|
self.thread.start()
|
|
|
|
def _stop_update(self):
|
|
self._update_after_stop_signal = True
|
|
self._need_to_stop.set()
|
|
self.thread.join()
|
|
|
|
def _get_subdirectories(self, a_dir):
|
|
return [os.path.join(a_dir, name) for name in os.listdir(a_dir) if os.path.isdir(os.path.join(a_dir, name))]
|
|
|
|
def _update_data(self):
|
|
data = {}
|
|
dirs = [{'name': name, 'path': path} for name, path in zip(self._names, self._train_dirs)]
|
|
|
|
all_completed = True
|
|
for dir_info in dirs:
|
|
path = dir_info.get('path')
|
|
content = self._update_data_from_dir(path)
|
|
|
|
if not content:
|
|
continue
|
|
|
|
data[path] = {
|
|
'path': path,
|
|
'name': dir_info.get('name'),
|
|
'content': content
|
|
}
|
|
|
|
passed_iterations = data[path]['content']['passed_iterations']
|
|
total_iterations = data[path]['content']['total_iterations']
|
|
all_completed &= (passed_iterations + 1 >= total_iterations and total_iterations != 0)
|
|
|
|
if all_completed:
|
|
self._need_to_stop.set()
|
|
|
|
self.data = data
|
|
|
|
def _update_data_from_dir(self, path):
|
|
data = {
|
|
'iterations': [],
|
|
'meta': {}
|
|
}
|
|
|
|
training_json = os.path.join(path, 'catboost_training.json')
|
|
|
|
if os.path.isfile(training_json):
|
|
try:
|
|
with open(training_json, 'r') as json_data:
|
|
training_data = json.load(json_data)
|
|
data['meta'] = training_data['meta']
|
|
data['iterations'] = training_data['iterations']
|
|
except ValueError:
|
|
pass
|
|
|
|
return {
|
|
'passed_iterations': data['iterations'][-1]['iteration'] if data['iterations'] else 0,
|
|
'total_iterations': data['meta']['iteration_count'] if data['meta'] else 0,
|
|
'data': data
|
|
}
|
|
|
|
@staticmethod
|
|
def _get_static_path(file_name):
|
|
return os.path.join(os.path.dirname(__file__), file_name)
|