dask-cuda 25.2.0__py3-none-any.whl → 25.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dask_cuda/GIT_COMMIT ADDED
@@ -0,0 +1 @@
1
+ 1f834655ecc6286b9e3082f037594f70dcb74062
dask_cuda/VERSION CHANGED
@@ -1 +1 @@
1
- 25.02.00
1
+ 25.06.00
dask_cuda/__init__.py CHANGED
@@ -5,8 +5,6 @@ if sys.platform != "linux":
5
5
 
6
6
  import dask
7
7
  import dask.utils
8
- import dask.dataframe.shuffle
9
- from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
10
8
  from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
11
9
  from distributed.protocol.serialize import dask_deserialize, dask_serialize
12
10
 
@@ -14,30 +12,43 @@ from ._version import __git_commit__, __version__
14
12
  from .cuda_worker import CUDAWorker
15
13
 
16
14
  from .local_cuda_cluster import LocalCUDACluster
17
- from .proxify_device_objects import proxify_decorator, unproxify_decorator
18
15
 
19
16
 
20
- # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
21
- patch_shuffle_expression()
22
- # Monkey patching Dask to make use of proxify and unproxify in compatibility mode
23
- dask.dataframe.shuffle.shuffle_group = proxify_decorator(
24
- dask.dataframe.shuffle.shuffle_group
25
- )
26
- dask.dataframe.core._concat = unproxify_decorator(dask.dataframe.core._concat)
27
-
28
-
29
- def _register_cudf_spill_aware():
30
- import cudf
31
-
32
- # Only enable Dask/cuDF spilling if cuDF spilling is disabled, see
33
- # https://github.com/rapidsai/dask-cuda/issues/1363
34
- if not cudf.get_option("spill"):
35
- # This reproduces the implementation of `_register_cudf`, see
36
- # https://github.com/dask/distributed/blob/40fcd65e991382a956c3b879e438be1b100dff97/distributed/protocol/__init__.py#L106-L115
37
- from cudf.comm import serialize
38
-
39
-
40
- for registry in [cuda_serialize, cuda_deserialize, dask_serialize, dask_deserialize]:
41
- for lib in ["cudf", "dask_cudf"]:
42
- if lib in registry._lazy:
43
- registry._lazy[lib] = _register_cudf_spill_aware
17
+ try:
18
+ import dask.dataframe as dask_dataframe
19
+ except ImportError:
20
+ # Dask DataFrame (optional) isn't installed
21
+ dask_dataframe = None
22
+
23
+
24
+ if dask_dataframe is not None:
25
+ from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
26
+ from .proxify_device_objects import proxify_decorator, unproxify_decorator
27
+
28
+ # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
29
+ patch_shuffle_expression()
30
+ # Monkey patching Dask to make use of proxify and unproxify in compatibility mode
31
+ dask_dataframe.shuffle.shuffle_group = proxify_decorator(
32
+ dask.dataframe.shuffle.shuffle_group
33
+ )
34
+ dask_dataframe.core._concat = unproxify_decorator(dask.dataframe.core._concat)
35
+
36
+ def _register_cudf_spill_aware():
37
+ import cudf
38
+
39
+ # Only enable Dask/cuDF spilling if cuDF spilling is disabled, see
40
+ # https://github.com/rapidsai/dask-cuda/issues/1363
41
+ if not cudf.get_option("spill"):
42
+ # This reproduces the implementation of `_register_cudf`, see
43
+ # https://github.com/dask/distributed/blob/40fcd65e991382a956c3b879e438be1b100dff97/distributed/protocol/__init__.py#L106-L115
44
+ from cudf.comm import serialize
45
+
46
+ for registry in [
47
+ cuda_serialize,
48
+ cuda_deserialize,
49
+ dask_serialize,
50
+ dask_deserialize,
51
+ ]:
52
+ for lib in ["cudf", "dask_cudf"]:
53
+ if lib in registry._lazy:
54
+ registry._lazy[lib] = _register_cudf_spill_aware
dask_cuda/_compat.py ADDED
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+
3
+ import functools
4
+ import importlib.metadata
5
+
6
+ import packaging.version
7
+
8
+
9
+ @functools.lru_cache(maxsize=None)
10
+ def get_dask_version() -> packaging.version.Version:
11
+ return packaging.version.parse(importlib.metadata.version("dask"))
12
+
13
+
14
+ @functools.lru_cache(maxsize=None)
15
+ def DASK_2025_4_0():
16
+ # dask 2025.4.0 isn't currently released, so we're relying
17
+ # on strictly greater than here.
18
+ return get_dask_version() > packaging.version.parse("2025.3.0")
@@ -1,15 +1,21 @@
1
+ # Copyright (c) 2021-2025 NVIDIA CORPORATION.
1
2
  import asyncio
