dask-cuda 25.4.0__py3-none-any.whl → 25.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/_compat.py +18 -0
  4. dask_cuda/benchmarks/common.py +4 -1
  5. dask_cuda/benchmarks/local_cudf_groupby.py +4 -1
  6. dask_cuda/benchmarks/local_cudf_merge.py +5 -2
  7. dask_cuda/benchmarks/local_cudf_shuffle.py +5 -2
  8. dask_cuda/benchmarks/local_cupy.py +4 -1
  9. dask_cuda/benchmarks/local_cupy_map_overlap.py +4 -1
  10. dask_cuda/benchmarks/utils.py +7 -4
  11. dask_cuda/cli.py +21 -15
  12. dask_cuda/cuda_worker.py +27 -57
  13. dask_cuda/device_host_file.py +31 -15
  14. dask_cuda/disk_io.py +7 -4
  15. dask_cuda/explicit_comms/comms.py +11 -7
  16. dask_cuda/explicit_comms/dataframe/shuffle.py +147 -55
  17. dask_cuda/get_device_memory_objects.py +18 -3
  18. dask_cuda/initialize.py +80 -44
  19. dask_cuda/is_device_object.py +4 -1
  20. dask_cuda/is_spillable_object.py +4 -1
  21. dask_cuda/local_cuda_cluster.py +63 -66
  22. dask_cuda/plugins.py +17 -16
  23. dask_cuda/proxify_device_objects.py +15 -10
  24. dask_cuda/proxify_host_file.py +30 -27
  25. dask_cuda/proxy_object.py +20 -17
  26. dask_cuda/tests/conftest.py +41 -0
  27. dask_cuda/tests/test_dask_cuda_worker.py +114 -27
  28. dask_cuda/tests/test_dgx.py +10 -18
  29. dask_cuda/tests/test_explicit_comms.py +51 -18
  30. dask_cuda/tests/test_from_array.py +7 -5
  31. dask_cuda/tests/test_initialize.py +16 -37
  32. dask_cuda/tests/test_local_cuda_cluster.py +164 -54
  33. dask_cuda/tests/test_proxify_host_file.py +33 -4
  34. dask_cuda/tests/test_proxy.py +18 -16
  35. dask_cuda/tests/test_rdd_ucx.py +160 -0
  36. dask_cuda/tests/test_spill.py +107 -27
  37. dask_cuda/tests/test_utils.py +106 -20
  38. dask_cuda/tests/test_worker_spec.py +5 -2
  39. dask_cuda/utils.py +319 -68
  40. dask_cuda/utils_test.py +23 -7
  41. dask_cuda/worker_common.py +196 -0
  42. dask_cuda/worker_spec.py +12 -5
  43. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/METADATA +5 -4
  44. dask_cuda-25.8.0.dist-info/RECORD +63 -0
  45. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/WHEEL +1 -1
  46. dask_cuda-25.8.0.dist-info/top_level.txt +6 -0
  47. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  48. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  49. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  50. dask_cuda-25.4.0.dist-info/RECORD +0 -56
  51. dask_cuda-25.4.0.dist-info/top_level.txt +0 -5
  52. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/entry_points.txt +0 -0
  53. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/licenses/LICENSE +0 -0
dask_cuda/proxy_object.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import copy as _copy
2
5
  import functools
3
6
  import operator
@@ -52,21 +55,21 @@ def asproxy(
52
55
  serializers: Optional[Iterable[str]] = None,
53
56
  subclass: Optional[Type["ProxyObject"]] = None,
54
57
  ) -> "ProxyObject":
55
- """Wrap `obj` in a ProxyObject object if it isn't already.
58
+ """Wrap ``obj`` in a ProxyObject object if it isn't already.
56
59
 
57
60
  Parameters
58
61
  ----------
59
62
  obj: object
60
63
  Object to wrap in a ProxyObject object.
61
64
  serializers: Iterable[str], optional
62
- Serializers to use to serialize `obj`. If None, no serialization is done.
65
+ Serializers to use to serialize ``obj``. If None, no serialization is done.
63
66
  subclass: class, optional
64
67
  Specify a subclass of ProxyObject to create instead of ProxyObject.
