dask-cuda 25.6.0__py3-none-any.whl → 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/benchmarks/common.py +4 -1
  4. dask_cuda/benchmarks/local_cudf_groupby.py +3 -0
  5. dask_cuda/benchmarks/local_cudf_merge.py +4 -1
  6. dask_cuda/benchmarks/local_cudf_shuffle.py +4 -1
  7. dask_cuda/benchmarks/local_cupy.py +3 -0
  8. dask_cuda/benchmarks/local_cupy_map_overlap.py +3 -0
  9. dask_cuda/benchmarks/utils.py +6 -3
  10. dask_cuda/cli.py +21 -15
  11. dask_cuda/cuda_worker.py +28 -58
  12. dask_cuda/device_host_file.py +31 -15
  13. dask_cuda/disk_io.py +7 -4
  14. dask_cuda/explicit_comms/comms.py +11 -7
  15. dask_cuda/explicit_comms/dataframe/shuffle.py +23 -23
  16. dask_cuda/get_device_memory_objects.py +4 -7
  17. dask_cuda/initialize.py +149 -94
  18. dask_cuda/local_cuda_cluster.py +52 -70
  19. dask_cuda/plugins.py +17 -16
  20. dask_cuda/proxify_device_objects.py +12 -10
  21. dask_cuda/proxify_host_file.py +30 -27
  22. dask_cuda/proxy_object.py +20 -17
  23. dask_cuda/tests/conftest.py +41 -0
  24. dask_cuda/tests/test_cudf_builtin_spilling.py +3 -1
  25. dask_cuda/tests/test_dask_cuda_worker.py +109 -25
  26. dask_cuda/tests/test_dask_setup.py +193 -0
  27. dask_cuda/tests/test_dgx.py +20 -44
  28. dask_cuda/tests/test_explicit_comms.py +31 -12
  29. dask_cuda/tests/test_from_array.py +4 -6
  30. dask_cuda/tests/test_initialize.py +233 -65
  31. dask_cuda/tests/test_local_cuda_cluster.py +129 -68
  32. dask_cuda/tests/test_proxify_host_file.py +28 -7
  33. dask_cuda/tests/test_proxy.py +15 -13
  34. dask_cuda/tests/test_spill.py +10 -3
  35. dask_cuda/tests/test_utils.py +100 -29
  36. dask_cuda/tests/test_worker_spec.py +6 -0
  37. dask_cuda/utils.py +211 -42
  38. dask_cuda/utils_test.py +10 -7
  39. dask_cuda/worker_common.py +196 -0
  40. dask_cuda/worker_spec.py +6 -1
  41. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/METADATA +11 -4
  42. dask_cuda-25.10.0.dist-info/RECORD +63 -0
  43. dask_cuda-25.10.0.dist-info/top_level.txt +6 -0
  44. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  45. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  46. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  47. dask_cuda-25.6.0.dist-info/RECORD +0 -57
  48. dask_cuda-25.6.0.dist-info/top_level.txt +0 -4
  49. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/WHEEL +0 -0
  50. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/entry_points.txt +0 -0
  51. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import os
2
5
  from unittest.mock import patch
3
6
 