2
3
  import concurrent.futures
3
4
  import contextlib
4
5
  import time
5
6
  import uuid
7
+ import weakref
6
8
  from typing import Any, Dict, Hashable, Iterable, List, Optional
7
9
 
8
10
  import distributed.comm
11
+ from dask.tokenize import tokenize
9
12
  from distributed import Client, Worker, default_client, get_worker
10
13
  from distributed.comm.addressing import parse_address, parse_host_port, unparse_address
11
14
 
12
- _default_comms = None
15
+ # Mapping tokenize(client ID, [worker addresses]) to CommsContext
16
+ _comms_cache: weakref.WeakValueDictionary[
17
+ str, "CommsContext"
18
+ ] = weakref.WeakValueDictionary()
13
19
 
14
20
 
15
21
  def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
@@ -38,9 +44,10 @@ def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
38
44
 
39
45
 
40
46
  def default_comms(client: Optional[Client] = None) -> "CommsContext":
41
- """Return the default comms object
47
+ """Return the default comms object for ``client``.
42
48
 
43
- Creates a new default comms object if no one exist.
49
+ Creates a new default comms object if one does not already exist
50
+ for ``client``.
44
51
 
45
52
  Parameters
46
53
  ----------
@@ -52,11 +59,31 @@ def default_comms(client: Optional[Client] = None) -> "CommsContext":
52
59
  -------
53
60
  comms: CommsContext
54
61
  The default comms object