65
- `subclass` must be pickable.
68
+ ``subclass`` must be pickable.
66
69
 
67
70
  Returns
68
71
  -------
69
- The ProxyObject proxying `obj`
72
+ The ProxyObject proxying ``obj``
70
73
  """
71
74
  if isinstance(obj, ProxyObject): # Already a proxy object
72
75
  ret = obj
@@ -119,7 +122,7 @@ def unproxy(obj):
119
122
 
120
123
  Returns
121
124
  -------
122
- The proxied object or `obj` itself if it isn't a ProxyObject
125
+ The proxied object or ``obj`` itself if it isn't a ProxyObject
123
126
  """
124
127
  try:
125
128
  obj = obj._pxy_deserialize()
@@ -185,16 +188,16 @@ class ProxyDetail:
185
188
  Dictionary of attributes that are accessible without deserializing
186
189
  the proxied object.
187
190
  type_serialized: bytes
188
- Pickled type of `obj`.
191
+ Pickled type of ``obj``.
189
192
  typename: str
190
- Name of the type of `obj`.
193
+ Name of the type of ``obj``.
191
194
  is_cuda_object: boolean
192
- Whether `obj` is a CUDA object or not.
195
+ Whether ``obj`` is a CUDA object or not.
193
196
  subclass: bytes
194
197
  Pickled type to use instead of ProxyObject when deserializing. The type
195
198
  must inherit from ProxyObject.
196
199
  serializers: str, optional
197
- Serializers to use to serialize `obj`. If None, no serialization is done.
200
+ Serializers to use to serialize ``obj``. If None, no serialization is done.
198
201
  explicit_proxy: bool
199
202
  Mark the proxy object as "explicit", which means that the user allows it
200
203
  as input argument to dask tasks even in compatibility-mode.
@@ -258,7 +261,7 @@ class ProxyDetail:
258
261
  return self.serializer is not None
259
262
 
260
263
  def serialize(self, serializers: Iterable[str]) -> Tuple[dict, list]:
261
- """Inplace serialization of the proxied object using the `serializers`
264
+ """Inplace serialization of the proxied object using the ``serializers``
262
265
 
263
266
  Parameters
264
267
  ----------
@@ -333,7 +336,7 @@ class ProxyObject:
333
336
  ProxyObject has some limitations and doesn't mimic the proxied object perfectly.
334
337
  Thus, if encountering problems remember that it is always possible to use unproxy()
335
338
  to access the proxied object directly or disable JIT deserialization completely
336
- with `jit_unspill=False`.
339
+ with ``jit_unspill=False``.
337
340
 
338
341
  Type checking using instance() works as expected but direct type checking
339
342
  doesn't:
@@ -386,7 +389,7 @@ class ProxyObject:
386
389
  serializers: Iterable[str],
387
390
  proxy_detail: Optional[ProxyDetail] = None,
388
391
  ) -> None:
389
- """Inplace serialization of the proxied object using the `serializers`
392
+ """Inplace serialization of the proxied object using the ``serializers``
390
393
 
391
394
  Parameters
392
395
  ----------
@@ -787,8 +790,8 @@ class ProxyObject:
787
790
  def obj_pxy_is_device_object(obj: ProxyObject):
788
791
  """
789
792
  In order to avoid de-serializing the proxied object,
790
- we check `is_cuda_object` instead of the default
791
- `hasattr(o, "__cuda_array_interface__")` check.
793
+ we check ``is_cuda_object`` instead of the default
794
+ ``hasattr(o, "__cuda_array_interface__")`` check.
792
795
  """
793
796
  return obj._pxy_get().is_cuda_object
794
797
 
@@ -830,7 +833,7 @@ def obj_pxy_dask_serialize(obj: ProxyObject):
830
833
 
831
834
  As serializers, it uses "dask" or "pickle", which means that proxied CUDA objects
832
835
  are spilled to main memory before communicated. Deserialization is needed, unless
