dask-cuda 25.8.0__py3-none-any.whl → 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/benchmarks/local_cudf_groupby.py +1 -1
  4. dask_cuda/benchmarks/local_cudf_merge.py +1 -1
  5. dask_cuda/benchmarks/local_cudf_shuffle.py +1 -1
  6. dask_cuda/benchmarks/local_cupy.py +1 -1
  7. dask_cuda/benchmarks/local_cupy_map_overlap.py +1 -1
  8. dask_cuda/benchmarks/utils.py +1 -1
  9. dask_cuda/cuda_worker.py +1 -1
  10. dask_cuda/get_device_memory_objects.py +1 -4
  11. dask_cuda/initialize.py +140 -121
  12. dask_cuda/local_cuda_cluster.py +10 -25
  13. dask_cuda/tests/test_cudf_builtin_spilling.py +3 -1
  14. dask_cuda/tests/test_dask_setup.py +193 -0
  15. dask_cuda/tests/test_dgx.py +16 -32
  16. dask_cuda/tests/test_explicit_comms.py +11 -10
  17. dask_cuda/tests/test_from_array.py +1 -5
  18. dask_cuda/tests/test_initialize.py +230 -41
  19. dask_cuda/tests/test_local_cuda_cluster.py +16 -62
  20. dask_cuda/tests/test_proxify_host_file.py +9 -4
  21. dask_cuda/tests/test_proxy.py +8 -8
  22. dask_cuda/tests/test_spill.py +3 -3
  23. dask_cuda/tests/test_utils.py +8 -23
  24. dask_cuda/tests/test_worker_spec.py +5 -2
  25. dask_cuda/utils.py +12 -66
  26. dask_cuda/utils_test.py +0 -13
  27. dask_cuda/worker_spec.py +7 -9
  28. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/METADATA +11 -4
  29. dask_cuda-25.10.0.dist-info/RECORD +63 -0
  30. shared-actions/check_nightly_success/check-nightly-success/check.py +1 -1
  31. dask_cuda/tests/test_rdd_ucx.py +0 -160
  32. dask_cuda-25.8.0.dist-info/RECORD +0 -63
  33. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/WHEEL +0 -0
  34. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/entry_points.txt +0 -0
  35. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/licenses/LICENSE +0 -0
  36. {dask_cuda-25.8.0.dist-info → dask_cuda-25.10.0.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,14 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import multiprocessing as mp
5
+ import os
6
+ import shutil
7
+ import subprocess
5
8
  import sys
9
+ import tempfile
10
+ import textwrap
6
11
 
12
+ import cuda.core.experimental
7
13
  import numpy
8
14
  import psutil
9
15
  import pytest
@@ -14,7 +20,7 @@ from distributed.deploy.local import LocalCluster
14
20
 
15
21
  from dask_cuda.initialize import initialize
16
22
  from dask_cuda.utils import get_ucx_config
17
- from dask_cuda.utils_test import IncreasedCloseTimeoutNanny, get_ucx_implementation
23
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
18
24
 
19
25
  mp = mp.get_context("spawn") # type: ignore
20
26
 
@@ -24,19 +30,19 @@ mp = mp.get_context("spawn") # type: ignore
24
30
  # of UCX before retrieving the current config.
25
31
 
26
32
 
27
- def _test_initialize_ucx_tcp(protocol):
28
- ucp = get_ucx_implementation(protocol)
33
+ def _test_initialize_ucx_tcp():
34
+ ucxx = pytest.importorskip("ucxx")
29
35
 
30
36
  kwargs = {"enable_tcp_over_ucx": True}
31
- initialize(protocol=protocol, **kwargs)
37
+ initialize(**kwargs)
32
38
  with LocalCluster(
33
- protocol=protocol,
39
+ protocol="ucx",
34
40
  dashboard_address=None,
35
41
  n_workers=1,
36
42
  threads_per_worker=1,
37
43
  processes=True,
38
44
  worker_class=IncreasedCloseTimeoutNanny,
39
- config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
45
+ config={"distributed-ucxx": get_ucx_config(**kwargs)},
40
46
  ) as cluster:
41
47
  with Client(cluster) as client:
42
48
  res = da.from_array(numpy.arange(10000), chunks=(1000,))
@@ -44,7 +50,7 @@ def _test_initialize_ucx_tcp(protocol):
44
50
  assert res == 49995000
45
51
 
46
52
  def check_ucx_options():
47
- conf = ucp.get_config()
53
+ conf = ucxx.get_config()
48
54
  assert "TLS" in conf
49
55
  assert "tcp" in conf["TLS"]
50
56
  assert "cuda_copy" in conf["TLS"]
@@ -55,29 +61,28 @@ def _test_initialize_ucx_tcp(protocol):
55
61
  assert all(client.run(check_ucx_options).values())
56
62
 
57
63
 
58
- @pytest.mark.parametrize("protocol", ["ucx", "ucx-old"])
59
- def test_initialize_ucx_tcp(protocol):
60
- get_ucx_implementation(protocol)
64
+ def test_initialize_ucx_tcp():
65
+ pytest.importorskip("distributed_ucxx")
61
66
 
62
- p = mp.Process(target=_test_initialize_ucx_tcp, args=(protocol,))
67
+ p = mp.Process(target=_test_initialize_ucx_tcp)
63
68
  p.start()
64
69
  p.join()
65
70
  assert not p.exitcode
66
71
 
67
72
 
68
- def _test_initialize_ucx_nvlink(protocol):
69
- ucp = get_ucx_implementation(protocol)
73
+ def _test_initialize_ucx_nvlink():
74
+ ucxx = pytest.importorskip("ucxx")
70
75
 
71
76
  kwargs = {"enable_nvlink": True}
72
- initialize(protocol=protocol, **kwargs)
77
+ initialize(**kwargs)
73
78
  with LocalCluster(
74
- protocol=protocol,
79
+ protocol="ucx",
75
80
  dashboard_address=None,
76
81
  n_workers=1,
77
82
  threads_per_worker=1,
78
83
  processes=True,
79
84
  worker_class=IncreasedCloseTimeoutNanny,
80
- config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
85
+ config={"distributed-ucxx": get_ucx_config(**kwargs)},
81
86
  ) as cluster:
82
87
  with Client(cluster) as client:
83
88
  res = da.from_array(numpy.arange(10000), chunks=(1000,))
@@ -85,7 +90,7 @@ def _test_initialize_ucx_nvlink(protocol):
85
90
  assert res == 49995000
86
91
 
87
92
  def check_ucx_options():
88
- conf = ucp.get_config()
93
+ conf = ucxx.get_config()
89
94
  assert "TLS" in conf
90
95
  assert "cuda_ipc" in conf["TLS"]
91
96
  assert "tcp" in conf["TLS"]
@@ -97,29 +102,28 @@ def _test_initialize_ucx_nvlink(protocol):
97
102
  assert all(client.run(check_ucx_options).values())
98
103
 
99
104
 
100
- @pytest.mark.parametrize("protocol", ["ucx", "ucx-old"])
101
- def test_initialize_ucx_nvlink(protocol):
102
- get_ucx_implementation(protocol)
105
+ def test_initialize_ucx_nvlink():
106
+ pytest.importorskip("distributed_ucxx")
103
107
 
104
- p = mp.Process(target=_test_initialize_ucx_nvlink, args=(protocol,))
108
+ p = mp.Process(target=_test_initialize_ucx_nvlink)
105
109
  p.start()
106
110
  p.join()
107
111
  assert not p.exitcode
108
112
 
109
113
 
110
- def _test_initialize_ucx_infiniband(protocol):
111
- ucp = get_ucx_implementation(protocol)
114
+ def _test_initialize_ucx_infiniband():
115
+ ucxx = pytest.importorskip("ucxx")
112
116
 
113
117
  kwargs = {"enable_infiniband": True}
114
- initialize(protocol=protocol, **kwargs)
118
+ initialize(**kwargs)
115
119
  with LocalCluster(
116
- protocol=protocol,
120
+ protocol="ucx",
117
121
  dashboard_address=None,
118
122
  n_workers=1,
119
123
  threads_per_worker=1,
120
124
  processes=True,
121
125
  worker_class=IncreasedCloseTimeoutNanny,
122
- config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
126
+ config={"distributed-ucxx": get_ucx_config(**kwargs)},
123
127
  ) as cluster:
124
128
  with Client(cluster) as client:
125
129
  res = da.from_array(numpy.arange(10000), chunks=(1000,))
@@ -127,7 +131,7 @@ def _test_initialize_ucx_infiniband(protocol):
127
131
  assert res == 49995000
128
132
 
129
133
  def check_ucx_options():
130
- conf = ucp.get_config()
134
+ conf = ucxx.get_config()
131
135
  assert "TLS" in conf
132
136
  assert "rc" in conf["TLS"]
133
137
  assert "tcp" in conf["TLS"]
@@ -142,28 +146,27 @@ def _test_initialize_ucx_infiniband(protocol):
142
146
  @pytest.mark.skipif(
143
147
  "ib0" not in psutil.net_if_addrs(), reason="Infiniband interface ib0 not found"
144
148
  )
145
- @pytest.mark.parametrize("protocol", ["ucx", "ucx-old"])
146
- def test_initialize_ucx_infiniband(protocol):
147
- get_ucx_implementation(protocol)
149
+ def test_initialize_ucx_infiniband():
150
+ pytest.importorskip("distributed_ucxx")
148
151
 
149
- p = mp.Process(target=_test_initialize_ucx_infiniband, args=(protocol,))
152
+ p = mp.Process(target=_test_initialize_ucx_infiniband)
150
153
  p.start()
151
154
  p.join()
152
155
  assert not p.exitcode
153
156
 
154
157
 
155
- def _test_initialize_ucx_all(protocol):
156
- ucp = get_ucx_implementation(protocol)
158
+ def _test_initialize_ucx_all():
159
+ ucxx = pytest.importorskip("ucxx")
157
160
 
158
- initialize(protocol=protocol)
161
+ initialize()
159
162
  with LocalCluster(
160
- protocol=protocol,
163
+ protocol="ucx",
161
164
  dashboard_address=None,
162
165
  n_workers=1,
163
166
  threads_per_worker=1,
164
167
  processes=True,
165
168
  worker_class=IncreasedCloseTimeoutNanny,
166
- config={"distributed.comm.ucx": get_ucx_config()},
169
+ config={"distributed-ucxx": get_ucx_config()},
167
170
  ) as cluster:
168
171
  with Client(cluster) as client:
169
172
  res = da.from_array(numpy.arange(10000), chunks=(1000,))
@@ -171,7 +174,7 @@ def _test_initialize_ucx_all(protocol):
171
174
  assert res == 49995000
172
175
 
173
176
  def check_ucx_options():
174
- conf = ucp.get_config()
177
+ conf = ucxx.get_config()
175
178
  assert "TLS" in conf
176
179
  assert conf["TLS"] == "all"
177
180
  assert all(
@@ -186,11 +189,10 @@ def _test_initialize_ucx_all(protocol):
186
189
  assert all(client.run(check_ucx_options).values())
187
190
 
188
191
 
189
- @pytest.mark.parametrize("protocol", ["ucx", "ucx-old"])
190
- def test_initialize_ucx_all(protocol):
191
- get_ucx_implementation(protocol)
192
+ def test_initialize_ucx_all():
193
+ pytest.importorskip("distributed_ucxx")
192
194
 
193
- p = mp.Process(target=_test_initialize_ucx_all, args=(protocol,))
195
+ p = mp.Process(target=_test_initialize_ucx_all)
194
196
  p.start()
195
197
  p.join()
196
198
  assert not p.exitcode
@@ -229,3 +231,190 @@ def test_dask_cuda_import():
229
231
  p.start()
230
232
  p.join()
231
233
  assert not p.exitcode
234
+
235
+
236
+ def _test_cuda_context_warning_with_subprocess_warnings(protocol):
237
+ """Test CUDA context warnings from both parent and worker subprocesses.
238
+
239
+ This test creates a standalone script that imports a problematic library
240
+ and creates LocalCUDACluster with processes=True. This should generate
241
+ warnings from both the parent process and each worker subprocess, since
242
+ they all inherit the CUDA context created at import time.
243
+ """
244
+ # Create temporary directory for our test files
245
+ temp_dir = tempfile.mkdtemp()
246
+
247
+ # Create the problematic library that creates CUDA context at import
248
+ problematic_library_code = textwrap.dedent(
249
+ """
250
+ # Problematic library that creates CUDA context at import time
251
+ import os
252
+
253
+ import cuda.core.experimental
254
+
255
+ try:
256
+ # Create CUDA context at import time, this will be inherited by subprocesses
257
+ cuda.core.experimental.Device().set_current()
258
+ print("Problematic library: Created CUDA context at import time")
259
+ os.environ['SUBPROCESS_CUDA_CONTEXT_CREATED'] = '1'
260
+ except Exception as e:
261
+ raise RuntimeError(
262
+ f"Problematic library: Failed to create CUDA context({e})"
263
+ )
264
+ os.environ['SUBPROCESS_CUDA_CONTEXT_CREATED'] = '0'
265
+ """
266
+ )
267
+
268
+ problematic_lib_path = os.path.join(temp_dir, "problematic_cuda_library.py")
269
+ with open(problematic_lib_path, "w") as f:
270
+ f.write(problematic_library_code)
271
+
272
+ # Create the main test script that imports the problematic library
273
+ # and creates LocalCUDACluster - this will run in a subprocess
274
+ main_script_code = textwrap.dedent(
275
+ f"""
276
+ # Main script that demonstrates the real-world problematic scenario
277
+ import os
278
+ import sys
279
+ import logging
280
+
281
+ # Add the temp directory to path so we can import our problematic library
282
+ sys.path.insert(0, '{temp_dir}')
283
+
284
+ print("=== Starting subprocess warnings test ===")
285
+
286
+ # This is the key part: import the problematic library BEFORE creating
287
+ # LocalCUDACluster. This creates a CUDA context that will be inherited
288
+ # by all worker subprocesses
289
+ print("Importing problematic library...")
290
+ import problematic_cuda_library
291
+
292
+ context_mode = os.environ.get('SUBPROCESS_CUDA_CONTEXT_CREATED', None)
293
+ if context_mode == "1":
294
+ print(f"Context creation successful")
295
+ else:
296
+ raise RuntimeError("Context creation failed")
297
+
298
+ if __name__ == "__main__":
299
+ try:
300
+ from dask_cuda import LocalCUDACluster
301
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
302
+
303
+ cluster = LocalCUDACluster(
304
+ dashboard_address=None,
305
+ worker_class=IncreasedCloseTimeoutNanny,
306
+ protocol=f"{protocol}",
307
+ )
308
+ print("LocalCUDACluster created successfully!")
309
+
310
+ cluster.close()
311
+ print("Cluster closed successfully")
312
+
313
+ except Exception as e:
314
+ raise RuntimeError(f"Cluster setup error: {{e}}")
315
+
316
+ print("=== Subprocess warnings test completed ===")
317
+ """
318
+ )
319
+
320
+ main_script_path = os.path.join(temp_dir, "test_subprocess_warnings.py")
321
+ with open(main_script_path, "w") as f:
322
+ f.write(main_script_code)
323
+
324
+ try:
325
+ # Run the main script in a subprocess
326
+ result = subprocess.run(
327
+ [sys.executable, main_script_path],
328
+ capture_output=True,
329
+ text=True,
330
+ timeout=30, # Reduced timeout for simpler test
331
+ cwd=os.getcwd(),
332
+ )
333
+
334
+ # Check for successful test execution regardless of warnings
335
+ assert (
336
+ "Context creation successful" in result.stdout
337
+ ), "Test did not create a CUDA context"
338
+ assert (
339
+ "Creating LocalCUDACluster" in result.stdout
340
+ or "LocalCUDACluster created successfully" in result.stdout
341
+ ), "LocalCUDACluster was not created"
342
+
343
+ # Check the log file for warnings from multiple processes
344
+ warnings_found = []
345
+ warnings_assigned_device_found = []
346
+
347
+ # Look for CUDA context warnings from different processes
348
+ lines = result.stderr.split("\n")
349
+ for line in lines:
350
+ if "A CUDA context for device" in line and "already exists" in line:
351
+ warnings_found.append(line)
352
+ if (
353
+ "should have a CUDA context assigned to device" in line
354
+ and "but instead the CUDA context is on device" in line
355
+ ):
356
+ warnings_assigned_device_found.append(line)
357
+
358
+ num_devices = cuda.core.experimental.system.num_devices
359
+
360
+ # Every worker raises the warning once. With protocol="ucx" the warning is
361
+ # raised once more by the parent process.
362
+ expected_warnings = num_devices if protocol == "tcp" else num_devices + 1
363
+ assert len(warnings_found) == expected_warnings, (
364
+ f"Expected {expected_warnings} CUDA context warnings, "
365
+ f"but found {len(warnings_assigned_device_found)}"
366
+ )
367
+
368
+ # Can only be tested in multi-GPU test environment, device 0 can never raise
369
+ # this warning (because it's where all CUDA contexts are created), thus one
370
+ # warning is raised by every device except 0.
371
+ expected_assigned_device_warnings = num_devices - 1
372
+ assert (
373
+ len(warnings_assigned_device_found) == expected_assigned_device_warnings
374
+ ), (
375
+ f"Expected {expected_assigned_device_warnings} warnings assigned to "
376
+ f"device, but found {len(warnings_assigned_device_found)}"
377
+ )
378
+
379
+ # Verify warnings contents
380
+ for warning in warnings_found:
381
+ assert (
382
+ "This is often the result of a CUDA-enabled library calling a "
383
+ "CUDA runtime function before Dask-CUDA" in warning
384
+ ), f"Warning missing explanatory text: {warning}"
385
+ for warning in warnings_assigned_device_found:
386
+ assert (
387
+ "This is often the result of a CUDA-enabled library calling a "
388
+ "CUDA runtime function before Dask-CUDA" in warning
389
+ ), f"Warning missing explanatory text: {warning}"
390
+ finally:
391
+ # Clean up temporary files
392
+ try:
393
+ if os.path.exists(temp_dir):
394
+ shutil.rmtree(temp_dir)
395
+ except Exception as e:
396
+ print(f"Cleanup error: {e}")
397
+
398
+
399
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
400
+ def test_cuda_context_warning_with_subprocess_warnings(protocol):
401
+ """Test CUDA context warnings from parent and worker subprocesses.
402
+
403
+ This test creates a standalone script that imports a problematic library at the top
404
+ level and then creates LocalCUDACluster with processes=True. This replicates the
405
+ exact real-world scenario where:
406
+
407
+ 1. User imports a problematic library that creates CUDA context at import time
408
+ 2. User creates LocalCUDACluster with multiple workers
409
+ 3. Each worker subprocess inherits the CUDA context and emits warnings
410
+ 4. Multiple warnings are generated (parent process + each worker subprocess)
411
+
412
+ This is the ultimate test as it demonstrates the distributed warning scenario
413
+ that users actually encounter in production.
414
+ """
415
+ p = mp.Process(
416
+ target=_test_cuda_context_warning_with_subprocess_warnings, args=(protocol,)
417
+ )
418
+ p.start()
419
+ p.join()
420
+ assert not p.exitcode
@@ -24,7 +24,7 @@ from dask_cuda.utils import (
24
24
  has_device_memory_resource,
25
25
  print_cluster_config,
26
26
  )
27
- from dask_cuda.utils_test import MockWorker, get_ucx_implementation
27
+ from dask_cuda.utils_test import MockWorker
28
28
 
29
29
 
30
30
  @gen_test(timeout=20)
@@ -93,53 +93,39 @@ async def test_with_subset_of_cuda_visible_devices():
93
93
  }
94
94
 
95
95
 
96
- @pytest.mark.parametrize(
97
- "protocol",
98
- ["ucx", "ucx-old"],
99
- )
100
96
  @gen_test(timeout=20)
101
- async def test_ucx_protocol(protocol):
102
- get_ucx_implementation(protocol)
97
+ async def test_ucx_protocol():
98
+ pytest.importorskip("distributed_ucxx")
103
99
 
104
100
  async with LocalCUDACluster(
105
- protocol=protocol, asynchronous=True, data=dict
101
+ protocol="ucx", asynchronous=True, data=dict
106
102
  ) as cluster:
107
103
  assert all(
108
- ws.address.startswith(f"{protocol}://")
109
- for ws in cluster.scheduler.workers.values()
104
+ ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values()
110
105
  )
111
106
 
112
107
 
113
- @pytest.mark.parametrize(
114
- "protocol",
115
- ["ucx", "ucx-old"],
116
- )
117
108
  @gen_test(timeout=20)
118
- async def test_explicit_ucx_with_protocol_none(protocol):
119
- get_ucx_implementation(protocol)
109
+ async def test_explicit_ucx_with_protocol_none():
110
+ pytest.importorskip("distributed_ucxx")
120
111
 
121
- initialize(protocol=protocol, enable_tcp_over_ucx=True)
112
+ initialize(enable_tcp_over_ucx=True)
122
113
  async with LocalCUDACluster(
123
114
  protocol=None,
124
115
  enable_tcp_over_ucx=True,
125
116
  asynchronous=True,
126
117
  ) as cluster:
127
118
  assert all(
128
- ws.address.startswith(f"{protocol}://")
129
- for ws in cluster.scheduler.workers.values()
119
+ ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values()
130
120
  )
131
121
 
132
122
 
133
123
  @pytest.mark.filterwarnings("ignore:Exception ignored in")
134
- @pytest.mark.parametrize(
135
- "protocol",
136
- ["ucx", "ucx-old"],
137
- )
138
124
  @gen_test(timeout=20)
139
- async def test_ucx_protocol_type_error(protocol):
140
- get_ucx_implementation(protocol)
125
+ async def test_ucx_protocol_type_error():
126
+ pytest.importorskip("distributed_ucxx")
141
127
 
142
- initialize(protocol=protocol, enable_tcp_over_ucx=True)
128
+ initialize(enable_tcp_over_ucx=True)
143
129
  with pytest.raises(TypeError):
144
130
  async with LocalCUDACluster(
145
131
  protocol="tcp", enable_tcp_over_ucx=True, asynchronous=True, data=dict
@@ -602,10 +588,6 @@ async def test_cudf_spill_no_dedicated_memory():
602
588
  )
603
589
 
604
590
 
605
- @pytest.mark.parametrize(
606
- "protocol",
607
- ["ucx", "ucx-old"],
608
- )
609
591
  @pytest.mark.parametrize(
610
592
  "jit_unspill",
611
593
  [False, True],
@@ -614,8 +596,8 @@ async def test_cudf_spill_no_dedicated_memory():
614
596
  "device_memory_limit",
615
597
  [None, "1B"],
616
598
  )
617
- def test_print_cluster_config(capsys, protocol, jit_unspill, device_memory_limit):
618
- get_ucx_implementation(protocol)
599
+ def test_print_cluster_config(capsys, jit_unspill, device_memory_limit):
600
+ pytest.importorskip("distributed_ucxx")
619
601
 
620
602
  pytest.importorskip("rich")
621
603
 
@@ -640,46 +622,18 @@ def test_print_cluster_config(capsys, protocol, jit_unspill, device_memory_limit
640
622
  n_workers=1,
641
623
  device_memory_limit=device_memory_limit,
642
624
  jit_unspill=jit_unspill,
643
- protocol=protocol,
625
+ protocol="ucx",
644
626
  ) as cluster:
645
627
  with Client(cluster) as client:
646
628
  print_cluster_config(client)
647
629
  captured = capsys.readouterr()
648
630
  assert "Dask Cluster Configuration" in captured.out
649
- assert protocol in captured.out
631
+ assert "ucx" in captured.out
650
632
  if device_memory_limit == "1B":
651
633
  assert "1 B" in captured.out
652
634
  assert "[plugin]" in captured.out
653
635
  client.shutdown()
654
636
 
655
- def ucxpy_reset(timeout=20):
656
- """Reset UCX-Py with a timeout.
657
-
658
- Attempt to reset UCX-Py, not doing so may cause a deadlock because UCX-Py is
659
- not thread-safe and the Dask cluster may still be alive while a new cluster
660
- and UCX-Py instances are initalized.
661
- """
662
- import time
663
-
664
- import ucp
665
-
666
- start = time.monotonic()
667
- while True:
668
- try:
669
- ucp.reset()
670
- except ucp._libs.exceptions.UCXError as e:
671
- if time.monotonic() - start > timeout:
672
- raise RuntimeError(
673
- f"Could not reset UCX-Py in {timeout} seconds, this may result "
674
- f"in a deadlock. Failure:\n{e}"
675
- )
676
- continue
677
- else:
678
- break
679
-
680
- if protocol == "ucx-old":
681
- ucxpy_reset()
682
-
683
637
 
684
638
  @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
685
639
  def test_death_timeout_raises():
@@ -448,14 +448,20 @@ async def test_worker_force_spill_to_disk():
448
448
  """Test Dask triggering CPU-to-Disk spilling"""
449
449
  cudf = pytest.importorskip("cudf")
450
450
 
451
+ def create_dataframe():
452
+ return cudf.DataFrame({"key": np.arange(10**8)})
453
+
451
454
  with dask.config.set({"distributed.worker.memory.terminate": False}):
452
455
  async with dask_cuda.LocalCUDACluster(
453
456
  n_workers=1, device_memory_limit="1MB", jit_unspill=True, asynchronous=True
454
457
  ) as cluster:
455
458
  async with Client(cluster, asynchronous=True) as client:
456
459
  # Create a df that are spilled to host memory immediately
457
- df = cudf.DataFrame({"key": np.arange(10**8)})
458
- [ddf] = client.persist([dask.dataframe.from_pandas(df, npartitions=1)])
460
+ ddf = dask.dataframe.from_delayed(
461
+ dask.delayed(create_dataframe)(),
462
+ meta=cudf.DataFrame({"key": cupy.arange(0)}),
463
+ )
464
+ [ddf] = client.persist([ddf])
459
465
  await ddf
460
466
 
461
467
  async def f(dask_worker):
@@ -466,13 +472,12 @@ async def test_worker_force_spill_to_disk():
466
472
  memory = w.monitor.proc.memory_info().rss
467
473
  w.memory_manager.memory_limit = memory - 10**8
468
474
  w.memory_manager.memory_target_fraction = 1
469
- print(w.memory_manager.data)
470
475
  await w.memory_manager.memory_monitor(w)
471
476
  # Check that host memory are freed
472
477
  assert w.monitor.proc.memory_info().rss < memory - 10**7
473
478
  w.memory_manager.memory_limit = memory * 10 # Un-limit
474
479
 
475
- client.run(f)
480
+ await client.run(f)
476
481
  log = str(await client.get_worker_logs())
477
482
  # Check that the worker doesn't complain about unmanaged memory
478
483
  assert "Unmanaged memory use is high" not in log
@@ -26,7 +26,7 @@ from dask_cuda import LocalCUDACluster, proxy_object
26
26
  from dask_cuda.disk_io import SpillToDiskFile
27
27
  from dask_cuda.proxify_device_objects import proxify_device_objects
28
28
  from dask_cuda.proxify_host_file import ProxifyHostFile
29
- from dask_cuda.utils_test import IncreasedCloseTimeoutNanny, get_ucx_implementation
29
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
30
30
 
31
31
  # Make the "disk" serializer available and use a directory that are
32
32
  # remove on exit.
@@ -407,12 +407,12 @@ class _PxyObjTest(proxy_object.ProxyObject):
407
407
 
408
408
 
409
409
  @pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)])