62
+
63
+ Notes
64
+ -----
65
+ There are some subtle points around explicit-comms and the lifecycle
66
+ of a Dask Cluster.
67
+
68
+ A :class:`CommsContext` establishes explicit communication channels
69
+ between the workers *at the time it's created*. If workers are added
70
+ or removed, they will not be included in the communication channels
71
+ with the other workers.
72
+
73
+ If you need to refresh the explicit communications channels, then
74
+ create a new :class:`CommsContext` object or call ``default_comms``
75
+ again after workers have been added to or removed from the cluster.
55
76
  """
56
- global _default_comms
57
- if _default_comms is None:
58
- _default_comms = CommsContext(client=client)
59
- return _default_comms
77
+ # Comms are unique to a {client, [workers]} pair, so we key our
78
+ # cache by the token of that.
79
+ client = client or default_client()
80
+ token = tokenize(client.id, list(client.scheduler_info()["workers"].keys()))
81
+ maybe_comms = _comms_cache.get(token)
82
+ if maybe_comms is None:
83
+ maybe_comms = CommsContext(client=client)
84
+ _comms_cache[token] = maybe_comms
85
+
86
+ return maybe_comms
60
87
 
61
88
 
62
89
  def worker_state(sessionId: Optional[int] = None) -> dict:
@@ -1,6 +1,9 @@
1
+ # Copyright (c) 2021-2025 NVIDIA CORPORATION.
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import asyncio
6
+ import functools
4
7
  from collections import defaultdict
5
8
  from math import ceil
6
9
  from operator import getitem
@@ -23,6 +26,7 @@ from distributed import wait
23
26
  from distributed.protocol import nested_deserialize, to_serialize
24
27
  from distributed.worker import Worker
25
28
 
29
+ from ..._compat import DASK_2025_4_0
26
30
  from .. import comms
27
31
 
28
32
  T = TypeVar("T")
@@ -582,6 +586,128 @@ def _use_explicit_comms() -> bool:
582
586
  return False
583
587
 
584
588
 
589
+ _base_lower = dask_expr._shuffle.Shuffle._lower
590
+ _base_compute = dask.base.compute
591
+
592
+
593
+ def _contains_shuffle_expr(*args) -> bool:
594
+ """
595
+ Check whether any of the arguments is a Shuffle expression.
596
+
597
+ This is called by `compute`, which is given a sequence of Dask Collections
598
+ to process. For each of those, we'll check whether the expresion contains a
599
+ Shuffle operation.
600
+ """
601
+ for collection in args:
602
+ if isinstance(collection, dask.dataframe.DataFrame):
603
+ shuffle_ops = list(
604
+ collection.expr.find_operations(
605
+ (
606
+ dask_expr._shuffle.RearrangeByColumn,
607
+ dask_expr.SetIndex,
608
+ dask_expr._shuffle.Shuffle,
609
+ )
610
+ )
611
+ )
612
+ if len(shuffle_ops) > 0:
613
+ return True
614
+ return False
615
+
616
+
617
+ @functools.wraps(_base_compute)
618
+ def _patched_compute(
619
+ *args,
620
+ traverse=True,
621
+ optimize_graph=True,
622
+ scheduler=None,
623
+ get=None,
624
+ **kwargs,
625
+ ):
626
+ # A patched version of dask.compute that explicitly materializes the task
627
+ # graph when we're using explicit-comms and the expression contains a
628
+ # Shuffle operation.
629
+ # https://github.com/rapidsai/dask-upstream-testing/issues/37#issuecomment-2779798670
630
+ # contains more details on the issue.
631
+ if DASK_2025_4_0() and _use_explicit_comms() and _contains_shuffle_expr(*args):
632
+ from dask.base import (
633
+ collections_to_expr,
634
+ flatten,
635
+ get_scheduler,
636
+ shorten_traceback,
637
+ unpack_collections,
638
+ )
639
+
640
+ collections, repack = unpack_collections(*args, traverse=traverse)
641
+ if not collections:
642
+ return args
643
+
644
+ schedule = get_scheduler(
645
+ scheduler=scheduler,
646
+ collections=collections,
647
+ get=get,
648
+ )
649
+ from dask._expr import FinalizeCompute
650
+
651
+ expr = collections_to_expr(collections, optimize_graph)
652
+ expr = FinalizeCompute(expr)
653
+
654
+ with shorten_traceback():
655
+ expr = expr.optimize()
656
+ keys = list(flatten(expr.__dask_keys__()))
657
+
658
+ # materialize the HLG here
659
+ expr = dict(expr.__dask_graph__())
660
+
661
+ results = schedule(expr, keys, **kwargs)
662
+ return repack(results)
663
+
664
+ else:
665
+ return _base_compute(
666
+ *args,
667
+ traverse=traverse,
668
+ optimize_graph=optimize_graph,
669
+ scheduler=scheduler,
670
+ get=get,
671
+ **kwargs,
672
+ )
673
+
674
+
675
+ class ECShuffle(dask_expr._shuffle.TaskShuffle):
676
+ """Explicit-Comms Shuffle Expression."""
677
+
678
+ def _layer(self):
679
+ # Execute an explicit-comms shuffle
680
+ if not hasattr(self, "_ec_shuffled"):
681
+ on = self.partitioning_index
682
+ df = dask_expr.new_collection(self.frame)
683
+ ec_shuffled = shuffle(
684
+ df,
685
+ [on] if isinstance(on, str) else on,
686
+ self.npartitions_out,
687
+ self.ignore_index,
688
+ )
689
+ object.__setattr__(self, "_ec_shuffled", ec_shuffled)
690
+ graph = self._ec_shuffled.dask.copy()
691
+ shuffled_name = self._ec_shuffled._name
692
+ for i in range(self.npartitions_out):
693
+ graph[(self._name, i)] = graph[(shuffled_name, i)]
694
+ return graph
695
+
696
+
697
+ def _patched_lower(self):
698
+ if self.method in (None, "tasks") and _use_explicit_comms():
699
+ return ECShuffle(
700
+ self.frame,
701
+ self.partitioning_index,
702
+ self.npartitions_out,
703
+ self.ignore_index,
704
+ self.options,
705
+ self.original_partitioning_index,
706
+ )
707
+ else:
708
+ return _base_lower(self)
709
+
710
+
585
711
  def patch_shuffle_expression() -> None:
586
712
  """Patch Dasks Shuffle expression.
587
713
 
@@ -590,40 +716,6 @@ def patch_shuffle_expression() -> None:
590
716
  an `ECShuffle` expression when the 'explicit-comms'
591
717
  config is set to `True`.
