dask-cuda 25.8.0__py3-none-any.whl → 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/benchmarks/local_cudf_groupby.py +1 -1
  4. dask_cuda/benchmarks/local_cudf_merge.py +1 -1
  5. dask_cuda/benchmarks/local_cudf_shuffle.py +1 -1
  6. dask_cuda/benchmarks/local_cupy.py +1 -1
  7. dask_cuda/benchmarks/local_cupy_map_overlap.py +1 -1
  8. dask_cuda/benchmarks/utils.py +1 -1
  9. dask_cuda/cuda_worker.py +1 -1
  10. dask_cuda/get_device_memory_objects.py +1 -4
  11. dask_cuda/initialize.py +140 -121
  12. dask_cuda/local_cuda_cluster.py +10 -25
  13. dask_cuda/tests/test_cudf_builtin_spilling.py +3 -1
  14. dask_cuda/tests/test_dask_setup.py +193 -0
  15. dask_cuda/tests/test_dgx.py +16 -32
  16. dask_cuda/tests/test_explicit_comms.py +11 -10
  17. dask_cuda/tests/test_from_array.py +1 -5
  18. dask_cuda/tests/test_initialize.py +230 -41
  19. dask_cuda/tests/test_local_cuda_cluster.py +16 -62
  20. dask_cuda/tests/test_proxify_host_file.py +9 -4
  21. dask_cuda/tests/test_proxy.py +8 -8
  22. dask_cuda/tests/test_spill.py +3 -3
  23. dask_cuda/tests/test_utils.py +8 -23
  24. dask_cuda/tests/test_worker_spec.py +5 -2
  25. dask_cuda/utils.py +12 -66
  26. dask_cuda/utils_test.py +0 -13
  27. dask_cuda/worker_spec.py +7 -9
  28. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/METADATA +11 -4
  29. dask_cuda-25.10.0.dist-info/RECORD +63 -0
  30. shared-actions/check_nightly_success/check-nightly-success/check.py +1 -1
  31. dask_cuda/tests/test_rdd_ucx.py +0 -160
  32. dask_cuda-25.8.0.dist-info/RECORD +0 -63
  33. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/WHEEL +0 -0
  34. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/entry_points.txt +0 -0
  35. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/licenses/LICENSE +0 -0
  36. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import json