833
- obj is serialized to disk on a shared filesystem see `handle_disk_serialized()`.
836
+ obj is serialized to disk on a shared filesystem see ``handle_disk_serialized()``.
834
837
  """
835
838
  pxy = obj._pxy_get(copy=True)
836
839
  if pxy.serializer == "disk":
@@ -851,7 +854,7 @@ def obj_pxy_cuda_serialize(obj: ProxyObject):
851
854
 
852
855
  As serializers, it uses "cuda", which means that proxied CUDA objects are _not_
853
856
  spilled to main memory before communicated. However, we still have to handle disk
854
- serialized proxied like in `obj_pxy_dask_serialize()`
857
+ serialized proxied like in ``obj_pxy_dask_serialize()``
855
858
  """
856
859
  pxy = obj._pxy_get(copy=True)
857
860
  if pxy.serializer in ("dask", "pickle"):
@@ -897,7 +900,7 @@ def obj_pxy_dask_deserialize(header, frames):
897
900
 
898
901
 
899
902
  def unproxify_input_wrapper(func):
900
- """Unproxify the input of `func`"""
903
+ """Unproxify the input of ``func``"""
901
904
 
902
905
  @functools.wraps(func)
903
906
  def wrapper(*args, **kwargs):
@@ -0,0 +1,41 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pytest
5
+
6
+ from dask_cuda.utils import has_device_memory_resource
7
+
8
+
9
+ def pytest_configure(config):
10
+ """Register custom markers."""
11
+ config.addinivalue_line(
12
+ "markers",
13
+ "skip_if_no_device_memory: mark test to skip if device has no dedicated memory "
14
+ "resource",
15
+ )
16
+ config.addinivalue_line(
17
+ "markers",
18
+ "skip_if_device_memory: mark test to skip if device has dedicated memory "
19
+ "resource",
20
+ )
21
+
22
+
23
+ def pytest_collection_modifyitems(items):
24
+ """Handle skip_if_no_device_memory marker."""
25
+ for item in items:
26
+ if item.get_closest_marker("skip_if_no_device_memory"):
27
+ skip_marker = item.get_closest_marker("skip_if_no_device_memory")
28
+ reason = skip_marker.kwargs.get(
29
+ "reason", "Test requires device with dedicated memory resource"
30
+ )
31
+ item.add_marker(
32
+ pytest.mark.skipif(not has_device_memory_resource(), reason=reason)
33
+ )
34
+ if item.get_closest_marker("skip_if_device_memory"):
35
+ skip_marker = item.get_closest_marker("skip_if_device_memory")
36
+ reason = skip_marker.kwargs.get(
37
+ "reason", "Test requires device without dedicated memory resource"
38
+ )
39
+ item.add_marker(
40
+ pytest.mark.skipif(has_device_memory_resource(), reason=reason)
41
+ )
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  from __future__ import absolute_import, division, print_function
2
5
 
3
6
  import os
@@ -16,15 +19,18 @@ from dask_cuda.utils import (
16
19
  get_cluster_configuration,
17
20
  get_device_total_memory,
18
21
  get_gpu_count_mig,
19
- get_gpu_uuid_from_index,
22
+ get_gpu_uuid,
20
23
  get_n_gpus,
24
+ has_device_memory_resource,
21
25
  wait_workers,
22
26
  )
23
27
 
24
28
 
25
- @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,3,7,8"})
26
- def test_cuda_visible_devices_and_memory_limit_and_nthreads(loop): # noqa: F811
27
- nthreads = 4
29
+ @patch.dict(
30
+ os.environ,
31
+ {"CUDA_VISIBLE_DEVICES": "0,3,7,8", "DASK_CUDA_TEST_DISABLE_DEVICE_SPECIFIC": "1"},
32
+ )
33
+ def test_cuda_visible_devices(loop): # noqa: F811
28
34
  with popen(["dask", "scheduler", "--port", "9359", "--no-dashboard"]):
29
35
  with popen(
30
36
  [
@@ -34,14 +40,10 @@ def test_cuda_visible_devices_and_memory_limit_and_nthreads(loop): # noqa: F811
34
40
  "127.0.0.1:9359",
35
41
  "--host",
36
42
  "127.0.0.1",
37
- "--device-memory-limit",
38
- "1 MB",
39
- "--nthreads",
40
- str(nthreads),
41
43
  "--no-dashboard",
42
44
  "--worker-class",
43
45
  "dask_cuda.utils_test.MockWorker",
44
- ]
46
+ ],
45
47
  ):
