dask-cuda 25.4.0__py3-none-any.whl → 25.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dask_cuda/GIT_COMMIT CHANGED
@@ -1 +1 @@
1
- e9ebd92886e6f518af02faf8a2cdadeb700b25a9
1
+ 1f834655ecc6286b9e3082f037594f70dcb74062
dask_cuda/VERSION CHANGED
@@ -1 +1 @@
1
- 25.04.00
1
+ 25.06.00
dask_cuda/_compat.py ADDED
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+
3
+ import functools
4
+ import importlib.metadata
5
+
6
+ import packaging.version
7
+
8
+
9
+ @functools.lru_cache(maxsize=None)
10
+ def get_dask_version() -> packaging.version.Version:
11
+ return packaging.version.parse(importlib.metadata.version("dask"))
12
+
13
+
14
+ @functools.lru_cache(maxsize=None)
15
+ def DASK_2025_4_0():
16
+ # dask 2025.4.0 isn't currently released, so we're relying
17
+ # on strictly greater than here.
18
+ return get_dask_version() > packaging.version.parse("2025.3.0")
@@ -1,6 +1,9 @@
1
+ # Copyright (c) 2021-2025 NVIDIA CORPORATION.
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import asyncio
6
+ import functools
4
7
  from collections import defaultdict
5
8
  from math import ceil
6
9
  from operator import getitem
@@ -23,6 +26,7 @@ from distributed import wait
23
26
  from distributed.protocol import nested_deserialize, to_serialize
24
27
  from distributed.worker import Worker
25
28
 
29
+ from ..._compat import DASK_2025_4_0
26
30
  from .. import comms
27
31
 
28
32
  T = TypeVar("T")
@@ -582,6 +586,128 @@ def _use_explicit_comms() -> bool:
582
586
  return False
583
587
 
584
588
 
589
+ _base_lower = dask_expr._shuffle.Shuffle._lower
590
+ _base_compute = dask.base.compute
591
+
592
+
593
+ def _contains_shuffle_expr(*args) -> bool:
594
+ """
595
+ Check whether any of the arguments is a Shuffle expression.
596
+
597
+ This is called by `compute`, which is given a sequence of Dask Collections
598
+ to process. For each of those, we'll check whether the expresion contains a
599
+ Shuffle operation.
600
+ """
601
+ for collection in args:
602
+ if isinstance(collection, dask.dataframe.DataFrame):
603
+ shuffle_ops = list(
604
+ collection.expr.find_operations(
605
+ (
606
+ dask_expr._shuffle.RearrangeByColumn,
607
+ dask_expr.SetIndex,
608
+ dask_expr._shuffle.Shuffle,
609
+ )
610
+ )
611
+ )
612
+ if len(shuffle_ops) > 0:
613
+ return True
614
+ return False
615
+
616
+
617
+ @functools.wraps(_base_compute)
618
+ def _patched_compute(
619
+ *args,
620
+ traverse=True,
621
+ optimize_graph=True,
622
+ scheduler=None,
623
+ get=None,
624
+ **kwargs,
625
+ ):
626
+ # A patched version of dask.compute that explicitly materializes the task
627
+ # graph when we're using explicit-comms and the expression contains a
628
+ # Shuffle operation.
629
+ # https://github.com/rapidsai/dask-upstream-testing/issues/37#issuecomment-2779798670
630
+ # contains more details on the issue.
631
+ if DASK_2025_4_0() and _use_explicit_comms() and _contains_shuffle_expr(*args):
632
+ from dask.base import (
633
+ collections_to_expr,
634
+ flatten,
635
+ get_scheduler,
636
+ shorten_traceback,
637
+ unpack_collections,
638
+ )
639
+
640
+ collections, repack = unpack_collections(*args, traverse=traverse)
641
+ if not collections:
642
+ return args
643
+
644
+ schedule = get_scheduler(
645
+ scheduler=scheduler,
646
+ collections=collections,
647
+ get=get,
648
+ )
649
+ from dask._expr import FinalizeCompute
650
+
651
+ expr = collections_to_expr(collections, optimize_graph)
652
+ expr = FinalizeCompute(expr)
653
+
654
+ with shorten_traceback():
655
+ expr = expr.optimize()
656
+ keys = list(flatten(expr.__dask_keys__()))
657
+
658
+ # materialize the HLG here
659
+ expr = dict(expr.__dask_graph__())
660
+
661
+ results = schedule(expr, keys, **kwargs)
662
+ return repack(results)
663
+
664
+ else:
665
+ return _base_compute(
666
+ *args,
667
+ traverse=traverse,
668
+ optimize_graph=optimize_graph,
669
+ scheduler=scheduler,
670
+ get=get,
671
+ **kwargs,
672
+ )
673
+
674
+
675
+ class ECShuffle(dask_expr._shuffle.TaskShuffle):
676
+ """Explicit-Comms Shuffle Expression."""
677
+
678
+ def _layer(self):
679
+ # Execute an explicit-comms shuffle
680
+ if not hasattr(self, "_ec_shuffled"):
681
+ on = self.partitioning_index
682
+ df = dask_expr.new_collection(self.frame)
683
+ ec_shuffled = shuffle(
684
+ df,
685
+ [on] if isinstance(on, str) else on,
686
+ self.npartitions_out,
687
+ self.ignore_index,
688
+ )
689
+ object.__setattr__(self, "_ec_shuffled", ec_shuffled)
690
+ graph = self._ec_shuffled.dask.copy()
691
+ shuffled_name = self._ec_shuffled._name
692
+ for i in range(self.npartitions_out):
693
+ graph[(self._name, i)] = graph[(shuffled_name, i)]
694
+ return graph
695
+
696
+
697
+ def _patched_lower(self):
698
+ if self.method in (None, "tasks") and _use_explicit_comms():
699
+ return ECShuffle(
700
+ self.frame,
701
+ self.partitioning_index,
702
+ self.npartitions_out,
703
+ self.ignore_index,
704
+ self.options,
705
+ self.original_partitioning_index,
706
+ )
707
+ else:
708
+ return _base_lower(self)
709
+
710
+
585
711
  def patch_shuffle_expression() -> None:
