dask-cuda 25.4.0__py3-none-any.whl → 25.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/_compat.py +18 -0
  4. dask_cuda/benchmarks/common.py +4 -1
  5. dask_cuda/benchmarks/local_cudf_groupby.py +4 -1
  6. dask_cuda/benchmarks/local_cudf_merge.py +5 -2
  7. dask_cuda/benchmarks/local_cudf_shuffle.py +5 -2
  8. dask_cuda/benchmarks/local_cupy.py +4 -1
  9. dask_cuda/benchmarks/local_cupy_map_overlap.py +4 -1
  10. dask_cuda/benchmarks/utils.py +7 -4
  11. dask_cuda/cli.py +21 -15
  12. dask_cuda/cuda_worker.py +27 -57
  13. dask_cuda/device_host_file.py +31 -15
  14. dask_cuda/disk_io.py +7 -4
  15. dask_cuda/explicit_comms/comms.py +11 -7
  16. dask_cuda/explicit_comms/dataframe/shuffle.py +147 -55
  17. dask_cuda/get_device_memory_objects.py +18 -3
  18. dask_cuda/initialize.py +80 -44
  19. dask_cuda/is_device_object.py +4 -1
  20. dask_cuda/is_spillable_object.py +4 -1
  21. dask_cuda/local_cuda_cluster.py +63 -66
  22. dask_cuda/plugins.py +17 -16
  23. dask_cuda/proxify_device_objects.py +15 -10
  24. dask_cuda/proxify_host_file.py +30 -27
  25. dask_cuda/proxy_object.py +20 -17
  26. dask_cuda/tests/conftest.py +41 -0
  27. dask_cuda/tests/test_dask_cuda_worker.py +114 -27
  28. dask_cuda/tests/test_dgx.py +10 -18
  29. dask_cuda/tests/test_explicit_comms.py +51 -18
  30. dask_cuda/tests/test_from_array.py +7 -5
  31. dask_cuda/tests/test_initialize.py +16 -37
  32. dask_cuda/tests/test_local_cuda_cluster.py +164 -54
  33. dask_cuda/tests/test_proxify_host_file.py +33 -4
  34. dask_cuda/tests/test_proxy.py +18 -16
  35. dask_cuda/tests/test_rdd_ucx.py +160 -0
  36. dask_cuda/tests/test_spill.py +107 -27
  37. dask_cuda/tests/test_utils.py +106 -20
  38. dask_cuda/tests/test_worker_spec.py +5 -2
  39. dask_cuda/utils.py +319 -68
  40. dask_cuda/utils_test.py +23 -7
  41. dask_cuda/worker_common.py +196 -0
  42. dask_cuda/worker_spec.py +12 -5
  43. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/METADATA +5 -4
  44. dask_cuda-25.8.0.dist-info/RECORD +63 -0
  45. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/WHEEL +1 -1
  46. dask_cuda-25.8.0.dist-info/top_level.txt +6 -0
  47. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  48. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  49. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  50. dask_cuda-25.4.0.dist-info/RECORD +0 -56
  51. dask_cuda-25.4.0.dist-info/top_level.txt +0 -5
  52. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/entry_points.txt +0 -0
  53. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/licenses/LICENSE +0 -0
dask_cuda/utils.py CHANGED
@@ -1,3 +1,7 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import importlib
1
5
  import math
2
6
  import operator
3
7
  import os
@@ -40,7 +44,7 @@ def unpack_bitmask(x, mask_bits=64):
40
44
  x: list of int
41
45
  A list of integers
42
46
  mask_bits: int
43
- An integer determining the bitwidth of `x`
47
+ An integer determining the bitwidth of ``x``
44
48
 
45
49
  Examples
46
50
  --------
@@ -86,6 +90,45 @@ def get_gpu_count():
86
90
  return pynvml.nvmlDeviceGetCount()
87
91
 
88
92
 
