77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
import json
|
|
|
|
from .. import CatBoostError
|
|
from ..eval.factor_utils import FactorUtils
|
|
from ..core import _NumpyAwareEncoder
|
|
|
|
|
|
class ExecutionCase:
|
|
|
|
def __init__(self,
|
|
params,
|
|
label=None,
|
|
ignored_features=None,
|
|
learning_rate=None):
|
|
"""
|
|
Instances of this class are cases which will be compared during evaluation
|
|
Params are CatBoost params
|
|
label is a string which will be used for plots and other visualisations
|
|
ignored_features is a set of additional feature indices to ignore
|
|
"""
|
|
case_params = dict(params)
|
|
|
|
if learning_rate is not None:
|
|
case_params["learning_rate"] = learning_rate
|
|
|
|
all_ignored_features = set()
|
|
if "ignored_features" in case_params:
|
|
all_ignored_features.update(set(case_params["ignored_features"]))
|
|
if ignored_features is not None:
|
|
all_ignored_features.update(ignored_features)
|
|
|
|
case_params["ignored_features"] = list(all_ignored_features)
|
|
|
|
self._label = label if label is not None else ""
|
|
self._ignored_features = ignored_features
|
|
self._ignored_features_str = FactorUtils.factors_to_ranges_string(self._ignored_features)
|
|
|
|
self.__set_params(case_params)
|
|
|
|
def __set_params(self, params):
|
|
self._params = params
|
|
self._params_hash = hash(json.dumps(self._params, sort_keys=True, cls=_NumpyAwareEncoder))
|
|
|
|
def _set_thread_count(self, thread_count):
|
|
if thread_count is not None and thread_count != -1:
|
|
params = self._params
|
|
params["thread_count"] = thread_count
|
|
self.__set_params(params)
|
|
|
|
@staticmethod
|
|
def _validate_ignored_features(ignored_features, eval_features):
|
|
for eval_feature in eval_features:
|
|
if eval_feature in ignored_features:
|
|
raise CatBoostError(
|
|
"Feature {} is in ignored set and in tmp-features set at the same time".format(eval_feature))
|
|
|
|
def get_params(self):
|
|
return dict(self._params)
|
|
|
|
def get_label(self):
|
|
return self._label
|
|
|
|
def __str__(self):
|
|
if len(self._label) == 0:
|
|
return "Ignore: {}".format(self._ignored_features_str)
|
|
else:
|
|
return '{}'.format(self._label)
|
|
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
|
|
def __eq__(self, other):
|
|
return self._params == other._params and self._label == other._label
|
|
|
|
def __hash__(self):
|
|
return hash((self._label, self._params_hash))
|