dask-cuda 25.6.0__py3-none-any.whl → 25.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/benchmarks/common.py +4 -1
  4. dask_cuda/benchmarks/local_cudf_groupby.py +4 -1
  5. dask_cuda/benchmarks/local_cudf_merge.py +5 -2
  6. dask_cuda/benchmarks/local_cudf_shuffle.py +5 -2
  7. dask_cuda/benchmarks/local_cupy.py +4 -1
  8. dask_cuda/benchmarks/local_cupy_map_overlap.py +4 -1
  9. dask_cuda/benchmarks/utils.py +7 -4
  10. dask_cuda/cli.py +21 -15
  11. dask_cuda/cuda_worker.py +27 -57
  12. dask_cuda/device_host_file.py +31 -15
  13. dask_cuda/disk_io.py +7 -4
  14. dask_cuda/explicit_comms/comms.py +11 -7
  15. dask_cuda/explicit_comms/dataframe/shuffle.py +23 -23
  16. dask_cuda/get_device_memory_objects.py +3 -3
  17. dask_cuda/initialize.py +80 -44
  18. dask_cuda/local_cuda_cluster.py +63 -66
  19. dask_cuda/plugins.py +17 -16
  20. dask_cuda/proxify_device_objects.py +12 -10
  21. dask_cuda/proxify_host_file.py +30 -27
  22. dask_cuda/proxy_object.py +20 -17
  23. dask_cuda/tests/conftest.py +41 -0
  24. dask_cuda/tests/test_dask_cuda_worker.py +109 -25
  25. dask_cuda/tests/test_dgx.py +10 -18
  26. dask_cuda/tests/test_explicit_comms.py +30 -12
  27. dask_cuda/tests/test_from_array.py +7 -5
  28. dask_cuda/tests/test_initialize.py +16 -37
  29. dask_cuda/tests/test_local_cuda_cluster.py +159 -52
  30. dask_cuda/tests/test_proxify_host_file.py +19 -3
  31. dask_cuda/tests/test_proxy.py +18 -16
  32. dask_cuda/tests/test_rdd_ucx.py +160 -0
  33. dask_cuda/tests/test_spill.py +7 -0
  34. dask_cuda/tests/test_utils.py +106 -20
  35. dask_cuda/tests/test_worker_spec.py +5 -2
  36. dask_cuda/utils.py +261 -38
  37. dask_cuda/utils_test.py +23 -7
  38. dask_cuda/worker_common.py +196 -0
  39. dask_cuda/worker_spec.py +12 -5
  40. {dask_cuda-25.6.0.dist-info → dask_cuda-25.8.0.dist-info}/METADATA +2 -2
  41. dask_cuda-25.8.0.dist-info/RECORD +63 -0
  42. dask_cuda-25.8.0.dist-info/top_level.txt +6 -0
  43. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  44. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  45. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  46. dask_cuda-25.6.0.dist-info/RECORD +0 -57
  47. dask_cuda-25.6.0.dist-info/top_level.txt +0 -4
  48. {dask_cuda-25.6.0.dist-info → dask_cuda-25.8.0.dist-info}/WHEEL +0 -0
  49. {dask_cuda-25.6.0.dist-info → dask_cuda-25.8.0.dist-info}/entry_points.txt +0 -0
  50. {dask_cuda-25.6.0.dist-info → dask_cuda-25.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import os
2
5
  from unittest.mock import patch
3
6
 
@@ -15,11 +18,13 @@ from dask_cuda.utils import (
15
18
  get_n_gpus,
16
19
  get_preload_options,
17
20
  get_ucx_config,
21
+ has_device_memory_resource,
18
22
  nvml_device_index,
19
23
  parse_cuda_visible_device,
20
24
  parse_device_memory_limit,
21
25
  unpack_bitmask,
22
26
  )
27
+ from dask_cuda.utils_test import get_ucx_implementation
23
28
 
24
29
 
25
30
  @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1,2"})
@@ -76,19 +81,19 @@ def test_get_device_total_memory():
76
81
  for i in range(get_n_gpus()):
77
82
  with cuda.gpus[i]:
78
83
  total_mem = get_device_total_memory(i)
79
- assert type(total_mem) is int
80
- assert total_mem > 0
84
+ if has_device_memory_resource():
85
+ assert type(total_mem) is int
86
+ assert total_mem > 0
87
+ else:
88
+ assert total_mem is None
81
89
 
82
90
 
83
91
  @pytest.mark.parametrize(
84
92
  "protocol",
85
- ["ucx", "ucxx"],
93
+ ["ucx", "ucx-old"],
86
94
  )
87
95
  def test_get_preload_options_default(protocol):
88
- if protocol == "ucx":
89
- pytest.importorskip("ucp")
90
- elif protocol == "ucxx":
91
- pytest.importorskip("ucxx")
96
+ get_ucx_implementation(protocol)
92
97
 
93
98
  opts = get_preload_options(
94
99
  protocol=protocol,
@@ -103,16 +108,13 @@ def test_get_preload_options_default(protocol):
103
108
 
104
109
  @pytest.mark.parametrize(
105
110
  "protocol",
106
- ["ucx", "ucxx"],
111
+ ["ucx", "ucx-old"],
107
112
  )
108
113
  @pytest.mark.parametrize("enable_tcp", [True, False])
109
114
  @pytest.mark.parametrize("enable_infiniband", [True, False])
110
115
  @pytest.mark.parametrize("enable_nvlink", [True, False])
111
116
  def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvlink):
112
- if protocol == "ucx":
113
- pytest.importorskip("ucp")
114
- elif protocol == "ucxx":
115
- pytest.importorskip("ucxx")
117
+ get_ucx_implementation(protocol)
116
118
 
117
119
  opts = get_preload_options(
118
120
  protocol=protocol,
@@ -135,11 +137,17 @@ def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvl
135
137
  assert "--enable-nvlink" in opts["preload_argv"]
136
138
 
137
139
 
140
+ @pytest.mark.parametrize(
141
+ "protocol",
142
+ ["ucx", "ucx-old"],
143
+ )
138
144
  @pytest.mark.parametrize("enable_tcp_over_ucx", [True, False, None])
139
145
  @pytest.mark.parametrize("enable_nvlink", [True, False, None])
140
146
  @pytest.mark.parametrize("enable_infiniband", [True, False, None])
141
- def test_get_ucx_config(enable_tcp_over_ucx, enable_infiniband, enable_nvlink):
142
- pytest.importorskip("ucp")
147
+ def test_get_ucx_config(
148
+ protocol, enable_tcp_over_ucx, enable_infiniband, enable_nvlink
149
+ ):
150
+ get_ucx_implementation(protocol)
143
151
 
144
152
  kwargs = {
145
153
  "enable_tcp_over_ucx": enable_tcp_over_ucx,
@@ -234,20 +242,98 @@ def test_parse_visible_devices():
234
242
  parse_cuda_visible_device([])
235
243
 
236
244
 
245
+ def test_parse_device_bytes():
246
+ total = get_device_total_memory(0)
247
+
248
+ assert parse_device_memory_limit(None) is None
249
+ assert parse_device_memory_limit(0) is None
250
+ assert parse_device_memory_limit("0") is None
251
+ assert parse_device_memory_limit("0.0") is None
252
+ assert parse_device_memory_limit("0 GiB") is None
253
+
254
+ assert parse_device_memory_limit(1) == 1
255
+ assert parse_device_memory_limit("1") == 1
256
+
257
+ assert parse_device_memory_limit(1000000000) == 1000000000
258
+ assert parse_device_memory_limit("1GB") == 1000000000
259
+
260
+ if has_device_memory_resource(0):
261
+ assert parse_device_memory_limit(1.0) == total
262
+ assert parse_device_memory_limit("1.0") == total
263
+
264
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
265
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
266
+ total * 0.8 // 256 * 256
267
+ )
268
+
269
+ assert parse_device_memory_limit("default") == parse_device_memory_limit(0.8)
270
+ else:
271
+ assert parse_device_memory_limit("default") is None
272
+
273
+ with pytest.raises(ValueError):
274
+ assert parse_device_memory_limit(1.0) == total
275
+ with pytest.raises(ValueError):
276
+ assert parse_device_memory_limit("1.0") == total
277
+ with pytest.raises(ValueError):
278
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
279
+ with pytest.raises(ValueError):
280
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
281
+ total * 0.8 // 256 * 256
282
+ )
283
+
284
+
237
285
  def test_parse_device_memory_limit():
238
286
  total = get_device_total_memory(0)
239
287
 
240
- assert parse_device_memory_limit(None) == total
241
- assert parse_device_memory_limit(0) == total
288
+ assert parse_device_memory_limit(None) is None
289
+ assert parse_device_memory_limit(0) is None
290
+ assert parse_device_memory_limit("0") is None
291
+ assert parse_device_memory_limit(0.0) is None
292
+ assert parse_device_memory_limit("0 GiB") is None
293
+
294
+ assert parse_device_memory_limit(1) == 1
295
+ assert parse_device_memory_limit("1") == 1
296
+
242
297
  assert parse_device_memory_limit("auto") == total
243
298
 
244
- assert parse_device_memory_limit(0.8) == int(total * 0.8)
245
- assert parse_device_memory_limit(0.8, alignment_size=256) == int(
246
- total * 0.8 // 256 * 256
247
- )
248
299
  assert parse_device_memory_limit(1000000000) == 1000000000
249
300
  assert parse_device_memory_limit("1GB") == 1000000000
250
301
 
302
+ if has_device_memory_resource(0):
303
+ assert parse_device_memory_limit(1.0) == total
304
+ assert parse_device_memory_limit("1.0") == total
305
+
306
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
307
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
308
+ total * 0.8 // 256 * 256
309
+ )
310
+ assert parse_device_memory_limit("default") == parse_device_memory_limit(0.8)
311
+ else:
312
+ assert parse_device_memory_limit("default") is None
313
+
314
+ with pytest.raises(ValueError):
315
+ assert parse_device_memory_limit(1.0) == total
316
+ with pytest.raises(ValueError):
317
+ assert parse_device_memory_limit("1.0") == total
318
+ with pytest.raises(ValueError):
319
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
320
+ with pytest.raises(ValueError):
321
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
322
+ total * 0.8 // 256 * 256
323
+ )
324
+
325
+
326
+ def test_has_device_memory_resoure():
327
+ has_memory_resource = has_device_memory_resource()
328
+ total = get_device_total_memory(0)
329
+
330
+ if has_memory_resource:
331
+ # Tested only in devices with a memory resource
332
+ assert total == parse_device_memory_limit("auto")
333
+ else:
334
+ # Tested only in devices without a memory resource
335
+ assert total is None
336
+
251
337
 