592
718
  """
593
-
594
- class ECShuffle(dask_expr._shuffle.TaskShuffle):
595
- """Explicit-Comms Shuffle Expression."""
596
-
597
- def _layer(self):
598
- # Execute an explicit-comms shuffle
599
- if not hasattr(self, "_ec_shuffled"):
600
- on = self.partitioning_index
601
- df = dask_expr.new_collection(self.frame)
602
- self._ec_shuffled = shuffle(
603
- df,
604
- [on] if isinstance(on, str) else on,
605
- self.npartitions_out,
606
- self.ignore_index,
607
- )
608
- graph = self._ec_shuffled.dask.copy()
609
- shuffled_name = self._ec_shuffled._name
610
- for i in range(self.npartitions_out):
611
- graph[(self._name, i)] = graph[(shuffled_name, i)]
612
- return graph
613
-
614
- _base_lower = dask_expr._shuffle.Shuffle._lower
615
-
616
- def _patched_lower(self):
617
- if self.method in (None, "tasks") and _use_explicit_comms():
618
- return ECShuffle(
619
- self.frame,
620
- self.partitioning_index,
621
- self.npartitions_out,
622
- self.ignore_index,
623
- self.options,
624
- self.original_partitioning_index,
625
- )
626
- else:
627
- return _base_lower(self)
719
+ dask.base.compute = _patched_compute
628
720
 
629
721
  dask_expr._shuffle.Shuffle._lower = _patched_lower
@@ -1,3 +1,5 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
+
1
3
  from typing import Set
2
4
 
3
5
  from dask.sizeof import sizeof
@@ -140,3 +142,16 @@ def register_cupy(): # NB: this overwrites dask.sizeof.register_cupy()
140
142
  @sizeof.register(cupy.ndarray)
141
143
  def sizeof_cupy_ndarray(x):
142
144
  return int(x.nbytes)
145
+
146
+
147
+ @sizeof.register_lazy("pylibcudf")
148
+ def register_pylibcudf():
149
+ import pylibcudf
150
+
151
+ @sizeof.register(pylibcudf.column.OwnerWithCAI)
152
+ def sizeof_owner_with_cai(x):
153
+ # OwnerWithCAI implements __cuda_array_interface__ so this should always
154
+ # be zero-copy
155
+ col = pylibcudf.column.Column.from_cuda_array_interface(x)
156
+ # col.data() returns a gpumemoryview, which knows the size in bytes
157
+ return col.data().nbytes
@@ -1,3 +1,4 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
1
2
  from __future__ import absolute_import, division, print_function
2
3
 
3
4
  from dask.utils import Dispatch
@@ -35,6 +36,8 @@ def register_cudf():
35
36
  def is_device_object_cudf_series(s):
36
37
  return True
37
38
 
38
- @is_device_object.register(cudf.BaseIndex)
39
+ @is_device_object.register(cudf.Index)
40
+ @is_device_object.register(cudf.RangeIndex)
41
+ @is_device_object.register(cudf.MultiIndex)
39
42
  def is_device_object_cudf_index(s):
40
43
  return True
@@ -1,3 +1,4 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
1
2
  from __future__ import absolute_import, division, print_function
2
3
 
3
4
  from typing import Optional
@@ -34,7 +35,9 @@ def register_cudf():
34
35
  def is_device_object_cudf_dataframe(df):
35
36
  return cudf_spilling_status()
36
37
 
37
- @is_spillable_object.register(cudf.BaseIndex)
38
+ @is_spillable_object.register(cudf.Index)
39
+ @is_spillable_object.register(cudf.RangeIndex)
40
+ @is_spillable_object.register(cudf.MultiIndex)
38
41
  def is_device_object_cudf_index(s):
39
42
  return cudf_spilling_status()
40
43
 
@@ -1,3 +1,4 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
1
2
  import functools
2
3
  import pydoc
3
4
  from collections import defaultdict
@@ -242,7 +243,9 @@ def _register_cudf():
242
243
 
243
244
  @dispatch.register(cudf.DataFrame)
244
245
  @dispatch.register(cudf.Series)
245
- @dispatch.register(cudf.BaseIndex)
246
+ @dispatch.register(cudf.Index)
247
+ @dispatch.register(cudf.MultiIndex)
248
+ @dispatch.register(cudf.RangeIndex)
246
249
  def proxify_device_object_cudf_dataframe(
247
250
  obj, proxied_id_to_proxy, found_proxies, excl_proxies
248
251
  ):
dask_cuda/proxy_object.py CHANGED
@@ -11,10 +11,6 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple, Type, Un
11
11
  import pandas
12
12
 
13
13
  import dask
14
- import dask.array.core
15
- import dask.dataframe.backends
16
- import dask.dataframe.dispatch
17
- import dask.dataframe.utils
18
14
  import dask.utils
19
15
  import distributed.protocol
20
16
  import distributed.utils
@@ -30,6 +26,22 @@ if TYPE_CHECKING:
30
26
  from .proxify_host_file import ProxyManager
31
27
 
32
28
 
29
+ try:
30
+ import dask.dataframe as dask_dataframe
31
+ import dask.dataframe.backends
32
+ import dask.dataframe.dispatch
33
+ import dask.dataframe.utils
34
+ except ImportError:
35
+ dask_dataframe = None
36
+
37
+
38
+ try:
39
+ import dask.array as dask_array
40
+ import dask.array.core
41
+ except ImportError:
42
+ dask_array = None
43
+
44
+
33
45
  # List of attributes that should be copied to the proxy at creation, which makes
34
46
  # them accessible without deserialization of the proxied object
35
47
  _FIXED_ATTRS = ["name", "__len__"]
@@ -884,14 +896,6 @@ def obj_pxy_dask_deserialize(header, frames):
884
896
  return subclass(pxy)
885
897
 
886
898
 
887
- @dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
888
- def get_parallel_type_proxy_object(obj: ProxyObject):
889
- # Notice, `get_parallel_type()` needs a instance not a type object
890
- return dask.dataframe.dispatch.get_parallel_type(
891
- obj.__class__.__new__(obj.__class__)
892
- )
893
-
894
-
895
899
  def unproxify_input_wrapper(func):
896
900
  """Unproxify the input of `func`"""
897
901
 
@@ -904,26 +908,42 @@ def unproxify_input_wrapper(func):
904
908
  return wrapper
905
909
 
906
910
 
907
- # Register dispatch of ProxyObject on all known dispatch objects
908
- for dispatch in (
909
- dask.dataframe.dispatch.hash_object_dispatch,
910
- dask.dataframe.dispatch.make_meta_dispatch,
911
- dask.dataframe.utils.make_scalar,
912
- dask.dataframe.dispatch.group_split_dispatch,
913
- dask.array.core.tensordot_lookup,
914
- dask.array.core.einsum_lookup,
915
- dask.array.core.concatenate_lookup,
916
- ):
917
- dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))
918
-
919
- dask.dataframe.dispatch.concat_dispatch.register(
920
- ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat)
921
- )
922
-
923
-
924
- # We overwrite the Dask dispatch of Pandas objects in order to
925
- # deserialize all ProxyObjects before concatenating
926
- dask.dataframe.dispatch.concat_dispatch.register(
927
- (pandas.DataFrame, pandas.Series, pandas.Index),
928
- unproxify_input_wrapper(dask.dataframe.backends.concat_pandas),
929
- )
911
+ if dask_array is not None:
912
+
913
+ # Register dispatch of ProxyObject on all known dispatch objects
914
+ for dispatch in (
915
+ dask.array.core.tensordot_lookup,
916
+ dask.array.core.einsum_lookup,
917
+ dask.array.core.concatenate_lookup,
918
+ ):
919
+ dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))
920
+
921
+
922
+ if dask_dataframe is not None:
923
+
924
+ @dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
925
+ def get_parallel_type_proxy_object(obj: ProxyObject):
926
+ # Notice, `get_parallel_type()` needs a instance not a type object
927
+ return dask.dataframe.dispatch.get_parallel_type(
928
+ obj.__class__.__new__(obj.__class__)
929
+ )
930
+
931
+ # Register dispatch of ProxyObject on all known dispatch objects
932
+ for dispatch in (
933
+ dask.dataframe.dispatch.hash_object_dispatch,
934
+ dask.dataframe.dispatch.make_meta_dispatch,
935
+ dask.dataframe.utils.make_scalar,
936
+ dask.dataframe.dispatch.group_split_dispatch,
937
+ ):
938
+ dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))
939
+
940
+ dask.dataframe.dispatch.concat_dispatch.register(
941
+ ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat)
942
+ )
943
+
944
+ # We overwrite the Dask dispatch of Pandas objects in order to
945
+ # deserialize all ProxyObjects before concatenating
946
+ dask.dataframe.dispatch.concat_dispatch.register(
947
+ (pandas.DataFrame, pandas.Series, pandas.Index),
948
+ unproxify_input_wrapper(dask.dataframe.backends.concat_pandas),
949
+ )
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  from __future__ import absolute_import, division, print_function
2
5
 
3
6
  import os
@@ -16,7 +19,7 @@ from dask_cuda.utils import (
16
19
  get_cluster_configuration,
17
20
  get_device_total_memory,
18
21
  get_gpu_count_mig,
19
- get_gpu_uuid_from_index,
22
+ get_gpu_uuid,
20
23
  get_n_gpus,
21
24
  wait_workers,
22
25
  )
@@ -409,7 +412,7 @@ def test_cuda_mig_visible_devices_and_memory_limit_and_nthreads(loop): # noqa:
409
412
 
410
413
 
411
414
  def test_cuda_visible_devices_uuid(loop): # noqa: F811
412
- gpu_uuid = get_gpu_uuid_from_index(0)
415
+ gpu_uuid = get_gpu_uuid(0)
413
416
 
414
417
  with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": gpu_uuid}):
415
418
  with popen(["dask", "scheduler", "--port", "9359", "--no-dashboard"]):