93
+ def get_gpu_handle(device_id=0):
94
+ """Get GPU handle from device index or UUID.
95
+
96
+ Parameters
97
+ ----------
98
+ device_id: int or str
99
+ The index or UUID of the device from which to obtain the handle.
100
+
101
+ Raises
102
+ ------
103
+ ValueError
104
+ If acquiring the device handle for the device specified failed.
105
+ pynvml.NVMLError
106
+ If any NVML error occurred while initializing.
107
+
108
+ Examples
109
+ --------
110
+ >>> get_gpu_handle(device_id=0)
111
+
112
+ >>> get_gpu_handle(device_id="GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
113
+ """
114
+ pynvml.nvmlInit()
115
+
116
+ try:
117
+ if device_id and not str(device_id).isnumeric():
118
+ # This means device_id is UUID.
119
+ # This works for both MIG and non-MIG device UUIDs.
120
+ handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_id))
121
+ if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
122
+ # Additionally get parent device handle
123
+ # if the device itself is a MIG instance
124
+ handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
125
+ else:
126
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
127
+ return handle
128
+ except pynvml.NVMLError:
129
+ raise ValueError(f"Invalid device index or UUID: {device_id}")
130
+
131
+
89
132
  @toolz.memoize
90
133
  def get_gpu_count_mig(return_uuids=False):
91
134
  """Return the number of MIG instances available
@@ -129,7 +172,7 @@ def get_cpu_affinity(device_index=None):
129
172
  Parameters
130
173
  ----------
131
174
  device_index: int or str
132
- Index or UUID of the GPU device
175
+ The index or UUID of the device from which to obtain the CPU affinity.
133
176
 
134
177
  Examples
135
178
  --------
@@ -148,26 +191,15 @@ def get_cpu_affinity(device_index=None):
148
191
  40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
149
192
  60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
150
193
  """
151
- pynvml.nvmlInit()
152
-
153
194
  try:
154
- if device_index and not str(device_index).isnumeric():
155
- # This means device_index is UUID.
156
- # This works for both MIG and non-MIG device UUIDs.
157
- handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_index))
158
- if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
159
- # Additionally get parent device handle
160
- # if the device itself is a MIG instance
161
- handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
162
- else:
163
- handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
195
+ handle = get_gpu_handle(device_index)
164
196
  # Result is a list of 64-bit integers, thus ceil(get_cpu_count() / 64)
165
197
  affinity = pynvml.nvmlDeviceGetCpuAffinity(
166
198
  handle,
167
199
  math.ceil(get_cpu_count() / 64),
168
200
  )
169
201
  return unpack_bitmask(affinity)
170
- except pynvml.NVMLError:
202
+ except (pynvml.NVMLError, ValueError):
171
203
  warnings.warn(
172
204
  "Cannot get CPU affinity for device with index %d, setting default affinity"
173
205
  % device_index
@@ -182,19 +214,51 @@ def get_n_gpus():
182
214
  return get_gpu_count()
183
215
 
184
216
 
185
- def get_device_total_memory(index=0):
217
+ def get_device_total_memory(device_index=0):
218
+ """Return total memory of CUDA device with index or with device identifier UUID.
219
+
220
+ Parameters
221
+ ----------
222
+ device_index: int or str
223
+ The index or UUID of the device from which to obtain the CPU affinity.
224
+
225
+ Returns
226
+ -------
227
+ The total memory of the CUDA Device in bytes, or ``None`` for devices that do not
228
+ have a dedicated memory resource, as is usually the case for system on a chip (SoC)
229
+ devices.
186
230
  """
187
- Return total memory of CUDA device with index or with device identifier UUID
231
+ handle = get_gpu_handle(device_index)
232
+
233
+ try:
234
+ return pynvml.nvmlDeviceGetMemoryInfo(handle).total
235
+ except pynvml.NVMLError_NotSupported:
236
+ return None
237
+
238
+
239
+ def has_device_memory_resource(device_index=0):
240
+ """Determine wheter CUDA device has dedicated memory resource.
241
+
242
+ Certain devices have no dedicated memory resource, such as system on a chip (SoC)
243
+ devices.
244
+
245
+ Parameters
246
+ ----------
247
+ device_index: int or str
248
+ The index or UUID of the device from which to obtain the CPU affinity.
249
+
250
+ Returns
251
+ -------
252
+ Whether the device has a dedicated memory resource.
188
253
  """
189
- pynvml.nvmlInit()
254
+ handle = get_gpu_handle(device_index)
190
255
 
191
- if index and not str(index).isnumeric():
192
- # This means index is UUID. This works for both MIG and non-MIG device UUIDs.
193
- handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(str(index)))
256
+ try:
257
+ pynvml.nvmlDeviceGetMemoryInfo(handle).total
258
+ except pynvml.NVMLError_NotSupported:
259
+ return False
194
260
  else:
195
- # This is a device index
196
- handle = pynvml.nvmlDeviceGetHandleByIndex(index)
197
- return pynvml.nvmlDeviceGetMemoryInfo(handle).total
261
+ return True
198
262
 
199
263
 
200
264
  def get_ucx_config(
@@ -202,9 +266,15 @@ def get_ucx_config(
202
266
  enable_infiniband=None,
203
267
  enable_nvlink=None,
204
268
  enable_rdmacm=None,
269
+ protocol=None,
205
270
  ):
206
271
  ucx_config = dask.config.get("distributed.comm.ucx")
207
272
 
273
+ # TODO: remove along with `protocol` kwarg when UCX-Py is removed, see
274
+ # https://github.com/rapidsai/dask-cuda/issues/1517
275
+ if protocol in ("ucx", "ucxx", "ucx-old"):
276
+ ucx_config[canonical_name("ucx-protocol", ucx_config)] = protocol
277
+
208
278
  ucx_config[canonical_name("create-cuda-context", ucx_config)] = True
209
279
  ucx_config[canonical_name("reuse-endpoints", ucx_config)] = False
210
280
 
@@ -288,7 +358,11 @@ def get_preload_options(
288
358
  if create_cuda_context:
289
359
  preload_options["preload_argv"].append("--create-cuda-context")
290
360
 
291
- if protocol in ["ucx", "ucxx"]:
361
+ try:
362
+ _get_active_ucx_implementation_name(protocol)
363
+ except ValueError:
364
+ pass
365
+ else:
292
366
  initialize_ucx_argv = []
293
367
  if enable_tcp_over_ucx:
294
368
  initialize_ucx_argv.append("--enable-tcp-over-ucx")
@@ -337,7 +411,7 @@ def wait_workers(
337
411
  Instance of client, used to query for number of workers connected.
338
412
  min_timeout: float
339
413
  Minimum number of seconds to wait before timeout. This value may be
340
- overridden by setting the `DASK_CUDA_WAIT_WORKERS_MIN_TIMEOUT` with
414
+ overridden by setting the ``DASK_CUDA_WAIT_WORKERS_MIN_TIMEOUT`` with
341
415
  a positive integer.
342
416
  seconds_per_gpu: float
343
417
  Seconds to wait for each GPU on the system. For example, if its
@@ -346,7 +420,7 @@ def wait_workers(
346
420
  used as timeout when larger than min_timeout.
347
421
  n_gpus: None or int
348
422
  If specified, will wait for a that amount of GPUs (i.e., Dask workers)
349
- to come online, else waits for a total of `get_n_gpus` workers.
423
+ to come online, else waits for a total of ``get_n_gpus`` workers.
350
424
  timeout_callback: None or callable
351
425
  A callback function to be executed if a timeout occurs, ignored if
352
426
  None.
@@ -362,7 +436,7 @@ def wait_workers(
362
436
 
363
437
  start = time.time()
364
438
  while True:
365
- if len(client.scheduler_info()["workers"]) == n_gpus:
439
+ if len(client.scheduler_info(n_workers=-1)["workers"]) == n_gpus:
366
440
  return True
367
441
  elif time.time() - start > timeout:
368
442
  if callable(timeout_callback):
@@ -376,7 +450,7 @@ async def _all_to_all(client):
376
450
  """
377
451
  Trigger all to all communication between workers and scheduler
378
452
  """
379
- workers = list(client.scheduler_info()["workers"])
453
+ workers = list(client.scheduler_info(n_workers=-1)["workers"])
380
454
  futs = []
381
455
  for w in workers:
382
456
  bit_of_data = b"0" * 1
@@ -465,8 +539,8 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES):
465
539
  """Get the device index for NVML addressing
466
540
 
467
541
  NVML expects the index of the physical device, unlike CUDA runtime which
468
- expects the address relative to `CUDA_VISIBLE_DEVICES`. This function
469
- returns the i-th device index from the `CUDA_VISIBLE_DEVICES`
542
+ expects the address relative to ``CUDA_VISIBLE_DEVICES``. This function
543
+ returns the i-th device index from the ``CUDA_VISIBLE_DEVICES``
470
544
  comma-separated string of devices or list.
471
545
 
472
546
  Examples
@@ -504,15 +578,125 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES):
504
578
  raise ValueError("`CUDA_VISIBLE_DEVICES` must be `str` or `list`")
505
579
 
506
580
 
581
+ def parse_device_bytes(device_bytes, device_index=0, alignment_size=1):
582
+ """Parse bytes relative to a specific CUDA device.
583
+
584
+ Parameters
585
+ ----------
586
+ device_bytes: float, int, str or None
587
+ Can be an integer (bytes), float (fraction of total device memory), string
588
+ (like ``"5GB"`` or ``"5000M"``), ``0`` and ``None`` are special cases
589
+ returning ``None``.
590
+ device_index: int or str
591
+ The index or UUID of the device from which to obtain the total memory amount.
592
+ Default: 0.
593
+ alignment_size: int
594
+ Number of bytes of alignment to use, i.e., allocation must be a multiple of
595
+ that size. RMM pool requires 256 bytes alignment.
596
+
597
+ Returns
598
+ -------
599
+ The parsed bytes value relative to the CUDA devices, or ``None`` as convenience if
600
+ ``device_bytes`` is ``None`` or any value that would evaluate to ``0``.
601
+
602
+ Examples
603
+ --------
604
+ >>> # On a 32GB CUDA device
605
+ >>> parse_device_bytes(None)
606
+ None
607
+ >>> parse_device_bytes(0)
608
+ None
609
+ >>> parse_device_bytes(0.0)
610
+ None
611
+ >>> parse_device_bytes("0 MiB")
612
+ None
613
+ >>> parse_device_bytes(1.0)
614
+ 34089730048
615
+ >>> parse_device_bytes(0.8)
616
+ 27271784038
617
+ >>> parse_device_bytes(1000000000)
618
+ 1000000000
619
+ >>> parse_device_bytes("1GB")
620
+ 1000000000
621
+ >>> parse_device_bytes("1GB")
622
+ 1000000000
623
+ """
624
+
625
+ def _align(size, alignment_size):
626
+ return size // alignment_size * alignment_size
627
+
628
+ def parse_fractional(v):
629
+ """Parse fractional value.
630
+
631
+ Ensures ``int(1)`` and ``str("1")`` are not treated as fractionals, but
632
+ ``float(1)`` is.
633
+
634
+ Fractionals must be represented as a ``float`` within the range
635
+ ``0.0 < v <= 1.0``.
636
+
637
+ Parameters
638
+ ----------
639
+ v: int, float or str
640
+ The value to check if fractional.
641
+
642
+ Returns
643
+ -------
644
+ """
645
+ # Check if `x` matches exactly `int(1)` or `str("1")`, and is not a `float(1)`
646
+ is_one = lambda x: not isinstance(x, float) and (x == 1 or x == "1")
647
+
648
+ if not is_one(v):
649
+ with suppress(ValueError, TypeError):
650
+ v = float(v)
651
+ if 0.0 < v <= 1.0:
652
+ return v
653
+
654
+ raise ValueError("The value is not fractional")
655
+
656
+ # Special case for fractional limit. This comes before `0` special cases because
657
+ # the `float` may be passed in a `str`, e.g., from `CUDAWorker`.
658
+ try:
659
+ fractional_device_bytes = parse_fractional(device_bytes)
660
+ except ValueError:
661
+ pass
662
+ else:
663
+ if not has_device_memory_resource():
664
+ raise ValueError(
665
+ "Fractional of total device memory not supported in devices without "
666
+ "a dedicated memory resource."
667
+ )
668
+ return _align(
669
+ int(get_device_total_memory(device_index) * fractional_device_bytes),
670
+ alignment_size,
671
+ )
672
+
673
+ # Special cases that evaluates to `None` or `0`
674
+ if device_bytes is None:
675
+ return None
676
+ elif device_bytes == 0.0:
677
+ return None
678
+ elif not isinstance(device_bytes, float) and parse_bytes(device_bytes) == 0:
679
+ return None
680
+
681
+ if isinstance(device_bytes, str):
682
+ return _align(parse_bytes(device_bytes), alignment_size)
683
+ else:
684
+ return _align(int(device_bytes), alignment_size)
685
+
686
+
507
687
  def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_size=1):
508
688
  """Parse memory limit to be used by a CUDA device.
509
689
 
510
690
  Parameters
511
691
  ----------
512
692
  device_memory_limit: float, int, str or None
513
- This can be a float (fraction of total device memory), an integer (bytes),
514
- a string (like 5GB or 5000M), and "auto", 0 or None for the total device
515
- size.
693
+ Can be an integer (bytes), float (fraction of total device memory), string
694
+ (like ``"5GB"`` or ``"5000M"``), ``"auto"``, ``0`` or ``None`` to disable
695
+ spilling to host (i.e. allow full device memory usage). Another special value
696
+ ``"default"`` is also available and returns the recommended Dask-CUDA's defaults
697
+ and means 80% of the total device memory (analogous to ``0.8``), and disabled
698
+ spilling (analogous to ``auto``/``0``/``None``) on devices without a dedicated
699
+ memory resource, such as system on a chip (SoC) devices.
516
700
  device_index: int or str
517
701
  The index or UUID of the device from which to obtain the total memory amount.
518
702
  Default: 0.
@@ -520,10 +704,23 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
520
704
  Number of bytes of alignment to use, i.e., allocation must be a multiple of
521
705
  that size. RMM pool requires 256 bytes alignment.
522
706
 
707
+ Returns
708
+ -------
709
+ The parsed memory limit in bytes, or ``None`` as convenience if
710
+ ``device_memory_limit`` is ``None`` or any value that would evaluate to ``0``.
711
+
523
712
  Examples
524
713
  --------
525
714
  >>> # On a 32GB CUDA device
526
715
  >>> parse_device_memory_limit(None)
716
+ None
717
+ >>> parse_device_memory_limit(0)
718
+ None
719
+ >>> parse_device_memory_limit(0.0)
720
+ None
721
+ >>> parse_device_memory_limit("0 MiB")
722
+ None
723
+ >>> parse_device_memory_limit(1.0)
527
724
  34089730048
528
725
  >>> parse_device_memory_limit(0.8)
529
726
  27271784038
@@ -531,48 +728,58 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
531
728
  1000000000
532
729
  >>> parse_device_memory_limit("1GB")
533
730
  1000000000
731
+ >>> parse_device_memory_limit("1GB")
732
+ 1000000000
733
+ >>> parse_device_memory_limit("auto") == (
734
+ ... parse_device_memory_limit(1.0)
735
+ ... if has_device_memory_resource()
736
+ ... else None
737
+ ... )
738
+ True
739
+ >>> parse_device_memory_limit("default") == (
740
+ ... parse_device_memory_limit(0.8)
741
+ ... if has_device_memory_resource()
742
+ ... else None
743
+ ... )
744
+ True
534
745
  """
535
746
 
536
- def _align(size, alignment_size):
537
- return size // alignment_size * alignment_size
538
-
539
- if device_memory_limit in {0, "0", None, "auto"}:
540
- return _align(get_device_total_memory(device_index), alignment_size)
541
-
542
- with suppress(ValueError, TypeError):
543
- device_memory_limit = float(device_memory_limit)
544
- if isinstance(device_memory_limit, float) and device_memory_limit <= 1:
545
- return _align(
546
- int(get_device_total_memory(device_index) * device_memory_limit),
547
- alignment_size,
548
- )
747
+ # Special cases for "auto" and "default".
748
+ if device_memory_limit in ["auto", "default"]:
749
+ if not has_device_memory_resource():
750
+ return None
751
+ if device_memory_limit == "auto":
752
+ device_memory_limit = get_device_total_memory(device_index)
753
+ else:
754
+ device_memory_limit = 0.8
549
755
 
550
- if isinstance(device_memory_limit, str):
551
- return _align(parse_bytes(device_memory_limit), alignment_size)
552
- else:
553
- return _align(int(device_memory_limit), alignment_size)
756
+ return parse_device_bytes(
757
+ device_bytes=device_memory_limit,
758
+ device_index=device_index,
759
+ alignment_size=alignment_size,
760
+ )
554
761
 
555
762
 
556
- def get_gpu_uuid_from_index(device_index=0):
763
+ def get_gpu_uuid(device_index=0):
557
764
  """Get GPU UUID from CUDA device index.
558
765
 
559
766
  Parameters
560
767
  ----------
561
768
  device_index: int or str
562
- The index of the device from which to obtain the UUID. Default: 0.
769
+ The index or UUID of the device from which to obtain the UUID.
563
770
 
564
771
  Examples
565
772
  --------
566
- >>> get_gpu_uuid_from_index()
773
+ >>> get_gpu_uuid()
567
774
  'GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005'
568
775
 
569
- >>> get_gpu_uuid_from_index(3)
776
+ >>> get_gpu_uuid(3)
570
777
  'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
571
- """
572
- import pynvml
573
778
 
574
- pynvml.nvmlInit()
575
- handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
779
+ >>> get_gpu_uuid("GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
780
+ 'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
781
+ """
782
+ handle = get_gpu_handle(device_index)
576
783
  try:
577
784
  return pynvml.nvmlDeviceGetUUID(handle).decode("utf-8")
578
785
  except AttributeError:
@@ -616,20 +823,26 @@ def get_worker_config(dask_worker):
616
823
  ret["device-memory-limit"] = dask_worker.data.manager._device_memory_limit
617
824
  else:
618
825
  has_device = hasattr(dask_worker.data, "device_buffer")
619
- if has_device:
826
+ if has_device and hasattr(dask_worker.data.device_buffer, "n"):
827
+ # If `n` is not an attribute, device spilling is disabled/unavailable.
620
828
  ret["device-memory-limit"] = dask_worker.data.device_buffer.n
621
829
 
622
830
  # using ucx ?
623
831
  scheme, loc = parse_address(dask_worker.scheduler.address)
624
832
  ret["protocol"] = scheme
625
- if scheme == "ucx":
626
- import ucp
833
+ try:
834
+ protocol = _get_active_ucx_implementation_name(scheme)
835
+ except ValueError:
836
+ pass
837
+ else:
838
+ if protocol == "ucxx":
839
+ import ucxx
627
840
 
628
- ret["ucx-transports"] = ucp.get_active_transports()
629
- elif scheme == "ucxx":
630
- import ucxx
841
+ ret["ucx-transports"] = ucxx.get_active_transports()
842
+ elif protocol == "ucx-old":
843
+ import ucp
631
844
 
632
- ret["ucx-transports"] = ucxx.get_active_transports()
845
+ ret["ucx-transports"] = ucp.get_active_transports()
633
846
 
634
847
  # comm timeouts
635
848
  ret["distributed.comm.timeouts"] = dask.config.get("distributed.comm.timeouts")
@@ -661,7 +874,7 @@ async def _get_cluster_configuration(client):
661
874
  if worker_config:
662
875
  w = list(worker_config.values())[0]
663
876
  ret.update(w)
664
- info = client.scheduler_info()
877
+ info = client.scheduler_info(n_workers=-1)
665
878
  workers = info.get("workers", {})
666
879
  ret["nworkers"] = len(workers)
667
880
  ret["nthreads"] = sum(w["nthreads"] for w in workers.values())
@@ -739,7 +952,7 @@ def get_rmm_device_memory_usage() -> Optional[int]:
739
952
  """Get current bytes allocated on current device through RMM
740
953
 
741
954
  Check the current RMM resource stack for resources such as
742
- `StatisticsResourceAdaptor` and `TrackingResourceAdaptor`
955
+ ``StatisticsResourceAdaptor`` and ``TrackingResourceAdaptor``
743
956
  that can report the current allocated bytes. Returns None,
744
957
  if no such resources exist.
745
958
 
@@ -775,3 +988,41 @@ class CommaSeparatedChoice(click.Choice):
775
988
  choices_str = ", ".join(f"'{c}'" for c in self.choices)
776
989
  self.fail(f"invalid choice(s): {v}. (choices are: {choices_str})")
777
990
  return values
991
+
992
+
993
+ def _get_active_ucx_implementation_name(protocol):
994
+ """Get the name of active UCX implementation.
995
+
996
+ Determine what UCX implementation is being activated based on a series of
997
+ conditions. UCXX is selected if:
998
+ - The protocol is `"ucxx"`, or the protocol is `"ucx"` and the `distributed-ucxx`
999
+ package is installed.
1000
+ UCX-Py is selected if:
1001
+ - The protocol is `"ucx-old"`, or the protocol is `"ucx"` and the `distributed-ucxx`
1002
+ package is not installed, in which case a `FutureWarning` is also raised.
1003
+
1004
+ Parameters
1005
+ ----------
1006
+ protocol: str
1007
+ The communication protocol selected.
1008
+
1009
+ Returns
1010
+ -------
1011
+ The selected implementation type, either "ucxx" or "ucx-old".
1012
+
1013
+ Raises
1014
+ ------
1015
+ ValueError
1016
+ If protocol is not a valid UCX protocol.
1017
+ """
1018
+ has_ucxx = importlib.util.find_spec("distributed_ucxx") is not None
1019
+
1020
+ if protocol == "ucxx" or (has_ucxx and protocol == "ucx"):
1021
+ # With https://github.com/rapidsai/rapids-dask-dependency/pull/116,
1022
+ # `protocol="ucx"` now points to UCXX (if distributed-ucxx is installed),
1023
+ # thus call the UCXX initializer.
1024
+ return "ucxx"
1025
+ elif protocol in ("ucx", "ucx-old"):
1026
+ return "ucx-old"
1027
+ else:
1028
+ raise ValueError("Protocol is neither UCXX nor UCX-Py")
dask_cuda/utils_test.py CHANGED
@@ -1,14 +1,19 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  from typing import Literal
2
5
 
3
6
  import distributed
4
7
  from distributed import Nanny, Worker
5
8
 
9
+ from .utils import _get_active_ucx_implementation_name
10
+
6
11
 
7
12
  class MockWorker(Worker):
8
13
  """Mock Worker class preventing NVML from getting used by SystemMonitor.
9
14
 
10
15
  By preventing the Worker from initializing NVML in the SystemMonitor, we can
11
- mock test multiple devices in `CUDA_VISIBLE_DEVICES` behavior with single-GPU
16
+ mock test multiple devices in ``CUDA_VISIBLE_DEVICES`` behavior with single-GPU
12
17
  machines.
13
18
  """
14
19
 
@@ -26,20 +31,31 @@ class MockWorker(Worker):
26
31
 
27
32
 
28
33
  class IncreasedCloseTimeoutNanny(Nanny):
29
- """Increase `Nanny`'s close timeout.
34
+ """Increase ``Nanny``'s close timeout.
30
35
 
31
- The internal close timeout mechanism of `Nanny` recomputes the time left to kill
32
- the `Worker` process based on elapsed time of the close task, which may leave
36
+ The internal close timeout mechanism of ``Nanny`` recomputes the time left to kill
37
+ the ``Worker`` process based on elapsed time of the close task, which may leave
33
38
  very little time for the subprocess to shutdown cleanly, which may cause tests
34
39
  to fail when the system is under higher load. This class increases the default
35
- close timeout of 5.0 seconds that `Nanny` sets by default, which can be overriden
40
+ close timeout of 5.0 seconds that ``Nanny`` sets by default, which can be overriden
36
41
  via Distributed's public API.
37
42
 
38
- This class can be used with the `worker_class` argument of `LocalCluster` or
39
- `LocalCUDACluster` to provide a much higher default of 30.0 seconds.
43
+ This class can be used with the ``worker_class`` argument of ``LocalCluster`` or
44
+ ``LocalCUDACluster`` to provide a much higher default of 30.0 seconds.
40
45
  """
41
46
 
42
47
  async def close( # type:ignore[override]
43
48
  self, timeout: float = 30.0, reason: str = "nanny-close"
44
49
  ) -> Literal["OK"]:
45
50
  return await super().close(timeout=timeout, reason=reason)
51
+
52
+
53
+ def get_ucx_implementation(protocol):
54
+ import pytest
55
+
56
+ protocol = _get_active_ucx_implementation_name(protocol)
57
+
58
+ if protocol == "ucxx":
59
+ return pytest.importorskip("ucxx")
60
+ else:
61
+ return pytest.importorskip("ucp")