@@ -15,6 +18,7 @@ from dask_cuda.utils import (
15
18
  get_n_gpus,
16
19
  get_preload_options,
17
20
  get_ucx_config,
21
+ has_device_memory_resource,
18
22
  nvml_device_index,
19
23
  parse_cuda_visible_device,
20
24
  parse_device_memory_limit,
@@ -76,22 +80,18 @@ def test_get_device_total_memory():
76
80
  for i in range(get_n_gpus()):
77
81
  with cuda.gpus[i]:
78
82
  total_mem = get_device_total_memory(i)
79
- assert type(total_mem) is int
80
- assert total_mem > 0
83
+ if has_device_memory_resource():
84
+ assert type(total_mem) is int
85
+ assert total_mem > 0
86
+ else:
87
+ assert total_mem is None
81
88
 
82
89
 
83
- @pytest.mark.parametrize(
84
- "protocol",
85
- ["ucx", "ucxx"],
86
- )
87
- def test_get_preload_options_default(protocol):
88
- if protocol == "ucx":
89
- pytest.importorskip("ucp")
90
- elif protocol == "ucxx":
91
- pytest.importorskip("ucxx")
90
+ def test_get_preload_options_default():
91
+ pytest.importorskip("distributed_ucxx")
92
92
 
93
93
  opts = get_preload_options(
94
- protocol=protocol,
94
+ protocol="ucx",
95
95
  create_cuda_context=True,
96
96
  )
97
97
 
@@ -101,21 +101,14 @@ def test_get_preload_options_default(protocol):
101
101
  assert opts["preload_argv"] == ["--create-cuda-context"]
102
102
 
103
103
 
104
- @pytest.mark.parametrize(
105
- "protocol",
106
- ["ucx", "ucxx"],
107
- )
108
104
  @pytest.mark.parametrize("enable_tcp", [True, False])
109
105
  @pytest.mark.parametrize("enable_infiniband", [True, False])
110
106
  @pytest.mark.parametrize("enable_nvlink", [True, False])
111
- def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvlink):
112
- if protocol == "ucx":
113
- pytest.importorskip("ucp")
114
- elif protocol == "ucxx":
115
- pytest.importorskip("ucxx")
107
+ def test_get_preload_options(enable_tcp, enable_infiniband, enable_nvlink):
108
+ pytest.importorskip("distributed_ucxx")
116
109
 
117
110
  opts = get_preload_options(
118
- protocol=protocol,
111
+ protocol="ucx",
119
112
  create_cuda_context=True,
120
113
  enable_tcp_over_ucx=enable_tcp,
121
114
  enable_infiniband=enable_infiniband,
@@ -139,7 +132,7 @@ def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvl
139
132
  @pytest.mark.parametrize("enable_nvlink", [True, False, None])
140
133
  @pytest.mark.parametrize("enable_infiniband", [True, False, None])
141
134
  def test_get_ucx_config(enable_tcp_over_ucx, enable_infiniband, enable_nvlink):
142
- pytest.importorskip("ucp")
135
+ pytest.importorskip("distributed_ucxx")
143
136
 
144
137
  kwargs = {
145
138
  "enable_tcp_over_ucx": enable_tcp_over_ucx,
@@ -234,20 +227,98 @@ def test_parse_visible_devices():
234
227
  parse_cuda_visible_device([])
235
228
 
236
229
 
230
+ def test_parse_device_bytes():
231
+ total = get_device_total_memory(0)
232
+
233
+ assert parse_device_memory_limit(None) is None
234
+ assert parse_device_memory_limit(0) is None
235
+ assert parse_device_memory_limit("0") is None
236
+ assert parse_device_memory_limit("0.0") is None
237
+ assert parse_device_memory_limit("0 GiB") is None
238
+
239
+ assert parse_device_memory_limit(1) == 1
240
+ assert parse_device_memory_limit("1") == 1
241
+
242
+ assert parse_device_memory_limit(1000000000) == 1000000000
243
+ assert parse_device_memory_limit("1GB") == 1000000000
244
+
245
+ if has_device_memory_resource(0):
246
+ assert parse_device_memory_limit(1.0) == total
247
+ assert parse_device_memory_limit("1.0") == total
248
+
249
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
250
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
251
+ total * 0.8 // 256 * 256
252
+ )
253
+
254
+ assert parse_device_memory_limit("default") == parse_device_memory_limit(0.8)
255
+ else:
256
+ assert parse_device_memory_limit("default") is None
257
+
258
+ with pytest.raises(ValueError):
259
+ assert parse_device_memory_limit(1.0) == total
260
+ with pytest.raises(ValueError):
261
+ assert parse_device_memory_limit("1.0") == total
262
+ with pytest.raises(ValueError):
263
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
264
+ with pytest.raises(ValueError):
265
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
266
+ total * 0.8 // 256 * 256
267
+ )
268
+
269
+
237
270
  def test_parse_device_memory_limit():
238
271
  total = get_device_total_memory(0)
239
272
 
240
- assert parse_device_memory_limit(None) == total
241
- assert parse_device_memory_limit(0) == total
273
+ assert parse_device_memory_limit(None) is None
274
+ assert parse_device_memory_limit(0) is None
275
+ assert parse_device_memory_limit("0") is None
276
+ assert parse_device_memory_limit(0.0) is None
277
+ assert parse_device_memory_limit("0 GiB") is None
278
+
279
+ assert parse_device_memory_limit(1) == 1
280
+ assert parse_device_memory_limit("1") == 1
281
+
242
282
  assert parse_device_memory_limit("auto") == total