252
338
  def test_parse_visible_mig_devices():
253
339
  pynvml.nvmlInit()
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import pytest
2
5
 
3
6
  from distributed import Nanny
@@ -28,7 +31,7 @@ def _check_env_value(spec, k, v):
28
31
  @pytest.mark.parametrize("num_devices", [1, 4])
29
32
  @pytest.mark.parametrize("cls", [Nanny])
30
33
  @pytest.mark.parametrize("interface", [None, "eth0", "enp1s0f0"])
31
- @pytest.mark.parametrize("protocol", [None, "tcp", "ucx"])
34
+ @pytest.mark.parametrize("protocol", [None, "tcp", "ucx", "ucx-old"])
32
35
  @pytest.mark.parametrize("dashboard_address", [None, ":0", ":8787"])
33
36
  @pytest.mark.parametrize("threads_per_worker", [1, 8])
34
37
  @pytest.mark.parametrize("silence_logs", [False, True])
@@ -58,7 +61,7 @@ def test_worker_spec(
58
61
  enable_nvlink=enable_nvlink,
59
62
  )
60
63
 
61
- if (enable_infiniband or enable_nvlink) and protocol != "ucx":
64
+ if (enable_infiniband or enable_nvlink) and protocol not in ("ucx", "ucx-old"):
62
65
  with pytest.raises(
63
66
  TypeError, match="Enabling InfiniBand or NVLink requires protocol='ucx'"
64
67
  ):