5
+ import os
6
+ import time
7
+ from contextlib import contextmanager
8
+ from unittest.mock import Mock, patch
9
+
10
+ import pytest
11
+
12
+ from distributed import Client
13
+ from distributed.utils import open_port
14
+ from distributed.utils_test import popen
15
+
16
+ from dask_cuda.initialize import dask_setup
17
+ from dask_cuda.utils import wait_workers
18
+
19
+
20
+ def test_dask_setup_function_with_mock_worker():
21
+ """Test the dask_setup function directly with mock worker."""
22
+ # Create a mock worker object
23
+ mock_worker = Mock()
24
+
25
+ with patch("dask_cuda.initialize._create_cuda_context") as mock_create_context:
26
+ # Test with create_cuda_context=True
27
+ # Call the underlying function directly (the Click decorator wraps the real
28
+ # function)
29
+ dask_setup.callback(
30
+ worker=mock_worker,
31
+ create_cuda_context=True,
32
+ )
33
+
34
+ mock_create_context.assert_called_once_with()
35
+
36
+ mock_create_context.reset_mock()
37
+
38
+ # Test with create_cuda_context=False
39
+ dask_setup.callback(
40
+ worker=mock_worker,
41
+ create_cuda_context=False,
42
+ )
43
+
44
+ mock_create_context.assert_not_called()
45
+
46
+
47
+ @contextmanager
48
+ def start_dask_scheduler(protocol: str, max_attempts: int = 5, timeout: int = 10):
49
+ """Start Dask scheduler in subprocess.
50
+
51
+ Attempts to start a Dask scheduler in subprocess, if the port is not available
52
+ retry on a different port up to a maximum of `max_attempts` attempts. The stdout
53
+ and stderr of the process is read to determine whether the scheduler failed to
54
+ bind to port or succeeded, and ensures no more than `timeout` seconds are awaited
55
+ for between reads.
56
+
57
+ This is primarily useful because UCX does not release TCP ports immediately. A
58
+ workaround without the need for this function is setting `UCX_TCP_CM_REUSEADDR=y`,
59
+ but that requires to be explicitly set when running tests, and that is not very
60
+ friendly.
61
+
62
+ Parameters
63
+ ----------
64
+ protocol: str
65
+ Communication protocol to use.
66
+ max_attempts: int
67
+ Maximum attempts to try to open scheduler.
68
+ timeout: int
69
+ Time to wait while reading stdout/stderr of subprocess.
70
+ """
71
+ port = open_port()
72
+ for _ in range(max_attempts):
73
+ with popen(
74
+ [
75
+ "dask",
76
+ "scheduler",
77
+ "--no-dashboard",
78
+ "--protocol",
79
+ protocol,
80
+ "--port",
81
+ str(port),
82
+ ],
83
+ capture_output=True, # Capture stdout and stderr
84
+ ) as scheduler_process:
85
+ # Check if the scheduler process started successfully by streaming output
86
+ try:
87
+ start_time = time.monotonic()
88
+ while True:
89
+ if time.monotonic() - start_time > timeout:
90
+ raise TimeoutError("Timeout while waiting for scheduler output")
91
+
92
+ line = scheduler_process.stdout.readline()
93
+ if not line:
94
+ break # End of output
95
+ print(
96
+ line.decode(), end=""
97
+ ) # Since capture_output=True, print the line here
98
+ if b"Scheduler at:" in line:
99
+ # Scheduler is now listening
100
+ break
101
+ elif b"UCXXBusyError" in line:
102
+ raise Exception("UCXXBusyError detected in scheduler output")
103
+ except Exception:
104
+ port += 1
105
+ else:
106
+ yield scheduler_process, port
107
+ return
108
+ else:
109
+ pytest.fail(f"Failed to start dask scheduler after {max_attempts} attempts.")
110
+
111
+
112
+ @pytest.mark.timeout(30)
113
+ @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
114
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
115
+ def test_dask_cuda_worker_cli_integration(protocol, tmp_path):
116
+ """Test that dask cuda worker CLI correctly passes arguments to dask_setup.
117
+
118
+ Verifies the end-to-end integration where the CLI tool actually launches and calls
119
+ dask_setup with correct args.
120
+ """
121
+
122
+ # Use pytest's tmp_path for file management
123
+ capture_file_path = tmp_path / "dask_setup_integration_test.json"
124
+ preload_file = tmp_path / "preload_capture.py"
125
+
126
+ # Write the preload script to tmp_path
127
+ preload_file.write_text(
128
+ f'''
129
+ import json
130
+ import os
131
+
132
+ def capture_dask_setup_call(worker, create_cuda_context):
133
+ """Capture dask_setup arguments and write to file."""
134
+ result = {{
135
+ 'worker_protocol': getattr(worker, '_protocol', 'unknown'),
136
+ 'create_cuda_context': create_cuda_context,
137
+ 'test_success': True
138
+ }}
139
+
140
+ # Write immediately to ensure it gets captured
141
+ with open(r"{capture_file_path}", 'w') as f:
142
+ json.dump(result, f)
143
+
144
+ # Patch dask_setup callback
145
+ from dask_cuda.initialize import dask_setup
146
+ dask_setup.callback = capture_dask_setup_call
147
+ '''
148
+ )
149
+
150
+ with start_dask_scheduler(protocol=protocol) as scheduler_process_port:
151
+ scheduler_process, scheduler_port = scheduler_process_port
152
+ sched_addr = f"{protocol}://127.0.0.1:{scheduler_port}"
153
+ print(f"{sched_addr=}", flush=True)
154
+
155
+ # Build dask cuda worker args
156
+ dask_cuda_worker_args = [
157
+ "dask",
158
+ "cuda",
159
+ "worker",
160
+ sched_addr,
161
+ "--host",
162
+ "127.0.0.1",
163
+ "--no-dashboard",
164
+ "--preload",
165
+ str(preload_file),
166
+ "--death-timeout",
167
+ "10",
168
+ ]
169
+
170
+ with popen(dask_cuda_worker_args):
171
+ # Wait and check for worker connection
172
+ with Client(sched_addr) as client:
173
+ assert wait_workers(client, n_gpus=1)
174
+
175
+ # Check if dask_setup was called and captured correctly
176
+ if capture_file_path.exists():
177
+ with open(capture_file_path, "r") as cf:
178
+ captured_args = json.load(cf)
179
+
180
+ # Verify the critical arguments were passed correctly
181
+ assert (
182
+ captured_args["create_cuda_context"] is True
183
+ ), "create_cuda_context should be True"
184
+
185
+ # Verify worker has a protocol set
186
+ assert (
187
+ captured_args["worker_protocol"] == protocol
188
+ ), "Worker should have a protocol"
189
+ else:
190
+ pytest.fail(
191
+ "capture file not found: dask_setup was not called or "
192
+ "failed to write to file"
193
+ )
@@ -13,16 +13,11 @@ from distributed import Client
13
13
 