243
283
 
244
- assert parse_device_memory_limit(0.8) == int(total * 0.8)
245
- assert parse_device_memory_limit(0.8, alignment_size=256) == int(
246
- total * 0.8 // 256 * 256
247
- )
248
284
  assert parse_device_memory_limit(1000000000) == 1000000000
249
285
  assert parse_device_memory_limit("1GB") == 1000000000
250
286
 
287
+ if has_device_memory_resource(0):
288
+ assert parse_device_memory_limit(1.0) == total
289
+ assert parse_device_memory_limit("1.0") == total
290
+
291
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
292
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
293
+ total * 0.8 // 256 * 256
294
+ )
295
+ assert parse_device_memory_limit("default") == parse_device_memory_limit(0.8)
296
+ else:
297
+ assert parse_device_memory_limit("default") is None
298
+
299
+ with pytest.raises(ValueError):
300
+ assert parse_device_memory_limit(1.0) == total
301
+ with pytest.raises(ValueError):
302
+ assert parse_device_memory_limit("1.0") == total
303
+ with pytest.raises(ValueError):
304
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
305
+ with pytest.raises(ValueError):
306
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
307
+ total * 0.8 // 256 * 256
308
+ )
309
+
310
+
311
+ def test_has_device_memory_resoure():
312
+ has_memory_resource = has_device_memory_resource()
313
+ total = get_device_total_memory(0)
314
+
315
+ if has_memory_resource:
316
+ # Tested only in devices with a memory resource
317
+ assert total == parse_device_memory_limit("auto")
318
+ else:
319
+ # Tested only in devices without a memory resource
320
+ assert total is None
321
+
251
322
 
252
323
  def test_parse_visible_mig_devices():
253
324
  pynvml.nvmlInit()
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import pytest
2
5
 
3
6
  from distributed import Nanny
@@ -45,6 +48,9 @@ def test_worker_spec(
45
48
  enable_infiniband,
46
49
  enable_nvlink,
47
50
  ):
51
+ if protocol == "ucx":
52
+ pytest.importorskip("distributed_ucxx")
53
+
48
54
  def _test():