dask_cuda/utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ import importlib
4
5
  import math
5
6
  import operator
6
7
  import os
@@ -43,7 +44,7 @@ def unpack_bitmask(x, mask_bits=64):
43
44
  x: list of int
44
45
  A list of integers
45
46
  mask_bits: int
46
- An integer determining the bitwidth of `x`
47
+ An integer determining the bitwidth of ``x``
47
48
 
48
49
  Examples
49
50
  --------
@@ -220,9 +221,44 @@ def get_device_total_memory(device_index=0):
220
221
  ----------
221
222
  device_index: int or str
222
223
  The index or UUID of the device from which to obtain the CPU affinity.
224
+
225
+ Returns
226
+ -------
227
+ The total memory of the CUDA Device in bytes, or ``None`` for devices that do not
228
+ have a dedicated memory resource, as is usually the case for system on a chip (SoC)
229
+ devices.
223
230
  """
224
231
  handle = get_gpu_handle(device_index)
225
- return pynvml.nvmlDeviceGetMemoryInfo(handle).total
232
+
233
+ try:
234
+ return pynvml.nvmlDeviceGetMemoryInfo(handle).total
235
+ except pynvml.NVMLError_NotSupported:
236
+ return None
237
+
238
+
239
+ def has_device_memory_resource(device_index=0):
240
+ """Determine wheter CUDA device has dedicated memory resource.
241
+
242
+ Certain devices have no dedicated memory resource, such as system on a chip (SoC)
243
+ devices.
244
+
245
+ Parameters
246
+ ----------
247
+ device_index: int or str
248
+ The index or UUID of the device from which to obtain the CPU affinity.
249
+
250
+ Returns
251
+ -------
252
+ Whether the device has a dedicated memory resource.
253
+ """
254
+ handle = get_gpu_handle(device_index)
255
+
256
+ try:
257
+ pynvml.nvmlDeviceGetMemoryInfo(handle).total
258
+ except pynvml.NVMLError_NotSupported:
259
+ return False
260
+ else:
261
+ return True
226
262
 
227
263
 
228
264
  def get_ucx_config(
@@ -230,9 +266,15 @@ def get_ucx_config(
230
266
  enable_infiniband=None,
231
267
  enable_nvlink=None,
232
268
  enable_rdmacm=None,
269
+ protocol=None,
233
270
  ):
234
271
  ucx_config = dask.config.get("distributed.comm.ucx")
235
272
 
273
+ # TODO: remove along with `protocol` kwarg when UCX-Py is removed, see
274
+ # https://github.com/rapidsai/dask-cuda/issues/1517
275
+ if protocol in ("ucx", "ucxx", "ucx-old"):
276
+ ucx_config[canonical_name("ucx-protocol", ucx_config)] = protocol
277
+
236
278
  ucx_config[canonical_name("create-cuda-context", ucx_config)] = True
237
279
  ucx_config[canonical_name("reuse-endpoints", ucx_config)] = False
238
280
 
@@ -316,7 +358,11 @@ def get_preload_options(
316
358
  if create_cuda_context:
317
359
  preload_options["preload_argv"].append("--create-cuda-context")
318
360
 
319
- if protocol in ["ucx", "ucxx"]:
361
+ try:
362
+ _get_active_ucx_implementation_name(protocol)
363
+ except ValueError:
364
+ pass
365
+ else:
320
366
  initialize_ucx_argv = []
321
367
  if enable_tcp_over_ucx:
322
368
  initialize_ucx_argv.append("--enable-tcp-over-ucx")
@@ -365,7 +411,7 @@ def wait_workers(
365
411
  Instance of client, used to query for number of workers connected.
366
412
  min_timeout: float
367
413
  Minimum number of seconds to wait before timeout. This value may be
368
- overridden by setting the `DASK_CUDA_WAIT_WORKERS_MIN_TIMEOUT` with
414
+ overridden by setting the ``DASK_CUDA_WAIT_WORKERS_MIN_TIMEOUT`` with
369
415
  a positive integer.
370
416
  seconds_per_gpu: float
371
417
  Seconds to wait for each GPU on the system. For example, if its
@@ -374,7 +420,7 @@ def wait_workers(
374
420
  used as timeout when larger than min_timeout.
375
421
  n_gpus: None or int
376
422
  If specified, will wait for a that amount of GPUs (i.e., Dask workers)
377
- to come online, else waits for a total of `get_n_gpus` workers.
423
+ to come online, else waits for a total of ``get_n_gpus`` workers.
378
424
  timeout_callback: None or callable
379
425
  A callback function to be executed if a timeout occurs, ignored if
380
426
  None.
@@ -390,7 +436,7 @@ def wait_workers(
390
436
 
391
437
  start = time.time()
392
438
  while True:
393
- if len(client.scheduler_info()["workers"]) == n_gpus:
439
+ if len(client.scheduler_info(n_workers=-1)["workers"]) == n_gpus:
394
440
  return True
395
441
  elif time.time() - start > timeout:
396
442
  if callable(timeout_callback):
@@ -404,7 +450,7 @@ async def _all_to_all(client):
404
450
  """
