dask-cuda 25.8.0__py3-none-any.whl → 25.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_cuda/GIT_COMMIT +1 -1
- dask_cuda/VERSION +1 -1
- dask_cuda/benchmarks/local_cudf_groupby.py +1 -1
- dask_cuda/benchmarks/local_cudf_merge.py +1 -1
- dask_cuda/benchmarks/local_cudf_shuffle.py +1 -1
- dask_cuda/benchmarks/local_cupy.py +1 -1
- dask_cuda/benchmarks/local_cupy_map_overlap.py +1 -1
- dask_cuda/benchmarks/utils.py +1 -1
- dask_cuda/cuda_worker.py +1 -1
- dask_cuda/get_device_memory_objects.py +1 -4
- dask_cuda/initialize.py +140 -121
- dask_cuda/local_cuda_cluster.py +10 -25
- dask_cuda/tests/test_cudf_builtin_spilling.py +3 -1
- dask_cuda/tests/test_dask_setup.py +193 -0
- dask_cuda/tests/test_dgx.py +16 -32
- dask_cuda/tests/test_explicit_comms.py +11 -10
- dask_cuda/tests/test_from_array.py +1 -5
- dask_cuda/tests/test_initialize.py +230 -41
- dask_cuda/tests/test_local_cuda_cluster.py +16 -62
- dask_cuda/tests/test_proxify_host_file.py +9 -4
- dask_cuda/tests/test_proxy.py +8 -8
- dask_cuda/tests/test_spill.py +3 -3
- dask_cuda/tests/test_utils.py +8 -23
- dask_cuda/tests/test_worker_spec.py +5 -2
- dask_cuda/utils.py +12 -66
- dask_cuda/utils_test.py +0 -13
- dask_cuda/worker_spec.py +7 -9
- {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/METADATA +11 -4
- dask_cuda-25.10.0.dist-info/RECORD +63 -0
- shared-actions/check_nightly_success/check-nightly-success/check.py +1 -1
- dask_cuda/tests/test_rdd_ucx.py +0 -160
- dask_cuda-25.8.0.dist-info/RECORD +0 -63
- {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/WHEEL +0 -0
- {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/entry_points.txt +0 -0
- {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/licenses/LICENSE +0 -0
- {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from unittest.mock import Mock, patch
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
|
|
12
|
+
from distributed import Client
|
|
13
|
+
from distributed.utils import open_port
|
|
14
|
+
from distributed.utils_test import popen
|
|
15
|
+
|
|
16
|
+
from dask_cuda.initialize import dask_setup
|
|
17
|
+
from dask_cuda.utils import wait_workers
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_dask_setup_function_with_mock_worker():
|
|
21
|
+
"""Test the dask_setup function directly with mock worker."""
|
|
22
|
+
# Create a mock worker object
|
|
23
|
+
mock_worker = Mock()
|
|
24
|
+
|
|
25
|
+
with patch("dask_cuda.initialize._create_cuda_context") as mock_create_context:
|
|
26
|
+
# Test with create_cuda_context=True
|
|
27
|
+
# Call the underlying function directly (the Click decorator wraps the real
|
|
28
|
+
# function)
|
|
29
|
+
dask_setup.callback(
|
|
30
|
+
worker=mock_worker,
|
|
31
|
+
create_cuda_context=True,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
mock_create_context.assert_called_once_with()
|
|
35
|
+
|
|
36
|
+
mock_create_context.reset_mock()
|
|
37
|
+
|
|
38
|
+
# Test with create_cuda_context=False
|
|
39
|
+
dask_setup.callback(
|
|
40
|
+
worker=mock_worker,
|
|
41
|
+
create_cuda_context=False,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
mock_create_context.assert_not_called()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@contextmanager
|
|
48
|
+
def start_dask_scheduler(protocol: str, max_attempts: int = 5, timeout: int = 10):
|
|
49
|
+
"""Start Dask scheduler in subprocess.
|
|
50
|
+
|
|
51
|
+
Attempts to start a Dask scheduler in subprocess, if the port is not available
|
|
52
|
+
retry on a different port up to a maximum of `max_attempts` attempts. The stdout
|
|
53
|
+
and stderr of the process is read to determine whether the scheduler failed to
|
|
54
|
+
bind to port or succeeded, and ensures no more than `timeout` seconds are awaited
|
|
55
|
+
for between reads.
|
|
56
|
+
|
|
57
|
+
This is primarily useful because UCX does not release TCP ports immediately. A
|
|
58
|
+
workaround without the need for this function is setting `UCX_TCP_CM_REUSEADDR=y`,
|
|
59
|
+
but that requires to be explicitly set when running tests, and that is not very
|
|
60
|
+
friendly.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
protocol: str
|
|
65
|
+
Communication protocol to use.
|
|
66
|
+
max_attempts: int
|
|
67
|
+
Maximum attempts to try to open scheduler.
|
|
68
|
+
timeout: int
|
|
69
|
+
Time to wait while reading stdout/stderr of subprocess.
|
|
70
|
+
"""
|
|
71
|
+
port = open_port()
|
|
72
|
+
for _ in range(max_attempts):
|
|
73
|
+
with popen(
|
|
74
|
+
[
|
|
75
|
+
"dask",
|
|
76
|
+
"scheduler",
|
|
77
|
+
"--no-dashboard",
|
|
78
|
+
"--protocol",
|
|
79
|
+
protocol,
|
|
80
|
+
"--port",
|
|
81
|
+
str(port),
|
|
82
|
+
],
|
|
83
|
+
capture_output=True, # Capture stdout and stderr
|
|
84
|
+
) as scheduler_process:
|
|
85
|
+
# Check if the scheduler process started successfully by streaming output
|
|
86
|
+
try:
|
|
87
|
+
start_time = time.monotonic()
|
|
88
|
+
while True:
|
|
89
|
+
if time.monotonic() - start_time > timeout:
|
|
90
|
+
raise TimeoutError("Timeout while waiting for scheduler output")
|
|
91
|
+
|
|
92
|
+
line = scheduler_process.stdout.readline()
|
|
93
|
+
if not line:
|
|
94
|
+
break # End of output
|
|
95
|
+
print(
|
|
96
|
+
line.decode(), end=""
|
|
97
|
+
) # Since capture_output=True, print the line here
|
|
98
|
+
if b"Scheduler at:" in line:
|
|
99
|
+
# Scheduler is now listening
|
|
100
|
+
break
|
|
101
|
+
elif b"UCXXBusyError" in line:
|
|
102
|
+
raise Exception("UCXXBusyError detected in scheduler output")
|
|
103
|
+
except Exception:
|
|
104
|
+
port += 1
|
|
105
|
+
else:
|
|
106
|
+
yield scheduler_process, port
|
|
107
|
+
return
|
|
108
|
+
else:
|
|
109
|
+
pytest.fail(f"Failed to start dask scheduler after {max_attempts} attempts.")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@pytest.mark.timeout(30)
|
|
113
|
+
@patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
|
|
114
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
|
|
115
|
+
def test_dask_cuda_worker_cli_integration(protocol, tmp_path):
|
|
116
|
+
"""Test that dask cuda worker CLI correctly passes arguments to dask_setup.
|
|
117
|
+
|
|
118
|
+
Verifies the end-to-end integration where the CLI tool actually launches and calls
|
|
119
|
+
dask_setup with correct args.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
# Use pytest's tmp_path for file management
|
|
123
|
+
capture_file_path = tmp_path / "dask_setup_integration_test.json"
|
|
124
|
+
preload_file = tmp_path / "preload_capture.py"
|
|
125
|
+
|
|
126
|
+
# Write the preload script to tmp_path
|
|
127
|
+
preload_file.write_text(
|
|
128
|
+
f'''
|
|
129
|
+
import json
|
|
130
|
+
import os
|
|
131
|
+
|
|
132
|
+
def capture_dask_setup_call(worker, create_cuda_context):
|
|
133
|
+
"""Capture dask_setup arguments and write to file."""
|
|
134
|
+
result = {{
|
|
135
|
+
'worker_protocol': getattr(worker, '_protocol', 'unknown'),
|
|
136
|
+
'create_cuda_context': create_cuda_context,
|
|
137
|
+
'test_success': True
|
|
138
|
+
}}
|
|
139
|
+
|
|
140
|
+
# Write immediately to ensure it gets captured
|
|
141
|
+
with open(r"{capture_file_path}", 'w') as f:
|
|
142
|
+
json.dump(result, f)
|
|
143
|
+
|
|
144
|
+
# Patch dask_setup callback
|
|
145
|
+
from dask_cuda.initialize import dask_setup
|
|
146
|
+
dask_setup.callback = capture_dask_setup_call
|
|
147
|
+
'''
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
with start_dask_scheduler(protocol=protocol) as scheduler_process_port:
|
|
151
|
+
scheduler_process, scheduler_port = scheduler_process_port
|
|
152
|
+
sched_addr = f"{protocol}://127.0.0.1:{scheduler_port}"
|
|
153
|
+
print(f"{sched_addr=}", flush=True)
|
|
154
|
+
|
|
155
|
+
# Build dask cuda worker args
|
|
156
|
+
dask_cuda_worker_args = [
|
|
157
|
+
"dask",
|
|
158
|
+
"cuda",
|
|
159
|
+
"worker",
|
|
160
|
+
sched_addr,
|
|
161
|
+
"--host",
|
|
162
|
+
"127.0.0.1",
|
|
163
|
+
"--no-dashboard",
|
|
164
|
+
"--preload",
|
|
165
|
+
str(preload_file),
|
|
166
|
+
"--death-timeout",
|
|
167
|
+
"10",
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
with popen(dask_cuda_worker_args):
|
|
171
|
+
# Wait and check for worker connection
|
|
172
|
+
with Client(sched_addr) as client:
|
|
173
|
+
assert wait_workers(client, n_gpus=1)
|
|
174
|
+
|
|
175
|
+
# Check if dask_setup was called and captured correctly
|
|
176
|
+
if capture_file_path.exists():
|
|
177
|
+
with open(capture_file_path, "r") as cf:
|
|
178
|
+
captured_args = json.load(cf)
|
|
179
|
+
|
|
180
|
+
# Verify the critical arguments were passed correctly
|
|
181
|
+
assert (
|
|
182
|
+
captured_args["create_cuda_context"] is True
|
|
183
|
+
), "create_cuda_context should be True"
|
|
184
|
+
|
|
185
|
+
# Verify worker has a protocol set
|
|
186
|
+
assert (
|
|
187
|
+
captured_args["worker_protocol"] == protocol
|
|
188
|
+
), "Worker should have a protocol"
|
|
189
|
+
else:
|
|
190
|
+
pytest.fail(
|
|
191
|
+
"capture file not found: dask_setup was not called or "
|
|
192
|
+
"failed to write to file"
|
|
193
|
+
)
|
dask_cuda/tests/test_dgx.py
CHANGED
|
@@ -13,16 +13,11 @@ from distributed import Client
|
|
|
13
13
|
|
|
14
14
|
from dask_cuda import LocalCUDACluster
|
|
15
15
|
from dask_cuda.initialize import initialize
|
|
16
|
-
from dask_cuda.utils_test import get_ucx_implementation
|
|
17
16
|
|
|
18
17
|
mp = mp.get_context("spawn") # type: ignore
|
|
19
18
|
psutil = pytest.importorskip("psutil")
|
|
20
19
|
|
|
21
20
|
|
|
22
|
-
def _is_ucx_116(ucp):
|
|
23
|
-
return ucp.get_ucx_version()[:2] == (1, 16)
|
|
24
|
-
|
|
25
|
-
|
|
26
21
|
class DGXVersion(Enum):
|
|
27
22
|
DGX_1 = auto()
|
|
28
23
|
DGX_2 = auto()
|
|
@@ -81,17 +76,17 @@ def test_default():
|
|
|
81
76
|
assert not p.exitcode
|
|
82
77
|
|
|
83
78
|
|
|
84
|
-
def _test_tcp_over_ucx(
|
|
85
|
-
|
|
79
|
+
def _test_tcp_over_ucx():
|
|
80
|
+
ucxx = pytest.importorskip("ucxx")
|
|
86
81
|
|
|
87
|
-
with LocalCUDACluster(protocol=
|
|
82
|
+
with LocalCUDACluster(protocol="ucx", enable_tcp_over_ucx=True) as cluster:
|
|
88
83
|
with Client(cluster) as client:
|
|
89
84
|
res = da.from_array(numpy.arange(10000), chunks=(1000,))
|
|
90
85
|
res = res.sum().compute()
|
|
91
86
|
assert res == 49995000
|
|
92
87
|
|
|
93
88
|
def check_ucx_options():
|
|
94
|
-
conf =
|
|
89
|
+
conf = ucxx.get_config()
|
|
95
90
|
assert "TLS" in conf
|
|
96
91
|
assert "tcp" in conf["TLS"]
|
|
97
92
|
assert "cuda_copy" in conf["TLS"]
|
|
@@ -101,16 +96,10 @@ def _test_tcp_over_ucx(protocol):
|
|
|
101
96
|
assert all(client.run(check_ucx_options).values())
|
|
102
97
|
|
|
103
98
|
|
|
104
|
-
|
|
105
|
-
"
|
|
106
|
-
["ucx", "ucx-old"],
|
|
107
|
-
)
|
|
108
|
-
def test_tcp_over_ucx(protocol):
|
|
109
|
-
ucp = get_ucx_implementation(protocol)
|
|
110
|
-
if _is_ucx_116(ucp):
|
|
111
|
-
pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
|
|
99
|
+
def test_tcp_over_ucx():
|
|
100
|
+
pytest.importorskip("distributed_ucxx")
|
|
112
101
|
|
|
113
|
-
p = mp.Process(target=_test_tcp_over_ucx
|
|
102
|
+
p = mp.Process(target=_test_tcp_over_ucx)
|
|
114
103
|
p.start()
|
|
115
104
|
p.join()
|
|
116
105
|
assert not p.exitcode
|
|
@@ -132,22 +121,22 @@ def test_tcp_only():
|
|
|
132
121
|
|
|
133
122
|
|
|
134
123
|
def _test_ucx_infiniband_nvlink(
|
|
135
|
-
skip_queue,
|
|
124
|
+
skip_queue, enable_infiniband, enable_nvlink, enable_rdmacm
|
|
136
125
|
):
|
|
126
|
+
ucxx = pytest.importorskip("ucxx")
|
|
137
127
|
cupy = pytest.importorskip("cupy")
|
|
138
|
-
ucp = get_ucx_implementation(protocol)
|
|
139
128
|
|
|
140
129
|
if enable_infiniband and not any(
|
|
141
|
-
[at.startswith("rc") for at in
|
|
130
|
+
[at.startswith("rc") for at in ucxx.get_active_transports()]
|
|
142
131
|
):
|
|
143
132
|
skip_queue.put("No support available for 'rc' transport in UCX")
|
|
144
133
|
return
|
|
145
134
|
else:
|
|
146
135
|
skip_queue.put("ok")
|
|
147
136
|
|
|
148
|
-
# `
|
|
137
|
+
# `ucxx.get_active_transports()` call above initializes UCX, we must reset it
|
|
149
138
|
# so that Dask doesn't try to initialize it again and raise an exception.
|
|
150
|
-
|
|
139
|
+
ucxx.reset()
|
|
151
140
|
|
|
152
141
|
if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
|
|
153
142
|
enable_tcp_over_ucx = None
|
|
@@ -163,7 +152,6 @@ def _test_ucx_infiniband_nvlink(
|
|
|
163
152
|
cm_tls_priority = ["tcp"]
|
|
164
153
|
|
|
165
154
|
initialize(
|
|
166
|
-
protocol=protocol,
|
|
167
155
|
enable_tcp_over_ucx=enable_tcp_over_ucx,
|
|
168
156
|
enable_infiniband=enable_infiniband,
|
|
169
157
|
enable_nvlink=enable_nvlink,
|
|
@@ -171,7 +159,7 @@ def _test_ucx_infiniband_nvlink(
|
|
|
171
159
|
)
|
|
172
160
|
|
|
173
161
|
with LocalCUDACluster(
|
|
174
|
-
protocol=
|
|
162
|
+
protocol="ucx",
|
|
175
163
|
interface="ib0",
|
|
176
164
|
enable_tcp_over_ucx=enable_tcp_over_ucx,
|
|
177
165
|
enable_infiniband=enable_infiniband,
|
|
@@ -185,7 +173,7 @@ def _test_ucx_infiniband_nvlink(
|
|
|
185
173
|
assert res == 49995000
|
|
186
174
|
|
|
187
175
|
def check_ucx_options():
|
|
188
|
-
conf =
|
|
176
|
+
conf = ucxx.get_config()
|
|
189
177
|
assert "TLS" in conf
|
|
190
178
|
assert all(t in conf["TLS"] for t in cm_tls)
|
|
191
179
|
assert all(p in conf["SOCKADDR_TLS_PRIORITY"] for p in cm_tls_priority)
|
|
@@ -201,7 +189,6 @@ def _test_ucx_infiniband_nvlink(
|
|
|
201
189
|
assert all(client.run(check_ucx_options).values())
|
|
202
190
|
|
|
203
191
|
|
|
204
|
-
@pytest.mark.parametrize("protocol", ["ucx", "ucx-old"])
|
|
205
192
|
@pytest.mark.parametrize(
|
|
206
193
|
"params",
|
|
207
194
|
[
|
|
@@ -216,10 +203,8 @@ def _test_ucx_infiniband_nvlink(
|
|
|
216
203
|
_get_dgx_version() == DGXVersion.DGX_A100,
|
|
217
204
|
reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
|
|
218
205
|
)
|
|
219
|
-
def test_ucx_infiniband_nvlink(
|
|
220
|
-
|
|
221
|
-
if _is_ucx_116(ucp) and params["enable_infiniband"] is False:
|
|
222
|
-
pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
|
|
206
|
+
def test_ucx_infiniband_nvlink(params):
|
|
207
|
+
pytest.importorskip("distributed_ucxx")
|
|
223
208
|
|
|
224
209
|
skip_queue = mp.Queue()
|
|
225
210
|
|
|
@@ -227,7 +212,6 @@ def test_ucx_infiniband_nvlink(protocol, params):
|
|
|
227
212
|
target=_test_ucx_infiniband_nvlink,
|
|
228
213
|
args=(
|
|
229
214
|
skip_queue,
|
|
230
|
-
protocol,
|
|
231
215
|
params["enable_infiniband"],
|
|
232
216
|
params["enable_nvlink"],
|
|
233
217
|
params["enable_rdmacm"],
|
|
@@ -26,10 +26,9 @@ from dask_cuda.explicit_comms.dataframe.shuffle import (
|
|
|
26
26
|
_contains_shuffle_expr,
|
|
27
27
|
shuffle as explicit_comms_shuffle,
|
|
28
28
|
)
|
|
29
|
-
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
|
|
29
|
+
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
|
|
30
30
|
|
|
31
31
|
mp = mp.get_context("spawn") # type: ignore
|
|
32
|
-
ucp = pytest.importorskip("ucp")
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
# Notice, all of the following tests is executed in a new process such
|
|
@@ -54,10 +53,11 @@ def _test_local_cluster(protocol):
|
|
|
54
53
|
assert sum(c.run(my_rank, 0)) == sum(range(4))
|
|
55
54
|
|
|
56
55
|
|
|
57
|
-
@pytest.mark.parametrize("protocol", ["tcp", "ucx"
|
|
56
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
|
|
58
57
|
def test_local_cluster(protocol):
|
|
59
58
|
if protocol.startswith("ucx"):
|
|
60
|
-
|
|
59
|
+
pytest.importorskip("distributed_ucxx")
|
|
60
|
+
|
|
61
61
|
p = mp.Process(target=_test_local_cluster, args=(protocol,))
|
|
62
62
|
p.start()
|
|
63
63
|
p.join()
|
|
@@ -202,13 +202,13 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
|
|
|
202
202
|
|
|
203
203
|
@pytest.mark.parametrize("nworkers", [1, 2, 3])
|
|
204
204
|
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
|
|
205
|
-
@pytest.mark.parametrize("protocol", ["tcp", "ucx"
|
|
205
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
|
|
206
206
|
@pytest.mark.parametrize("_partitions", [True, False])
|
|
207
207
|
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
|
|
208
208
|
if backend == "cudf":
|
|
209
209
|
pytest.importorskip("cudf")
|
|
210
210
|
if protocol.startswith("ucx"):
|
|
211
|
-
|
|
211
|
+
pytest.importorskip("distributed_ucxx")
|
|
212
212
|
|
|
213
213
|
p = mp.Process(
|
|
214
214
|
target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
|
|
@@ -325,12 +325,13 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
|
|
|
325
325
|
|
|
326
326
|
@pytest.mark.parametrize("nworkers", [1, 2, 4])
|
|
327
327
|
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
|
|
328
|
-
@pytest.mark.parametrize("protocol", ["tcp", "ucx"
|
|
328
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
|
|
329
329
|
def test_dataframe_shuffle_merge(backend, protocol, nworkers):
|
|
330
330
|
if backend == "cudf":
|
|
331
331
|
pytest.importorskip("cudf")
|
|
332
332
|
if protocol.startswith("ucx"):
|
|
333
|
-
|
|
333
|
+
pytest.importorskip("distributed_ucxx")
|
|
334
|
+
|
|
334
335
|
p = mp.Process(
|
|
335
336
|
target=_test_dataframe_shuffle_merge, args=(backend, protocol, nworkers)
|
|
336
337
|
)
|
|
@@ -364,14 +365,14 @@ def _test_jit_unspill(protocol):
|
|
|
364
365
|
assert_eq(got, expected)
|
|
365
366
|
|
|
366
367
|
|
|
367
|
-
@pytest.mark.parametrize("protocol", ["tcp", "ucx"
|
|
368
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
|
|
368
369
|
@pytest.mark.skip_if_no_device_memory(
|
|
369
370
|
"JIT-Unspill not supported in devices without dedicated memory resource"
|
|
370
371
|
)
|
|
371
372
|
def test_jit_unspill(protocol):
|
|
372
373
|
pytest.importorskip("cudf")
|
|
373
374
|
if protocol.startswith("ucx"):
|
|
374
|
-
|
|
375
|
+
pytest.importorskip("distributed_ucxx")
|
|
375
376
|
|
|
376
377
|
p = mp.Process(target=_test_jit_unspill, args=(protocol,))
|
|
377
378
|
p.start()
|
|
@@ -7,16 +7,12 @@ import dask.array as da
|
|
|
7
7
|
from distributed import Client
|
|
8
8
|
|
|
9
9
|
from dask_cuda import LocalCUDACluster
|
|
10
|
-
from dask_cuda.utils_test import get_ucx_implementation
|
|
11
10
|
|
|
12
11
|
cupy = pytest.importorskip("cupy")
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
@pytest.mark.parametrize("protocol", ["ucx", "
|
|
14
|
+
@pytest.mark.parametrize("protocol", ["ucx", "tcp"])
|
|
16
15
|
def test_ucx_from_array(protocol):
|
|
17
|
-
if protocol.startswith("ucx"):
|
|
18
|
-
get_ucx_implementation(protocol)
|
|
19
|
-
|
|
20
16
|
N = 10_000
|
|
21
17
|
with LocalCUDACluster(protocol=protocol) as cluster:
|
|
22
18
|
with Client(cluster):
|