49
55
  return worker_spec(
50
56
  CUDA_VISIBLE_DEVICES=list(range(num_devices)),
dask_cuda/utils.py CHANGED
@@ -18,7 +18,6 @@ import pynvml
18
18
  import toolz
19
19
 
20
20
  import dask
21
- import distributed # noqa: required for dask.config.get("distributed.comm.ucx")
22
21
  from dask.config import canonical_name
23
22
  from dask.utils import format_bytes, parse_bytes
24
23
  from distributed import wait
@@ -43,7 +42,7 @@ def unpack_bitmask(x, mask_bits=64):
43
42
  x: list of int
44
43
  A list of integers
45
44
  mask_bits: int
46
- An integer determining the bitwidth of `x`
45
+ An integer determining the bitwidth of ``x``
47
46
 
48
47
  Examples
49
48
  --------
@@ -220,9 +219,44 @@ def get_device_total_memory(device_index=0):
220
219
  ----------
221
220
  device_index: int or str
222
221
  The index or UUID of the device from which to obtain the CPU affinity.
222
+
223
+ Returns
224
+ -------
225
+ The total memory of the CUDA Device in bytes, or ``None`` for devices that do not
226
+ have a dedicated memory resource, as is usually the case for system on a chip (SoC)
227
+ devices.
228
+ """
229
+ handle = get_gpu_handle(device_index)
230
+
231
+ try:
232
+ return pynvml.nvmlDeviceGetMemoryInfo(handle).total
233
+ except pynvml.NVMLError_NotSupported:
234
+ return None
235
+
236
+
237
+ def has_device_memory_resource(device_index=0):
238
+ """Determine wheter CUDA device has dedicated memory resource.
239
+
240
+ Certain devices have no dedicated memory resource, such as system on a chip (SoC)
241
+ devices.
242
+
243
+ Parameters
244
+ ----------
245
+ device_index: int or str
246
+ The index or UUID of the device from which to obtain the CPU affinity.
247
+
248
+ Returns
249
+ -------
250
+ Whether the device has a dedicated memory resource.
223
251
  """
224
252
  handle = get_gpu_handle(device_index)
225
- return pynvml.nvmlDeviceGetMemoryInfo(handle).total
253
+
254
+ try:
255
+ pynvml.nvmlDeviceGetMemoryInfo(handle).total
256
+ except pynvml.NVMLError_NotSupported:
257
+ return False
258
+ else:
259
+ return True
226
260
 
227
261
 
228
262
  def get_ucx_config(
@@ -231,10 +265,15 @@ def get_ucx_config(
231
265
  enable_nvlink=None,
232
266
  enable_rdmacm=None,
233
267
  ):
234
- ucx_config = dask.config.get("distributed.comm.ucx")
268
+ try:
269
+ import distributed_ucxx
270
+ except ImportError:
271
+ return None
272
+
273
+ distributed_ucxx.config.setup_config()
274
+ ucx_config = dask.config.get("distributed-ucxx")
235
275
 
236
276
  ucx_config[canonical_name("create-cuda-context", ucx_config)] = True
237
- ucx_config[canonical_name("reuse-endpoints", ucx_config)] = False
238
277
 
239
278
  # If any transport is explicitly disabled (`False`) by the user, others that
240
279
  # are not specified should be enabled (`True`). If transports are explicitly
@@ -316,7 +355,7 @@ def get_preload_options(
316
355
  if create_cuda_context:
317
356
  preload_options["preload_argv"].append("--create-cuda-context")
318
357
 
319
- if protocol in ["ucx", "ucxx"]:
358
+ if protocol in ("ucx", "ucxx"):
320
359
  initialize_ucx_argv = []
321
360
  if enable_tcp_over_ucx:
322
361
  initialize_ucx_argv.append("--enable-tcp-over-ucx")
@@ -365,7 +404,7 @@ def wait_workers(
365
404
  Instance of client, used to query for number of workers connected.
366
405
  min_timeout: float
367
406
  Minimum number of seconds to wait before timeout. This value may be
368
- overridden by setting the `DASK_CUDA_WAIT_WORKERS_MIN_TIMEOUT` with
407
+ overridden by setting the ``DASK_CUDA_WAIT_WORKERS_MIN_TIMEOUT`` with
369
408
  a positive integer.
370
409
  seconds_per_gpu: float
371
410
  Seconds to wait for each GPU on the system. For example, if its
@@ -374,7 +413,7 @@ def wait_workers(
374
413
  used as timeout when larger than min_timeout.
375
414
  n_gpus: None or int
376
415
  If specified, will wait for a that amount of GPUs (i.e., Dask workers)
377
- to come online, else waits for a total of `get_n_gpus` workers.
416
+ to come online, else waits for a total of ``get_n_gpus`` workers.
378
417
  timeout_callback: None or callable
379
418
  A callback function to be executed if a timeout occurs, ignored if
380
419
  None.
@@ -390,7 +429,7 @@ def wait_workers(
390
429
 
391
430
  start = time.time()
392
431
  while True:
393
- if len(client.scheduler_info()["workers"]) == n_gpus:
432
+ if len(client.scheduler_info(n_workers=-1)["workers"]) == n_gpus:
394
433
  return True
395
434
  elif time.time() - start > timeout:
396
435
  if callable(timeout_callback):
@@ -404,7 +443,7 @@ async def _all_to_all(client):
404
443
  """
405
444
  Trigger all to all communication between workers and scheduler
406
445
  """
407
- workers = list(client.scheduler_info()["workers"])
446
+ workers = list(client.scheduler_info(n_workers=-1)["workers"])
408
447
  futs = []
409
448
  for w in workers:
410
449
  bit_of_data = b"0" * 1
@@ -493,8 +532,8 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES):
493
532
  """Get the device index for NVML addressing
494
533
 
495
534
  NVML expects the index of the physical device, unlike CUDA runtime which
496
- expects the address relative to `CUDA_VISIBLE_DEVICES`. This function
497
- returns the i-th device index from the `CUDA_VISIBLE_DEVICES`
535
+ expects the address relative to ``CUDA_VISIBLE_DEVICES``. This function
536
+ returns the i-th device index from the ``CUDA_VISIBLE_DEVICES``
498
537
  comma-separated string of devices or list.
499
538
 
500
539
  Examples
@@ -532,15 +571,125 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES):
532
571
  raise ValueError("`CUDA_VISIBLE_DEVICES` must be `str` or `list`")