405
451
  Trigger all to all communication between workers and scheduler
406
452
  """
407
- workers = list(client.scheduler_info()["workers"])
453
+ workers = list(client.scheduler_info(n_workers=-1)["workers"])
408
454
  futs = []
409
455
  for w in workers:
410
456
  bit_of_data = b"0" * 1
@@ -493,8 +539,8 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES):
493
539
  """Get the device index for NVML addressing
494
540
 
495
541
  NVML expects the index of the physical device, unlike CUDA runtime which
496
- expects the address relative to `CUDA_VISIBLE_DEVICES`. This function
497
- returns the i-th device index from the `CUDA_VISIBLE_DEVICES`
542
+ expects the address relative to ``CUDA_VISIBLE_DEVICES``. This function
543
+ returns the i-th device index from the ``CUDA_VISIBLE_DEVICES``
498
544
  comma-separated string of devices or list.
499
545
 
500
546
  Examples
@@ -532,15 +578,125 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES):
532
578
  raise ValueError("`CUDA_VISIBLE_DEVICES` must be `str` or `list`")
533
579
 
534
580
 
581
+ def parse_device_bytes(device_bytes, device_index=0, alignment_size=1):
582
+ """Parse bytes relative to a specific CUDA device.
583
+
584
+ Parameters
585
+ ----------
586
+ device_bytes: float, int, str or None
587
+ Can be an integer (bytes), float (fraction of total device memory), string
588
+ (like ``"5GB"`` or ``"5000M"``), ``0`` and ``None`` are special cases
589
+ returning ``None``.
590
+ device_index: int or str
591
+ The index or UUID of the device from which to obtain the total memory amount.
592
+ Default: 0.
593
+ alignment_size: int
594
+ Number of bytes of alignment to use, i.e., allocation must be a multiple of
595
+ that size. RMM pool requires 256 bytes alignment.
596
+
597
+ Returns
598
+ -------
599
+ The parsed bytes value relative to the CUDA devices, or ``None`` as convenience if
600
+ ``device_bytes`` is ``None`` or any value that would evaluate to ``0``.
601
+
602
+ Examples
603
+ --------
604
+ >>> # On a 32GB CUDA device
605
+ >>> parse_device_bytes(None)
606
+ None
607
+ >>> parse_device_bytes(0)
608
+ None
609
+ >>> parse_device_bytes(0.0)
610
+ None
611
+ >>> parse_device_bytes("0 MiB")
612
+ None
613
+ >>> parse_device_bytes(1.0)
614
+ 34089730048
615
+ >>> parse_device_bytes(0.8)
616
+ 27271784038
617
+ >>> parse_device_bytes(1000000000)
618
+ 1000000000
619
+ >>> parse_device_bytes("1GB")
620
+ 1000000000
621
+ >>> parse_device_bytes("1GB")
622
+ 1000000000
623
+ """
624
+
625
+ def _align(size, alignment_size):
626
+ return size // alignment_size * alignment_size
627
+
628
+ def parse_fractional(v):
629
+ """Parse fractional value.
630
+
631
+ Ensures ``int(1)`` and ``str("1")`` are not treated as fractionals, but
632
+ ``float(1)`` is.
633
+
634
+ Fractionals must be represented as a ``float`` within the range
635
+ ``0.0 < v <= 1.0``.
636
+
637
+ Parameters
638
+ ----------
639
+ v: int, float or str
640
+ The value to check if fractional.
641
+
642
+ Returns
643
+ -------
644
+ """
645
+ # Check if `x` matches exactly `int(1)` or `str("1")`, and is not a `float(1)`
646
+ is_one = lambda x: not isinstance(x, float) and (x == 1 or x == "1")
647
+
648
+ if not is_one(v):
649
+ with suppress(ValueError, TypeError):
650
+ v = float(v)
651
+ if 0.0 < v <= 1.0:
652
+ return v
653
+
654
+ raise ValueError("The value is not fractional")
655
+
656
+ # Special case for fractional limit. This comes before `0` special cases because
657
+ # the `float` may be passed in a `str`, e.g., from `CUDAWorker`.
658
+ try:
659
+ fractional_device_bytes = parse_fractional(device_bytes)
660
+ except ValueError:
661
+ pass
662
+ else:
663
+ if not has_device_memory_resource():
664
+ raise ValueError(
665
+ "Fractional of total device memory not supported in devices without "
666
+ "a dedicated memory resource."
667
+ )
668
+ return _align(
669
+ int(get_device_total_memory(device_index) * fractional_device_bytes),
670
+ alignment_size,
671
+ )
672
+
673
+ # Special cases that evaluates to `None` or `0`
674
+ if device_bytes is None:
675
+ return None
676
+ elif device_bytes == 0.0:
677
+ return None
678
+ elif not isinstance(device_bytes, float) and parse_bytes(device_bytes) == 0:
679
+ return None
680
+
681
+ if isinstance(device_bytes, str):
682
+ return _align(parse_bytes(device_bytes), alignment_size)
683
+ else:
684
+ return _align(int(device_bytes), alignment_size)
685
+
686
+
535
687
  def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_size=1):