14
14
  from dask_cuda import LocalCUDACluster
15
15
  from dask_cuda.initialize import initialize
16
- from dask_cuda.utils_test import get_ucx_implementation
17
16
 
18
17
  mp = mp.get_context("spawn") # type: ignore
19
18
  psutil = pytest.importorskip("psutil")
20
19
 
21
20
 
22
- def _is_ucx_116(ucp):
23
- return ucp.get_ucx_version()[:2] == (1, 16)
24
-
25
-
26
21
  class DGXVersion(Enum):
27
22
  DGX_1 = auto()
28
23
  DGX_2 = auto()
@@ -81,17 +76,17 @@ def test_default():
81
76
  assert not p.exitcode
82
77
 
83
78
 
84
- def _test_tcp_over_ucx(protocol):
85
- ucp = get_ucx_implementation(protocol)
79
+ def _test_tcp_over_ucx():
80
+ ucxx = pytest.importorskip("ucxx")
86
81
 
87
- with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster:
82
+ with LocalCUDACluster(protocol="ucx", enable_tcp_over_ucx=True) as cluster:
88
83
  with Client(cluster) as client:
89
84
  res = da.from_array(numpy.arange(10000), chunks=(1000,))
90
85
  res = res.sum().compute()
91
86
  assert res == 49995000
92
87
 
93
88
  def check_ucx_options():
94
- conf = ucp.get_config()
89
+ conf = ucxx.get_config()
95
90
  assert "TLS" in conf
96
91
  assert "tcp" in conf["TLS"]
97
92
  assert "cuda_copy" in conf["TLS"]
@@ -101,16 +96,10 @@ def _test_tcp_over_ucx(protocol):
101
96
  assert all(client.run(check_ucx_options).values())
102
97
 
103
98
 
104
- @pytest.mark.parametrize(
105
- "protocol",
106
- ["ucx", "ucx-old"],
107
- )
108
- def test_tcp_over_ucx(protocol):
109
- ucp = get_ucx_implementation(protocol)
110
- if _is_ucx_116(ucp):
111
- pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
99
+ def test_tcp_over_ucx():
100
+ pytest.importorskip("distributed_ucxx")
112
101
 
113
- p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,))
102
+ p = mp.Process(target=_test_tcp_over_ucx)
114
103
  p.start()
115
104
  p.join()
116
105
  assert not p.exitcode
@@ -132,22 +121,22 @@ def test_tcp_only():
132
121
 
