Files
MLPproject/.venv/lib/python3.12/site-packages/xgboost/testing/federated.py
2025-10-23 15:44:32 +02:00

156 lines
5.2 KiB
Python

# pylint: disable=unbalanced-tuple-unpacking, too-many-locals
"""Tests for federated learning."""
import multiprocessing
import os
import subprocess
import tempfile
import time
from typing import List, cast
from sklearn.datasets import dump_svmlight_file, load_svmlight_file
from sklearn.model_selection import train_test_split
import xgboost as xgb
import xgboost.federated
from xgboost import testing as tm
from .._typing import EvalsLog
from ..collective import _Args as CollArgs
SERVER_KEY = "server-key.pem"
SERVER_CERT = "server-cert.pem"
CLIENT_KEY = "client-key.pem"
CLIENT_CERT = "client-cert.pem"
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
"""Run federated server for test."""
if with_ssl:
xgboost.federated.run_federated_server(
world_size,
port,
server_key_path=SERVER_KEY,
server_cert_path=SERVER_CERT,
client_cert_path=CLIENT_CERT,
)
else:
xgboost.federated.run_federated_server(world_size, port)
def run_worker(
port: int, world_size: int, rank: int, with_ssl: bool, device: str
) -> None:
"""Run federated client worker for test."""
comm_env: CollArgs = {
"dmlc_communicator": "federated",
"federated_server_address": f"localhost:{port}",
"federated_world_size": world_size,
"federated_rank": rank,
}
if with_ssl:
comm_env["federated_server_cert_path"] = SERVER_CERT
comm_env["federated_client_key_path"] = CLIENT_KEY
comm_env["federated_client_cert_path"] = CLIENT_CERT
cpu_count = os.cpu_count()
assert cpu_count is not None
n_threads = cpu_count // world_size
# Always call this before using distributed module
with xgb.collective.CommunicatorContext(**comm_env):
# Load file, file will not be sharded in federated mode.
X, y = load_svmlight_file(f"agaricus.txt-{rank}.train")
dtrain = xgb.DMatrix(X, y)
X, y = load_svmlight_file(f"agaricus.txt-{rank}.test")
dtest = xgb.DMatrix(X, y)
# Specify parameters via map, definition are same as c++ version
param = {
"max_depth": 2,
"eta": 1,
"objective": "binary:logistic",
"nthread": n_threads,
"tree_method": "hist",
"device": device,
}
# Specify validations set to watch performance
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 20
# Run training, all the features in training API is available.
results: EvalsLog = {}
bst = xgb.train(
param,
dtrain,
num_round,
evals=watchlist,
early_stopping_rounds=2,
evals_result=results,
)
assert tm.non_increasing(cast(List[float], results["train"]["logloss"]))
assert tm.non_increasing(cast(List[float], results["eval"]["logloss"]))
# save the model, only ask process 0 to save the model.
if xgb.collective.get_rank() == 0:
with tempfile.TemporaryDirectory() as tmpdir:
bst.save_model(os.path.join(tmpdir, "model.json"))
xgb.collective.communicator_print("Finished training\n")
def run_federated(world_size: int, with_ssl: bool, use_gpu: bool) -> None:
"""Launcher for clients and the server."""
port = 9091
server = multiprocessing.Process(
target=run_server, args=(port, world_size, with_ssl)
)
server.start()
time.sleep(1)
if not server.is_alive():
raise ValueError("Error starting Federated Learning server")
workers = []
for rank in range(world_size):
device = f"cuda:{rank}" if use_gpu else "cpu"
worker = multiprocessing.Process(
target=run_worker, args=(port, world_size, rank, with_ssl, device)
)
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
server.terminate()
def run_federated_learning(with_ssl: bool, use_gpu: bool, test_path: str) -> None:
"""Run federated learning tests."""
n_workers = 2
if with_ssl:
command = "openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout {part}-key.pem -out {part}-cert.pem -subj /C=US/CN=localhost" # pylint: disable=line-too-long
server_key = command.format(part="server").split()
subprocess.check_call(server_key)
client_key = command.format(part="client").split()
subprocess.check_call(client_key)
train_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.train")
test_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.test")
X_train, y_train = load_svmlight_file(train_path)
X_test, y_test = load_svmlight_file(test_path)
X0, X1, y0, y1 = train_test_split(X_train, y_train, test_size=0.5)
X0_valid, X1_valid, y0_valid, y1_valid = train_test_split(
X_test, y_test, test_size=0.5
)
dump_svmlight_file(X0, y0, "agaricus.txt-0.train")
dump_svmlight_file(X0_valid, y0_valid, "agaricus.txt-0.test")
dump_svmlight_file(X1, y1, "agaricus.txt-1.train")
dump_svmlight_file(X1_valid, y1_valid, "agaricus.txt-1.test")
run_federated(world_size=n_workers, with_ssl=with_ssl, use_gpu=use_gpu)