Files
MLPproject/.venv/lib/python3.12/site-packages/xgboost/dask/utils.py
2025-10-23 15:44:32 +02:00

120 lines
3.5 KiB
Python

# pylint: disable=invalid-name
"""Utilities for the XGBoost Dask interface."""
import logging
import warnings
from functools import cache as fcache
from typing import Any, Dict, Optional, Tuple
import dask
import distributed
from packaging.version import Version
from packaging.version import parse as parse_version
from ..collective import Config
LOGGER = logging.getLogger("[xgboost.dask]")
def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> int:
"""Get the number of threads from a worker and the user-supplied parameters."""
# dask worker nthreads, "state" is available in 2022.6.1
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
n_threads = None
for p in ["nthread", "n_jobs"]:
if local_param.get(p, None) is not None and local_param.get(p, dwnt) != dwnt:
LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p]
break
if n_threads == 0 or n_threads is None:
n_threads = dwnt
return n_threads
def get_address_from_user(
dconfig: Optional[Dict[str, Any]], coll_cfg: Config
) -> Tuple[Optional[str], int]:
"""Get the tracker address from the optional user configuration.
Parameters
----------
dconfig :
Dask global configuration.
coll_cfg :
Collective configuration.
Returns
-------
The IP address along with the port number.
"""
valid_config = ["scheduler_address"]
host_ip = None
port = 0
if dconfig is not None:
for k in dconfig:
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
warnings.warn(
(
"Use `coll_cfg` instead of the Dask global configuration store"
f" for the XGBoost tracker configuration: {k}."
),
FutureWarning,
)
else:
dconfig = {}
host_ip = dconfig.get("scheduler_address", None)
if host_ip is not None and host_ip.startswith("[") and host_ip.endswith("]"):
# convert dask bracket format to proper IPv6 address.
host_ip = host_ip[1:-1]
if host_ip is not None:
try:
host_ip, port = distributed.comm.get_address_host_port(host_ip)
except ValueError:
pass
if coll_cfg is None:
coll_cfg = Config()
if coll_cfg.tracker_host_ip is not None:
if host_ip is not None and coll_cfg.tracker_host_ip != host_ip:
raise ValueError(
"Conflicting host IP addresses from the dask configuration and the "
f"collective configuration: {host_ip} v.s. {coll_cfg.tracker_host_ip}."
)
host_ip = coll_cfg.tracker_host_ip
if coll_cfg.tracker_port is not None:
if (
port != 0
and port is not None
and coll_cfg.tracker_port != 0
and port != coll_cfg.tracker_port
):
raise ValueError(
"Conflicting ports from the dask configuration and the "
f"collective configuration: {port} v.s. {coll_cfg.tracker_port}."
)
port = coll_cfg.tracker_port
return host_ip, port
@fcache
def _DASK_VERSION() -> Version:
return parse_version(dask.__version__)
@fcache
def _DASK_2024_12_1() -> bool:
return _DASK_VERSION() >= parse_version("2024.12.1")
@fcache
def _DASK_2025_3_0() -> bool:
return _DASK_VERSION() >= parse_version("2025.3.0")