36 lines
1013 B
Python
36 lines
1013 B
Python
"""Helpers for test code."""
|
|
|
|
from typing import Any, Literal, TypeAlias
|
|
|
|
import numpy as np
|
|
|
|
from ..compat import import_cupy
|
|
from ..core import DMatrix
|
|
from ..data import _is_cupy_alike
|
|
|
|
Device: TypeAlias = Literal["cpu", "cuda"]
|
|
|
|
|
|
def assert_allclose(
|
|
device: Device, a: Any, b: Any, *, rtol: float = 1e-7, atol: float = 0
|
|
) -> None:
|
|
"""Dispatch the assert_allclose for devices."""
|
|
if device == "cpu" and not _is_cupy_alike(a) and not _is_cupy_alike(b):
|
|
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)
|
|
else:
|
|
cp = import_cupy()
|
|
cp.testing.assert_allclose(a, b, atol=atol, rtol=rtol)
|
|
|
|
|
|
def predictor_equal(lhs: DMatrix, rhs: DMatrix) -> bool:
|
|
"""Assert whether two DMatrices contain the same predictors."""
|
|
lcsr = lhs.get_data()
|
|
rcsr = rhs.get_data()
|
|
return all(
|
|
(
|
|
np.array_equal(lcsr.data, rcsr.data),
|
|
np.array_equal(lcsr.indices, rcsr.indices),
|
|
np.array_equal(lcsr.indptr, rcsr.indptr),
|
|
)
|
|
)
|