46
48
  with Client("127.0.0.1:9359", loop=loop) as client:
47
49
  assert wait_workers(client, n_gpus=4)
@@ -55,12 +57,43 @@ def test_cuda_visible_devices_and_memory_limit_and_nthreads(loop): # noqa: F811
55
57
  for v in result.values():
56
58
  del expected[v]
57
59
 
58
- workers = client.scheduler_info()["workers"]
60
+ assert len(expected) == 0
61
+
62
+
63
+ def test_memory_limit_and_nthreads(loop): # noqa: F811
64
+ nthreads = 4
65
+
66
+ device_memory_limit_args = []
67
+ if has_device_memory_resource():
68
+ device_memory_limit_args = ["--device-memory-limit", "1 MB"]
69
+
70
+ with popen(["dask", "scheduler", "--port", "9359", "--no-dashboard"]):
71
+ with popen(
72
+ [
73
+ "dask",
74
+ "cuda",
75
+ "worker",
76
+ "127.0.0.1:9359",
77
+ "--host",
78
+ "127.0.0.1",
79
+ *device_memory_limit_args,
80
+ "--nthreads",
81
+ str(nthreads),
82
+ "--no-dashboard",
83
+ "--worker-class",
84
+ "dask_cuda.utils_test.MockWorker",
85
+ ],
86
+ ):
87
+ with Client("127.0.0.1:9359", loop=loop) as client:
88
+ assert wait_workers(client, n_gpus=get_n_gpus())
89
+
90
+ def get_visible_devices():
91
+ return os.environ["CUDA_VISIBLE_DEVICES"]
92
+
93
+ workers = client.scheduler_info(n_workers=-1)["workers"]
59
94
  for w in workers.values():
60
95
  assert w["memory_limit"] == MEMORY_LIMIT // len(workers)
61
96
 
62
- assert len(expected) == 0
63
-
64
97
 
65
98
  def test_rmm_pool(loop): # noqa: F811
66
99
  rmm = pytest.importorskip("rmm")
@@ -116,11 +149,6 @@ def test_rmm_managed(loop): # noqa: F811
116
149
  def test_rmm_async(loop): # noqa: F811
117
150
  rmm = pytest.importorskip("rmm")
118
151
 
119
- driver_version = rmm._cuda.gpu.driverGetVersion()
120
- runtime_version = rmm._cuda.gpu.runtimeGetVersion()
121
- if driver_version < 11020 or runtime_version < 11020:
122
- pytest.skip("cudaMallocAsync not supported")
123
-
124
152
  with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]):
125
153
  with popen(
126
154
  [
@@ -156,11 +184,6 @@ def test_rmm_async(loop): # noqa: F811
156
184
  def test_rmm_async_with_maximum_pool_size(loop): # noqa: F811
157
185
  rmm = pytest.importorskip("rmm")
158
186
 
159
- driver_version = rmm._cuda.gpu.driverGetVersion()
160
- runtime_version = rmm._cuda.gpu.runtimeGetVersion()
161
- if driver_version < 11020 or runtime_version < 11020:
162
- pytest.skip("cudaMallocAsync not supported")
163
-
164
187
  with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]):
165
188
  with popen(
166
189
  [
@@ -260,8 +283,12 @@ def test_cudf_spill_disabled(loop): # noqa: F811
260
283
  assert v == 0
261
284
 
262
285
 
286
+ @pytest.mark.skip_if_no_device_memory(
287
+ "Devices without dedicated memory resources cannot enable cuDF spill"
288
+ )
263
289
  def test_cudf_spill(loop): # noqa: F811
264
290
  cudf = pytest.importorskip("cudf")
291
+
265
292
  with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]):