586
712
  """Patch Dasks Shuffle expression.
587
713
 
@@ -590,40 +716,6 @@ def patch_shuffle_expression() -> None:
590
716
  an `ECShuffle` expression when the 'explicit-comms'
591
717
  config is set to `True`.
592
718
  """
593
-
594
- class ECShuffle(dask_expr._shuffle.TaskShuffle):
595
- """Explicit-Comms Shuffle Expression."""
596
-
597
- def _layer(self):
598
- # Execute an explicit-comms shuffle
599
- if not hasattr(self, "_ec_shuffled"):
600
- on = self.partitioning_index
601
- df = dask_expr.new_collection(self.frame)
602
- self._ec_shuffled = shuffle(
603
- df,
604
- [on] if isinstance(on, str) else on,
605
- self.npartitions_out,
606
- self.ignore_index,
607
- )
608
- graph = self._ec_shuffled.dask.copy()
609
- shuffled_name = self._ec_shuffled._name
610
- for i in range(self.npartitions_out):
611
- graph[(self._name, i)] = graph[(shuffled_name, i)]
612
- return graph
613
-
614
- _base_lower = dask_expr._shuffle.Shuffle._lower
615
-
616
- def _patched_lower(self):
617
- if self.method in (None, "tasks") and _use_explicit_comms():
618
- return ECShuffle(
619
- self.frame,
620
- self.partitioning_index,
621
- self.npartitions_out,
622
- self.ignore_index,
623
- self.options,
624
- self.original_partitioning_index,
625
- )
626
- else:
627
- return _base_lower(self)
719
+ dask.base.compute = _patched_compute
628
720
 
629
721
  dask_expr._shuffle.Shuffle._lower = _patched_lower
@@ -1,3 +1,5 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
+
1
3
  from typing import Set
2
4
 
3
5
  from dask.sizeof import sizeof
@@ -140,3 +142,16 @@ def register_cupy(): # NB: this overwrites dask.sizeof.register_cupy()
140
142
  @sizeof.register(cupy.ndarray)
141
143
  def sizeof_cupy_ndarray(x):
142
144
  return int(x.nbytes)
145
+
146
+
147
+ @sizeof.register_lazy("pylibcudf")
148
+ def register_pylibcudf():
149
+ import pylibcudf
150
+
151
+ @sizeof.register(pylibcudf.column.OwnerWithCAI)
152
+ def sizeof_owner_with_cai(x):
153
+ # OwnerWithCAI implements __cuda_array_interface__ so this should always
154
+ # be zero-copy
155
+ col = pylibcudf.column.Column.from_cuda_array_interface(x)
156
+ # col.data() returns a gpumemoryview, which knows the size in bytes
157
+ return col.data().nbytes
@@ -1,3 +1,4 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
1
2
  from __future__ import absolute_import, division, print_function
2
3
 
3
4
  from dask.utils import Dispatch
@@ -35,6 +36,8 @@ def register_cudf():
35
36
  def is_device_object_cudf_series(s):
36
37
  return True
37
38
 
38
- @is_device_object.register(cudf.BaseIndex)
39
+ @is_device_object.register(cudf.Index)
40
+ @is_device_object.register(cudf.RangeIndex)
41
+ @is_device_object.register(cudf.MultiIndex)
39
42
  def is_device_object_cudf_index(s):
40
43
  return True
@@ -1,3 +1,4 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
1
2
  from __future__ import absolute_import, division, print_function
2
3
 
3
4
  from typing import Optional
@@ -34,7 +35,9 @@ def register_cudf():
34
35
  def is_device_object_cudf_dataframe(df):
35
36
  return cudf_spilling_status()
36
37
 
37
- @is_spillable_object.register(cudf.BaseIndex)
38
+ @is_spillable_object.register(cudf.Index)
39
+ @is_spillable_object.register(cudf.RangeIndex)
40
+ @is_spillable_object.register(cudf.MultiIndex)
38
41
  def is_device_object_cudf_index(s):
39
42
  return cudf_spilling_status()
40
43
 
@@ -1,3 +1,4 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
1
2
  import functools
2
3
  import pydoc
3
4
  from collections import defaultdict
@@ -242,7 +243,9 @@ def _register_cudf():
242
243
 
243
244
  @dispatch.register(cudf.DataFrame)
244
245
  @dispatch.register(cudf.Series)
245
- @dispatch.register(cudf.BaseIndex)
246
+ @dispatch.register(cudf.Index)
247
+ @dispatch.register(cudf.MultiIndex)
248
+ @dispatch.register(cudf.RangeIndex)
246
249
  def proxify_device_object_cudf_dataframe(
247
250
  obj, proxied_id_to_proxy, found_proxies, excl_proxies
248
251
  ):
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  from __future__ import absolute_import, division, print_function
2
5
 
