Files
MLPproject/.venv/lib/python3.12/site-packages/xgboost/dask/__init__.py
2025-10-23 15:44:32 +02:00

2218 lines
77 KiB
Python

# pylint: disable=too-many-arguments, too-many-locals
# pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines
"""
Dask extensions for distributed training
----------------------------------------
See :doc:`Distributed XGBoost with Dask </tutorials/dask>` for simple tutorial. Also
:doc:`/python/dask-examples/index` for some examples.
There are two sets of APIs in this module, one is the functional API including
``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper
inherited from single-node Scikit-Learn interface.
The implementation is heavily influenced by dask_xgboost:
https://github.com/dask/dask-xgboost
Optional dask configuration
===========================
- **coll_cfg**:
Specify the scheduler address along with communicator configurations. This can be
used as a replacement of the existing global Dask configuration
`xgboost.scheduler_address` (see below). See :ref:`tracker-ip` for more info. The
`tracker_host_ip` should specify the IP address of the Dask scheduler node.
.. versionadded:: 3.0.0
.. code-block:: python
from xgboost import dask as dxgb
from xgboost.collective import Config
coll_cfg = Config(
retry=1, timeout=20, tracker_host_ip="10.23.170.98", tracker_port=0
)
clf = dxgb.DaskXGBClassifier(coll_cfg=coll_cfg)
# or
dxgb.train(client, {}, Xy, num_boost_round=10, coll_cfg=coll_cfg)
- **xgboost.scheduler_address**: Specify the scheduler address
.. versionadded:: 1.6.0
.. deprecated:: 3.0.0
.. code-block:: python
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
# We can also specify the port.
dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"})
"""
import logging
from collections import defaultdict
from contextlib import contextmanager
from functools import partial, update_wrapper
from threading import Thread
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
ParamSpec,
Sequence,
Set,
Tuple,
TypeAlias,
TypedDict,
TypeGuard,
TypeVar,
Union,
)
import dask
import distributed
import numpy
from dask import array as da
from dask import bag as db
from dask import dataframe as dd
from dask.delayed import Delayed
from distributed import Future
from .. import collective, config
from .._data_utils import Categories
from .._typing import FeatureNames, FeatureTypes, IterationRange
from ..callback import TrainingCallback
from ..collective import Config as CollConfig
from ..collective import _Args as CollArgs
from ..collective import _ArgVals as CollArgsVals
from ..compat import _is_cudf_df
from ..core import (
Booster,
DMatrix,
Metric,
Objective,
XGBoostError,
_check_distributed_params,
_deprecate_positional_args,
_expect,
)
from ..data import _is_cudf_ser, _is_cupy_alike
from ..sklearn import (
XGBClassifier,
XGBClassifierBase,
XGBModel,
XGBRanker,
XGBRankerMixIn,
XGBRegressorBase,
_can_use_qdm,
_check_rf_callback,
_cls_predict_proba,
_objective_decorator,
_wrap_evaluation_matrices,
xgboost_model_doc,
)
from ..tracker import RabitTracker
from ..training import train as worker_train
from .data import _get_dmatrices, no_group_split
from .utils import _DASK_2024_12_1, _DASK_2025_3_0, get_address_from_user, get_n_threads
_DaskCollection: TypeAlias = Union[da.Array, dd.DataFrame, dd.Series]
_DataT: TypeAlias = Union[da.Array, dd.DataFrame] # do not use series as predictor
TrainReturnT = TypedDict(
"TrainReturnT",
{
"booster": Booster,
"history": Dict,
},
)
__all__ = [
"CommunicatorContext",
"DaskDMatrix",
"DaskQuantileDMatrix",
"DaskXGBRegressor",
"DaskXGBClassifier",
"DaskXGBRanker",
"DaskXGBRFRegressor",
"DaskXGBRFClassifier",
"train",
"predict",
"inplace_predict",
]
# TODOs:
# - CV
#
# Note for developers:
#
# As of writing asyncio is still a new feature of Python and in depth documentation is
# rare. Best examples of various asyncio tricks are in dask (luckily). Classes like
# Client, Worker are awaitable. Some general rules for the implementation here:
#
# - Synchronous world is different from asynchronous one, and they don't mix well.
# - Write everything with async, then use distributed Client sync function to do the
# switch.
# - Use Any for type hint when the return value can be union of Awaitable and plain
# value. This is caused by Client.sync can return both types depending on
# context. Right now there's no good way to silent:
#
# await train(...)
#
# if train returns an Union type.
LOGGER = logging.getLogger("[xgboost.dask]")
def _try_start_tracker(
n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
timeout: Optional[int],
) -> CollArgs:
env: CollArgs = {}
try:
if isinstance(addrs[0], tuple):
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_tracker = RabitTracker(
n_workers=n_workers,
host_ip=host_ip,
port=port,
sortby="task",
timeout=0 if timeout is None else timeout,
)
else:
addr = addrs[0]
assert isinstance(addr, str) or addr is None
rabit_tracker = RabitTracker(
n_workers=n_workers,
host_ip=addr,
sortby="task",
timeout=0 if timeout is None else timeout,
)
rabit_tracker.start()
# No timeout since we don't want to abort the training
thread = Thread(target=rabit_tracker.wait_for)
thread.daemon = True
thread.start()
env.update(rabit_tracker.worker_args())
except XGBoostError as e:
if len(addrs) < 2:
raise
LOGGER.warning(
"Failed to bind address '%s', trying to use '%s' instead. Error:\n %s",
str(addrs[0]),
str(addrs[1]),
str(e),
)
env = _try_start_tracker(n_workers, addrs[1:], timeout)
return env
def _start_tracker(
n_workers: int,
addr_from_dask: Optional[str],
addr_from_user: Optional[Tuple[str, int]],
timeout: Optional[int],
) -> CollArgs:
"""Start Rabit tracker, recurse to try different addresses."""
env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask], timeout)
return env
class CommunicatorContext(collective.CommunicatorContext):
"""A context controlling collective communicator initialization and finalization."""
def __init__(self, **args: CollArgsVals) -> None:
super().__init__(**args)
worker = distributed.get_worker()
# We use task ID for rank assignment which makes the RABIT rank consistent (but
# not the same as task ID is string and "10" is sorted before "2") with dask
# worker name. This outsources the rank assignment to dask and prevents
# non-deterministic issue.
self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{worker.name}]:{worker.address}"
def _get_client(client: Optional["distributed.Client"]) -> "distributed.Client":
"""Simple wrapper around testing None."""
if not isinstance(client, (type(distributed.get_client()), type(None))):
raise TypeError(
_expect([type(distributed.get_client()), type(None)], type(client))
)
ret = distributed.get_client() if client is None else client
return ret
# From the implementation point of view, DaskDMatrix complicates a lots of
# things. A large portion of the code base is about syncing and extracting
# stuffs from DaskDMatrix. But having an independent data structure gives us a
# chance to perform some specialized optimizations, like building histogram
# index directly.
class DaskDMatrix:
# pylint: disable=too-many-instance-attributes
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix
accepts only dask collection.
.. note::
`DaskDMatrix` does not repartition or move data between workers. It's the
caller's responsibility to balance the data.
.. note::
For aligning partitions with ranking query groups, use the
:py:class:`DaskXGBRanker` and its ``allow_group_split`` option.
.. versionadded:: 1.0.0
Parameters
----------
client :
Specify the dask client used for training. Use default client returned from
dask if it's set to None.
"""
@_deprecate_positional_args
def __init__(
self,
client: Optional["distributed.Client"],
data: _DataT,
label: Optional[_DaskCollection] = None,
*,
weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
missing: Optional[float] = None,
silent: bool = False, # pylint: disable=unused-argument
feature_names: Optional[FeatureNames] = None,
feature_types: Optional[FeatureTypes] = None,
group: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None,
enable_categorical: bool = False,
) -> None:
client = _get_client(client)
self.feature_names = feature_names
self.feature_types = feature_types
if isinstance(feature_types, Categories):
raise TypeError(
"The Dask interface can handle categories from DataFrame automatically."
)
self.missing = missing if missing is not None else numpy.nan
self.enable_categorical = enable_categorical
if qid is not None and weight is not None:
raise NotImplementedError("per-group weight is not implemented.")
if group is not None:
raise NotImplementedError(
"group structure is not implemented, use qid instead."
)
if len(data.shape) != 2:
raise ValueError(f"Expecting 2 dimensional input, got: {data.shape}")
if not isinstance(data, (dd.DataFrame, da.Array)):
raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
self._n_cols = data.shape[1]
assert isinstance(self._n_cols, int)
self.worker_map: Dict[str, List[Future]] = defaultdict(list)
self.is_quantile: bool = False
self._init = client.sync(
self._map_local_data,
client=client,
data=data,
label=label,
weights=weight,
base_margin=base_margin,
qid=qid,
feature_weights=feature_weights,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
)
def __await__(self) -> Generator[None, None, "DaskDMatrix"]:
return self._init.__await__()
async def _map_local_data(
self,
*,
client: "distributed.Client",
data: _DataT,
label: Optional[_DaskCollection] = None,
weights: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None,
) -> "DaskDMatrix":
"""Obtain references to local data."""
def inconsistent(
left: List[Any], left_name: str, right: List[Any], right_name: str
) -> str:
msg = (
f"Partitions between {left_name} and {right_name} are not "
f"consistent: {len(left)} != {len(right)}. "
f"Please try to repartition/rechunk your data."
)
return msg
def to_futures(d: _DaskCollection) -> List[Future]:
"""Breaking data into partitions."""
d = client.persist(d)
if (
hasattr(d.partitions, "shape")
and len(d.partitions.shape) > 1
and d.partitions.shape[1] > 1
):
raise ValueError(
"Data should be"
" partitioned by row. To avoid this specify the number"
" of columns for your dask Array explicitly. e.g."
" chunks=(partition_size, -1])"
)
return client.futures_of(d)
def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Future]]:
if meta is not None:
meta_parts: List[Future] = to_futures(meta)
return meta_parts
return None
X_parts = to_futures(data)
y_parts = flatten_meta(label)
w_parts = flatten_meta(weights)
margin_parts = flatten_meta(base_margin)
qid_parts = flatten_meta(qid)
ll_parts = flatten_meta(label_lower_bound)
lu_parts = flatten_meta(label_upper_bound)
parts: Dict[str, List[Future]] = {"data": X_parts}
def append_meta(m_parts: Optional[List[Future]], name: str) -> None:
if m_parts is not None:
assert len(X_parts) == len(m_parts), inconsistent(
X_parts, "X", m_parts, name
)
parts[name] = m_parts
append_meta(y_parts, "label")
append_meta(w_parts, "weight")
append_meta(margin_parts, "base_margin")
append_meta(qid_parts, "qid")
append_meta(ll_parts, "label_lower_bound")
append_meta(lu_parts, "label_upper_bound")
# At this point, `parts` looks like:
# [(x0, x1, ..), (y0, y1, ..), ..] in future form
# turn into list of dictionaries.
packed_parts: List[Dict[str, Future]] = []
for i in range(len(X_parts)):
part_dict: Dict[str, Future] = {}
for key, value in parts.items():
part_dict[key] = value[i]
packed_parts.append(part_dict)
# delay the zipped result
# pylint: disable=no-member
delayed_parts: List[Delayed] = list(map(dask.delayed, packed_parts))
# At this point, the mental model should look like:
# [{"data": x0, "label": y0, ..}, {"data": x1, "label": y1, ..}, ..]
# Convert delayed objects into futures and make sure they are realized
#
# This also makes partitions to align (co-locate) on workers (X_0, y_0 should be
# on the same worker).
fut_parts: List[Future] = client.compute(delayed_parts)
await distributed.wait(fut_parts) # async wait for parts to be computed
for part in fut_parts:
# Each part is [{"data": x0, "label": y0, ..}, ...] in future form.
assert part.status == "finished", part.status
# Preserving the partition order for prediction.
self.partition_order = {}
for i, part in enumerate(fut_parts):
self.partition_order[part.key] = i
key_to_partition = {part.key: part for part in fut_parts}
who_has: Dict[str, Tuple[str, ...]] = await client.scheduler.who_has(
keys=[part.key for part in fut_parts]
)
worker_map: Dict[str, List[Future]] = defaultdict(list)
for key, workers in who_has.items():
worker_map[next(iter(workers))].append(key_to_partition[key])
self.worker_map = worker_map
if feature_weights is None:
self.feature_weights = None
else:
self.feature_weights = await client.compute(feature_weights).result()
return self
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
"""Create a dictionary of objects that can be pickled for function
arguments.
"""
return {
"feature_names": self.feature_names,
"feature_types": self.feature_types,
"feature_weights": self.feature_weights,
"missing": self.missing,
"enable_categorical": self.enable_categorical,
"parts": self.worker_map.get(worker_addr, None),
"is_quantile": self.is_quantile,
}
def num_col(self) -> int:
"""Get the number of columns (features) in the DMatrix.
Returns
-------
number of columns
"""
return self._n_cols
_MapRetT = TypeVar("_MapRetT")
_P = ParamSpec("_P")
async def map_worker_partitions(
client: Optional["distributed.Client"],
func: Callable[_P, _MapRetT],
*refs: Any,
workers: Sequence[str],
) -> _MapRetT:
"""Map a function onto partitions of each worker."""
# Note for function purity:
# XGBoost is sensitive to data partition and uses random number generator.
client = _get_client(client)
futures = []
for addr in workers:
args = []
for ref in refs:
if isinstance(ref, DaskDMatrix):
# pylint: disable=protected-access
args.append(ref._create_fn_args(addr))
else:
args.append(ref)
def fn(_address: str, *args: _P.args, **kwargs: _P.kwargs) -> List[_MapRetT]:
worker = distributed.get_worker()
if worker.address != _address:
raise ValueError(
f"Invalid worker address: {worker.address}, expecting {_address}. "
"This is likely caused by one of the workers died and Dask "
"re-scheduled a different one. Resilience is not yet supported."
)
# Turn result into a list for bag construction
return [func(*args, **kwargs)]
# XGBoost requires all workers running training tasks to be unique. Meaning, we
# can't run 2 training jobs on the same node. This at best leads to an error
# (NCCL unique check), at worst leads to extremely slow training performance
# without any warning.
#
# See disitributed.scheduler.decide_worker for `allow_other_workers`. In
# summary, the scheduler chooses a worker from the valid set that has the task
# dependencies. Each XGBoost's training task has all dependencies in a single
# worker. As a result, the right worker should be picked by the scheduler even
# if `allow_other_workers` is set to True.
#
# In addition, the scheduler only discards the valid set (the `workers` arg) if
# there's no candidate can be found. This is likely caused by killed workers. In
# that case, the check in `fn` should be able to stop the task. If we don't
# relax the constraint and prevent Dask from choosing an invalid worker, the
# task will simply hangs. We prefer a quick error here.
#
fut = client.submit(
update_wrapper(partial(fn, addr), fn),
*args,
pure=False,
workers=[addr],
allow_other_workers=True,
)
futures.append(fut)
def first_valid(results: Iterable[Optional[_MapRetT]]) -> Optional[_MapRetT]:
for v in results:
if v is not None:
return v
return None
bag = db.from_delayed(futures)
fut = await bag.reduction(first_valid, first_valid)
result = await client.compute(fut).result()
return result
class DaskQuantileDMatrix(DaskDMatrix):
"""A dask version of :py:class:`QuantileDMatrix`. See :py:class:`DaskDMatrix` for
parameter documents.
"""
@_deprecate_positional_args
def __init__(
self,
client: Optional["distributed.Client"],
data: _DataT,
label: Optional[_DaskCollection] = None,
*,
weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
missing: Optional[float] = None,
silent: bool = False, # disable=unused-argument
feature_names: Optional[FeatureNames] = None,
feature_types: Optional[Union[Any, List[Any]]] = None,
max_bin: Optional[int] = None,
ref: Optional[DaskDMatrix] = None,
group: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None,
enable_categorical: bool = False,
max_quantile_batches: Optional[int] = None,
) -> None:
super().__init__(
client=client,
data=data,
label=label,
weight=weight,
base_margin=base_margin,
group=group,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
missing=missing,
silent=silent,
feature_weights=feature_weights,
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=enable_categorical,
)
self.max_bin = max_bin
self.max_quantile_batches = max_quantile_batches
self.is_quantile = True
self._ref: Optional[int] = id(ref) if ref is not None else None
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
args = super()._create_fn_args(worker_addr)
args["max_bin"] = self.max_bin
args["max_quantile_batches"] = self.max_quantile_batches
if self._ref is not None:
args["ref"] = self._ref
return args
async def _get_rabit_args(
client: "distributed.Client",
n_workers: int,
dconfig: Optional[Dict[str, Any]] = None,
coll_cfg: Optional[CollConfig] = None,
) -> Dict[str, Union[str, int]]:
"""Get rabit context arguments from data distribution in DaskDMatrix."""
# There are 3 possible different addresses:
# 1. Provided by user via dask.config
# 2. Guessed by xgboost `get_host_ip` function
# 3. From dask scheduler
# We try 1 and 3 if 1 is available, otherwise 2 and 3.
# See if user config is available
coll_cfg = CollConfig() if coll_cfg is None else coll_cfg
host_ip: Optional[str] = None
port: int = 0
host_ip, port = get_address_from_user(dconfig, coll_cfg)
if host_ip is not None:
user_addr = (host_ip, port)
else:
user_addr = None
# Try address from dask scheduler, this might not work, see
# https://github.com/dask/dask-xgboost/pull/40
try:
sched_addr = distributed.comm.get_address_host(client.scheduler.address)
sched_addr = sched_addr.strip("/:")
except Exception: # pylint: disable=broad-except
sched_addr = None
# We assume the scheduler is a fair process and run the tracker there.
env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr, coll_cfg.tracker_timeout
)
env = coll_cfg.get_comm_config(env)
assert env is not None
return env
def _get_dask_config() -> Optional[Dict[str, Any]]:
return dask.config.get("xgboost", default=None)
# train and predict methods are supposed to be "functional", which meets the
# dask paradigm. But as a side effect, the `evals_result` in single-node API
# is no longer supported since it mutates the input parameter, and it's not
# intuitive to sync the mutation result. Therefore, a dictionary containing
# evaluation history is instead returned.
def _get_workers_from_data(
dtrain: DaskDMatrix, evals: Optional[Sequence[Tuple[DaskDMatrix, str]]]
) -> List[str]:
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
if evals:
for e in evals:
assert len(e) == 2
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
if e[0] is dtrain:
continue
worker_map = set(e[0].worker_map.keys())
X_worker_map = X_worker_map.union(worker_map)
return list(X_worker_map)
async def _check_workers_are_alive(
workers: List[str], client: "distributed.Client"
) -> None:
info = await client.scheduler.identity()
current_workers = info["workers"].keys()
missing_workers = set(workers) - current_workers
if missing_workers:
raise RuntimeError(f"Missing required workers: {missing_workers}")
async def _train_async(
*,
client: "distributed.Client",
global_config: Dict[str, Any],
dconfig: Optional[Dict[str, Any]],
params: Dict[str, Any],
dtrain: DaskDMatrix,
num_boost_round: int,
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]],
obj: Optional[Objective],
early_stopping_rounds: Optional[int],
verbose_eval: Union[int, bool],
xgb_model: Optional[Booster],
callbacks: Optional[Sequence[TrainingCallback]],
custom_metric: Optional[Metric],
coll_cfg: Optional[CollConfig],
) -> Optional[TrainReturnT]:
workers = _get_workers_from_data(dtrain, evals)
await _check_workers_are_alive(workers, client)
coll_args = await _get_rabit_args(
client, len(workers), dconfig=dconfig, coll_cfg=coll_cfg
)
_check_distributed_params(params)
# This function name is displayed in the Dask dashboard task status, let's make it
# clear that it's XGBoost training.
def do_train( # pylint: disable=too-many-positional-arguments
parameters: Dict,
coll_args: Dict[str, Union[str, int]],
train_id: int,
evals_name: List[str],
evals_id: List[int],
train_ref: dict,
*refs: dict,
) -> Optional[TrainReturnT]:
worker = distributed.get_worker()
local_param = parameters.copy()
n_threads = get_n_threads(local_param, worker)
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
local_history: TrainingCallback.EvalsLog = {}
global_config.update({"nthread": n_threads})
with CommunicatorContext(**coll_args), config.config_context(**global_config):
Xy, evals = _get_dmatrices(
train_ref,
train_id,
*refs,
evals_id=evals_id,
evals_name=evals_name,
n_threads=n_threads,
# We need the model for reference categories.
model=xgb_model,
)
booster = worker_train(
params=local_param,
dtrain=Xy,
num_boost_round=num_boost_round,
evals_result=local_history,
evals=evals if len(evals) != 0 else None,
obj=obj,
custom_metric=custom_metric,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks,
)
# Don't return the boosters from empty workers. It's quite difficult to
# guarantee everything is in sync in the present of empty workers, especially
# with complex objectives like quantile.
if Xy.num_row() != 0:
ret: Optional[TrainReturnT] = {
"booster": booster,
"history": local_history,
}
else:
ret = None
return ret
async with distributed.MultiLock(workers, client):
if evals is not None:
evals_data = [d for d, n in evals]
evals_name = [n for d, n in evals]
evals_id = [id(d) for d in evals_data]
else:
evals_data = []
evals_name = []
evals_id = []
result = await map_worker_partitions(
client,
do_train,
# extra function parameters
params,
coll_args,
id(dtrain),
evals_name,
evals_id,
*([dtrain] + evals_data),
# workers to be used for training
workers=workers,
)
return result
@_deprecate_positional_args
def train( # pylint: disable=unused-argument
client: "distributed.Client",
params: Dict[str, Any],
dtrain: DaskDMatrix,
num_boost_round: int = 10,
*,
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]] = None,
obj: Optional[Objective] = None,
early_stopping_rounds: Optional[int] = None,
xgb_model: Optional[Booster] = None,
verbose_eval: Union[int, bool] = True,
callbacks: Optional[Sequence[TrainingCallback]] = None,
custom_metric: Optional[Metric] = None,
coll_cfg: Optional[CollConfig] = None,
) -> Any:
"""Train XGBoost model.
.. versionadded:: 1.0.0
.. note::
Other parameters are the same as :py:func:`xgboost.train` except for
`evals_result`, which is returned as part of function return value instead of
argument.
Parameters
----------
client :
Specify the dask client used for training. Use default client returned from
dask if it's set to None.
coll_cfg :
Configuration for the communicator used during training. See
:py:class:`~xgboost.collective.Config`.
Returns
-------
results: dict
A dictionary containing trained booster and evaluation history. `history` field
is the same as `eval_result` from `xgboost.train`.
.. code-block:: python
{'booster': xgboost.Booster,
'history': {'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}}}
"""
client = _get_client(client)
return client.sync(
_train_async,
global_config=config.get_config(),
dconfig=_get_dask_config(),
**locals(),
)
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
return is_df and len(output_shape) <= 2
def _maybe_dataframe(
data: Any, prediction: Any, columns: List[int], is_df: bool
) -> Any:
"""Return dataframe for prediction when applicable."""
if _can_output_df(is_df, prediction.shape):
# Need to preserve the index for dataframe.
# See issue: https://github.com/dmlc/xgboost/issues/6939
# In older versions of dask, the partition is actually a numpy array when input
# is dataframe.
index = getattr(data, "index", None)
if _is_cudf_df(data):
import cudf
if prediction.size == 0:
return cudf.DataFrame({}, columns=columns, dtype=numpy.float32)
prediction = cudf.DataFrame(
prediction, columns=columns, dtype=numpy.float32, index=index
)
else:
import pandas as pd
if prediction.size == 0:
return pd.DataFrame(
{}, columns=columns, dtype=numpy.float32, index=index
)
prediction = pd.DataFrame(
prediction, columns=columns, dtype=numpy.float32, index=index
)
return prediction
async def _direct_predict_impl( # pylint: disable=too-many-branches
*,
mapped_predict: Callable,
booster: "distributed.Future",
data: _DataT,
base_margin: Optional[_DaskCollection],
output_shape: Tuple[int, ...],
meta: Dict[int, str],
) -> _DaskCollection:
columns = tuple(meta.keys())
if len(output_shape) >= 3 and isinstance(data, dd.DataFrame):
# Without this check, dask will finish the prediction silently even if output
# dimension is greater than 3. But during map_partitions, dask passes a
# `dd.DataFrame` as local input to xgboost, which is converted to csr_matrix by
# `_convert_unknown_data` since dd.DataFrame is not known to xgboost native
# binding.
raise ValueError(
"Use `da.Array` or `DaskDMatrix` when output has more than 2 dimensions."
)
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
if base_margin is not None and isinstance(base_margin, da.Array):
# Easier for map_partitions
base_margin_df: Optional[Union[dd.DataFrame, dd.Series]] = (
base_margin.to_dask_dataframe()
)
else:
base_margin_df = base_margin
predictions = dd.map_partitions(
mapped_predict,
booster,
data,
True,
columns,
base_margin_df,
meta=dd.utils.make_meta(meta),
)
# classification can return a dataframe, drop 1 dim when it's reg/binary
if len(output_shape) == 1:
predictions = predictions.iloc[:, 0]
else:
if base_margin is not None and isinstance(
base_margin, (dd.Series, dd.DataFrame)
):
# Easier for map_blocks
base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
else:
base_margin_array = base_margin
# Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
# contrib)/3(contrib, interaction)/4(interaction) dims.
if len(output_shape) == 1:
drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim.
new_axis: Union[int, List[int]] = []
else:
drop_axis = []
if isinstance(data, dd.DataFrame):
new_axis = list(range(len(output_shape) - 2))
else:
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
if len(output_shape) == 2:
# Somehow dask fail to infer output shape change for 2-dim prediction, and
# `chunks = (None, output_shape[1])` doesn't work due to None is not
# supported in map_blocks.
# data must be an array here as dataframe + 2-dim output predict will return
# a dataframe instead.
chunks: Optional[List[Tuple]] = list(data.chunks)
assert isinstance(chunks, list)
chunks[1] = (output_shape[1],)
else:
chunks = None
predictions = da.map_blocks(
mapped_predict,
booster,
data,
False,
columns,
base_margin_array,
chunks=chunks,
drop_axis=drop_axis,
new_axis=new_axis,
dtype=numpy.float32,
)
return predictions
def _infer_predict_output(
booster: Booster, features: int, is_df: bool, inplace: bool, **kwargs: Any
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
"""Create a dummy test sample to infer output shape for prediction."""
assert isinstance(features, int)
rng = numpy.random.RandomState(1994)
test_sample = rng.randn(1, features)
if inplace:
kwargs = kwargs.copy()
if kwargs.pop("predict_type") == "margin":
kwargs["output_margin"] = True
m = DMatrix(test_sample, enable_categorical=True)
# generated DMatrix doesn't have feature name, so no validation.
test_predt = booster.predict(m, validate_features=False, **kwargs)
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
meta: Dict[int, str] = {}
if _can_output_df(is_df, test_predt.shape):
for i in range(n_columns):
meta[i] = "f4"
return test_predt.shape, meta
async def _get_model_future(
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
) -> "distributed.Future":
# See https://github.com/dask/dask/issues/11179#issuecomment-2168094529 for the use
# of hash.
# https://github.com/dask/distributed/pull/8796 Don't use broadcast in the `scatter`
# call, otherwise, the predict function might hang.
if isinstance(model, Booster):
booster = await client.scatter(model, hash=False)
elif isinstance(model, dict):
booster = await client.scatter(model["booster"], hash=False)
elif isinstance(model, distributed.Future):
booster = model
t = booster.type
if t is not Booster:
raise TypeError(
f"Underlying type of model future should be `Booster`, got {t}"
)
else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
return booster
# pylint: disable=too-many-statements
async def _predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"],
data: _DataT,
*,
output_margin: bool,
missing: float,
pred_leaf: bool,
pred_contribs: bool,
approx_contribs: bool,
pred_interactions: bool,
validate_features: bool,
iteration_range: IterationRange,
strict_shape: bool,
) -> _DaskCollection:
_booster = await _get_model_future(client, model)
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
def mapped_predict(
booster: Booster, partition: Any, is_df: bool, columns: List[int], _: Any
) -> Any:
with config.config_context(**global_config):
m = DMatrix(
data=partition,
missing=missing,
enable_categorical=True,
)
predt = booster.predict(
data=m,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features,
iteration_range=iteration_range,
strict_shape=strict_shape,
)
predt = _maybe_dataframe(partition, predt, columns, is_df)
return predt
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
_output_shape, meta = await client.compute(
client.submit(
_infer_predict_output,
_booster,
features=data.shape[1],
is_df=isinstance(data, dd.DataFrame),
inplace=False,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
strict_shape=strict_shape,
)
)
return await _direct_predict_impl(
mapped_predict=mapped_predict,
booster=_booster,
data=data,
base_margin=None,
output_shape=_output_shape,
meta=meta,
)
output_shape, _ = await client.compute(
client.submit(
_infer_predict_output,
booster=_booster,
features=data.num_col(),
is_df=False,
inplace=False,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
strict_shape=strict_shape,
)
)
# Prediction on dask DMatrix.
partition_order = data.partition_order
feature_names = data.feature_names
feature_types = data.feature_types
missing = data.missing
def dispatched_predict(booster: Booster, part: Dict[str, Any]) -> numpy.ndarray:
data = part["data"]
base_margin = part.get("base_margin", None)
with config.config_context(**global_config):
m = DMatrix(
data,
missing=missing,
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=True,
)
predt = booster.predict(
m,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features,
iteration_range=iteration_range,
strict_shape=strict_shape,
)
return predt
all_parts = []
all_orders = []
all_shapes = []
all_workers: List[str] = []
workers_address = list(data.worker_map.keys())
for worker_addr in workers_address:
list_of_parts = data.worker_map[worker_addr]
all_parts.extend(list_of_parts)
all_workers.extend(len(list_of_parts) * [worker_addr])
all_orders.extend([partition_order[part.key] for part in list_of_parts])
for w, part in zip(all_workers, all_parts):
s = client.submit(lambda part: part["data"].shape[0], part, workers=[w])
all_shapes.append(s)
parts_with_order = list(zip(all_parts, all_shapes, all_orders, all_workers))
parts_with_order = sorted(parts_with_order, key=lambda p: p[2])
all_parts = [part for part, shape, order, w in parts_with_order]
all_shapes = [shape for part, shape, order, w in parts_with_order]
all_workers = [w for part, shape, order, w in parts_with_order]
futures = []
for w, part in zip(all_workers, all_parts):
f = client.submit(dispatched_predict, _booster, part, workers=[w])
futures.append(f)
# Constructing a dask array from list of numpy arrays
# See https://docs.dask.org/en/latest/array-creation.html
arrays = []
all_shapes = await client.gather(all_shapes)
for i, rows in enumerate(all_shapes):
arrays.append(
da.from_delayed(
futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
)
)
predictions = da.concatenate(arrays, axis=0)
return predictions
@_deprecate_positional_args
def predict( # pylint: disable=unused-argument
client: Optional["distributed.Client"],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: Union[DaskDMatrix, _DataT],
*,
output_margin: bool = False,
missing: float = numpy.nan,
pred_leaf: bool = False,
pred_contribs: bool = False,
approx_contribs: bool = False,
pred_interactions: bool = False,
validate_features: bool = True,
iteration_range: IterationRange = (0, 0),
strict_shape: bool = False,
) -> Any:
"""Run prediction with a trained booster.
.. note::
Using ``inplace_predict`` might be faster when some features are not needed.
See :py:meth:`xgboost.Booster.predict` for details on various parameters. When
output has more than 2 dimensions (shap value, leaf with strict_shape), input
should be ``da.Array`` or ``DaskDMatrix``.
.. versionadded:: 1.0.0
Parameters
----------
client:
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model:
The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
data:
Input data used for prediction. When input is a dataframe object,
prediction output is a series.
missing:
Used when input data is not DaskDMatrix. Specify the value
considered as missing.
Returns
-------
prediction: dask.array.Array/dask.dataframe.Series
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is
an array, when input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
shape.
"""
client = _get_client(client)
return client.sync(_predict_async, global_config=config.get_config(), **locals())
async def _inplace_predict_async( # pylint: disable=too-many-branches
*,
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"],
data: _DataT,
iteration_range: IterationRange,
predict_type: str,
missing: float,
validate_features: bool,
base_margin: Optional[_DaskCollection],
strict_shape: bool,
) -> _DaskCollection:
client = _get_client(client)
booster = await _get_model_future(client, model)
if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
if base_margin is not None and not isinstance(
data, (da.Array, dd.DataFrame, dd.Series)
):
raise TypeError(_expect([da.Array, dd.DataFrame, dd.Series], type(base_margin)))
def mapped_predict(
booster: Booster,
partition: Any,
is_df: bool,
columns: List[int],
base_margin: Any,
) -> Any:
with config.config_context(**global_config):
prediction = booster.inplace_predict(
partition,
iteration_range=iteration_range,
predict_type=predict_type,
missing=missing,
base_margin=base_margin,
validate_features=validate_features,
strict_shape=strict_shape,
)
prediction = _maybe_dataframe(partition, prediction, columns, is_df)
return prediction
# await turns future into value.
shape, meta = await client.compute(
client.submit(
_infer_predict_output,
booster,
features=data.shape[1],
is_df=isinstance(data, dd.DataFrame),
inplace=True,
predict_type=predict_type,
iteration_range=iteration_range,
strict_shape=strict_shape,
)
)
return await _direct_predict_impl(
mapped_predict=mapped_predict,
booster=booster,
data=data,
base_margin=base_margin,
output_shape=shape,
meta=meta,
)
@_deprecate_positional_args
def inplace_predict( # pylint: disable=unused-argument
client: Optional["distributed.Client"],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DataT,
*,
iteration_range: IterationRange = (0, 0),
predict_type: str = "value",
missing: float = numpy.nan,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None,
strict_shape: bool = False,
) -> Any:
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for
details.
.. versionadded:: 1.1.0
Parameters
----------
client:
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model:
See :py:func:`xgboost.dask.predict` for details.
data :
dask collection.
iteration_range:
See :py:meth:`xgboost.Booster.predict` for details.
predict_type:
See :py:meth:`xgboost.Booster.inplace_predict` for details.
missing:
Value in the input data which needs to be present as a missing
value. If None, defaults to np.nan.
base_margin:
See :py:obj:`xgboost.DMatrix` for details.
.. versionadded:: 1.4.0
strict_shape:
See :py:meth:`xgboost.Booster.predict` for details.
.. versionadded:: 1.4.0
Returns
-------
prediction :
When input data is ``dask.array.Array``, the return value is an array, when
input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
shape.
"""
client = _get_client(client)
# When used in asynchronous environment, the `client` object should have
# `asynchronous` attribute as True. When invoked by the skl interface, it's
# responsible for setting up the client.
return client.sync(
_inplace_predict_async, global_config=config.get_config(), **locals()
)
async def _async_wrap_evaluation_matrices(
client: Optional["distributed.Client"],
device: Optional[str],
tree_method: Optional[str],
max_bin: Optional[int],
**kwargs: Any,
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
"""A switch function for async environment."""
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
if _can_use_qdm(tree_method, device):
return DaskQuantileDMatrix(
client=client, ref=ref, max_bin=max_bin, **kwargs
)
return DaskDMatrix(client=client, **kwargs)
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_dispatch, **kwargs)
train_dmatrix = await train_dmatrix
if evals is None:
return train_dmatrix, evals
awaited = []
for e in evals:
if e[0] is train_dmatrix: # already awaited
awaited.append(e)
continue
awaited.append((await e[0], e[1]))
return train_dmatrix, awaited
@contextmanager
def _set_worker_client(
model: "DaskScikitLearnBase", client: "distributed.Client"
) -> Generator:
"""Temporarily set the client for sklearn model."""
try:
model.client = client
yield model
finally:
model.client = None # type:ignore
class DaskScikitLearnBase(XGBModel):
"""Base class for implementing scikit-learn interface with Dask"""
_client = None
def __init__(self, *, coll_cfg: Optional[CollConfig] = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.coll_cfg = coll_cfg
async def _predict_async(
self,
data: _DataT,
*,
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[IterationRange],
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
# Dask doesn't support gblinear and accepts only Dask collection types (array
# and dataframe). We can perform inplace predict.
assert self._can_use_inplace_predict()
predts = await inplace_predict(
client=self.client,
model=self.get_booster(),
data=data,
iteration_range=iteration_range,
predict_type="margin" if output_margin else "value",
missing=self.missing,
base_margin=base_margin,
validate_features=validate_features,
)
if isinstance(predts, dd.DataFrame):
predts = predts.to_dask_array()
# Make sure the booster is part of the task graph implicitly
# only needed for certain versions of dask.
if _DASK_2024_12_1() and not _DASK_2025_3_0():
# Fixes this issue for dask>=2024.1.1,<2025.3.0
# Dask==2025.3.0 fails with:
# RuntimeError: Attempting to use an asynchronous
# Client in a synchronous context of `dask.compute`
#
# Dask==2025.4.0 fails with:
# TypeError: Value type is not supported for data
# iterator:<class 'distributed.client.Future'>
predts = predts.persist()
return predts
@_deprecate_positional_args
def predict(
self,
X: _DataT,
*,
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
return self.client.sync(
self._predict_async,
X,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin,
iteration_range=iteration_range,
)
async def _apply_async(
self,
X: _DataT,
iteration_range: Optional[IterationRange] = None,
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix: DaskDMatrix = await DaskDMatrix(
self.client,
data=X,
missing=self.missing,
feature_types=self.feature_types,
)
predts = await predict(
self.client,
model=self.get_booster(),
data=test_dmatrix,
pred_leaf=True,
iteration_range=iteration_range,
)
return predts
def apply(
self,
X: _DataT,
iteration_range: Optional[IterationRange] = None,
) -> Any:
return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
def __await__(self) -> Awaitable[Any]:
# Generate a coroutine wrapper to make this class awaitable.
async def _() -> Awaitable[Any]:
return self
return self._client_sync(_).__await__()
def __getstate__(self) -> Dict:
this = self.__dict__.copy()
if "_client" in this:
del this["_client"]
return this
@property
def client(self) -> "distributed.Client":
"""The dask client used in this model. The `Client` object can not be
serialized for transmission, so if task is launched from a worker instead of
directly from the client process, this attribute needs to be set at that worker.
"""
client = _get_client(self._client)
return client
@client.setter
def client(self, clt: "distributed.Client") -> None:
# calling `worker_client' doesn't return the correct `asynchronous` attribute,
# so we have to pass it ourselves.
self._asynchronous = clt.asynchronous if clt is not None else False
self._client = clt
def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
"""Get the correct client, when method is invoked inside a worker we
should use `worker_client' instead of default client.
"""
if self._client is None:
asynchronous = getattr(self, "_asynchronous", False)
try:
distributed.get_worker()
in_worker = True
except ValueError:
in_worker = False
if in_worker:
with distributed.worker_client() as client:
with _set_worker_client(self, client) as this:
ret = this.client.sync(
func, **kwargs, asynchronous=asynchronous
)
return ret
return ret
return self.client.sync(func, **kwargs, asynchronous=self.client.asynchronous)
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
)
class DaskXGBRegressor(XGBRegressorBase, DaskScikitLearnBase):
"""dummy doc string to workaround pylint, replaced by the decorator."""
async def _fit_async(
self,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
verbose: Union[int, bool],
xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection],
) -> _DaskCollection:
params = self.get_xgb_params()
model, metric, params, feature_weights = self._configure_fit(
xgb_model, params, feature_weights
)
dtrain, evals = await _async_wrap_evaluation_matrices(
client=self.client,
device=self.device,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
y=y,
group=None,
qid=None,
sample_weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
eval_set=eval_set,
sample_weight_eval_set=sample_weight_eval_set,
base_margin_eval_set=base_margin_eval_set,
eval_group=None,
eval_qid=None,
missing=self.missing,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
if callable(self.objective):
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client,
global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=self.early_stopping_rounds,
callbacks=self.callbacks,
coll_cfg=self.coll_cfg,
xgb_model=model,
)
self._Booster = results["booster"]
self._set_evaluation_result(results["history"])
return self
# pylint: disable=missing-docstring, disable=unused-argument
@_deprecate_positional_args
def fit(
self,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
verbose: Optional[Union[int, bool]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
) -> "DaskXGBRegressor":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
return self._client_sync(self._fit_async, **args)
@xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost classification.",
["estimators", "model"],
)
class DaskXGBClassifier(XGBClassifierBase, DaskScikitLearnBase):
# pylint: disable=missing-class-docstring
async def _fit_async(
self,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
verbose: Union[int, bool],
xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection],
) -> "DaskXGBClassifier":
params = self.get_xgb_params()
model, metric, params, feature_weights = self._configure_fit(
xgb_model, params, feature_weights
)
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
device=self.device,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
y=y,
group=None,
qid=None,
sample_weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
eval_set=eval_set,
sample_weight_eval_set=sample_weight_eval_set,
base_margin_eval_set=base_margin_eval_set,
eval_group=None,
eval_qid=None,
missing=self.missing,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
# pylint: disable=attribute-defined-outside-init
if isinstance(y, da.Array):
self.classes_ = await self.client.compute(da.unique(y))
else:
self.classes_ = await self.client.compute(y.drop_duplicates())
if _is_cudf_ser(self.classes_):
self.classes_ = self.classes_.to_cupy()
if _is_cupy_alike(self.classes_):
self.classes_ = self.classes_.get()
self.classes_ = numpy.array(self.classes_)
self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2:
params["objective"] = "multi:softprob"
params["num_class"] = self.n_classes_
else:
params["objective"] = "binary:logistic"
if callable(self.objective):
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client,
global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=self.early_stopping_rounds,
callbacks=self.callbacks,
coll_cfg=self.coll_cfg,
xgb_model=model,
)
self._Booster = results["booster"]
if not callable(self.objective):
self.objective = params["objective"]
self._set_evaluation_result(results["history"])
return self
# pylint: disable=unused-argument
def fit(
self,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
verbose: Optional[Union[int, bool]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
) -> "DaskXGBClassifier":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
return self._client_sync(self._fit_async, **args)
async def _predict_proba_async(
self,
X: _DataT,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[IterationRange],
) -> _DaskCollection:
if self.objective == "multi:softmax":
raise ValueError(
"multi:softmax doesn't support `predict_proba`. "
"Switch to `multi:softproba` instead"
)
predts = await super()._predict_async(
data=X,
output_margin=False,
validate_features=validate_features,
base_margin=base_margin,
iteration_range=iteration_range,
)
vstack = update_wrapper(
partial(da.vstack, allow_unknown_chunksizes=True), da.vstack
)
return _cls_predict_proba(getattr(self, "n_classes_", 0), predts, vstack)
# pylint: disable=missing-function-docstring
def predict_proba(
self,
X: _DaskCollection,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
return self._client_sync(
self._predict_proba_async,
X=X,
validate_features=validate_features,
base_margin=base_margin,
iteration_range=iteration_range,
)
predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__
async def _predict_async(
self,
data: _DataT,
*,
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[IterationRange],
) -> _DaskCollection:
pred_probs = await super()._predict_async(
data,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin,
iteration_range=iteration_range,
)
if output_margin:
return pred_probs
if len(pred_probs.shape) == 1:
preds = (pred_probs > 0.5).astype(int)
else:
assert len(pred_probs.shape) == 2
assert isinstance(pred_probs, da.Array)
# when using da.argmax directly, dask will construct a numpy based return
# array, which runs into error when computing GPU based prediction.
def _argmax(x: Any) -> Any:
return x.argmax(axis=1)
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
return preds
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
.. versionadded:: 1.4.0
""",
["estimators", "model"],
extra_parameters="""
allow_group_split :
.. versionadded:: 3.0.0
Whether a query group can be split among multiple workers. When set to `False`,
inputs must be Dask dataframes or series. If you have many small query groups,
this can significantly increase the fragmentation of the data, and the internal
DMatrix construction can take longer.
""",
end_note="""
.. note::
For the dask implementation, group is not supported, use qid instead.
""",
)
class DaskXGBRanker(XGBRankerMixIn, DaskScikitLearnBase):
@_deprecate_positional_args
def __init__(
self,
*,
objective: str = "rank:ndcg",
allow_group_split: bool = False,
coll_cfg: Optional[CollConfig] = None,
**kwargs: Any,
) -> None:
if callable(objective):
raise ValueError("Custom objective function not supported by XGBRanker.")
self.allow_group_split = allow_group_split
super().__init__(objective=objective, coll_cfg=coll_cfg, **kwargs)
def _wrapper_params(self) -> Set[str]:
params = super()._wrapper_params()
params.add("allow_group_split")
return params
async def _fit_async(
self,
X: _DataT,
y: _DaskCollection,
*,
qid: Optional[_DaskCollection],
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
eval_qid: Optional[Sequence[_DaskCollection]],
verbose: Union[int, bool],
xgb_model: Optional[Union[XGBModel, Booster]],
feature_weights: Optional[_DaskCollection],
) -> "DaskXGBRanker":
params = self.get_xgb_params()
model, metric, params, feature_weights = self._configure_fit(
xgb_model, params, feature_weights
)
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
device=self.device,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
y=y,
group=None,
qid=qid,
sample_weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
eval_set=eval_set,
sample_weight_eval_set=sample_weight_eval_set,
base_margin_eval_set=base_margin_eval_set,
eval_group=None,
eval_qid=eval_qid,
missing=self.missing,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client,
global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=self.early_stopping_rounds,
callbacks=self.callbacks,
xgb_model=model,
coll_cfg=self.coll_cfg,
)
self._Booster = results["booster"]
self.evals_result_ = results["history"]
return self
# pylint: disable=unused-argument, arguments-differ
@_deprecate_positional_args
def fit(
self,
X: _DataT,
y: _DaskCollection,
*,
group: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_group: Optional[Sequence[_DaskCollection]] = None,
eval_qid: Optional[Sequence[_DaskCollection]] = None,
verbose: Optional[Union[int, bool]] = False,
xgb_model: Optional[Union[XGBModel, str, Booster]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
) -> "DaskXGBRanker":
msg = "Use the `qid` instead of the `group` with the dask interface."
if not (group is None and eval_group is None):
raise ValueError(msg)
if qid is None:
raise ValueError("`qid` is required for ranking.")
def check_df(X: _DaskCollection) -> TypeGuard[dd.DataFrame]:
if not isinstance(X, dd.DataFrame):
raise TypeError(
"When `allow_group_split` is set to False, X is required to be"
" a dataframe."
)
return True
def check_ser(
qid: Optional[_DaskCollection], name: str
) -> TypeGuard[Optional[dd.Series]]:
if not isinstance(qid, dd.Series) and qid is not None:
raise TypeError(
f"When `allow_group_split` is set to False, {name} is required to "
"be a series."
)
return True
if not self.allow_group_split:
assert (
check_df(X)
and check_ser(qid, "qid")
and check_ser(y, "y")
and check_ser(sample_weight, "sample_weight")
and check_ser(base_margin, "base_margin")
)
assert qid is not None and y is not None
X_id = id(X)
X, qid, y, sample_weight, base_margin = no_group_split(
self.device,
X,
qid,
y=y,
sample_weight=sample_weight,
base_margin=base_margin,
)
if eval_set is not None:
new_eval_set = []
new_eval_qid = []
new_sample_weight_eval_set = []
new_base_margin_eval_set = []
assert eval_qid
for i, (Xe, ye) in enumerate(eval_set):
we = sample_weight_eval_set[i] if sample_weight_eval_set else None
be = base_margin_eval_set[i] if base_margin_eval_set else None
assert check_df(Xe)
assert eval_qid
qe = eval_qid[i]
assert (
eval_qid
and check_ser(qe, "qid")
and check_ser(ye, "y")
and check_ser(we, "sample_weight")
and check_ser(be, "base_margin")
)
assert qe is not None and ye is not None
if id(Xe) != X_id:
Xe, qe, ye, we, be = no_group_split(
self.device, Xe, qe, ye, we, be
)
else:
Xe, qe, ye, we, be = X, qid, y, sample_weight, base_margin
new_eval_set.append((Xe, ye))
new_eval_qid.append(qe)
if we is not None:
new_sample_weight_eval_set.append(we)
if be is not None:
new_base_margin_eval_set.append(be)
eval_set = new_eval_set
eval_qid = new_eval_qid
sample_weight_eval_set = (
new_sample_weight_eval_set if new_sample_weight_eval_set else None
)
base_margin_eval_set = (
new_base_margin_eval_set if new_base_margin_eval_set else None
)
return self._client_sync(
self._fit_async,
X=X,
y=y,
qid=qid,
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
eval_qid=eval_qid,
verbose=verbose,
xgb_model=xgb_model,
sample_weight_eval_set=sample_weight_eval_set,
base_margin_eval_set=base_margin_eval_set,
feature_weights=feature_weights,
)
# FIXME(trivialfis): arguments differ due to additional parameters like group and
# qid.
fit.__doc__ = XGBRanker.fit.__doc__
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.
.. versionadded:: 1.4.0
""",
["model", "objective"],
extra_parameters="""
n_estimators : int
Number of trees in random forest to fit.
""",
)
class DaskXGBRFRegressor(DaskXGBRegressor):
@_deprecate_positional_args
def __init__(
self,
*,
learning_rate: Optional[float] = 1,
subsample: Optional[float] = 0.8,
colsample_bynode: Optional[float] = 0.8,
reg_lambda: Optional[float] = 1e-5,
coll_cfg: Optional[CollConfig] = None,
**kwargs: Any,
) -> None:
super().__init__(
learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
coll_cfg=coll_cfg,
**kwargs,
)
def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
params["num_parallel_tree"] = self.n_estimators
return params
def get_num_boosting_rounds(self) -> int:
return 1
# pylint: disable=unused-argument
def fit(
self,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
verbose: Optional[Union[int, bool]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
) -> "DaskXGBRFRegressor":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args)
return self
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
.. versionadded:: 1.4.0
""",
["model", "objective"],
extra_parameters="""
n_estimators : int
Number of trees in random forest to fit.
""",
)
class DaskXGBRFClassifier(DaskXGBClassifier):
@_deprecate_positional_args
def __init__(
self,
*,
learning_rate: Optional[float] = 1,
subsample: Optional[float] = 0.8,
colsample_bynode: Optional[float] = 0.8,
reg_lambda: Optional[float] = 1e-5,
coll_cfg: Optional[CollConfig] = None,
**kwargs: Any,
) -> None:
super().__init__(
learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
coll_cfg=coll_cfg,
**kwargs,
)
def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
params["num_parallel_tree"] = self.n_estimators
return params
def get_num_boosting_rounds(self) -> int:
return 1
# pylint: disable=unused-argument
def fit(
self,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
verbose: Optional[Union[int, bool]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
) -> "DaskXGBRFClassifier":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(self.early_stopping_rounds, self.callbacks)
super().fit(**args)
return self