533
572
 
534
573
 
574
+ def parse_device_bytes(device_bytes, device_index=0, alignment_size=1):
575
+ """Parse bytes relative to a specific CUDA device.
576
+
577
+ Parameters
578
+ ----------
579
+ device_bytes: float, int, str or None
580
+ Can be an integer (bytes), float (fraction of total device memory), string
581
+ (like ``"5GB"`` or ``"5000M"``), ``0`` and ``None`` are special cases
582
+ returning ``None``.
583
+ device_index: int or str
584
+ The index or UUID of the device from which to obtain the total memory amount.
585
+ Default: 0.
586
+ alignment_size: int
587
+ Number of bytes of alignment to use, i.e., allocation must be a multiple of
588
+ that size. RMM pool requires 256 bytes alignment.
589
+
590
+ Returns
591
+ -------
592
+ The parsed bytes value relative to the CUDA devices, or ``None`` as convenience if
593
+ ``device_bytes`` is ``None`` or any value that would evaluate to ``0``.
594
+
595
+ Examples
596
+ --------
597
+ >>> # On a 32GB CUDA device
598
+ >>> parse_device_bytes(None)
599
+ None
600
+ >>> parse_device_bytes(0)
601
+ None
602
+ >>> parse_device_bytes(0.0)
603
+ None
604
+ >>> parse_device_bytes("0 MiB")
605
+ None
606
+ >>> parse_device_bytes(1.0)
607
+ 34089730048
608
+ >>> parse_device_bytes(0.8)
609
+ 27271784038
610
+ >>> parse_device_bytes(1000000000)
611
+ 1000000000
612
+ >>> parse_device_bytes("1GB")
613
+ 1000000000
614
+ >>> parse_device_bytes("1GB")
615
+ 1000000000
616
+ """
617
+
618
+ def _align(size, alignment_size):
619
+ return size // alignment_size * alignment_size
620
+
621
+ def parse_fractional(v):
622
+ """Parse fractional value.
623
+
624
+ Ensures ``int(1)`` and ``str("1")`` are not treated as fractionals, but
625
+ ``float(1)`` is.
626
+
627
+ Fractionals must be represented as a ``float`` within the range
628
+ ``0.0 < v <= 1.0``.
629
+
630
+ Parameters
631
+ ----------
632
+ v: int, float or str
633
+ The value to check if fractional.
634
+
635
+ Returns
636
+ -------
637
+ """
638
+ # Check if `x` matches exactly `int(1)` or `str("1")`, and is not a `float(1)`
639
+ is_one = lambda x: not isinstance(x, float) and (x == 1 or x == "1")
640
+
641
+ if not is_one(v):
642
+ with suppress(ValueError, TypeError):
643
+ v = float(v)
644
+ if 0.0 < v <= 1.0:
645
+ return v
646
+
647
+ raise ValueError("The value is not fractional")
648
+
649
+ # Special case for fractional limit. This comes before `0` special cases because
650
+ # the `float` may be passed in a `str`, e.g., from `CUDAWorker`.
651
+ try:
652
+ fractional_device_bytes = parse_fractional(device_bytes)
653
+ except ValueError:
654
+ pass
655
+ else:
656
+ if not has_device_memory_resource():
657
+ raise ValueError(
658
+ "Fractional of total device memory not supported in devices without "
659
+ "a dedicated memory resource."
660
+ )
661
+ return _align(
662
+ int(get_device_total_memory(device_index) * fractional_device_bytes),
663
+ alignment_size,
664
+ )
665
+
666
+ # Special cases that evaluates to `None` or `0`
667
+ if device_bytes is None:
668
+ return None
669
+ elif device_bytes == 0.0:
670
+ return None
671
+ elif not isinstance(device_bytes, float) and parse_bytes(device_bytes) == 0:
672
+ return None
673
+
674
+ if isinstance(device_bytes, str):
675
+ return _align(parse_bytes(device_bytes), alignment_size)
676
+ else:
677
+ return _align(int(device_bytes), alignment_size)
678
+
679
+
535
680
  def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_size=1):