3
6
  import os
@@ -16,7 +19,7 @@ from dask_cuda.utils import (
16
19
  get_cluster_configuration,
17
20
  get_device_total_memory,
18
21
  get_gpu_count_mig,
19
- get_gpu_uuid_from_index,
22
+ get_gpu_uuid,
20
23
  get_n_gpus,
21
24
  wait_workers,
22
25
  )
@@ -409,7 +412,7 @@ def test_cuda_mig_visible_devices_and_memory_limit_and_nthreads(loop): # noqa:
409
412
 
410
413
 
411
414
  def test_cuda_visible_devices_uuid(loop): # noqa: F811
412
- gpu_uuid = get_gpu_uuid_from_index(0)
415
+ gpu_uuid = get_gpu_uuid(0)
413
416
 
414
417
  with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": gpu_uuid}):
415
418
  with popen(["dask", "scheduler", "--port", "9359", "--no-dashboard"]):
@@ -21,18 +21,16 @@ from distributed.deploy.local import LocalCluster
21
21
 
22
22
  import dask_cuda
23
23
  from dask_cuda.explicit_comms import comms
24
- from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
24
+ from dask_cuda.explicit_comms.dataframe.shuffle import (
25
+ _contains_shuffle_expr,
26
+ shuffle as explicit_comms_shuffle,
27
+ )
25
28
  from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
26
29
 
27
30
  mp = mp.get_context("spawn") # type: ignore
28
31
  ucp = pytest.importorskip("ucp")
29
32
 
30
33
 
31
- # Set default shuffle method to "tasks"
32
- if dask.config.get("dataframe.shuffle.method", None) is None:
33
- dask.config.set({"dataframe.shuffle.method": "tasks"})
34
-
35
-
36
34
  # Notice, all of the following tests is executed in a new process such
37
35
  # that UCX options of the different tests doesn't conflict.
38
36
 
@@ -530,3 +528,20 @@ def test_scaled_cluster_gets_new_comms_context():
530
528
  expected = shuffled.compute()
531
529
 
532
530
  assert_eq(result, expected)
531
+
532
+
533
+ def test_contains_shuffle_expr():
534
+ df = dd.from_pandas(pd.DataFrame({"key": np.arange(10)}), npartitions=2)
535
+ assert not _contains_shuffle_expr(df)
536
+
537
+ with dask.config.set(explicit_comms=True):
538
+ shuffled = df.shuffle(on="key")
539
+
540
+ assert _contains_shuffle_expr(shuffled)
541
+ assert not _contains_shuffle_expr(df)
542
+
543
+ # this requires an active client.
544
+ with LocalCluster(n_workers=1) as cluster:
545
+ with Client(cluster):
546
+ explict_shuffled = explicit_comms_shuffle(df, ["key"])
547
+ assert not _contains_shuffle_expr(explict_shuffled)
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import asyncio
2
5
  import os
3
6
  import pkgutil
@@ -16,7 +19,7 @@ from dask_cuda.utils import (
16
19
  get_cluster_configuration,
17
20
  get_device_total_memory,
18
21
  get_gpu_count_mig,
19
- get_gpu_uuid_from_index,
22
+ get_gpu_uuid,
20
23
  print_cluster_config,
21
24
  )
22
25
  from dask_cuda.utils_test import MockWorker
@@ -419,7 +422,7 @@ async def test_available_mig_workers():
419
422
 
420
423
  @gen_test(timeout=20)
421
424
  async def test_gpu_uuid():
422
- gpu_uuid = get_gpu_uuid_from_index(0)
425
+ gpu_uuid = get_gpu_uuid(0)
423
426
 
