Files
MLPproject/.venv/lib/python3.12/site-packages/lightgbm/compat.py
2025-10-23 15:44:32 +02:00

367 lines
12 KiB
Python

# coding: utf-8
"""Compatibility library."""
from typing import TYPE_CHECKING, Any, List
# scikit-learn is intentionally imported first here,
# see https://github.com/microsoft/LightGBM/issues/6509
"""sklearn"""
try:
from sklearn import __version__ as _sklearn_version
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
try:
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
except ImportError:
from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
from sklearn.utils.validation import NotFittedError
try:
from sklearn.utils.validation import _check_sample_weight
except ImportError:
from sklearn.utils.validation import check_consistent_length
# dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X)
return sample_weight
try:
from sklearn.utils.validation import validate_data
except ImportError:
# validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
# It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
def validate_data(
_estimator,
X,
y="no_validation",
accept_sparse: bool = True,
# 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
ensure_all_finite: bool = False,
ensure_min_samples: int = 1,
# trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
**ignored_kwargs,
):
# it's safe to import _num_features unconditionally because:
#
# * it was first added in scikit-learn 0.24.2
# * lightgbm cannot be used with scikit-learn versions older than that
# * this validate_data() re-implementation will not be called in scikit-learn>=1.6
#
from sklearn.utils.validation import _num_features
# _num_features() raises a TypeError on 1-dimensional input. That's a problem
# because scikit-learn's 'check_fit1d' estimator check sets that expectation that
# estimators must raise a ValueError when a 1-dimensional input is passed to fit().
#
# So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
if hasattr(X, "shape") and len(X.shape) == 1:
n_features_in_ = 1
else:
n_features_in_ = _num_features(X)
no_val_y = isinstance(y, str) and y == "no_validation"
# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
if no_val_y:
X = check_array(
X,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)
else:
X, y = check_X_y(
X,
y,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)
# this only needs to be updated at fit() time
_estimator.n_features_in_ = n_features_in_
# raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_:
raise ValueError(
f"X has {n_features_in_} features, but {_estimator.__class__.__name__} "
f"is expecting {_estimator._n_features} features as input."
)
if no_val_y:
return X
else:
return X, y
SKLEARN_INSTALLED = True
_LGBMBaseCrossValidator = BaseCrossValidator
_LGBMModelBase = BaseEstimator
_LGBMRegressorBase = RegressorMixin
_LGBMClassifierBase = ClassifierMixin
_LGBMLabelEncoder = LabelEncoder
LGBMNotFittedError = NotFittedError
_LGBMStratifiedKFold = StratifiedKFold
_LGBMGroupKFold = GroupKFold
_LGBMCheckSampleWeight = _check_sample_weight
_LGBMAssertAllFinite = assert_all_finite
_LGBMCheckClassificationTargets = check_classification_targets
_LGBMComputeSampleWeight = compute_sample_weight
_LGBMValidateData = validate_data
except ImportError:
SKLEARN_INSTALLED = False
class _LGBMModelBase: # type: ignore
"""Dummy class for sklearn.base.BaseEstimator."""
pass
class _LGBMClassifierBase: # type: ignore
"""Dummy class for sklearn.base.ClassifierMixin."""
pass
class _LGBMRegressorBase: # type: ignore
"""Dummy class for sklearn.base.RegressorMixin."""
pass
_LGBMBaseCrossValidator = None
_LGBMLabelEncoder = None
LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
_LGBMGroupKFold = None
_LGBMCheckSampleWeight = None
_LGBMAssertAllFinite = None
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None
_LGBMValidateData = None
_sklearn_version = None
# additional scikit-learn imports only for type hints
if TYPE_CHECKING:
# sklearn.utils.Tags can be imported unconditionally once
# lightgbm's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_Tags = None
"""pandas"""
try:
from pandas import DataFrame as pd_DataFrame
from pandas import Series as pd_Series
from pandas import concat
try:
from pandas import CategoricalDtype as pd_CategoricalDtype
except ImportError:
from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
PANDAS_INSTALLED = True
except ImportError:
PANDAS_INSTALLED = False
class pd_Series: # type: ignore
"""Dummy class for pandas.Series."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class pd_DataFrame: # type: ignore
"""Dummy class for pandas.DataFrame."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class pd_CategoricalDtype: # type: ignore
"""Dummy class for pandas.CategoricalDtype."""
def __init__(self, *args: Any, **kwargs: Any):
pass
concat = None
"""matplotlib"""
try:
import matplotlib # noqa: F401
MATPLOTLIB_INSTALLED = True
except ImportError:
MATPLOTLIB_INSTALLED = False
"""graphviz"""
try:
import graphviz # noqa: F401
GRAPHVIZ_INSTALLED = True
except ImportError:
GRAPHVIZ_INSTALLED = False
"""datatable"""
try:
import datatable
if hasattr(datatable, "Frame"):
dt_DataTable = datatable.Frame
else:
dt_DataTable = datatable.DataTable
DATATABLE_INSTALLED = True
except ImportError:
DATATABLE_INSTALLED = False
class dt_DataTable: # type: ignore
"""Dummy class for datatable.DataTable."""
def __init__(self, *args: Any, **kwargs: Any):
pass
"""dask"""
try:
from dask import delayed
from dask.array import Array as dask_Array
from dask.array import from_delayed as dask_array_from_delayed
from dask.bag import from_delayed as dask_bag_from_delayed
from dask.dataframe import DataFrame as dask_DataFrame
from dask.dataframe import Series as dask_Series
from dask.distributed import Client, Future, default_client, wait
DASK_INSTALLED = True
# catching 'ValueError' here because of this:
# https://github.com/microsoft/LightGBM/issues/6365#issuecomment-2002330003
#
# That's potentially risky as dask does some significant import-time processing,
# like loading configuration from environment variables and files, and catching
# ValueError here might hide issues with that config-loading.
#
# But in exchange, it's less likely that 'import lightgbm' will fail for
# dask-related reasons, which is beneficial for any workloads that are using
# lightgbm but not its Dask functionality.
except (ImportError, ValueError):
DASK_INSTALLED = False
dask_array_from_delayed = None # type: ignore[assignment]
dask_bag_from_delayed = None # type: ignore[assignment]
delayed = None
default_client = None # type: ignore[assignment]
wait = None # type: ignore[assignment]
class Client: # type: ignore
"""Dummy class for dask.distributed.Client."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class Future: # type: ignore
"""Dummy class for dask.distributed.Future."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class dask_Array: # type: ignore
"""Dummy class for dask.array.Array."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class dask_DataFrame: # type: ignore
"""Dummy class for dask.dataframe.DataFrame."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class dask_Series: # type: ignore
"""Dummy class for dask.dataframe.Series."""
def __init__(self, *args: Any, **kwargs: Any):
pass
"""pyarrow"""
try:
import pyarrow.compute as pa_compute
from pyarrow import Array as pa_Array
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table
from pyarrow import chunked_array as pa_chunked_array
from pyarrow.types import is_boolean as arrow_is_boolean
from pyarrow.types import is_floating as arrow_is_floating
from pyarrow.types import is_integer as arrow_is_integer
PYARROW_INSTALLED = True
except ImportError:
PYARROW_INSTALLED = False
class pa_Array: # type: ignore
"""Dummy class for pa.Array."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_ChunkedArray: # type: ignore
"""Dummy class for pa.ChunkedArray."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_Table: # type: ignore
"""Dummy class for pa.Table."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_compute: # type: ignore
"""Dummy class for pyarrow.compute module."""
all = None
equal = None
pa_chunked_array = None
arrow_is_boolean = None
arrow_is_integer = None
arrow_is_floating = None
"""cffi"""
try:
from pyarrow.cffi import ffi as arrow_cffi
CFFI_INSTALLED = True
except ImportError:
CFFI_INSTALLED = False
class arrow_cffi: # type: ignore
"""Dummy class for pyarrow.cffi.ffi."""
CData = None
def __init__(self, *args: Any, **kwargs: Any):
pass
"""cpu_count()"""
try:
from joblib import cpu_count
def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
return cpu_count(only_physical_cores=only_physical_cores)
except ImportError:
try:
from psutil import cpu_count
def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
return cpu_count(logical=not only_physical_cores) or 1
except ImportError:
from multiprocessing import cpu_count
def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
return cpu_count()
__all__: List[str] = []