367 lines
12 KiB
Python
367 lines
12 KiB
Python
# coding: utf-8
|
|
"""Compatibility library."""
|
|
|
|
from typing import TYPE_CHECKING, Any, List
|
|
|
|
# scikit-learn is intentionally imported first here,
|
|
# see https://github.com/microsoft/LightGBM/issues/6509
|
|
"""sklearn"""
|
|
try:
|
|
from sklearn import __version__ as _sklearn_version
|
|
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
|
|
from sklearn.preprocessing import LabelEncoder
|
|
from sklearn.utils.class_weight import compute_sample_weight
|
|
from sklearn.utils.multiclass import check_classification_targets
|
|
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
|
|
|
|
try:
|
|
from sklearn.exceptions import NotFittedError
|
|
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
|
|
except ImportError:
|
|
from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
|
|
from sklearn.utils.validation import NotFittedError
|
|
try:
|
|
from sklearn.utils.validation import _check_sample_weight
|
|
except ImportError:
|
|
from sklearn.utils.validation import check_consistent_length
|
|
|
|
# dummy function to support older version of scikit-learn
|
|
def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
|
|
check_consistent_length(sample_weight, X)
|
|
return sample_weight
|
|
|
|
try:
|
|
from sklearn.utils.validation import validate_data
|
|
except ImportError:
|
|
# validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
|
|
# It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
|
|
def validate_data(
|
|
_estimator,
|
|
X,
|
|
y="no_validation",
|
|
accept_sparse: bool = True,
|
|
# 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
|
|
ensure_all_finite: bool = False,
|
|
ensure_min_samples: int = 1,
|
|
# trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
|
|
**ignored_kwargs,
|
|
):
|
|
# it's safe to import _num_features unconditionally because:
|
|
#
|
|
# * it was first added in scikit-learn 0.24.2
|
|
# * lightgbm cannot be used with scikit-learn versions older than that
|
|
# * this validate_data() re-implementation will not be called in scikit-learn>=1.6
|
|
#
|
|
from sklearn.utils.validation import _num_features
|
|
|
|
# _num_features() raises a TypeError on 1-dimensional input. That's a problem
|
|
# because scikit-learn's 'check_fit1d' estimator check sets that expectation that
|
|
# estimators must raise a ValueError when a 1-dimensional input is passed to fit().
|
|
#
|
|
# So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
|
|
if hasattr(X, "shape") and len(X.shape) == 1:
|
|
n_features_in_ = 1
|
|
else:
|
|
n_features_in_ = _num_features(X)
|
|
|
|
no_val_y = isinstance(y, str) and y == "no_validation"
|
|
|
|
# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
|
|
if no_val_y:
|
|
X = check_array(
|
|
X,
|
|
accept_sparse=accept_sparse,
|
|
force_all_finite=ensure_all_finite,
|
|
ensure_min_samples=ensure_min_samples,
|
|
)
|
|
else:
|
|
X, y = check_X_y(
|
|
X,
|
|
y,
|
|
accept_sparse=accept_sparse,
|
|
force_all_finite=ensure_all_finite,
|
|
ensure_min_samples=ensure_min_samples,
|
|
)
|
|
|
|
# this only needs to be updated at fit() time
|
|
_estimator.n_features_in_ = n_features_in_
|
|
|
|
# raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
|
|
if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_:
|
|
raise ValueError(
|
|
f"X has {n_features_in_} features, but {_estimator.__class__.__name__} "
|
|
f"is expecting {_estimator._n_features} features as input."
|
|
)
|
|
|
|
if no_val_y:
|
|
return X
|
|
else:
|
|
return X, y
|
|
|
|
SKLEARN_INSTALLED = True
|
|
_LGBMBaseCrossValidator = BaseCrossValidator
|
|
_LGBMModelBase = BaseEstimator
|
|
_LGBMRegressorBase = RegressorMixin
|
|
_LGBMClassifierBase = ClassifierMixin
|
|
_LGBMLabelEncoder = LabelEncoder
|
|
LGBMNotFittedError = NotFittedError
|
|
_LGBMStratifiedKFold = StratifiedKFold
|
|
_LGBMGroupKFold = GroupKFold
|
|
_LGBMCheckSampleWeight = _check_sample_weight
|
|
_LGBMAssertAllFinite = assert_all_finite
|
|
_LGBMCheckClassificationTargets = check_classification_targets
|
|
_LGBMComputeSampleWeight = compute_sample_weight
|
|
_LGBMValidateData = validate_data
|
|
except ImportError:
|
|
SKLEARN_INSTALLED = False
|
|
|
|
class _LGBMModelBase: # type: ignore
|
|
"""Dummy class for sklearn.base.BaseEstimator."""
|
|
|
|
pass
|
|
|
|
class _LGBMClassifierBase: # type: ignore
|
|
"""Dummy class for sklearn.base.ClassifierMixin."""
|
|
|
|
pass
|
|
|
|
class _LGBMRegressorBase: # type: ignore
|
|
"""Dummy class for sklearn.base.RegressorMixin."""
|
|
|
|
pass
|
|
|
|
_LGBMBaseCrossValidator = None
|
|
_LGBMLabelEncoder = None
|
|
LGBMNotFittedError = ValueError
|
|
_LGBMStratifiedKFold = None
|
|
_LGBMGroupKFold = None
|
|
_LGBMCheckSampleWeight = None
|
|
_LGBMAssertAllFinite = None
|
|
_LGBMCheckClassificationTargets = None
|
|
_LGBMComputeSampleWeight = None
|
|
_LGBMValidateData = None
|
|
_sklearn_version = None
|
|
|
|
# additional scikit-learn imports only for type hints
|
|
if TYPE_CHECKING:
|
|
# sklearn.utils.Tags can be imported unconditionally once
|
|
# lightgbm's minimum scikit-learn version is 1.6 or higher
|
|
try:
|
|
from sklearn.utils import Tags as _sklearn_Tags
|
|
except ImportError:
|
|
_sklearn_Tags = None
|
|
|
|
|
|
"""pandas"""
|
|
try:
|
|
from pandas import DataFrame as pd_DataFrame
|
|
from pandas import Series as pd_Series
|
|
from pandas import concat
|
|
|
|
try:
|
|
from pandas import CategoricalDtype as pd_CategoricalDtype
|
|
except ImportError:
|
|
from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
|
|
PANDAS_INSTALLED = True
|
|
except ImportError:
|
|
PANDAS_INSTALLED = False
|
|
|
|
class pd_Series: # type: ignore
|
|
"""Dummy class for pandas.Series."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
class pd_DataFrame: # type: ignore
|
|
"""Dummy class for pandas.DataFrame."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
class pd_CategoricalDtype: # type: ignore
|
|
"""Dummy class for pandas.CategoricalDtype."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
concat = None
|
|
|
|
"""matplotlib"""
|
|
try:
|
|
import matplotlib # noqa: F401
|
|
|
|
MATPLOTLIB_INSTALLED = True
|
|
except ImportError:
|
|
MATPLOTLIB_INSTALLED = False
|
|
|
|
"""graphviz"""
|
|
try:
|
|
import graphviz # noqa: F401
|
|
|
|
GRAPHVIZ_INSTALLED = True
|
|
except ImportError:
|
|
GRAPHVIZ_INSTALLED = False
|
|
|
|
"""datatable"""
|
|
try:
|
|
import datatable
|
|
|
|
if hasattr(datatable, "Frame"):
|
|
dt_DataTable = datatable.Frame
|
|
else:
|
|
dt_DataTable = datatable.DataTable
|
|
DATATABLE_INSTALLED = True
|
|
except ImportError:
|
|
DATATABLE_INSTALLED = False
|
|
|
|
class dt_DataTable: # type: ignore
|
|
"""Dummy class for datatable.DataTable."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
|
|
"""dask"""
|
|
try:
|
|
from dask import delayed
|
|
from dask.array import Array as dask_Array
|
|
from dask.array import from_delayed as dask_array_from_delayed
|
|
from dask.bag import from_delayed as dask_bag_from_delayed
|
|
from dask.dataframe import DataFrame as dask_DataFrame
|
|
from dask.dataframe import Series as dask_Series
|
|
from dask.distributed import Client, Future, default_client, wait
|
|
|
|
DASK_INSTALLED = True
|
|
# catching 'ValueError' here because of this:
|
|
# https://github.com/microsoft/LightGBM/issues/6365#issuecomment-2002330003
|
|
#
|
|
# That's potentially risky as dask does some significant import-time processing,
|
|
# like loading configuration from environment variables and files, and catching
|
|
# ValueError here might hide issues with that config-loading.
|
|
#
|
|
# But in exchange, it's less likely that 'import lightgbm' will fail for
|
|
# dask-related reasons, which is beneficial for any workloads that are using
|
|
# lightgbm but not its Dask functionality.
|
|
except (ImportError, ValueError):
|
|
DASK_INSTALLED = False
|
|
|
|
dask_array_from_delayed = None # type: ignore[assignment]
|
|
dask_bag_from_delayed = None # type: ignore[assignment]
|
|
delayed = None
|
|
default_client = None # type: ignore[assignment]
|
|
wait = None # type: ignore[assignment]
|
|
|
|
class Client: # type: ignore
|
|
"""Dummy class for dask.distributed.Client."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
class Future: # type: ignore
|
|
"""Dummy class for dask.distributed.Future."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
class dask_Array: # type: ignore
|
|
"""Dummy class for dask.array.Array."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
class dask_DataFrame: # type: ignore
|
|
"""Dummy class for dask.dataframe.DataFrame."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
class dask_Series: # type: ignore
|
|
"""Dummy class for dask.dataframe.Series."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
|
|
"""pyarrow"""
|
|
try:
|
|
import pyarrow.compute as pa_compute
|
|
from pyarrow import Array as pa_Array
|
|
from pyarrow import ChunkedArray as pa_ChunkedArray
|
|
from pyarrow import Table as pa_Table
|
|
from pyarrow import chunked_array as pa_chunked_array
|
|
from pyarrow.types import is_boolean as arrow_is_boolean
|
|
from pyarrow.types import is_floating as arrow_is_floating
|
|
from pyarrow.types import is_integer as arrow_is_integer
|
|
|
|
PYARROW_INSTALLED = True
|
|
except ImportError:
|
|
PYARROW_INSTALLED = False
|
|
|
|
class pa_Array: # type: ignore
|
|
"""Dummy class for pa.Array."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
class pa_ChunkedArray: # type: ignore
|
|
"""Dummy class for pa.ChunkedArray."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
class pa_Table: # type: ignore
|
|
"""Dummy class for pa.Table."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
class pa_compute: # type: ignore
|
|
"""Dummy class for pyarrow.compute module."""
|
|
|
|
all = None
|
|
equal = None
|
|
|
|
pa_chunked_array = None
|
|
arrow_is_boolean = None
|
|
arrow_is_integer = None
|
|
arrow_is_floating = None
|
|
|
|
|
|
"""cffi"""
|
|
try:
|
|
from pyarrow.cffi import ffi as arrow_cffi
|
|
|
|
CFFI_INSTALLED = True
|
|
except ImportError:
|
|
CFFI_INSTALLED = False
|
|
|
|
class arrow_cffi: # type: ignore
|
|
"""Dummy class for pyarrow.cffi.ffi."""
|
|
|
|
CData = None
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
pass
|
|
|
|
|
|
"""cpu_count()"""
|
|
try:
|
|
from joblib import cpu_count
|
|
|
|
def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
|
|
return cpu_count(only_physical_cores=only_physical_cores)
|
|
except ImportError:
|
|
try:
|
|
from psutil import cpu_count
|
|
|
|
def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
|
|
return cpu_count(logical=not only_physical_cores) or 1
|
|
except ImportError:
|
|
from multiprocessing import cpu_count
|
|
|
|
def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
|
|
return cpu_count()
|
|
|
|
|
|
__all__: List[str] = []
|