import json from .. import CatBoostError from ..eval.factor_utils import FactorUtils from ..core import _NumpyAwareEncoder class ExecutionCase: def __init__(self, params, label=None, ignored_features=None, learning_rate=None): """ Instances of this class are cases which will be compared during evaluation Params are CatBoost params label is a string which will be used for plots and other visualisations ignored_features is a set of additional feature indices to ignore """ case_params = dict(params) if learning_rate is not None: case_params["learning_rate"] = learning_rate all_ignored_features = set() if "ignored_features" in case_params: all_ignored_features.update(set(case_params["ignored_features"])) if ignored_features is not None: all_ignored_features.update(ignored_features) case_params["ignored_features"] = list(all_ignored_features) self._label = label if label is not None else "" self._ignored_features = ignored_features self._ignored_features_str = FactorUtils.factors_to_ranges_string(self._ignored_features) self.__set_params(case_params) def __set_params(self, params): self._params = params self._params_hash = hash(json.dumps(self._params, sort_keys=True, cls=_NumpyAwareEncoder)) def _set_thread_count(self, thread_count): if thread_count is not None and thread_count != -1: params = self._params params["thread_count"] = thread_count self.__set_params(params) @staticmethod def _validate_ignored_features(ignored_features, eval_features): for eval_feature in eval_features: if eval_feature in ignored_features: raise CatBoostError( "Feature {} is in ignored set and in tmp-features set at the same time".format(eval_feature)) def get_params(self): return dict(self._params) def get_label(self): return self._label def __str__(self): if len(self._label) == 0: return "Ignore: {}".format(self._ignored_features_str) else: return '{}'.format(self._label) def __repr__(self): return self.__str__() def __eq__(self, other): return self._params == other._params and self._label == other._label def __hash__(self): return hash((self._label, self._params_hash))