266
293
  with popen(
267
294
  [
@@ -289,6 +316,24 @@ def test_cudf_spill(loop): # noqa: F811
289
316
  assert v == 2
290
317
 
291
318
 
319
+ @pytest.mark.skip_if_device_memory(
320
+ "Devices with dedicated memory resources cannot test error"
321
+ )
322
+ def test_cudf_spill_no_dedicated_memory_error():
323
+ pytest.importorskip("cudf")
324
+
325
+ ret = subprocess.run(
326
+ ["dask", "cuda", "worker", "127.0.0.1:9369", "--enable-cudf-spill"],
327
+ capture_output=True,
328
+ )
329
+
330
+ assert ret.returncode != 0
331
+ assert (
332
+ b"cuDF spilling is not supported on devices without dedicated memory"
333
+ in ret.stderr
334
+ )
335
+
336
+
292
337
  @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
293
338
  def test_dashboard_address(loop): # noqa: F811
294
339
  with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]):
@@ -409,7 +454,7 @@ def test_cuda_mig_visible_devices_and_memory_limit_and_nthreads(loop): # noqa:
409
454
 
410
455
 
411
456
  def test_cuda_visible_devices_uuid(loop): # noqa: F811
412
- gpu_uuid = get_gpu_uuid_from_index(0)
457
+ gpu_uuid = get_gpu_uuid(0)
413
458
 
414
459
  with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": gpu_uuid}):
415
460
  with popen(["dask", "scheduler", "--port", "9359", "--no-dashboard"]):
@@ -469,6 +514,11 @@ def test_rmm_track_allocations(loop): # noqa: F811
469
514
  @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
470
515
  def test_get_cluster_configuration(loop): # noqa: F811
471
516
  pytest.importorskip("rmm")
517
+
518
+ device_memory_limit_args = []
519
+ if has_device_memory_resource():
520
+ device_memory_limit_args += ["--device-memory-limit", "30 B"]
521
+
472
522
  with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]):
473
523
  with popen(
474
524
  [
@@ -478,8 +528,7 @@ def test_get_cluster_configuration(loop): # noqa: F811
478
528
  "127.0.0.1:9369",
479
529
  "--host",
480
530
  "127.0.0.1",
481
- "--device-memory-limit",
482
- "30 B",
531
+ *device_memory_limit_args,
483
532
  "--rmm-pool-size",
484
533
  "2 GB",
485
534
  "--rmm-maximum-pool-size",
@@ -496,12 +545,17 @@ def test_get_cluster_configuration(loop): # noqa: F811
496
545
  assert ret["[plugin] RMMSetup"]["initial_pool_size"] == 2000000000
497
546
  assert ret["[plugin] RMMSetup"]["maximum_pool_size"] == 3000000000
498
547
  assert ret["jit-unspill"] is False
499
- assert ret["device-memory-limit"] == 30
548
+ if has_device_memory_resource():
549
+ assert ret["device-memory-limit"] == 30
500
550
 
501
551
 
502
552
  @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
553
+ @pytest.mark.skip_if_no_device_memory(
554
+ "Devices without dedicated memory resources do not support fractional limits"
555
+ )
503
556
  def test_worker_fraction_limits(loop): # noqa: F811
504
557
  pytest.importorskip("rmm")
558
+
505
559
  with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]):
506
560
  with popen(
507
561
  [
@@ -542,6 +596,33 @@ def test_worker_fraction_limits(loop): # noqa: F811
542
596
  )
543
597
 
544
598
 
599
+ @pytest.mark.parametrize(
600
+ "argument", ["pool_size", "maximum_pool_size", "release_threshold"]
601
+ )
602
+ @pytest.mark.skip_if_device_memory(
603
+ "Devices with dedicated memory resources cannot test error"
604
+ )
605
+ def test_worker_fraction_limits_no_dedicated_memory(argument):
606
+ if argument == "pool_size":
607
+ argument_list = ["--rmm-pool-size", "0.1"]
608
+ elif argument == "maximum_pool_size":
609
+ argument_list = ["--rmm-pool-size", "1 GiB", "--rmm-maximum-pool-size", "0.1"]
610
+ else:
611
+ argument_list = ["--rmm-async", "--rmm-release-threshold", "0.1"]
612
+
613
+ with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]):
614
+ ret = subprocess.run(
615
+ ["dask", "cuda", "worker", "127.0.0.1:9369", *argument_list],
616
+ capture_output=True,
617
+ )
618
+
619
+ assert ret.returncode != 0
620
+ assert (
621
+ b"Fractional of total device memory not supported in devices without a "
622
+ b"dedicated memory resource" in ret.stderr
623
+ )
624
+
625
+
545
626
  @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
