dask-cuda 25.6.0__py3-none-any.whl → 25.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_cuda/GIT_COMMIT +1 -1
- dask_cuda/VERSION +1 -1
- dask_cuda/benchmarks/common.py +4 -1
- dask_cuda/benchmarks/local_cudf_groupby.py +3 -0
- dask_cuda/benchmarks/local_cudf_merge.py +4 -1
- dask_cuda/benchmarks/local_cudf_shuffle.py +4 -1
- dask_cuda/benchmarks/local_cupy.py +3 -0
- dask_cuda/benchmarks/local_cupy_map_overlap.py +3 -0
- dask_cuda/benchmarks/utils.py +6 -3
- dask_cuda/cli.py +21 -15
- dask_cuda/cuda_worker.py +28 -58
- dask_cuda/device_host_file.py +31 -15
- dask_cuda/disk_io.py +7 -4
- dask_cuda/explicit_comms/comms.py +11 -7
- dask_cuda/explicit_comms/dataframe/shuffle.py +23 -23
- dask_cuda/get_device_memory_objects.py +4 -7
- dask_cuda/initialize.py +149 -94
- dask_cuda/local_cuda_cluster.py +52 -70
- dask_cuda/plugins.py +17 -16
- dask_cuda/proxify_device_objects.py +12 -10
- dask_cuda/proxify_host_file.py +30 -27
- dask_cuda/proxy_object.py +20 -17
- dask_cuda/tests/conftest.py +41 -0
- dask_cuda/tests/test_cudf_builtin_spilling.py +3 -1
- dask_cuda/tests/test_dask_cuda_worker.py +109 -25
- dask_cuda/tests/test_dask_setup.py +193 -0
- dask_cuda/tests/test_dgx.py +20 -44
- dask_cuda/tests/test_explicit_comms.py +31 -12
- dask_cuda/tests/test_from_array.py +4 -6
- dask_cuda/tests/test_initialize.py +233 -65
- dask_cuda/tests/test_local_cuda_cluster.py +129 -68
- dask_cuda/tests/test_proxify_host_file.py +28 -7
- dask_cuda/tests/test_proxy.py +15 -13
- dask_cuda/tests/test_spill.py +10 -3
- dask_cuda/tests/test_utils.py +100 -29
- dask_cuda/tests/test_worker_spec.py +6 -0
- dask_cuda/utils.py +211 -42
- dask_cuda/utils_test.py +10 -7
- dask_cuda/worker_common.py +196 -0
- dask_cuda/worker_spec.py +6 -1
- {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/METADATA +11 -4
- dask_cuda-25.10.0.dist-info/RECORD +63 -0
- dask_cuda-25.10.0.dist-info/top_level.txt +6 -0
- shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
- shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
- shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
- dask_cuda-25.6.0.dist-info/RECORD +0 -57
- dask_cuda-25.6.0.dist-info/top_level.txt +0 -4
- {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/WHEEL +0 -0
- {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/entry_points.txt +0 -0
- {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from unittest.mock import Mock, patch
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
|
|
12
|
+
from distributed import Client
|
|
13
|
+
from distributed.utils import open_port
|
|
14
|
+
from distributed.utils_test import popen
|
|
15
|
+
|
|
16
|
+
from dask_cuda.initialize import dask_setup
|
|
17
|
+
from dask_cuda.utils import wait_workers
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_dask_setup_function_with_mock_worker():
|
|
21
|
+
"""Test the dask_setup function directly with mock worker."""
|
|
22
|
+
# Create a mock worker object
|
|
23
|
+
mock_worker = Mock()
|
|
24
|
+
|
|
25
|
+
with patch("dask_cuda.initialize._create_cuda_context") as mock_create_context:
|
|
26
|
+
# Test with create_cuda_context=True
|
|
27
|
+
# Call the underlying function directly (the Click decorator wraps the real
|
|
28
|
+
# function)
|
|
29
|
+
dask_setup.callback(
|
|
30
|
+
worker=mock_worker,
|
|
31
|
+
create_cuda_context=True,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
mock_create_context.assert_called_once_with()
|
|
35
|
+
|
|
36
|
+
mock_create_context.reset_mock()
|
|
37
|
+
|
|
38
|
+
# Test with create_cuda_context=False
|
|
39
|
+
dask_setup.callback(
|
|
40
|
+
worker=mock_worker,
|
|
41
|
+
create_cuda_context=False,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
mock_create_context.assert_not_called()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@contextmanager
|
|
48
|
+
def start_dask_scheduler(protocol: str, max_attempts: int = 5, timeout: int = 10):
|
|
49
|
+
"""Start Dask scheduler in subprocess.
|
|
50
|
+
|
|
51
|
+
Attempts to start a Dask scheduler in subprocess, if the port is not available
|
|
52
|
+
retry on a different port up to a maximum of `max_attempts` attempts. The stdout
|
|
53
|
+
and stderr of the process is read to determine whether the scheduler failed to
|
|
54
|
+
bind to port or succeeded, and ensures no more than `timeout` seconds are awaited
|
|
55
|
+
for between reads.
|
|
56
|
+
|
|
57
|
+
This is primarily useful because UCX does not release TCP ports immediately. A
|
|
58
|
+
workaround without the need for this function is setting `UCX_TCP_CM_REUSEADDR=y`,
|
|
59
|
+
but that requires to be explicitly set when running tests, and that is not very
|
|
60
|
+
friendly.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
protocol: str
|
|
65
|
+
Communication protocol to use.
|
|
66
|
+
max_attempts: int
|
|
67
|
+
Maximum attempts to try to open scheduler.
|
|
68
|
+
timeout: int
|
|
69
|
+
Time to wait while reading stdout/stderr of subprocess.
|
|
70
|
+
"""
|
|
71
|
+
port = open_port()
|
|
72
|
+
for _ in range(max_attempts):
|
|
73
|
+
with popen(
|
|
74
|
+
[
|
|
75
|
+
"dask",
|
|
76
|
+
"scheduler",
|
|
77
|
+
"--no-dashboard",
|
|
78
|
+
"--protocol",
|
|
79
|
+
protocol,
|
|
80
|
+
"--port",
|
|
81
|
+
str(port),
|
|
82
|
+
],
|
|
83
|
+
capture_output=True, # Capture stdout and stderr
|
|
84
|
+
) as scheduler_process:
|
|
85
|
+
# Check if the scheduler process started successfully by streaming output
|
|
86
|
+
try:
|
|
87
|
+
start_time = time.monotonic()
|
|
88
|
+
while True:
|
|
89
|
+
if time.monotonic() - start_time > timeout:
|
|
90
|
+
raise TimeoutError("Timeout while waiting for scheduler output")
|
|
91
|
+
|
|
92
|
+
line = scheduler_process.stdout.readline()
|
|
93
|
+
if not line:
|
|
94
|
+
break # End of output
|
|
95
|
+
print(
|
|
96
|
+
line.decode(), end=""
|
|
97
|
+
) # Since capture_output=True, print the line here
|
|
98
|
+
if b"Scheduler at:" in line:
|
|
99
|
+
# Scheduler is now listening
|
|
100
|
+
break
|
|
101
|
+
elif b"UCXXBusyError" in line:
|
|
102
|
+
raise Exception("UCXXBusyError detected in scheduler output")
|
|
103
|
+
except Exception:
|
|
104
|
+
port += 1
|
|
105
|
+
else:
|
|
106
|
+
yield scheduler_process, port
|
|
107
|
+
return
|
|
108
|
+
else:
|
|
109
|
+
pytest.fail(f"Failed to start dask scheduler after {max_attempts} attempts.")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@pytest.mark.timeout(30)
|
|
113
|
+
@patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
|
|
114
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
|
|
115
|
+
def test_dask_cuda_worker_cli_integration(protocol, tmp_path):
|
|
116
|
+
"""Test that dask cuda worker CLI correctly passes arguments to dask_setup.
|
|
117
|
+
|
|
118
|
+
Verifies the end-to-end integration where the CLI tool actually launches and calls
|
|
119
|
+
dask_setup with correct args.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
# Use pytest's tmp_path for file management
|
|
123
|
+
capture_file_path = tmp_path / "dask_setup_integration_test.json"
|
|
124
|
+
preload_file = tmp_path / "preload_capture.py"
|
|
125
|
+
|
|
126
|
+
# Write the preload script to tmp_path
|
|
127
|
+
preload_file.write_text(
|
|
128
|
+
f'''
|
|
129
|
+
import json
|
|
130
|
+
import os
|
|
131
|
+
|
|
132
|
+
def capture_dask_setup_call(worker, create_cuda_context):
|
|
133
|
+
"""Capture dask_setup arguments and write to file."""
|
|
134
|
+
result = {{
|
|
135
|
+
'worker_protocol': getattr(worker, '_protocol', 'unknown'),
|
|
136
|
+
'create_cuda_context': create_cuda_context,
|
|
137
|
+
'test_success': True
|
|
138
|
+
}}
|
|
139
|
+
|
|
140
|
+
# Write immediately to ensure it gets captured
|
|
141
|
+
with open(r"{capture_file_path}", 'w') as f:
|
|
142
|
+
json.dump(result, f)
|
|
143
|
+
|
|
144
|
+
# Patch dask_setup callback
|
|
145
|
+
from dask_cuda.initialize import dask_setup
|
|
146
|
+
dask_setup.callback = capture_dask_setup_call
|
|
147
|
+
'''
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
with start_dask_scheduler(protocol=protocol) as scheduler_process_port:
|
|
151
|
+
scheduler_process, scheduler_port = scheduler_process_port
|
|
152
|
+
sched_addr = f"{protocol}://127.0.0.1:{scheduler_port}"
|
|
153
|
+
print(f"{sched_addr=}", flush=True)
|
|
154
|
+
|
|
155
|
+
# Build dask cuda worker args
|
|
156
|
+
dask_cuda_worker_args = [
|
|
157
|
+
"dask",
|
|
158
|
+
"cuda",
|
|
159
|
+
"worker",
|
|
160
|
+
sched_addr,
|
|
161
|
+
"--host",
|
|
162
|
+
"127.0.0.1",
|
|
163
|
+
"--no-dashboard",
|
|
164
|
+
"--preload",
|
|
165
|
+
str(preload_file),
|
|
166
|
+
"--death-timeout",
|
|
167
|
+
"10",
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
with popen(dask_cuda_worker_args):
|
|
171
|
+
# Wait and check for worker connection
|
|
172
|
+
with Client(sched_addr) as client:
|
|
173
|
+
assert wait_workers(client, n_gpus=1)
|
|
174
|
+
|
|
175
|
+
# Check if dask_setup was called and captured correctly
|
|
176
|
+
if capture_file_path.exists():
|
|
177
|
+
with open(capture_file_path, "r") as cf:
|
|
178
|
+
captured_args = json.load(cf)
|
|
179
|
+
|
|
180
|
+
# Verify the critical arguments were passed correctly
|
|
181
|
+
assert (
|
|
182
|
+
captured_args["create_cuda_context"] is True
|
|
183
|
+
), "create_cuda_context should be True"
|
|
184
|
+
|
|
185
|
+
# Verify worker has a protocol set
|
|
186
|
+
assert (
|
|
187
|
+
captured_args["worker_protocol"] == protocol
|
|
188
|
+
), "Worker should have a protocol"
|
|
189
|
+
else:
|
|
190
|
+
pytest.fail(
|
|
191
|
+
"capture file not found: dask_setup was not called or "
|
|
192
|
+
"failed to write to file"
|
|
193
|
+
)
|
dask_cuda/tests/test_dgx.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
1
4
|
import multiprocessing as mp
|
|
2
5
|
import os
|
|
3
6
|
from enum import Enum, auto
|
|
@@ -15,10 +18,6 @@ mp = mp.get_context("spawn") # type: ignore
|
|
|
15
18
|
psutil = pytest.importorskip("psutil")
|
|
16
19
|
|
|
17
20
|
|
|
18
|
-
def _is_ucx_116(ucp):
|
|
19
|
-
return ucp.get_ucx_version()[:2] == (1, 16)
|
|
20
|
-
|
|
21
|
-
|
|
22
21
|
class DGXVersion(Enum):
|
|
23
22
|
DGX_1 = auto()
|
|
24
23
|
DGX_2 = auto()
|
|
@@ -77,20 +76,17 @@ def test_default():
|
|
|
77
76
|
assert not p.exitcode
|
|
78
77
|
|
|
79
78
|
|
|
80
|
-
def _test_tcp_over_ucx(
|
|
81
|
-
|
|
82
|
-
ucp = pytest.importorskip("ucp")
|
|
83
|
-
elif protocol == "ucxx":
|
|
84
|
-
ucp = pytest.importorskip("ucxx")
|
|
79
|
+
def _test_tcp_over_ucx():
|
|
80
|
+
ucxx = pytest.importorskip("ucxx")
|
|
85
81
|
|
|
86
|
-
with LocalCUDACluster(protocol=
|
|
82
|
+
with LocalCUDACluster(protocol="ucx", enable_tcp_over_ucx=True) as cluster:
|
|
87
83
|
with Client(cluster) as client:
|
|
88
84
|
res = da.from_array(numpy.arange(10000), chunks=(1000,))
|
|
89
85
|
res = res.sum().compute()
|
|
90
86
|
assert res == 49995000
|
|
91
87
|
|
|
92
88
|
def check_ucx_options():
|
|
93
|
-
conf =
|
|
89
|
+
conf = ucxx.get_config()
|
|
94
90
|
assert "TLS" in conf
|
|
95
91
|
assert "tcp" in conf["TLS"]
|
|
96
92
|
assert "cuda_copy" in conf["TLS"]
|
|
@@ -100,19 +96,10 @@ def _test_tcp_over_ucx(protocol):
|
|
|
100
96
|
assert all(client.run(check_ucx_options).values())
|
|
101
97
|
|
|
102
98
|
|
|
103
|
-
|
|
104
|
-
"
|
|
105
|
-
|
|
106
|
-
)
|
|
107
|
-
def test_tcp_over_ucx(protocol):
|
|
108
|
-
if protocol == "ucx":
|
|
109
|
-
ucp = pytest.importorskip("ucp")
|
|
110
|
-
elif protocol == "ucxx":
|
|
111
|
-
ucp = pytest.importorskip("ucxx")
|
|
112
|
-
if _is_ucx_116(ucp):
|
|
113
|
-
pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
|
|
114
|
-
|
|
115
|
-
p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,))
|
|
99
|
+
def test_tcp_over_ucx():
|
|
100
|
+
pytest.importorskip("distributed_ucxx")
|
|
101
|
+
|
|
102
|
+
p = mp.Process(target=_test_tcp_over_ucx)
|
|
116
103
|
p.start()
|
|
117
104
|
p.join()
|
|
118
105
|
assert not p.exitcode
|
|
@@ -134,25 +121,22 @@ def test_tcp_only():
|
|
|
134
121
|
|
|
135
122
|
|
|
136
123
|
def _test_ucx_infiniband_nvlink(
|
|
137
|
-
skip_queue,
|
|
124
|
+
skip_queue, enable_infiniband, enable_nvlink, enable_rdmacm
|
|
138
125
|
):
|
|
126
|
+
ucxx = pytest.importorskip("ucxx")
|
|
139
127
|
cupy = pytest.importorskip("cupy")
|
|
140
|
-
if protocol == "ucx":
|
|
141
|
-
ucp = pytest.importorskip("ucp")
|
|
142
|
-
elif protocol == "ucxx":
|
|
143
|
-
ucp = pytest.importorskip("ucxx")
|
|
144
128
|
|
|
145
129
|
if enable_infiniband and not any(
|
|
146
|
-
[at.startswith("rc") for at in
|
|
130
|
+
[at.startswith("rc") for at in ucxx.get_active_transports()]
|
|
147
131
|
):
|
|
148
132
|
skip_queue.put("No support available for 'rc' transport in UCX")
|
|
149
133
|
return
|
|
150
134
|
else:
|
|
151
135
|
skip_queue.put("ok")
|
|
152
136
|
|
|
153
|
-
# `
|
|
137
|
+
# `ucxx.get_active_transports()` call above initializes UCX, we must reset it
|
|
154
138
|
# so that Dask doesn't try to initialize it again and raise an exception.
|
|
155
|
-
|
|
139
|
+
ucxx.reset()
|
|
156
140
|
|
|
157
141
|
if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
|
|
158
142
|
enable_tcp_over_ucx = None
|
|
@@ -168,7 +152,6 @@ def _test_ucx_infiniband_nvlink(
|
|
|
168
152
|
cm_tls_priority = ["tcp"]
|
|
169
153
|
|
|
170
154
|
initialize(
|
|
171
|
-
protocol=protocol,
|
|
172
155
|
enable_tcp_over_ucx=enable_tcp_over_ucx,
|
|
173
156
|
enable_infiniband=enable_infiniband,
|
|
174
157
|
enable_nvlink=enable_nvlink,
|
|
@@ -176,7 +159,7 @@ def _test_ucx_infiniband_nvlink(
|
|
|
176
159
|
)
|
|
177
160
|
|
|
178
161
|
with LocalCUDACluster(
|
|
179
|
-
protocol=
|
|
162
|
+
protocol="ucx",
|
|
180
163
|
interface="ib0",
|
|
181
164
|
enable_tcp_over_ucx=enable_tcp_over_ucx,
|
|
182
165
|
enable_infiniband=enable_infiniband,
|
|
@@ -190,7 +173,7 @@ def _test_ucx_infiniband_nvlink(
|
|
|
190
173
|
assert res == 49995000
|
|
191
174
|
|
|
192
175
|
def check_ucx_options():
|
|
193
|
-
conf =
|
|
176
|
+
conf = ucxx.get_config()
|
|
194
177
|
assert "TLS" in conf
|
|
195
178
|
assert all(t in conf["TLS"] for t in cm_tls)
|
|
196
179
|
assert all(p in conf["SOCKADDR_TLS_PRIORITY"] for p in cm_tls_priority)
|
|
@@ -206,7 +189,6 @@ def _test_ucx_infiniband_nvlink(
|
|
|
206
189
|
assert all(client.run(check_ucx_options).values())
|
|
207
190
|
|
|
208
191
|
|
|
209
|
-
@pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
|
|
210
192
|
@pytest.mark.parametrize(
|
|
211
193
|
"params",
|
|
212
194
|
[
|
|
@@ -221,13 +203,8 @@ def _test_ucx_infiniband_nvlink(
|
|
|
221
203
|
_get_dgx_version() == DGXVersion.DGX_A100,
|
|
222
204
|
reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
|
|
223
205
|
)
|
|
224
|
-
def test_ucx_infiniband_nvlink(
|
|
225
|
-
|
|
226
|
-
ucp = pytest.importorskip("ucp")
|
|
227
|
-
elif protocol == "ucxx":
|
|
228
|
-
ucp = pytest.importorskip("ucxx")
|
|
229
|
-
if _is_ucx_116(ucp) and params["enable_infiniband"] is False:
|
|
230
|
-
pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
|
|
206
|
+
def test_ucx_infiniband_nvlink(params):
|
|
207
|
+
pytest.importorskip("distributed_ucxx")
|
|
231
208
|
|
|
232
209
|
skip_queue = mp.Queue()
|
|
233
210
|
|
|
@@ -235,7 +212,6 @@ def test_ucx_infiniband_nvlink(protocol, params):
|
|
|
235
212
|
target=_test_ucx_infiniband_nvlink,
|
|
236
213
|
args=(
|
|
237
214
|
skip_queue,
|
|
238
|
-
protocol,
|
|
239
215
|
params["enable_infiniband"],
|
|
240
216
|
params["enable_nvlink"],
|
|
241
217
|
params["enable_rdmacm"],
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright (c) 2021-2025 NVIDIA CORPORATION.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
3
|
|
|
3
4
|
import asyncio
|
|
4
5
|
import multiprocessing as mp
|
|
@@ -28,7 +29,6 @@ from dask_cuda.explicit_comms.dataframe.shuffle import (
|
|
|
28
29
|
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
|
|
29
30
|
|
|
30
31
|
mp = mp.get_context("spawn") # type: ignore
|
|
31
|
-
ucp = pytest.importorskip("ucp")
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
# Notice, all of the following tests is executed in a new process such
|
|
@@ -53,8 +53,11 @@ def _test_local_cluster(protocol):
|
|
|
53
53
|
assert sum(c.run(my_rank, 0)) == sum(range(4))
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
@pytest.mark.parametrize("protocol", ["tcp", "ucx"
|
|
56
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
|
|
57
57
|
def test_local_cluster(protocol):
|
|
58
|
+
if protocol.startswith("ucx"):
|
|
59
|
+
pytest.importorskip("distributed_ucxx")
|
|
60
|
+
|
|
58
61
|
p = mp.Process(target=_test_local_cluster, args=(protocol,))
|
|
59
62
|
p.start()
|
|
60
63
|
p.join()
|
|
@@ -97,7 +100,7 @@ def test_dataframe_merge_empty_partitions():
|
|
|
97
100
|
|
|
98
101
|
|
|
99
102
|
def check_partitions(df, npartitions):
|
|
100
|
-
"""Check that all values in
|
|
103
|
+
"""Check that all values in ``df`` hashes to the same"""
|
|
101
104
|
dtypes = {}
|
|
102
105
|
for col, dtype in df.dtypes.items():
|
|
103
106
|
if pd.api.types.is_numeric_dtype(dtype):
|
|
@@ -199,11 +202,13 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
|
|
|
199
202
|
|
|
200
203
|
@pytest.mark.parametrize("nworkers", [1, 2, 3])
|
|
201
204
|
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
|
|
202
|
-
@pytest.mark.parametrize("protocol", ["tcp", "ucx"
|
|
205
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
|
|
203
206
|
@pytest.mark.parametrize("_partitions", [True, False])
|
|
204
207
|
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
|
|
205
208
|
if backend == "cudf":
|
|
206
209
|
pytest.importorskip("cudf")
|
|
210
|
+
if protocol.startswith("ucx"):
|
|
211
|
+
pytest.importorskip("distributed_ucxx")
|
|
207
212
|
|
|
208
213
|
p = mp.Process(
|
|
209
214
|
target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
|
|
@@ -320,10 +325,13 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
|
|
|
320
325
|
|
|
321
326
|
@pytest.mark.parametrize("nworkers", [1, 2, 4])
|
|
322
327
|
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
|
|
323
|
-
@pytest.mark.parametrize("protocol", ["tcp", "ucx"
|
|
328
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
|
|
324
329
|
def test_dataframe_shuffle_merge(backend, protocol, nworkers):
|
|
325
330
|
if backend == "cudf":
|
|
326
331
|
pytest.importorskip("cudf")
|
|
332
|
+
if protocol.startswith("ucx"):
|
|
333
|
+
pytest.importorskip("distributed_ucxx")
|
|
334
|
+
|
|
327
335
|
p = mp.Process(
|
|
328
336
|
target=_test_dataframe_shuffle_merge, args=(backend, protocol, nworkers)
|
|
329
337
|
)
|
|
@@ -357,9 +365,14 @@ def _test_jit_unspill(protocol):
|
|
|
357
365
|
assert_eq(got, expected)
|
|
358
366
|
|
|
359
367
|
|
|
360
|
-
@pytest.mark.parametrize("protocol", ["tcp", "ucx"
|
|
368
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
|
|
369
|
+
@pytest.mark.skip_if_no_device_memory(
|
|
370
|
+
"JIT-Unspill not supported in devices without dedicated memory resource"
|
|
371
|
+
)
|
|
361
372
|
def test_jit_unspill(protocol):
|
|
362
373
|
pytest.importorskip("cudf")
|
|
374
|
+
if protocol.startswith("ucx"):
|
|
375
|
+
pytest.importorskip("distributed_ucxx")
|
|
363
376
|
|
|
364
377
|
p = mp.Process(target=_test_jit_unspill, args=(protocol,))
|
|
365
378
|
p.start()
|
|
@@ -384,7 +397,7 @@ def _test_lock_workers(scheduler_address, ranks):
|
|
|
384
397
|
|
|
385
398
|
def test_lock_workers():
|
|
386
399
|
"""
|
|
387
|
-
Testing
|
|
400
|
+
Testing ``run(...,lock_workers=True)`` by spawning 30 runs with overlapping
|
|
388
401
|
and non-overlapping worker sets.
|
|
389
402
|
"""
|
|
390
403
|
try:
|
|
@@ -423,7 +436,9 @@ def test_create_destroy_create():
|
|
|
423
436
|
with LocalCluster(n_workers=1) as cluster:
|
|
424
437
|
with Client(cluster) as client:
|
|
425
438
|
context = comms.default_comms()
|
|
426
|
-
scheduler_addresses_old = list(
|
|
439
|
+
scheduler_addresses_old = list(
|
|
440
|
+
client.scheduler_info(n_workers=-1)["workers"].keys()
|
|
441
|
+
)
|
|
427
442
|
comms_addresses_old = list(comms.default_comms().worker_addresses)
|
|
428
443
|
assert comms.default_comms() is context
|
|
429
444
|
assert len(comms._comms_cache) == 1
|
|
@@ -444,7 +459,9 @@ def test_create_destroy_create():
|
|
|
444
459
|
# because we referenced the old cluster's addresses.
|
|
445
460
|
with LocalCluster(n_workers=1) as cluster:
|
|
446
461
|
with Client(cluster) as client:
|
|
447
|
-
scheduler_addresses_new = list(
|
|
462
|
+
scheduler_addresses_new = list(
|
|
463
|
+
client.scheduler_info(n_workers=-1)["workers"].keys()
|
|
464
|
+
)
|
|
448
465
|
comms_addresses_new = list(comms.default_comms().worker_addresses)
|
|
449
466
|
|
|
450
467
|
assert scheduler_addresses_new == comms_addresses_new
|
|
@@ -485,7 +502,8 @@ def test_scaled_cluster_gets_new_comms_context():
|
|
|
485
502
|
"n_workers": 2,
|
|
486
503
|
}
|
|
487
504
|
expected_1 = {
|
|
488
|
-
k: expected_values
|
|
505
|
+
k: expected_values
|
|
506
|
+
for k in client.scheduler_info(n_workers=-1)["workers"]
|
|
489
507
|
}
|
|
490
508
|
assert result_1 == expected_1
|
|
491
509
|
|
|
@@ -513,7 +531,8 @@ def test_scaled_cluster_gets_new_comms_context():
|
|
|
513
531
|
"n_workers": 3,
|
|
514
532
|
}
|
|
515
533
|
expected_2 = {
|
|
516
|
-
k: expected_values
|
|
534
|
+
k: expected_values
|
|
535
|
+
for k in client.scheduler_info(n_workers=-1)["workers"]
|
|
517
536
|
}
|
|
518
537
|
assert result_2 == expected_2
|
|
519
538
|
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
1
4
|
import pytest
|
|
2
5
|
|
|
3
6
|
import dask.array as da
|
|
@@ -8,13 +11,8 @@ from dask_cuda import LocalCUDACluster
|
|
|
8
11
|
cupy = pytest.importorskip("cupy")
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
@pytest.mark.parametrize("protocol", ["ucx", "
|
|
14
|
+
@pytest.mark.parametrize("protocol", ["ucx", "tcp"])
|
|
12
15
|
def test_ucx_from_array(protocol):
|
|
13
|
-
if protocol == "ucx":
|
|
14
|
-
pytest.importorskip("ucp")
|
|
15
|
-
elif protocol == "ucxx":
|
|
16
|
-
pytest.importorskip("ucxx")
|
|
17
|
-
|
|
18
16
|
N = 10_000
|
|
19
17
|
with LocalCUDACluster(protocol=protocol) as cluster:
|
|
20
18
|
with Client(cluster):
|