410
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
410
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
411
411
  @gen_test(timeout=120)
412
412
  async def test_communicating_proxy_objects(protocol, send_serializers):
413
413
  """Testing serialization of cuDF dataframe when communicating"""
414
- if protocol.startswith("ucx"):
415
- get_ucx_implementation(protocol)
414
+ if protocol == "ucx":
415
+ pytest.importorskip("distributed_ucxx")
416
416
  cudf = pytest.importorskip("cudf")
417
417
 
418
418
  def task(x):
@@ -421,7 +421,7 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
421
421
  serializers_used = x._pxy_get().serializer
422
422
 
423
423
  # Check that `x` is serialized with the expected serializers
424
- if protocol in ["ucx", "ucx-old"]:
424
+ if protocol == "ucx":
425
425
  if send_serializers is None:
426
426
  assert serializers_used == "cuda"
427
427
  else:
@@ -452,13 +452,13 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
452
452
  await client.submit(task, df)
453
453
 
454
454
 
455
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
455
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
456
456
  @pytest.mark.parametrize("shared_fs", [True, False])
457
457
  @gen_test(timeout=20)
458
458
  async def test_communicating_disk_objects(protocol, shared_fs):
459
459
  """Testing disk serialization of cuDF dataframe when communicating"""
460
- if protocol.startswith("ucx"):
461
- get_ucx_implementation(protocol)
460
+ if protocol == "ucx":
461
+ pytest.importorskip("distributed_ucxx")
462
462
  cudf = pytest.importorskip("cudf")
463
463
  ProxifyHostFile._spill_to_disk.shared_filesystem = shared_fs
464
464
 
@@ -15,10 +15,10 @@ from distributed.sizeof import sizeof
15
15
  from distributed.utils import Deadline
16
16
  from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401
17
17
 
18
- import dask_cudf
18
+ dask_cudf = pytest.importorskip("dask_cudf")
19
19
 
20
- from dask_cuda import LocalCUDACluster, utils
21
- from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
20
+ from dask_cuda import LocalCUDACluster, utils # noqa: E402
21
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny # noqa: E402
22
22
 
23
23
  if not utils.has_device_memory_resource():
24
24
  pytest.skip(