536
688
  """Parse memory limit to be used by a CUDA device.
537
689
 
538
690
  Parameters
539
691
  ----------
540
692
  device_memory_limit: float, int, str or None
541
- This can be a float (fraction of total device memory), an integer (bytes),
542
- a string (like 5GB or 5000M), and "auto", 0 or None for the total device
543
- size.
693
+ Can be an integer (bytes), float (fraction of total device memory), string
694
+ (like ``"5GB"`` or ``"5000M"``), ``"auto"``, ``0`` or ``None`` to disable
695
+ spilling to host (i.e. allow full device memory usage). Another special value
696
+ ``"default"`` is also available and returns the recommended Dask-CUDA's defaults
697
+ and means 80% of the total device memory (analogous to ``0.8``), and disabled
698
+ spilling (analogous to ``auto``/``0``/``None``) on devices without a dedicated
699
+ memory resource, such as system on a chip (SoC) devices.
544
700
  device_index: int or str
545
701
  The index or UUID of the device from which to obtain the total memory amount.
546
702
  Default: 0.
@@ -548,10 +704,23 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
548
704
  Number of bytes of alignment to use, i.e., allocation must be a multiple of
549
705
  that size. RMM pool requires 256 bytes alignment.
550
706
 
707
+ Returns
708
+ -------
709
+ The parsed memory limit in bytes, or ``None`` as convenience if
710
+ ``device_memory_limit`` is ``None`` or any value that would evaluate to ``0``.
711
+
551
712
  Examples
552
713
  --------
553
714
  >>> # On a 32GB CUDA device
554
715
  >>> parse_device_memory_limit(None)
716
+ None
717
+ >>> parse_device_memory_limit(0)
718
+ None
719
+ >>> parse_device_memory_limit(0.0)
720
+ None
721
+ >>> parse_device_memory_limit("0 MiB")
722
+ None
723
+ >>> parse_device_memory_limit(1.0)
555
724
  34089730048
556
725
  >>> parse_device_memory_limit(0.8)
557
726
  27271784038
@@ -559,26 +728,36 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
559
728
  1000000000
560
729
  >>> parse_device_memory_limit("1GB")
561
730
  1000000000
731
+ >>> parse_device_memory_limit("1GB")
732
+ 1000000000
733
+ >>> parse_device_memory_limit("auto") == (
734
+ ... parse_device_memory_limit(1.0)
735
+ ... if has_device_memory_resource()
736
+ ... else None
737
+ ... )
738
+ True
739
+ >>> parse_device_memory_limit("default") == (
740
+ ... parse_device_memory_limit(0.8)
741
+ ... if has_device_memory_resource()
742
+ ... else None
743
+ ... )
744
+ True
562
745
  """