536
681
  """Parse memory limit to be used by a CUDA device.
537
682
 
538
683
  Parameters
539
684
  ----------
540
685
  device_memory_limit: float, int, str or None
541
- This can be a float (fraction of total device memory), an integer (bytes),
542
- a string (like 5GB or 5000M), and "auto", 0 or None for the total device
543
- size.
686
+ Can be an integer (bytes), float (fraction of total device memory), string
687
+ (like ``"5GB"`` or ``"5000M"``), ``"auto"``, ``0`` or ``None`` to disable
688
+ spilling to host (i.e. allow full device memory usage). Another special value
689
+ ``"default"`` is also available and returns the recommended Dask-CUDA's defaults
690
+ and means 80% of the total device memory (analogous to ``0.8``), and disabled
691
+ spilling (analogous to ``auto``/``0``/``None``) on devices without a dedicated
692
+ memory resource, such as system on a chip (SoC) devices.
544
693
  device_index: int or str
545
694
  The index or UUID of the device from which to obtain the total memory amount.
546
695
  Default: 0.
@@ -548,10 +697,23 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
548
697
  Number of bytes of alignment to use, i.e., allocation must be a multiple of
549
698
  that size. RMM pool requires 256 bytes alignment.
550
699
 
700
+ Returns
701
+ -------
702
+ The parsed memory limit in bytes, or ``None`` as convenience if
703
+ ``device_memory_limit`` is ``None`` or any value that would evaluate to ``0``.
704
+
551
705
  Examples
552
706
  --------
553
707
  >>> # On a 32GB CUDA device
554
708
  >>> parse_device_memory_limit(None)
709
+ None
710
+ >>> parse_device_memory_limit(0)
711
+ None
712
+ >>> parse_device_memory_limit(0.0)
713
+ None
714
+ >>> parse_device_memory_limit("0 MiB")
715
+ None
716
+ >>> parse_device_memory_limit(1.0)
555
717
  34089730048
556
718
  >>> parse_device_memory_limit(0.8)
557
719
  27271784038
@@ -559,26 +721,36 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
559
721
  1000000000
560
722
  >>> parse_device_memory_limit("1GB")
561
723
  1000000000
724
+ >>> parse_device_memory_limit("1GB")
725
+ 1000000000
726
+ >>> parse_device_memory_limit("auto") == (
727
+ ... parse_device_memory_limit(1.0)
728
+ ... if has_device_memory_resource()
729
+ ... else None
730
+ ... )
731
+ True
732
+ >>> parse_device_memory_limit("default") == (
733
+ ... parse_device_memory_limit(0.8)
734
+ ... if has_device_memory_resource()
735
+ ... else None
736
+ ... )
737
+ True
562
738
  """
563
739
 
564
- def _align(size, alignment_size):
565
- return size // alignment_size * alignment_size
566
-
567
- if device_memory_limit in {0, "0", None, "auto"}:
568
- return _align(get_device_total_memory(device_index), alignment_size)
569
-
570
- with suppress(ValueError, TypeError):
571
- device_memory_limit = float(device_memory_limit)
572
- if isinstance(device_memory_limit, float) and device_memory_limit <= 1:
573
- return _align(
574
- int(get_device_total_memory(device_index) * device_memory_limit),
575
- alignment_size,
576
- )
740
+ # Special cases for "auto" and "default".
741
+ if device_memory_limit in ["auto", "default"]:
742
+ if not has_device_memory_resource():
743
+ return None
744
+ if device_memory_limit == "auto":
745
+ device_memory_limit = get_device_total_memory(device_index)
746
+ else:
747
+ device_memory_limit = 0.8
577
748
 
578
- if isinstance(device_memory_limit, str):
579
- return _align(parse_bytes(device_memory_limit), alignment_size)
580
- else:
581
- return _align(int(device_memory_limit), alignment_size)
749
+ return parse_device_bytes(
750
+ device_bytes=device_memory_limit,
751
+ device_index=device_index,
752
+ alignment_size=alignment_size,
753
+ )
582
754
 
583
755
 