546
627
  def test_worker_timeout():
547
628
  ret = subprocess.run(
@@ -592,6 +673,12 @@ def test_worker_cudf_spill_warning(enable_cudf_spill_warning): # noqa: F811
592
673
  capture_output=True,
593
674
  )
594
675
  if enable_cudf_spill_warning:
595
- assert b"UserWarning: cuDF spilling is enabled" in ret.stderr
676
+ if has_device_memory_resource():
677
+ assert b"UserWarning: cuDF spilling is enabled" in ret.stderr
678
+ else:
679
+ assert (
680
+ b"cuDF spilling is not supported on devices without dedicated "
681
+ b"memory" in ret.stderr
682
+ )
596
683
  else:
597
684
  assert b"UserWarning: cuDF spilling is enabled" not in ret.stderr
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import multiprocessing as mp
2
5
  import os
3
6
  from enum import Enum, auto
@@ -10,6 +13,7 @@ from distributed import Client
10
13
 
11
14
  from dask_cuda import LocalCUDACluster
12
15
  from dask_cuda.initialize import initialize
16
+ from dask_cuda.utils_test import get_ucx_implementation
13
17
 
14
18
  mp = mp.get_context("spawn") # type: ignore
15
19
  psutil = pytest.importorskip("psutil")
@@ -78,10 +82,7 @@ def test_default():
78
82
 
79
83
 
80
84
  def _test_tcp_over_ucx(protocol):
81
- if protocol == "ucx":
82
- ucp = pytest.importorskip("ucp")
83
- elif protocol == "ucxx":
84
- ucp = pytest.importorskip("ucxx")
85
+ ucp = get_ucx_implementation(protocol)
85
86
 
86
87
  with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster:
87
88
  with Client(cluster) as client:
@@ -102,13 +103,10 @@ def _test_tcp_over_ucx(protocol):
102
103
 
103
104
  @pytest.mark.parametrize(
104
105
  "protocol",
105
- ["ucx", "ucxx"],
106
+ ["ucx", "ucx-old"],
106
107
  )
107
108
  def test_tcp_over_ucx(protocol):
108
- if protocol == "ucx":
109
- ucp = pytest.importorskip("ucp")
110
- elif protocol == "ucxx":
111
- ucp = pytest.importorskip("ucxx")
109
+ ucp = get_ucx_implementation(protocol)
112
110
  if _is_ucx_116(ucp):
113
111
  pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
114
112
 
@@ -137,10 +135,7 @@ def _test_ucx_infiniband_nvlink(
137
135
  skip_queue, protocol, enable_infiniband, enable_nvlink, enable_rdmacm
138
136
  ):
139
137
  cupy = pytest.importorskip("cupy")
140
- if protocol == "ucx":
141
- ucp = pytest.importorskip("ucp")
142
- elif protocol == "ucxx":
143
- ucp = pytest.importorskip("ucxx")
138
+ ucp = get_ucx_implementation(protocol)
144
139
 