563
746
 
564
- def _align(size, alignment_size):
565
- return size // alignment_size * alignment_size
566
-
567
- if device_memory_limit in {0, "0", None, "auto"}:
568
- return _align(get_device_total_memory(device_index), alignment_size)
569
-
570
- with suppress(ValueError, TypeError):
571
- device_memory_limit = float(device_memory_limit)
572
- if isinstance(device_memory_limit, float) and device_memory_limit <= 1:
573
- return _align(
574
- int(get_device_total_memory(device_index) * device_memory_limit),
575
- alignment_size,
576
- )
747
+ # Special cases for "auto" and "default".
748
+ if device_memory_limit in ["auto", "default"]:
749
+ if not has_device_memory_resource():
750
+ return None
751
+ if device_memory_limit == "auto":
752
+ device_memory_limit = get_device_total_memory(device_index)
753
+ else:
754
+ device_memory_limit = 0.8
577
755
 
578
- if isinstance(device_memory_limit, str):
579
- return _align(parse_bytes(device_memory_limit), alignment_size)
580
- else:
581
- return _align(int(device_memory_limit), alignment_size)
756
+ return parse_device_bytes(
757
+ device_bytes=device_memory_limit,
758
+ device_index=device_index,
759
+ alignment_size=alignment_size,
760
+ )
582
761
 