424
427
  async with LocalCUDACluster(
425
428
  CUDA_VISIBLE_DEVICES=gpu_uuid,
@@ -1,3 +1,5 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
+
1
3
  from typing import Iterable
2
4
  from unittest.mock import patch
3
5
 
@@ -414,7 +416,7 @@ async def test_compatibility_mode_dataframe_shuffle(compatibility_mode, npartiti
414
416
  ddf = dask.dataframe.from_pandas(
415
417
  cudf.DataFrame({"key": np.arange(10)}), npartitions=npartitions
416
418
  )
417
- res = ddf.shuffle(on="key", shuffle_method="tasks").persist()
419
+ [res] = client.persist([ddf.shuffle(on="key", shuffle_method="tasks")])
418
420
 
419
421
  # With compatibility mode on, we shouldn't encounter any proxy objects
420
422
  if compatibility_mode:
@@ -440,7 +442,7 @@ async def test_worker_force_spill_to_disk():
440
442
  async with Client(cluster, asynchronous=True) as client:
441
443
  # Create a df that are spilled to host memory immediately
442
444
  df = cudf.DataFrame({"key": np.arange(10**8)})
443
- ddf = dask.dataframe.from_pandas(df, npartitions=1).persist()
445
+ [ddf] = client.persist([dask.dataframe.from_pandas(df, npartitions=1)])
444
446
  await ddf
445
447
 
446
448
  async def f(dask_worker):
@@ -498,3 +500,14 @@ def test_on_demand_debug_info():
498
500
  assert f"WARNING - RMM allocation of {size} failed" in log
499
501
  assert f"RMM allocs: {size}" in log
500
502
  assert "traceback:" in log
503
+
504
+
505
+ def test_sizeof_owner_with_cai():
506
+ cudf = pytest.importorskip("cudf")
507
+ s = cudf.Series([1, 2, 3])
508
+
509
+ items = dask_cuda.get_device_memory_objects.dispatch(s)
510
+ assert len(items) == 1
511
+ item = items[0]
512
+ result = dask.sizeof.sizeof(item)
513
+ assert result == 24
@@ -1,14 +1,18 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
+
1
3
  import gc
2
4
  import os
3
5
  from time import sleep
6
+ from typing import TypedDict
4
7
 
5
8
  import pytest
6
9
 
7
10
  import dask
8
11
  from dask import array as da
9
- from distributed import Client, wait
12
+ from distributed import Client, Worker, wait
10
13
  from distributed.metrics import time
11
14
  from distributed.sizeof import sizeof
15
+ from distributed.utils import Deadline
12
16
  from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401
13
17
 
14
18
  import dask_cudf
@@ -72,24 +76,66 @@ def cudf_spill(request):
72
76
 
73
77
 
74
78
  def device_host_file_size_matches(
75
- dhf, total_bytes, device_chunk_overhead=0, serialized_chunk_overhead=1024
79
+ dask_worker: Worker,
80
+ total_bytes,
81
+ device_chunk_overhead=0,
82
+ serialized_chunk_overhead=1024,
76
83
  ):
77
- byte_sum = dhf.device_buffer.fast.total_weight
84
+ worker_data_sizes = collect_device_host_file_size(
85
+ dask_worker,
86
+ device_chunk_overhead=device_chunk_overhead,
87
+ serialized_chunk_overhead=serialized_chunk_overhead,
88
+ )
89
+ byte_sum = (
90
+ worker_data_sizes["device_fast"]
91
+ + worker_data_sizes["host_fast"]
92
+ + worker_data_sizes["host_buffer"]
93
+ + worker_data_sizes["disk"]
94
+ )
95
+ return (
96
+ byte_sum >= total_bytes
97
+ and byte_sum
98
+ <= total_bytes
99
+ + worker_data_sizes["device_overhead"]
100
+ + worker_data_sizes["host_overhead"]
101
+ + worker_data_sizes["disk_overhead"]
102
+ )
103
+
104
+
105
+ class WorkerDataSizes(TypedDict):
106
+ device_fast: int
107
+ host_fast: int
108
+ host_buffer: int
109
+ disk: int
110
+ device_overhead: int
111
+ host_overhead: int
112
+ disk_overhead: int
113
+
114
+
115
+ def collect_device_host_file_size(
116
+ dask_worker: Worker,
117
+ device_chunk_overhead: int,
118
+ serialized_chunk_overhead: int,
119
+ ) -> WorkerDataSizes:
120
+ dhf = dask_worker.data
78
121
 
79
- # `dhf.host_buffer.fast` is only available when Worker's `memory_limit != 0`
122
+ device_fast = dhf.device_buffer.fast.total_weight or 0
80
123
  if hasattr(dhf.host_buffer, "fast"):
81
- byte_sum += dhf.host_buffer.fast.total_weight
124
+ host_fast = dhf.host_buffer.fast.total_weight or 0
125
+ host_buffer = 0
82
126
  else:
83
- byte_sum += sum([sizeof(b) for b in dhf.host_buffer.values()])
127
+ host_buffer = sum([sizeof(b) for b in dhf.host_buffer.values()])
128
+ host_fast = 0
84
129
 
85
- # `dhf.disk` is only available when Worker's `memory_limit != 0`
86
130
  if dhf.disk is not None:
87
131
  file_path = [
88
132
  os.path.join(dhf.disk.directory, fname)
89
133
  for fname in dhf.disk.filenames.values()
90
134
  ]
91
135
  file_size = [os.path.getsize(f) for f in file_path]
92
- byte_sum += sum(file_size)
136
+ disk = sum(file_size)
137
+ else:
138
+ disk = 0
93
139
 
94
140
  # Allow up to chunk_overhead bytes overhead per chunk
95
141
  device_overhead = len(dhf.device) * device_chunk_overhead
@@ -98,17 +144,25 @@ def device_host_file_size_matches(
98
144
  len(dhf.disk) * serialized_chunk_overhead if dhf.disk is not None else 0
99
145
  )
100
146
 
101
- return (
102
- byte_sum >= total_bytes
103
- and byte_sum <= total_bytes + device_overhead + host_overhead + disk_overhead
147
+ return WorkerDataSizes(
148
+ device_fast=device_fast,
149
+ host_fast=host_fast,
150
+ host_buffer=host_buffer,
151
+ disk=disk,
152
+ device_overhead=device_overhead,
153
+ host_overhead=host_overhead,
154
+ disk_overhead=disk_overhead,
104
155
  )
105
156
 
106
157
 
107
158
  def assert_device_host_file_size(
108
- dhf, total_bytes, device_chunk_overhead=0, serialized_chunk_overhead=1024
159
+ dask_worker: Worker,
160
+ total_bytes,
161
+ device_chunk_overhead=0,
162
+ serialized_chunk_overhead=1024,
109
163
  ):
110
164
  assert device_host_file_size_matches(
111
- dhf, total_bytes, device_chunk_overhead, serialized_chunk_overhead
165
+ dask_worker, total_bytes, device_chunk_overhead, serialized_chunk_overhead
112
166
  )
113
167
 
114
168
 
@@ -119,7 +173,7 @@ def worker_assert(
119
173
  dask_worker=None,
120
174
  ):
121
175
  assert_device_host_file_size(
122
- dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
176
+ dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
123
177
  )
124
178
 
125
179
 
@@ -131,12 +185,12 @@ def delayed_worker_assert(
131
185
  ):
132
186
  start = time()
133
187
  while not device_host_file_size_matches(
134
- dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
188
+ dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
135
189
  ):
136
190
  sleep(0.01)
137
191
  if time() < start + 3:
138
192
  assert_device_host_file_size(
139
- dask_worker.data,
193
+ dask_worker,
140
194
  total_size,
141
195
  device_chunk_overhead,
142
196
  serialized_chunk_overhead,
@@ -224,8 +278,8 @@ async def test_cupy_cluster_device_spill(params):
224
278
  x = rs.random(int(50e6), chunks=2e6)
225
279
  await wait(x)
226
280
 
227
- xx = x.persist()
228
- await wait(xx)
281
+ [xx] = client.persist([x])
282
+ await xx
229
283
 
230
284
  # Allow up to 1024 bytes overhead per chunk serialized
231
285
  await client.run(
@@ -344,19 +398,38 @@ async def test_cudf_cluster_device_spill(params, cudf_spill):
344
398
  sizes = sizes.to_arrow().to_pylist()
345
399
  nbytes = sum(sizes)
346
400
 
347
- cdf2 = cdf.persist()
348
- await wait(cdf2)
401
+ [cdf2] = client.persist([cdf])
402
+ await cdf2
349
403
 
350
404
  del cdf
351
405
  gc.collect()
352
406
 
353
407
  if enable_cudf_spill:
354
- await client.run(
355
- worker_assert,
356
- 0,
357
- 0,
358
- 0,
408
+ expected_data = WorkerDataSizes(
409
+ device_fast=0,
410
+ host_fast=0,
411
+ host_buffer=0,
412
+ disk=0,
413
+ device_overhead=0,
414
+ host_overhead=0,
415
+ disk_overhead=0,
359
416
  )
417
+
418
+ deadline = Deadline.after(duration=3)
419
+ while not deadline.expired:
420
+ data = await client.run(
421
+ collect_device_host_file_size,
422
+ device_chunk_overhead=0,
423
+ serialized_chunk_overhead=0,
424
+ )
425
+ expected = {k: expected_data for k in data}
426
+ if data == expected:
427
+ break
428
+ sleep(0.01)
429
+
430
+ # final assertion for pytest to reraise with a nice traceback
431
+ assert data == expected
432
+
360
433
  else:
361
434
  await client.run(
362
435
  assert_host_chunks,
@@ -419,8 +492,8 @@ async def test_cudf_spill_cluster(cudf_spill):
419
492
  }
420
493
  )
421
494
 
422
- ddf = dask_cudf.from_cudf(cdf, npartitions=2).sum().persist()
423
- await wait(ddf)
495
+ [ddf] = client.persist([dask_cudf.from_cudf(cdf, npartitions=2).sum()])
496
+ await ddf
424
497
 
425
498
  await client.run(_assert_cudf_spill_stats, enable_cudf_spill)
426
499
  _assert_cudf_spill_stats(enable_cudf_spill)
dask_cuda/utils.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import math
2
5
  import operator
3
6
  import os
@@ -86,6 +89,45 @@ def get_gpu_count():
86
89
  return pynvml.nvmlDeviceGetCount()
87
90
 
88
91
 
92
+ def get_gpu_handle(device_id=0):
93
+ """Get GPU handle from device index or UUID.
94
+
95
+ Parameters
96
+ ----------
97
+ device_id: int or str
98
+ The index or UUID of the device from which to obtain the handle.
99
+
100
+ Raises
101
+ ------
102
+ ValueError
103
+ If acquiring the device handle for the device specified failed.
104
+ pynvml.NVMLError
105
+ If any NVML error occurred while initializing.
106
+
107
+ Examples
108
+ --------
109
+ >>> get_gpu_handle(device_id=0)
110
+
111
+ >>> get_gpu_handle(device_id="GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
112
+ """
113
+ pynvml.nvmlInit()
114
+
115
+ try:
116
+ if device_id and not str(device_id).isnumeric():
117
+ # This means device_id is UUID.
118
+ # This works for both MIG and non-MIG device UUIDs.
119
+ handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_id))
120
+ if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
121
+ # Additionally get parent device handle
122
+ # if the device itself is a MIG instance
123
+ handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
124
+ else:
125
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
126
+ return handle
127
+ except pynvml.NVMLError:
128
+ raise ValueError(f"Invalid device index or UUID: {device_id}")
129
+
130
+
89
131
  @toolz.memoize
90
132
  def get_gpu_count_mig(return_uuids=False):
91
133
  """Return the number of MIG instances available
@@ -129,7 +171,7 @@ def get_cpu_affinity(device_index=None):
129
171
  Parameters
130
172
  ----------
131
173
  device_index: int or str
132
- Index or UUID of the GPU device
174
+ The index or UUID of the device from which to obtain the CPU affinity.
133
175
 
134
176
  Examples
135
177
  --------
@@ -148,26 +190,15 @@ def get_cpu_affinity(device_index=None):
148
190
  40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
149
191
  60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
150
192
  """
151
- pynvml.nvmlInit()
152
-
153
193
  try:
154
- if device_index and not str(device_index).isnumeric():
155
- # This means device_index is UUID.
156
- # This works for both MIG and non-MIG device UUIDs.
157
- handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_index))
158
- if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
159
- # Additionally get parent device handle
160
- # if the device itself is a MIG instance
161
- handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
162
- else:
163
- handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
194
+ handle = get_gpu_handle(device_index)
164
195
  # Result is a list of 64-bit integers, thus ceil(get_cpu_count() / 64)
165
196
  affinity = pynvml.nvmlDeviceGetCpuAffinity(
166
197
  handle,
167
198
  math.ceil(get_cpu_count() / 64),
168
199
  )
169
200
  return unpack_bitmask(affinity)
170
- except pynvml.NVMLError:
201
+ except (pynvml.NVMLError, ValueError):
171
202
  warnings.warn(
172
203
  "Cannot get CPU affinity for device with index %d, setting default affinity"
173
204
  % device_index
@@ -182,18 +213,15 @@ def get_n_gpus():
182
213
  return get_gpu_count()
183
214
 
184
215
 
185
- def get_device_total_memory(index=0):
186
- """
187
- Return total memory of CUDA device with index or with device identifier UUID
188
- """
189
- pynvml.nvmlInit()
216
+ def get_device_total_memory(device_index=0):
217
+ """Return total memory of CUDA device with index or with device identifier UUID.
190
218
 
191
- if index and not str(index).isnumeric():
192
- # This means index is UUID. This works for both MIG and non-MIG device UUIDs.
193
- handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(str(index)))
194
- else:
195
- # This is a device index
196
- handle = pynvml.nvmlDeviceGetHandleByIndex(index)
219
+ Parameters
220
+ ----------
221
+ device_index: int or str
222
+ The index or UUID of the device from which to obtain the CPU affinity.
223
+ """
224
+ handle = get_gpu_handle(device_index)
197
225
  return pynvml.nvmlDeviceGetMemoryInfo(handle).total
198
226
 
199
227
 
@@ -553,26 +581,26 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
553
581
  return _align(int(device_memory_limit), alignment_size)
554
582
 
555
583
 
556
- def get_gpu_uuid_from_index(device_index=0):
584
+ def get_gpu_uuid(device_index=0):
557
585
  """Get GPU UUID from CUDA device index.
558
586
 
559
587
  Parameters
560
588
  ----------
561
589
  device_index: int or str
562
- The index of the device from which to obtain the UUID. Default: 0.
590
+ The index or UUID of the device from which to obtain the UUID.
563
591
 
564
592
  Examples
565
593
  --------
566
- >>> get_gpu_uuid_from_index()
594
+ >>> get_gpu_uuid()
567
595
  'GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005'
568
596
 
569
- >>> get_gpu_uuid_from_index(3)
597
+ >>> get_gpu_uuid(3)
570
598
  'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
571
- """
572
- import pynvml
573
599
 
574
- pynvml.nvmlInit()
575
- handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
600
+ >>> get_gpu_uuid("GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
601
+ 'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
602
+ """
603
+ handle = get_gpu_handle(device_index)
576
604
  try:
577
605
  return pynvml.nvmlDeviceGetUUID(handle).decode("utf-8")
578
606
  except AttributeError:
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dask-cuda
3
- Version: 25.4.0
3
+ Version: 25.6.0
4
4
  Summary: Utilities for Dask and CUDA interactions
5
5
  Author: NVIDIA Corporation
6
- License: Apache 2.0
6
+ License: Apache-2.0
7
7
  Project-URL: Homepage, https://github.com/rapidsai/dask-cuda
8
8
  Project-URL: Documentation, https://docs.rapids.ai/api/dask-cuda/stable/
9
9
  Project-URL: Source, https://github.com/rapidsai/dask-cuda
@@ -15,15 +15,16 @@ Classifier: Programming Language :: Python :: 3
15
15
  Classifier: Programming Language :: Python :: 3.10
16
16
  Classifier: Programming Language :: Python :: 3.11
17
17
  Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
18
19
  Requires-Python: >=3.10
19
20
  Description-Content-Type: text/markdown
20
21
  License-File: LICENSE
21
22
  Requires-Dist: click>=8.1
22
- Requires-Dist: numba<0.61.0a0,>=0.59.1
23
+ Requires-Dist: numba<0.62.0a0,>=0.59.1
23
24
  Requires-Dist: numpy<3.0a0,>=1.23
24
25
  Requires-Dist: pandas>=1.3
25
26
  Requires-Dist: pynvml<13.0.0a0,>=12.0.0
26
- Requires-Dist: rapids-dask-dependency==25.4.*
27
+ Requires-Dist: rapids-dask-dependency==25.6.*
27
28
  Requires-Dist: zict>=2.0.0
28
29
  Provides-Extra: docs
29
30
  Requires-Dist: numpydoc>=1.1.0; extra == "docs"
@@ -1,21 +1,22 @@
1
- dask_cuda/GIT_COMMIT,sha256=wbY8QunTBf6nZeA4ulUfzAdQWyE7hoxV330KmJ3VnjA,41
2
- dask_cuda/VERSION,sha256=EM36MPurzJgotElKb8R7ZaIOF2woBA69gsVnmiyf-LY,8
1
+ dask_cuda/GIT_COMMIT,sha256=TiWUPXNqs5gL3lxRLAbL9S16XUILnjLBQ-tX9pxEkwE,41
2
+ dask_cuda/VERSION,sha256=mkkPLCPxib-wy79AMMpM4Bq103DbRbHiXhZFFnGa_sk,8
3
3
  dask_cuda/__init__.py,sha256=Wbc7R0voN4vsQkb7SKuVXH0YXuXtfnAxrupxfM4lT10,1933
4
+ dask_cuda/_compat.py,sha256=AG2lKGAtZitDPBjHeFDKLTN_B5HKodrhZ2kHlk1Z-D0,498
4
5
  dask_cuda/_version.py,sha256=cHDO9AzNtxkCVhwYu7hL3H7RPAkQnxpKBjElOst3rkI,964
5
6
  dask_cuda/cli.py,sha256=cScVyNiA_l9uXeDgkIcmbcR4l4cH1_1shqSqsVmuHPE,17053
6
7
  dask_cuda/cuda_worker.py,sha256=rZ1ITG_ZCbuaMA9e8uSqCjU8Km4AMphGGrxpBPQG8xU,9477
7
8
  dask_cuda/device_host_file.py,sha256=yS31LGtt9VFAG78uBBlTDr7HGIng2XymV1OxXIuEMtM,10272
8
9
  dask_cuda/disk_io.py,sha256=urSLKiPvJvYmKCzDPOUDCYuLI3r1RUiyVh3UZGRoF_Y,6626
9
- dask_cuda/get_device_memory_objects.py,sha256=R3U2cq4fJZPgtsUKyIguy9161p3Q99oxmcCmTcg6BtQ,4075
10
+ dask_cuda/get_device_memory_objects.py,sha256=peqXY8nAOtZpo9Pk1innP0rKySB8X4647YYqrwLYPHo,4569
10
11
  dask_cuda/initialize.py,sha256=Gjcxs_c8DTafgsHe5-2mw4lJdOmbFJJAZVOnxA8lTjM,6462
11
- dask_cuda/is_device_object.py,sha256=CnajvbQiX0FzFzwft0MqK1OPomx3ZGDnDxT56wNjixw,1046
12
- dask_cuda/is_spillable_object.py,sha256=CddGmg0tuSpXh2m_TJSY6GRpnl1WRHt1CRcdWgHPzWA,1457
12
+ dask_cuda/is_device_object.py,sha256=x9klFdeQzLcug7wZMxN3GK2AS121tlDe-LQ2uznm5yo,1179
13
+ dask_cuda/is_spillable_object.py,sha256=8gj6QgtKcmzrpQwy8rE-pS1R8tjaJOeD-Fzr6LumjJg,1596
13
14
  dask_cuda/local_cuda_cluster.py,sha256=wqwKVRV6jT13sf9e-XsvbVBlTrnhmcbmHQBFPTFcayw,20335
14
15
  dask_cuda/plugins.py,sha256=A2aT8HA6q_JhIEx6-XKcpbWEbl7aTg1GNoZQH8_vh00,7197
15
- dask_cuda/proxify_device_objects.py,sha256=99CD7LOE79YiQGJ12sYl_XImVhJXpFR4vG5utdkjTQo,8108
16
+ dask_cuda/proxify_device_objects.py,sha256=jWljqWddOT8NksyNKOh_9nFoV70_3P6s8P91oXdCfEk,8225
16
17
  dask_cuda/proxify_host_file.py,sha256=Wf5CFCC1JN5zmfvND3ls0M5FL01Y8VhHrk0xV3UQ9kk,30850
17
18
  dask_cuda/proxy_object.py,sha256=mrCCGwS-mltcY8oddJEXnPL6rV2dBpGgsFypBVbxRsA,30150
18
- dask_cuda/utils.py,sha256=Goq-m78rYZ-bcJitg47N1h_PC4PDuzXG0CUVH7V8azU,25515
19
+ dask_cuda/utils.py,sha256=wJ-oTj6mJHojz7JEMTh_QFnvz5igj4ULCbpI0r_XqMY,26273
19
20
  dask_cuda/utils_test.py,sha256=WNMR0gic2tuP3pgygcR9g52NfyX8iGMOan6juXhpkCE,1694
20
21
  dask_cuda/worker_spec.py,sha256=7-Uq_e5q2SkTlsmctMcYLCa9_3RiiVHZLIN7ctfaFmE,4376
21
22
  dask_cuda/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -30,27 +31,27 @@ dask_cuda/benchmarks/utils.py,sha256=_x0XXL_F3W-fExpuQfTBwuK3WnrVuXQQepbnvjUqS9o
30
31
  dask_cuda/explicit_comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
32
  dask_cuda/explicit_comms/comms.py,sha256=uq-XPOH38dFcYS_13Vomj2ER6zxQz7DPeSM000mOVmY,11541
32
33
  dask_cuda/explicit_comms/dataframe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
- dask_cuda/explicit_comms/dataframe/shuffle.py,sha256=g9xDyFKmblEuevZt5Drh66uMLw-LUNOI8CIucDdACmY,21231
34
+ dask_cuda/explicit_comms/dataframe/shuffle.py,sha256=yG9_7BuXSswiZjFfs6kVdHBA2-mlSBKN1i6phgNTJMY,23815
34
35
  dask_cuda/tests/test_cudf_builtin_spilling.py,sha256=qVN9J0Hdv66A9COFArLIdRriyyxEKpS3lEZGHbVHaq8,4903
35
- dask_cuda/tests/test_dask_cuda_worker.py,sha256=C1emlr47yGa3TdSSlAXJRzguY4bcH74htk21x9th7nQ,20556
36
+ dask_cuda/tests/test_dask_cuda_worker.py,sha256=yG_RcOTF6vt-LBBVrjEQ_2vRZvVfFgFDMedPMSzkFws,20657
36
37
  dask_cuda/tests/test_device_host_file.py,sha256=79ssUISo1YhsW_7HdwqPfsH2LRzS2bi5BjPym1Sdgqw,5882
37
38
  dask_cuda/tests/test_dgx.py,sha256=BPCF4ZvhrVKkT43OOFHdijuo-M34vW3V18C8rRH1HXg,7489
38
- dask_cuda/tests/test_explicit_comms.py,sha256=xnQjjUrd6RFd9CS99pVuWY1frfiMXzRv_fW4rk9opOk,19465
39
+ dask_cuda/tests/test_explicit_comms.py,sha256=hrNrTKP-pBSohyUqn1hnXKkUttGwRLeYY2bniEXM1FM,19944
39
40
  dask_cuda/tests/test_from_array.py,sha256=okT1B6UqHmLxoy0uER0Ylm3UyOmi5BAXwJpTuTAw44I,601
40
41
  dask_cuda/tests/test_gds.py,sha256=j1Huud6UGm1fbkyRLQEz_ysrVw__5AimwSn_M-2GEvs,1513
41
42
  dask_cuda/tests/test_initialize.py,sha256=4Ovv_ClokKibPX6wfuaoQgN4eKCohagRFoE3s3D7Huk,8119
42
- dask_cuda/tests/test_local_cuda_cluster.py,sha256=Lc9QncyGwBwhaZPGBfreXJf3ZC9Zd8SjDc2fpeQ-BT0,19710
43
- dask_cuda/tests/test_proxify_host_file.py,sha256=LC3jjo_gbfhdIy1Zy_ynmgyv31HXFoBINCe1-XXZ4XU,18994
43
+ dask_cuda/tests/test_local_cuda_cluster.py,sha256=AiVUx3PkuIeobw1QXdr3mvom_l8DFVvRIvMQE91zAag,19811
44
+ dask_cuda/tests/test_proxify_host_file.py,sha256=pFORynzqGpe9mz_rPwTVW6O4VoY2E0EmjsT7Ux_c920,19333
44
45
  dask_cuda/tests/test_proxy.py,sha256=U9uE-QesTwquNKzTReEKiYgoRgS_pfGW-A-gJNppHyg,23817
45
- dask_cuda/tests/test_spill.py,sha256=CYMbp5HDBYlZ7T_n8RfSOZxaWFcAQKjprjRM7Wupcdw,13419
46
+ dask_cuda/tests/test_spill.py,sha256=A4-pJWCfShUaEGKbUdeIpcVL8zCyyPfAjdlJ0As3LDQ,15462
46
47
  dask_cuda/tests/test_utils.py,sha256=PQI_oTONWnKSKlkQfEeK-vlmYa0-cPpDjDEbm74cNCE,9104
47
48
  dask_cuda/tests/test_version.py,sha256=vK2HjlRLX0nxwvRsYxBqhoZryBNZklzA-vdnyuWDxVg,365
48
49
  dask_cuda/tests/test_worker_spec.py,sha256=Bvu85vkqm6ZDAYPXKMJlI2pm9Uc5tiYKNtO4goXSw-I,2399
49
- dask_cuda-25.4.0.dist-info/licenses/LICENSE,sha256=MjI3I-EgxfEvZlgjk82rgiFsZqSDXHFETd2QJ89UwDA,11348
50
+ dask_cuda-25.6.0.dist-info/licenses/LICENSE,sha256=MjI3I-EgxfEvZlgjk82rgiFsZqSDXHFETd2QJ89UwDA,11348
50
51
  examples/ucx/client_initialize.py,sha256=YN3AXHF8btcMd6NicKKhKR9SXouAsK1foJhFspbOn70,1262
51
52
  examples/ucx/local_cuda_cluster.py,sha256=7xVY3EhwhkY2L4VZin_BiMCbrjhirDNChoC86KiETNc,1983
52
- dask_cuda-25.4.0.dist-info/METADATA,sha256=udK2maTnpkUBnOOtTvGOwySUtJxnIo4rcIOmySPBuOk,2294
53
- dask_cuda-25.4.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
54
- dask_cuda-25.4.0.dist-info/entry_points.txt,sha256=UcRaKVEpywtxc6pF1VnfMB0UK4sJg7a8_NdZF67laPM,136
55
- dask_cuda-25.4.0.dist-info/top_level.txt,sha256=3kKxJxeM108fuYc_lwwlklP7YBU9IEmdmRAouzi397o,33
56
- dask_cuda-25.4.0.dist-info/RECORD,,
53
+ dask_cuda-25.6.0.dist-info/METADATA,sha256=Eolq3LbkRkU0ukh5enuHzLIK-YYo-Q_PX2bobo-rT1E,2345
54
+ dask_cuda-25.6.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
+ dask_cuda-25.6.0.dist-info/entry_points.txt,sha256=UcRaKVEpywtxc6pF1VnfMB0UK4sJg7a8_NdZF67laPM,136
56
+ dask_cuda-25.6.0.dist-info/top_level.txt,sha256=S_m57qClWFTZ9rBMNTPikpBiy9vTn6_4pjGuInt0XE8,28
57
+ dask_cuda-25.6.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,5 +1,4 @@
1
1
  ci
2
2
  conda
3
3
  dask_cuda
4
- dist
5
4
  examples