584
756
  def get_gpu_uuid(device_index=0):
@@ -644,17 +816,14 @@ def get_worker_config(dask_worker):
644
816
  ret["device-memory-limit"] = dask_worker.data.manager._device_memory_limit
645
817
  else:
646
818
  has_device = hasattr(dask_worker.data, "device_buffer")
647
- if has_device:
819
+ if has_device and hasattr(dask_worker.data.device_buffer, "n"):
820
+ # If `n` is not an attribute, device spilling is disabled/unavailable.
648
821
  ret["device-memory-limit"] = dask_worker.data.device_buffer.n
649
822
 
650
823
  # using ucx ?
651
- scheme, loc = parse_address(dask_worker.scheduler.address)
652
- ret["protocol"] = scheme
653
- if scheme == "ucx":
654
- import ucp
655
-
656
- ret["ucx-transports"] = ucp.get_active_transports()
657
- elif scheme == "ucxx":
824
+ protocol, loc = parse_address(dask_worker.scheduler.address)
825
+ ret["protocol"] = protocol
826
+ if protocol in ("ucx", "ucxx"):
658
827
  import ucxx
659
828
 
660
829
  ret["ucx-transports"] = ucxx.get_active_transports()
@@ -689,7 +858,7 @@ async def _get_cluster_configuration(client):
689
858
  if worker_config:
690
859
  w = list(worker_config.values())[0]
691
860
  ret.update(w)
692
- info = client.scheduler_info()
861
+ info = client.scheduler_info(n_workers=-1)
693
862
  workers = info.get("workers", {})
694
863
  ret["nworkers"] = len(workers)
695
864
  ret["nthreads"] = sum(w["nthreads"] for w in workers.values())
@@ -767,7 +936,7 @@ def get_rmm_device_memory_usage() -> Optional[int]:
767
936
  """Get current bytes allocated on current device through RMM
768
937
 
769
938
  Check the current RMM resource stack for resources such as
770
- `StatisticsResourceAdaptor` and `TrackingResourceAdaptor`
939
+ ``StatisticsResourceAdaptor`` and ``TrackingResourceAdaptor``
771
940
  that can report the current allocated bytes. Returns None,
772
941
  if no such resources exist.
773
942
 
dask_cuda/utils_test.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  from typing import Literal
2
5
 
3
6
  import distributed
@@ -8,7 +11,7 @@ class MockWorker(Worker):
8
11
  """Mock Worker class preventing NVML from getting used by SystemMonitor.
9
12
 
10
13
  By preventing the Worker from initializing NVML in the SystemMonitor, we can
11
- mock test multiple devices in `CUDA_VISIBLE_DEVICES` behavior with single-GPU
14
+ mock test multiple devices in ``CUDA_VISIBLE_DEVICES`` behavior with single-GPU
12
15
  machines.
13
16
  """
14
17
 
@@ -26,17 +29,17 @@ class MockWorker(Worker):
26
29
 
27
30
 
28
31
  class IncreasedCloseTimeoutNanny(Nanny):
29
- """Increase `Nanny`'s close timeout.
32
+ """Increase ``Nanny``'s close timeout.
30
33
 
31
- The internal close timeout mechanism of `Nanny` recomputes the time left to kill
32
- the `Worker` process based on elapsed time of the close task, which may leave
34
+ The internal close timeout mechanism of ``Nanny`` recomputes the time left to kill
35
+ the ``Worker`` process based on elapsed time of the close task, which may leave
33
36
  very little time for the subprocess to shutdown cleanly, which may cause tests
34
37
  to fail when the system is under higher load. This class increases the default
35
- close timeout of 5.0 seconds that `Nanny` sets by default, which can be overriden
38
+ close timeout of 5.0 seconds that ``Nanny`` sets by default, which can be overriden
36
39
  via Distributed's public API.
37
40
 
38
- This class can be used with the `worker_class` argument of `LocalCluster` or
39
- `LocalCUDACluster` to provide a much higher default of 30.0 seconds.
41
+ This class can be used with the ``worker_class`` argument of ``LocalCluster`` or
42
+ ``LocalCUDACluster`` to provide a much higher default of 30.0 seconds.
40
43
  """
41
44
 
42
45
  async def close( # type:ignore[override]