583
762
 
584
763
  def get_gpu_uuid(device_index=0):
@@ -644,20 +823,26 @@ def get_worker_config(dask_worker):
644
823
  ret["device-memory-limit"] = dask_worker.data.manager._device_memory_limit
645
824
  else:
646
825
  has_device = hasattr(dask_worker.data, "device_buffer")
647
- if has_device:
826
+ if has_device and hasattr(dask_worker.data.device_buffer, "n"):
827
+ # If `n` is not an attribute, device spilling is disabled/unavailable.
648
828
  ret["device-memory-limit"] = dask_worker.data.device_buffer.n
649
829
 
650
830
  # using ucx ?
651
831
  scheme, loc = parse_address(dask_worker.scheduler.address)
652
832
  ret["protocol"] = scheme
653
- if scheme == "ucx":
654
- import ucp
833
+ try:
834
+ protocol = _get_active_ucx_implementation_name(scheme)
835
+ except ValueError:
836
+ pass
837
+ else:
838
+ if protocol == "ucxx":
839
+ import ucxx
655
840
 
656
- ret["ucx-transports"] = ucp.get_active_transports()
657
- elif scheme == "ucxx":
658
- import ucxx
841
+ ret["ucx-transports"] = ucxx.get_active_transports()
842
+ elif protocol == "ucx-old":
843
+ import ucp
659
844
 
660
- ret["ucx-transports"] = ucxx.get_active_transports()
845
+ ret["ucx-transports"] = ucp.get_active_transports()
661
846
 
662
847
  # comm timeouts
663
848
  ret["distributed.comm.timeouts"] = dask.config.get("distributed.comm.timeouts")
@@ -689,7 +874,7 @@ async def _get_cluster_configuration(client):
689
874
  if worker_config:
690
875
  w = list(worker_config.values())[0]
691
876
  ret.update(w)
692
- info = client.scheduler_info()
877
+ info = client.scheduler_info(n_workers=-1)
693
878
  workers = info.get("workers", {})
694
879
  ret["nworkers"] = len(workers)
695
880
  ret["nthreads"] = sum(w["nthreads"] for w in workers.values())
@@ -767,7 +952,7 @@ def get_rmm_device_memory_usage() -> Optional[int]:
767
952
  """Get current bytes allocated on current device through RMM
768
953
 
769
954
  Check the current RMM resource stack for resources such as
770
- `StatisticsResourceAdaptor` and `TrackingResourceAdaptor`
955
+ ``StatisticsResourceAdaptor`` and ``TrackingResourceAdaptor``
771
956
  that can report the current allocated bytes. Returns None,
772
957
  if no such resources exist.
773
958
 
@@ -803,3 +988,41 @@ class CommaSeparatedChoice(click.Choice):
803
988
  choices_str = ", ".join(f"'{c}'" for c in self.choices)
804
989
  self.fail(f"invalid choice(s): {v}. (choices are: {choices_str})")
805
990
  return values
991
+
992
+
993
+ def _get_active_ucx_implementation_name(protocol):
994
+ """Get the name of active UCX implementation.
995
+
996
+ Determine what UCX implementation is being activated based on a series of
997
+ conditions. UCXX is selected if:
998
+ - The protocol is `"ucxx"`, or the protocol is `"ucx"` and the `distributed-ucxx`
999
+ package is installed.
1000
+ UCX-Py is selected if:
1001
+ - The protocol is `"ucx-old"`, or the protocol is `"ucx"` and the `distributed-ucxx`
1002
+ package is not installed, in which case a `FutureWarning` is also raised.
1003
+
1004
+ Parameters
1005
+ ----------
1006
+ protocol: str
1007
+ The communication protocol selected.
1008
+
1009
+ Returns
1010
+ -------
1011
+ The selected implementation type, either "ucxx" or "ucx-old".
1012
+
1013
+ Raises
1014
+ ------
1015
+ ValueError
1016
+ If protocol is not a valid UCX protocol.
1017
+ """
1018
+ has_ucxx = importlib.util.find_spec("distributed_ucxx") is not None
1019
+
1020
+ if protocol == "ucxx" or (has_ucxx and protocol == "ucx"):
1021
+ # With https://github.com/rapidsai/rapids-dask-dependency/pull/116,
1022
+ # `protocol="ucx"` now points to UCXX (if distributed-ucxx is installed),
1023
+ # thus call the UCXX initializer.
1024
+ return "ucxx"
1025
+ elif protocol in ("ucx", "ucx-old"):
1026
+ return "ucx-old"
1027
+ else:
1028
+ raise ValueError("Protocol is neither UCXX nor UCX-Py")