133
122
 
134
123
  def _test_ucx_infiniband_nvlink(
135
- skip_queue, protocol, enable_infiniband, enable_nvlink, enable_rdmacm
124
+ skip_queue, enable_infiniband, enable_nvlink, enable_rdmacm
136
125
  ):
126
+ ucxx = pytest.importorskip("ucxx")
137
127
  cupy = pytest.importorskip("cupy")
138
- ucp = get_ucx_implementation(protocol)
139
128
 
140
129
  if enable_infiniband and not any(
141
- [at.startswith("rc") for at in ucp.get_active_transports()]
130
+ [at.startswith("rc") for at in ucxx.get_active_transports()]
142
131
  ):
143
132
  skip_queue.put("No support available for 'rc' transport in UCX")
144
133
  return
145
134
  else:
146
135
  skip_queue.put("ok")
147
136
 
148
- # `ucp.get_active_transports()` call above initializes UCX, we must reset it
137
+ # `ucxx.get_active_transports()` call above initializes UCX, we must reset it
149
138
  # so that Dask doesn't try to initialize it again and raise an exception.
150
- ucp.reset()
139
+ ucxx.reset()
151
140
 
152
141
  if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
153
142
  enable_tcp_over_ucx = None
@@ -163,7 +152,6 @@ def _test_ucx_infiniband_nvlink(
163
152
  cm_tls_priority = ["tcp"]
164
153
 
165
154
  initialize(
166
- protocol=protocol,
167
155
  enable_tcp_over_ucx=enable_tcp_over_ucx,
168
156
  enable_infiniband=enable_infiniband,
169
157
  enable_nvlink=enable_nvlink,
@@ -171,7 +159,7 @@ def _test_ucx_infiniband_nvlink(
171
159
  )
172
160
 
173
161
  with LocalCUDACluster(
174
- protocol=protocol,
162
+ protocol="ucx",
175
163
  interface="ib0",
176
164
  enable_tcp_over_ucx=enable_tcp_over_ucx,
177
165
  enable_infiniband=enable_infiniband,
@@ -185,7 +173,7 @@ def _test_ucx_infiniband_nvlink(
185
173
  assert res == 49995000
186
174
 
187
175
  def check_ucx_options():
188
- conf = ucp.get_config()
176
+ conf = ucxx.get_config()
189
177
  assert "TLS" in conf
190
178
  assert all(t in conf["TLS"] for t in cm_tls)
191
179
  assert all(p in conf["SOCKADDR_TLS_PRIORITY"] for p in cm_tls_priority)
@@ -201,7 +189,6 @@ def _test_ucx_infiniband_nvlink(
201
189
  assert all(client.run(check_ucx_options).values())
202
190
 
203
191
 
204
- @pytest.mark.parametrize("protocol", ["ucx", "ucx-old"])
205
192
  @pytest.mark.parametrize(
206
193
  "params",
207
194
  [
@@ -216,10 +203,8 @@ def _test_ucx_infiniband_nvlink(
216
203
  _get_dgx_version() == DGXVersion.DGX_A100,
217
204
  reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
218
205
  )
219
- def test_ucx_infiniband_nvlink(protocol, params):
220
- ucp = get_ucx_implementation(protocol)
221
- if _is_ucx_116(ucp) and params["enable_infiniband"] is False:
222
- pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
206
+ def test_ucx_infiniband_nvlink(params):
207
+ pytest.importorskip("distributed_ucxx")
223
208
 
224
209
  skip_queue = mp.Queue()
225
210
 
@@ -227,7 +212,6 @@ def test_ucx_infiniband_nvlink(protocol, params):
227
212
  target=_test_ucx_infiniband_nvlink,
228
213
  args=(
229
214
  skip_queue,
230
- protocol,
231
215
  params["enable_infiniband"],
232
216
  params["enable_nvlink"],
233
217
  params["enable_rdmacm"],
@@ -26,10 +26,9 @@ from dask_cuda.explicit_comms.dataframe.shuffle import (
26
26
  _contains_shuffle_expr,
27
27
  shuffle as explicit_comms_shuffle,
28
28
  )
29
- from dask_cuda.utils_test import IncreasedCloseTimeoutNanny, get_ucx_implementation
29
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
30
30
 
31
31
  mp = mp.get_context("spawn") # type: ignore
32
- ucp = pytest.importorskip("ucp")
33
32
 
34
33
 
35
34
  # Notice, all of the following tests is executed in a new process such
@@ -54,10 +53,11 @@ def _test_local_cluster(protocol):
54
53
  assert sum(c.run(my_rank, 0)) == sum(range(4))
55
54
 
56
55
 
57
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
56
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
58
57
  def test_local_cluster(protocol):
59
58
  if protocol.startswith("ucx"):
60
- get_ucx_implementation(protocol)
59
+ pytest.importorskip("distributed_ucxx")
60
+
61
61
  p = mp.Process(target=_test_local_cluster, args=(protocol,))
62
62
  p.start()
63
63
  p.join()
@@ -202,13 +202,13 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
202
202
 
203
203
  @pytest.mark.parametrize("nworkers", [1, 2, 3])
204
204
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
205
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
205
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
206
206
  @pytest.mark.parametrize("_partitions", [True, False])
207
207
  def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
208
208
  if backend == "cudf":
209
209
  pytest.importorskip("cudf")
210
210
  if protocol.startswith("ucx"):
211
- get_ucx_implementation(protocol)
211
+ pytest.importorskip("distributed_ucxx")
212
212
 
213
213
  p = mp.Process(
214
214
  target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
@@ -325,12 +325,13 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
325
325
 
326
326
  @pytest.mark.parametrize("nworkers", [1, 2, 4])
327
327
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
328
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
328
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
329
329
  def test_dataframe_shuffle_merge(backend, protocol, nworkers):
330
330
  if backend == "cudf":
331
331
  pytest.importorskip("cudf")
332
332
  if protocol.startswith("ucx"):
333
- get_ucx_implementation(protocol)
333
+ pytest.importorskip("distributed_ucxx")
334
+
334
335
  p = mp.Process(
335
336
  target=_test_dataframe_shuffle_merge, args=(backend, protocol, nworkers)
336
337
  )
@@ -364,14 +365,14 @@ def _test_jit_unspill(protocol):
364
365
  assert_eq(got, expected)
365
366
 
366
367
 
367
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
368
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
368
369
  @pytest.mark.skip_if_no_device_memory(
369
370
  "JIT-Unspill not supported in devices without dedicated memory resource"
370
371
  )
371
372
  def test_jit_unspill(protocol):
372
373
  pytest.importorskip("cudf")
373
374
  if protocol.startswith("ucx"):
374
- get_ucx_implementation(protocol)
375
+ pytest.importorskip("distributed_ucxx")
375
376
 
376
377
  p = mp.Process(target=_test_jit_unspill, args=(protocol,))
377
378
  p.start()
@@ -7,16 +7,12 @@ import dask.array as da
7
7
  from distributed import Client
8
8
 
9
9
  from dask_cuda import LocalCUDACluster
10
- from dask_cuda.utils_test import get_ucx_implementation
11
10
 
12
11
  cupy = pytest.importorskip("cupy")
13
12
 
14
13
 
15
- @pytest.mark.parametrize("protocol", ["ucx", "ucx-old", "tcp"])
14
+ @pytest.mark.parametrize("protocol", ["ucx", "tcp"])
16
15
  def test_ucx_from_array(protocol):
17
- if protocol.startswith("ucx"):
18
- get_ucx_implementation(protocol)
19
-
20
16
  N = 10_000
21
17
  with LocalCUDACluster(protocol=protocol) as cluster:
22
18
  with Client(cluster):