145
140
  if enable_infiniband and not any(
146
141
  [at.startswith("rc") for at in ucp.get_active_transports()]
@@ -206,7 +201,7 @@ def _test_ucx_infiniband_nvlink(
206
201
  assert all(client.run(check_ucx_options).values())
207
202
 
208
203
 
209
- @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
204
+ @pytest.mark.parametrize("protocol", ["ucx", "ucx-old"])
210
205
  @pytest.mark.parametrize(
211
206
  "params",
212
207
  [
@@ -222,10 +217,7 @@ def _test_ucx_infiniband_nvlink(
222
217
  reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
223
218
  )
224
219
  def test_ucx_infiniband_nvlink(protocol, params):
225
- if protocol == "ucx":
226
- ucp = pytest.importorskip("ucp")
227
- elif protocol == "ucxx":
228
- ucp = pytest.importorskip("ucxx")
220
+ ucp = get_ucx_implementation(protocol)
229
221
  if _is_ucx_116(ucp) and params["enable_infiniband"] is False:
230
222
  pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
231
223
 
@@ -1,4 +1,5 @@
1
- # Copyright (c) 2021-2025 NVIDIA CORPORATION.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
2
3
 
3
4
  import asyncio
4
5
  import multiprocessing as mp
@@ -21,18 +22,16 @@ from distributed.deploy.local import LocalCluster
21
22
 
22
23
  import dask_cuda
23
24
  from dask_cuda.explicit_comms import comms
24
- from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
25
- from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
25
+ from dask_cuda.explicit_comms.dataframe.shuffle import (
26
+ _contains_shuffle_expr,
27
+ shuffle as explicit_comms_shuffle,
28
+ )
29
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny, get_ucx_implementation
26
30
 
27
31
  mp = mp.get_context("spawn") # type: ignore
28
32
  ucp = pytest.importorskip("ucp")
29
33
 
30
34
 
31
- # Set default shuffle method to "tasks"
32
- if dask.config.get("dataframe.shuffle.method", None) is None:
33
- dask.config.set({"dataframe.shuffle.method": "tasks"})
34
-
35
-
36
35
  # Notice, all of the following tests is executed in a new process such
37
36
  # that UCX options of the different tests doesn't conflict.
38
37
 
@@ -55,8 +54,10 @@ def _test_local_cluster(protocol):
55
54
  assert sum(c.run(my_rank, 0)) == sum(range(4))
56
55
 
57
56
 
58
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
57
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
59
58
  def test_local_cluster(protocol):
59
+ if protocol.startswith("ucx"):
60
+ get_ucx_implementation(protocol)
60
61
  p = mp.Process(target=_test_local_cluster, args=(protocol,))
61
62
  p.start()
62
63
  p.join()
@@ -99,7 +100,7 @@ def test_dataframe_merge_empty_partitions():
99
100
 
100
101
 
101
102
  def check_partitions(df, npartitions):
102
- """Check that all values in `df` hashes to the same"""
103
+ """Check that all values in ``df`` hashes to the same"""
103
104
  dtypes = {}
104
105
  for col, dtype in df.dtypes.items():
105
106
  if pd.api.types.is_numeric_dtype(dtype):
@@ -201,11 +202,13 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
201
202
 
202
203
  @pytest.mark.parametrize("nworkers", [1, 2, 3])
203
204
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
204
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
205
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
205
206
  @pytest.mark.parametrize("_partitions", [True, False])
206
207
  def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
207
208
  if backend == "cudf":
208
209
  pytest.importorskip("cudf")
210
+ if protocol.startswith("ucx"):
211
+ get_ucx_implementation(protocol)
209
212
 
210
213
  p = mp.Process(
211
214
  target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
@@ -322,10 +325,12 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
322
325
 
323
326
  @pytest.mark.parametrize("nworkers", [1, 2, 4])
324
327
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
325
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
328
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
326
329
  def test_dataframe_shuffle_merge(backend, protocol, nworkers):
327
330
  if backend == "cudf":
328
331
  pytest.importorskip("cudf")
332
+ if protocol.startswith("ucx"):
333
+ get_ucx_implementation(protocol)
329
334
  p = mp.Process(
330
335
  target=_test_dataframe_shuffle_merge, args=(backend, protocol, nworkers)
331
336
  )
@@ -359,9 +364,14 @@ def _test_jit_unspill(protocol):
359
364
  assert_eq(got, expected)
360
365
 
361
366
 
362
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
367
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucx-old"])
368
+ @pytest.mark.skip_if_no_device_memory(
369
+ "JIT-Unspill not supported in devices without dedicated memory resource"
370
+ )
363
371
  def test_jit_unspill(protocol):
364
372
  pytest.importorskip("cudf")
373
+ if protocol.startswith("ucx"):
374
+ get_ucx_implementation(protocol)
365
375
 
366
376
  p = mp.Process(target=_test_jit_unspill, args=(protocol,))
367
377
  p.start()
@@ -386,7 +396,7 @@ def _test_lock_workers(scheduler_address, ranks):
386
396
 
387
397
  def test_lock_workers():
388
398
  """
389
- Testing `run(...,lock_workers=True)` by spawning 30 runs with overlapping
399
+ Testing ``run(...,lock_workers=True)`` by spawning 30 runs with overlapping
390
400
  and non-overlapping worker sets.
391
401
  """
392
402
  try:
@@ -425,7 +435,9 @@ def test_create_destroy_create():
425
435
  with LocalCluster(n_workers=1) as cluster:
426
436
  with Client(cluster) as client:
427
437
  context = comms.default_comms()
428
- scheduler_addresses_old = list(client.scheduler_info()["workers"].keys())
438
+ scheduler_addresses_old = list(
439
+ client.scheduler_info(n_workers=-1)["workers"].keys()
440
+ )
429
441
  comms_addresses_old = list(comms.default_comms().worker_addresses)
430
442
  assert comms.default_comms() is context
431
443
  assert len(comms._comms_cache) == 1
@@ -446,7 +458,9 @@ def test_create_destroy_create():
446
458
  # because we referenced the old cluster's addresses.
447
459
  with LocalCluster(n_workers=1) as cluster:
448
460
  with Client(cluster) as client:
449
- scheduler_addresses_new = list(client.scheduler_info()["workers"].keys())
461
+ scheduler_addresses_new = list(
462
+ client.scheduler_info(n_workers=-1)["workers"].keys()
463
+ )
450
464
  comms_addresses_new = list(comms.default_comms().worker_addresses)
451
465
 
452
466
  assert scheduler_addresses_new == comms_addresses_new
@@ -487,7 +501,8 @@ def test_scaled_cluster_gets_new_comms_context():
487
501
  "n_workers": 2,
488
502
  }
489
503
  expected_1 = {
490
- k: expected_values for k in client.scheduler_info()["workers"]
504
+ k: expected_values
505
+ for k in client.scheduler_info(n_workers=-1)["workers"]
491
506
  }
492
507
  assert result_1 == expected_1
493
508
 
@@ -515,7 +530,8 @@ def test_scaled_cluster_gets_new_comms_context():
515
530
  "n_workers": 3,
516
531
  }
517
532
  expected_2 = {
518
- k: expected_values for k in client.scheduler_info()["workers"]
533
+ k: expected_values
534
+ for k in client.scheduler_info(n_workers=-1)["workers"]
519
535
  }
520
536
  assert result_2 == expected_2
521
537
 
@@ -530,3 +546,20 @@ def test_scaled_cluster_gets_new_comms_context():
530
546
  expected = shuffled.compute()
531
547
 
532
548
  assert_eq(result, expected)
549
+
550
+
551
+ def test_contains_shuffle_expr():
552
+ df = dd.from_pandas(pd.DataFrame({"key": np.arange(10)}), npartitions=2)
553
+ assert not _contains_shuffle_expr(df)
554
+
555
+ with dask.config.set(explicit_comms=True):
556
+ shuffled = df.shuffle(on="key")
557
+
558
+ assert _contains_shuffle_expr(shuffled)
559
+ assert not _contains_shuffle_expr(df)
560
+
561
+ # this requires an active client.
562
+ with LocalCluster(n_workers=1) as cluster:
563
+ with Client(cluster):
564
+ explict_shuffled = explicit_comms_shuffle(df, ["key"])
565
+ assert not _contains_shuffle_expr(explict_shuffled)
@@ -1,19 +1,21 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import pytest
2
5
 
3
6
  import dask.array as da
4
7
  from distributed import Client
5
8
 
6
9
  from dask_cuda import LocalCUDACluster
10
+ from dask_cuda.utils_test import get_ucx_implementation
7
11
 
8
12
  cupy = pytest.importorskip("cupy")
9
13
 
10
14
 
11
- @pytest.mark.parametrize("protocol", ["ucx", "ucxx", "tcp"])
15
+ @pytest.mark.parametrize("protocol", ["ucx", "ucx-old", "tcp"])
12
16
  def test_ucx_from_array(protocol):
13
- if protocol == "ucx":
14
- pytest.importorskip("ucp")
15
- elif protocol == "ucxx":
16
- pytest.importorskip("ucxx")
17
+ if protocol.startswith("ucx"):
18
+ get_ucx_implementation(protocol)
17
19
 
18
20
  N = 10_000
19
21
  with LocalCUDACluster(protocol=protocol) as cluster: