dask-cuda 25.6.0__py3-none-any.whl → 25.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/benchmarks/common.py +4 -1
  4. dask_cuda/benchmarks/local_cudf_groupby.py +4 -1
  5. dask_cuda/benchmarks/local_cudf_merge.py +5 -2
  6. dask_cuda/benchmarks/local_cudf_shuffle.py +5 -2
  7. dask_cuda/benchmarks/local_cupy.py +4 -1
  8. dask_cuda/benchmarks/local_cupy_map_overlap.py +4 -1
  9. dask_cuda/benchmarks/utils.py +7 -4
  10. dask_cuda/cli.py +21 -15
  11. dask_cuda/cuda_worker.py +27 -57
  12. dask_cuda/device_host_file.py +31 -15
  13. dask_cuda/disk_io.py +7 -4
  14. dask_cuda/explicit_comms/comms.py +11 -7
  15. dask_cuda/explicit_comms/dataframe/shuffle.py +23 -23
  16. dask_cuda/get_device_memory_objects.py +3 -3
  17. dask_cuda/initialize.py +80 -44
  18. dask_cuda/local_cuda_cluster.py +63 -66
  19. dask_cuda/plugins.py +17 -16
  20. dask_cuda/proxify_device_objects.py +12 -10
  21. dask_cuda/proxify_host_file.py +30 -27
  22. dask_cuda/proxy_object.py +20 -17
  23. dask_cuda/tests/conftest.py +41 -0
  24. dask_cuda/tests/test_dask_cuda_worker.py +109 -25
  25. dask_cuda/tests/test_dgx.py +10 -18
  26. dask_cuda/tests/test_explicit_comms.py +30 -12
  27. dask_cuda/tests/test_from_array.py +7 -5
  28. dask_cuda/tests/test_initialize.py +16 -37
  29. dask_cuda/tests/test_local_cuda_cluster.py +159 -52
  30. dask_cuda/tests/test_proxify_host_file.py +19 -3
  31. dask_cuda/tests/test_proxy.py +18 -16
  32. dask_cuda/tests/test_rdd_ucx.py +160 -0
  33. dask_cuda/tests/test_spill.py +7 -0
  34. dask_cuda/tests/test_utils.py +106 -20
  35. dask_cuda/tests/test_worker_spec.py +5 -2
  36. dask_cuda/utils.py +261 -38
  37. dask_cuda/utils_test.py +23 -7
  38. dask_cuda/worker_common.py +196 -0
  39. dask_cuda/worker_spec.py +12 -5
  40. {dask_cuda-25.6.0.dist-info → dask_cuda-25.8.0.dist-info}/METADATA +2 -2
  41. dask_cuda-25.8.0.dist-info/RECORD +63 -0
  42. dask_cuda-25.8.0.dist-info/top_level.txt +6 -0
  43. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  44. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  45. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  46. dask_cuda-25.6.0.dist-info/RECORD +0 -57
  47. dask_cuda-25.6.0.dist-info/top_level.txt +0 -4
  48. {dask_cuda-25.6.0.dist-info → dask_cuda-25.8.0.dist-info}/WHEEL +0 -0
  49. {dask_cuda-25.6.0.dist-info → dask_cuda-25.8.0.dist-info}/entry_points.txt +0 -0
  50. {dask_cuda-25.6.0.dist-info → dask_cuda-25.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -65,13 +65,13 @@ def get_no_comm_postprocess(
65
65
  ) -> Callable[[DataFrame], DataFrame]:
66
66
  """Get function for post-processing partitions not communicated
67
67
 
68
- In cuDF, the `group_split_dispatch` uses `scatter_by_map` to create
68
+ In cuDF, the ``group_split_dispatch`` uses ``scatter_by_map`` to create
69
69
  the partitions, which is implemented by splitting a single base dataframe
70
70
  into multiple partitions. This means that memory are not freed until
71
71
  ALL partitions are deleted.
72
72
 
73
73
  In order to free memory ASAP, we can deep copy partitions NOT being
74
- communicated. We do this when `num_rounds != batchsize`.
74
+ communicated. We do this when ``num_rounds != batchsize``.
75
75
 
76
76
  Parameters
77
77
  ----------
@@ -116,7 +116,7 @@ async def send(
116
116
  rank_to_out_part_ids: Dict[int, Set[int]],
117
117
  out_part_id_to_dataframe: Dict[int, DataFrame],
118
118
  ) -> None:
119
- """Notice, items sent are removed from `out_part_id_to_dataframe`"""
119
+ """Notice, items sent are removed from ``out_part_id_to_dataframe``"""
120
120
  futures = []
121
121
  for rank, out_part_ids in rank_to_out_part_ids.items():
122
122
  if rank != myrank:
@@ -135,7 +135,7 @@ async def recv(
135
135
  out_part_id_to_dataframe_list: Dict[int, List[DataFrame]],
136
136
  proxify: Proxify,
137
137
  ) -> None:
138
- """Notice, received items are appended to `out_parts_list`"""
138
+ """Notice, received items are appended to ``out_parts_list``"""
139
139
 
140
140
  async def read_msg(rank: int) -> None:
141
141
  msg: Dict[int, DataFrame] = nested_deserialize(await eps[rank].read())
@@ -150,11 +150,11 @@ async def recv(
150
150
  def compute_map_index(
151
151
  df: DataFrame, column_names: List[str], npartitions: int
152
152
  ) -> Series:
153
- """Return a Series that maps each row `df` to a partition ID
153
+ """Return a Series that maps each row ``df`` to a partition ID
154
154
 
155
155
  The partitions are determined by hashing the columns given by column_names
156
- unless if `column_names[0] == "_partitions"`, in which case the values of
157
- `column_names[0]` are used as index.
156
+ unless if ``column_names[0] == "_partitions"``, in which case the values of
157
+ ``column_names[0]`` are used as index.
158
158
 
159
159
  Parameters
160
160
  ----------
@@ -168,7 +168,7 @@ def compute_map_index(
168
168
  Returns
169
169
  -------
170
170
  Series
171
- Series that maps each row `df` to a partition ID
171
+ Series that maps each row ``df`` to a partition ID
172
172
  """
173
173
 
174
174
  if column_names[0] == "_partitions":
@@ -193,8 +193,8 @@ def partition_dataframe(
193
193
  """Partition dataframe to a dict of dataframes
194
194
 
195
195
  The partitions are determined by hashing the columns given by column_names
196
- unless `column_names[0] == "_partitions"`, in which case the values of
197
- `column_names[0]` are used as index.
196
+ unless ``column_names[0] == "_partitions"``, in which case the values of
197
+ ``column_names[0]`` are used as index.
198
198
 
199
199
  Parameters
200
200
  ----------
@@ -301,13 +301,13 @@ async def send_recv_partitions(
301
301
  rank_to_out_part_ids
302
302
  dict that for each worker rank specifies a set of output partition IDs.
303
303
  If the worker shouldn't return any partitions, it is excluded from the
304
- dict. Partition IDs are global integers `0..npartitions` and corresponds
305
- to the dict keys returned by `group_split_dispatch`.
304
+ dict. Partition IDs are global integers ``0..npartitions`` and corresponds
305
+ to the dict keys returned by ``group_split_dispatch``.
306
306
  out_part_id_to_dataframe
307
307
  Mapping from partition ID to dataframe. This dict is cleared on return.
308
308
  no_comm_postprocess
309
309
  Function to post-process partitions not communicated.
310
- See `get_no_comm_postprocess`
310
+ See ``get_no_comm_postprocess``
311
311
  proxify
312
312
  Function to proxify object.
313
313
  out_part_id_to_dataframe_list
@@ -365,8 +365,8 @@ async def shuffle_task(
365
365
  rank_to_out_part_ids: dict
366
366
  dict that for each worker rank specifies a set of output partition IDs.
367
367
  If the worker shouldn't return any partitions, it is excluded from the
368
- dict. Partition IDs are global integers `0..npartitions` and corresponds
369
- to the dict keys returned by `group_split_dispatch`.
368
+ dict. Partition IDs are global integers ``0..npartitions`` and corresponds
369
+ to the dict keys returned by ``group_split_dispatch``.
370
370
  column_names: list of strings
371
371
  List of column names on which we want to split.
372
372
  npartitions: int
@@ -449,7 +449,7 @@ def shuffle(
449
449
  List of column names on which we want to split.
450
450
  npartitions: int or None
451
451
  The desired number of output partitions. If None, the number of output
452
- partitions equals `df.npartitions`
452
+ partitions equals ``df.npartitions``
453
453
  ignore_index: bool
454
454
  Ignore index during shuffle. If True, performance may improve,
455
455
  but index values will not be preserved.
@@ -460,7 +460,7 @@ def shuffle(
460
460
  If -1, each worker will handle all its partitions in a single round and
461
461
  all techniques to reduce memory usage are disabled, which might be faster
462
462
  when memory pressure isn't an issue.
463
- If None, the value of `DASK_EXPLICIT_COMMS_BATCHSIZE` is used or 1 if not
463
+ If None, the value of ``DASK_EXPLICIT_COMMS_BATCHSIZE`` is used or 1 if not
464
464
  set thus by default, we prioritize robustness over performance.
465
465
 
466
466
  Returns
@@ -471,12 +471,12 @@ def shuffle(
471
471
  Developer Notes
472
472
  ---------------
473
473
  The implementation consist of three steps:
474
- (a) Stage the partitions of `df` on all workers and then cancel them
474
+ (a) Stage the partitions of ``df`` on all workers and then cancel them
475
475
  thus at this point the Dask Scheduler doesn't know about any of the
476
476
  the partitions.
477
477
  (b) Submit a task on each worker that shuffle (all-to-all communicate)
478
478
  the staged partitions and return a list of dataframe-partitions.
479
- (c) Submit a dask graph that extract (using `getitem()`) individual
479
+ (c) Submit a dask graph that extract (using ``getitem()``) individual
480
480
  dataframe-partitions from (b).
481
481
  """
482
482
  c = comms.default_comms()
@@ -594,7 +594,7 @@ def _contains_shuffle_expr(*args) -> bool:
594
594
  """
595
595
  Check whether any of the arguments is a Shuffle expression.
596
596
 
597
- This is called by `compute`, which is given a sequence of Dask Collections
597
+ This is called by ``compute``, which is given a sequence of Dask Collections
598
598
  to process. For each of those, we'll check whether the expresion contains a
599
599
  Shuffle operation.
600
600
  """
@@ -712,9 +712,9 @@ def patch_shuffle_expression() -> None:
712
712
  """Patch Dasks Shuffle expression.
713
713
 
714
714
  Notice, this is monkey patched into Dask at dask_cuda
715
- import, and it changes `Shuffle._layer` to lower into
716
- an `ECShuffle` expression when the 'explicit-comms'
717
- config is set to `True`.
715
+ import, and it changes ``Shuffle._layer`` to lower into
716
+ an ``ECShuffle`` expression when the 'explicit-comms'
717
+ config is set to ``True``.
718
718
  """
719
719
  dask.base.compute = _patched_compute
720
720
 
@@ -27,10 +27,10 @@ class DeviceMemoryId:
27
27
 
28
28
 
29
29
  def get_device_memory_ids(obj) -> Set[DeviceMemoryId]:
30
- """Find all CUDA device objects in `obj`
30
+ """Find all CUDA device objects in ``obj``
31
31
 
32
- Search through `obj` and find all CUDA device objects, which are objects
33
- that either are known to `dispatch` or implement `__cuda_array_interface__`.
32
+ Search through ``obj`` and find all CUDA device objects, which are objects
33
+ that either are known to ``dispatch`` or implement ``__cuda_array_interface__``.
34
34
 
35
35
  Parameters
36
36
  ----------
dask_cuda/initialize.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import logging
2
5
  import os
3
6
 
@@ -7,7 +10,7 @@ import numba.cuda
7
10
  import dask
8
11
  from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context
9
12
 
10
- from .utils import get_ucx_config
13
+ from .utils import _get_active_ucx_implementation_name, get_ucx_config
11
14
 
12
15
  logger = logging.getLogger(__name__)
13
16
 
@@ -22,65 +25,97 @@ def _create_cuda_context_handler():
22
25
  numba.cuda.current_context()
23
26
 
24
27
 
25
- def _create_cuda_context(protocol="ucx"):
26
- if protocol not in ["ucx", "ucxx"]:
27
- return
28
+ def _warn_generic():
28
29
  try:
30
+ # TODO: update when UCX-Py is removed, see
31
+ # https://github.com/rapidsai/dask-cuda/issues/1517
32
+ import distributed.comm.ucx
33
+
29
34
  # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
30
35
  # context directly from the UCX module, thus avoiding a similar warning there.
31
- try:
32
- if protocol == "ucx":
33
- import distributed.comm.ucx
34
-
35
- distributed.comm.ucx.init_once()
36
- elif protocol == "ucxx":
37
- import distributed_ucxx.ucxx
38
-
39
- distributed_ucxx.ucxx.init_once()
40
- except ModuleNotFoundError:
41
- # UCX initialization has to be delegated to Distributed, it will take care
42
- # of setting correct environment variables and importing `ucp` after that.
43
- # Therefore if ``import ucp`` fails we can just continue here.
44
- pass
36
+ cuda_visible_device = get_device_index_and_uuid(
37
+ os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
38
+ )
39
+ ctx = has_cuda_context()
40
+ if (
41
+ ctx.has_context
42
+ and not distributed.comm.ucx.cuda_context_created.has_context
43
+ ):
44
+ distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
45
+
46
+ _create_cuda_context_handler()
47
+
48
+ if not distributed.comm.ucx.cuda_context_created.has_context:
49
+ ctx = has_cuda_context()
50
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
51
+ distributed.comm.ucx._warn_cuda_context_wrong_device(
52
+ cuda_visible_device, ctx.device_info, os.getpid()
53
+ )
54
+
55
+ except Exception:
56
+ logger.error("Unable to start CUDA Context", exc_info=True)
57
+
58
+
59
+ def _initialize_ucx():
60
+ try:
61
+ import distributed.comm.ucx
62
+
63
+ distributed.comm.ucx.init_once()
64
+ except ModuleNotFoundError:
65
+ # UCX initialization has to be delegated to Distributed, it will take care
66
+ # of setting correct environment variables and importing `ucp` after that.
67
+ # Therefore if ``import ucp`` fails we can just continue here.
68
+ pass
69
+
70
+
71
+ def _initialize_ucxx():
72
+ try:
73
+ # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
74
+ # context directly from the UCX module, thus avoiding a similar warning there.
75
+ import distributed_ucxx.ucxx
76
+
77
+ distributed_ucxx.ucxx.init_once()
45
78
 
46
79
  cuda_visible_device = get_device_index_and_uuid(
47
80
  os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
48
81
  )
49
82
  ctx = has_cuda_context()
50
- if protocol == "ucx":
51
- if (
52
- ctx.has_context
53
- and not distributed.comm.ucx.cuda_context_created.has_context
54
- ):
55
- distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
56
- elif protocol == "ucxx":
57
- if (
58
- ctx.has_context
59
- and not distributed_ucxx.ucxx.cuda_context_created.has_context
60
- ):
61
- distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
83
+ if (
84
+ ctx.has_context
85
+ and not distributed_ucxx.ucxx.cuda_context_created.has_context
86
+ ):
87
+ distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
62
88
 
63
89
  _create_cuda_context_handler()
64
90
 
65
- if protocol == "ucx":
66
- if not distributed.comm.ucx.cuda_context_created.has_context:
67
- ctx = has_cuda_context()
68
- if ctx.has_context and ctx.device_info != cuda_visible_device:
69
- distributed.comm.ucx._warn_cuda_context_wrong_device(
70
- cuda_visible_device, ctx.device_info, os.getpid()
71
- )
72
- elif protocol == "ucxx":
73
- if not distributed_ucxx.ucxx.cuda_context_created.has_context:
74
- ctx = has_cuda_context()
75
- if ctx.has_context and ctx.device_info != cuda_visible_device:
76
- distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
77
- cuda_visible_device, ctx.device_info, os.getpid()
78
- )
91
+ if not distributed_ucxx.ucxx.cuda_context_created.has_context:
92
+ ctx = has_cuda_context()
93
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
94
+ distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
95
+ cuda_visible_device, ctx.device_info, os.getpid()
96
+ )
79
97
 
80
98
  except Exception:
81
99
  logger.error("Unable to start CUDA Context", exc_info=True)
82
100
 
83
101
 
102
+ def _create_cuda_context(protocol="ucx"):
103
+ if protocol not in ["ucx", "ucxx", "ucx-old"]:
104
+ return
105
+
106
+ try:
107
+ ucx_implementation = _get_active_ucx_implementation_name(protocol)
108
+ except ValueError:
109
+ # Not a UCX protocol, just raise CUDA context warnings if needed.
110
+ _warn_generic()
111
+ else:
112
+ if ucx_implementation == "ucxx":
113
+ _initialize_ucxx()
114
+ else:
115
+ _initialize_ucx()
116
+ _warn_generic()
117
+
118
+
84
119
  def initialize(
85
120
  create_cuda_context=True,
86
121
  enable_tcp_over_ucx=None,
@@ -138,6 +173,7 @@ def initialize(
138
173
  enable_infiniband=enable_infiniband,
139
174
  enable_nvlink=enable_nvlink,
140
175
  enable_rdmacm=enable_rdmacm,
176
+ protocol=protocol,
141
177
  )
142
178
  dask.config.set({"distributed.comm.ucx": ucx_config})
143
179
 
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import copy
2
5
  import logging
3
6
  import os
@@ -8,18 +11,15 @@ import dask
8
11
  from distributed import LocalCluster, Nanny, Worker
9
12
  from distributed.worker_memory import parse_memory_limit
10
13
 
11
- from .device_host_file import DeviceHostFile
12
14
  from .initialize import initialize
13
- from .plugins import CPUAffinity, CUDFSetup, PreImport, RMMSetup
14
- from .proxify_host_file import ProxifyHostFile
15
15
  from .utils import (
16
16
  cuda_visible_devices,
17
- get_cpu_affinity,
18
17
  get_ucx_config,
19
18
  nvml_device_index,
20
19
  parse_cuda_visible_device,
21
20
  parse_device_memory_limit,
22
21
  )
22
+ from .worker_common import worker_data_function, worker_plugins
23
23
 
24
24
 
25
25
  class LoggedWorker(Worker):
@@ -68,11 +68,16 @@ class LocalCUDACluster(LocalCluster):
68
68
  starts spilling to disk (not available if JIT-Unspill is enabled). Can be an
69
69
  integer (bytes), float (fraction of total system memory), string (like ``"5GB"``
70
70
  or ``"5000M"``), or ``"auto"``, 0, or ``None`` for no memory management.
71
- device_memory_limit : int, float, str, or None, default 0.8
71
+ device_memory_limit : int, float, str, or None, default "default"
72
72
  Size of the CUDA device LRU cache, which is used to determine when the worker
73
73
  starts spilling to host memory. Can be an integer (bytes), float (fraction of
74
- total device memory), string (like ``"5GB"`` or ``"5000M"``), or ``"auto"``, 0,
74
+ total device memory), string (like ``"5GB"`` or ``"5000M"``), ``"auto"``, ``0``
75
75
  or ``None`` to disable spilling to host (i.e. allow full device memory usage).
76
+ Another special value ``"default"`` (which happens to be the default) is also
77
+ available and uses the recommended Dask-CUDA's defaults and means 80% of the
78
+ total device memory (analogous to ``0.8``), and disabled spilling (analogous
79
+ to ``auto``/``0``) on devices without a dedicated memory resource, such as
80
+ system on a chip (SoC) devices.
76
81
  enable_cudf_spill : bool, default False
77
82
  Enable automatic cuDF spilling.
78
83
 
@@ -87,7 +92,7 @@ class LocalCUDACluster(LocalCluster):
87
92
  ``dask.temporary-directory`` in the local Dask configuration, using the current
88
93
  working directory if this is not set.
89
94
  shared_filesystem: bool or None, default None
90
- Whether the `local_directory` above is shared between all workers or not.
95
+ Whether the ``local_directory`` above is shared between all workers or not.
91
96
  If ``None``, the "jit-unspill-shared-fs" config value are used, which
92
97
  defaults to True. Notice, in all other cases this option defaults to False,
93
98
  but on a local cluster it defaults to True -- we assume all workers use the
@@ -100,13 +105,16 @@ class LocalCUDACluster(LocalCluster):
100
105
  are not supported or disabled.
101
106
  enable_infiniband : bool, default None
102
107
  Set environment variables to enable UCX over InfiniBand, requires
103
- ``protocol="ucx"`` and implies ``enable_tcp_over_ucx=True`` when ``True``.
108
+ ``protocol="ucx"``, ``protocol="ucxx"`` or ``protocol="ucx-old"``, and implies
109
+ ``enable_tcp_over_ucx=True`` when ``True``.
104
110
  enable_nvlink : bool, default None
105
- Set environment variables to enable UCX over NVLink, requires ``protocol="ucx"``
106
- and implies ``enable_tcp_over_ucx=True`` when ``True``.
111
+ Set environment variables to enable UCX over NVLink, requires
112
+ ``protocol="ucx"``, ``protocol="ucxx"`` or ``protocol="ucx-old"``, and implies
113
+ ``enable_tcp_over_ucx=True`` when ``True``.
107
114
  enable_rdmacm : bool, default None
108
115
  Set environment variables to enable UCX RDMA connection manager support,
109
- requires ``protocol="ucx"`` and ``enable_infiniband=True``.
116
+ requires ``protocol="ucx"``, ``protocol="ucxx"`` or ``protocol="ucx-old"``,
117
+ and ``enable_infiniband=True``.
110
118
  rmm_pool_size : int, str or None, default None
111
119
  RMM pool size to initialize each worker with. Can be an integer (bytes), float
112
120
  (fraction of total device memory), string (like ``"5GB"`` or ``"5000M"``), or
@@ -123,8 +131,8 @@ class LocalCUDACluster(LocalCluster):
123
131
  and to set the maximum pool size.
124
132
 
125
133
  .. note::
126
- When paired with `--enable-rmm-async` the maximum size cannot be guaranteed
127
- due to fragmentation.
134
+ When paired with ``--enable-rmm-async`` the maximum size cannot be
135
+ guaranteed due to fragmentation.
128
136
 
129
137
  .. note::
130
138
  This size is a per-worker configuration, and not cluster-wide.
@@ -140,9 +148,8 @@ class LocalCUDACluster(LocalCluster):
140
148
  See ``rmm.mr.CudaAsyncMemoryResource`` for more info.
141
149
 
142
150
  .. warning::
143
- The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also
144
- incompatible with RMM pools and managed memory. Trying to enable both will
145
- result in an exception.
151
+ The asynchronous allocator is incompatible with RMM pools and managed
152
+ memory. Trying to enable both will result in an exception.
146
153
  rmm_allocator_external_lib_list: str, list or None, default None
147
154
  List of external libraries for which to set RMM as the allocator.
148
155
  Supported options are: ``["torch", "cupy"]``. Can be a comma-separated string
@@ -201,7 +208,8 @@ class LocalCUDACluster(LocalCluster):
201
208
  Raises
202
209
  ------
203
210
  TypeError
204
- If InfiniBand or NVLink are enabled and ``protocol!="ucx"``.
211
+ If InfiniBand or NVLink are enabled and
212
+ ``protocol not in ("ucx", "ucxx", "ucx-old")``.
205
213
  ValueError
206
214
  If RMM pool, RMM managed memory or RMM async allocator are requested but RMM
207
215
  cannot be imported.
@@ -221,10 +229,9 @@ class LocalCUDACluster(LocalCluster):
221
229
  n_workers=None,
222
230
  threads_per_worker=1,
223
231
  memory_limit="auto",
224
- device_memory_limit=0.8,
232
+ device_memory_limit="default",
225
233
  enable_cudf_spill=False,
226
234
  cudf_spill_stats=0,
227
- data=None,
228
235
  local_directory=None,
229
236
  shared_filesystem=None,
230
237
  protocol=None,
@@ -242,7 +249,6 @@ class LocalCUDACluster(LocalCluster):
242
249
  rmm_track_allocations=False,
243
250
  jit_unspill=None,
244
251
  log_spilling=False,
245
- worker_class=None,
246
252
  pre_import=None,
247
253
  **kwargs,
248
254
  ):
@@ -339,40 +345,29 @@ class LocalCUDACluster(LocalCluster):
339
345
  jit_unspill = dask.config.get("jit-unspill", default=False)
340
346
  data = kwargs.pop("data", None)
341
347
  if data is None:
342
- if device_memory_limit is None and memory_limit is None:
343
- data = {}
344
- elif jit_unspill:
345
- if enable_cudf_spill:
346
- warnings.warn(
347
- "Enabling cuDF spilling and JIT-Unspill together is not "
348
- "safe, consider disabling JIT-Unspill."
349
- )
350
-
351
- data = (
352
- ProxifyHostFile,
353
- {
354
- "device_memory_limit": self.device_memory_limit,
355
- "memory_limit": self.memory_limit,
356
- "shared_filesystem": shared_filesystem,
357
- },
358
- )
359
- else:
360
- data = (
361
- DeviceHostFile,
362
- {
363
- "device_memory_limit": self.device_memory_limit,
364
- "memory_limit": self.memory_limit,
365
- "log_spilling": log_spilling,
366
- },
367
- )
348
+ self.data = worker_data_function(
349
+ device_memory_limit=self.device_memory_limit,
350
+ memory_limit=self.memory_limit,
351
+ jit_unspill=jit_unspill,
352
+ enable_cudf_spill=enable_cudf_spill,
353
+ shared_filesystem=shared_filesystem,
354
+ )
368
355
 
369
356
  if enable_tcp_over_ucx or enable_infiniband or enable_nvlink:
370
357
  if protocol is None:
371
- protocol = "ucx"
372
- elif protocol not in ["ucx", "ucxx"]:
358
+ ucx_protocol = dask.config.get(
359
+ "distributed.comm.ucx.ucx-protocol", default=None
360
+ )
361
+ if ucx_protocol is not None:
362
+ # TODO: remove when UCX-Py is removed,
363
+ # see https://github.com/rapidsai/dask-cuda/issues/1517
364
+ protocol = ucx_protocol
365
+ else:
366
+ protocol = "ucx"
367
+ elif protocol not in ("ucx", "ucxx", "ucx-old"):
373
368
  raise TypeError(
374
- "Enabling InfiniBand or NVLink requires protocol='ucx' or "
375
- "protocol='ucxx'"
369
+ "Enabling InfiniBand or NVLink requires protocol='ucx', "
370
+ "protocol='ucxx' or protocol='ucx-old'"
376
371
  )
377
372
 
378
373
  self.host = kwargs.get("host", None)
@@ -385,6 +380,7 @@ class LocalCUDACluster(LocalCluster):
385
380
  enable_rdmacm=enable_rdmacm,
386
381
  )
387
382
 
383
+ worker_class = kwargs.pop("worker_class", None)
388
384
  if worker_class is not None:
389
385
  if log_spilling is True:
390
386
  raise ValueError(
@@ -441,28 +437,29 @@ class LocalCUDACluster(LocalCluster):
441
437
  spec = copy.deepcopy(self.new_spec)
442
438
  worker_count = self.cuda_visible_devices.index(name)
443
439
  visible_devices = cuda_visible_devices(worker_count, self.cuda_visible_devices)
440
+ device_index = nvml_device_index(0, visible_devices)
444
441
  spec["options"].update(
445
442
  {
446
443
  "env": {
447
444
  "CUDA_VISIBLE_DEVICES": visible_devices,
448
445
  },
449
- "plugins": {
450
- CPUAffinity(
451
- get_cpu_affinity(nvml_device_index(0, visible_devices))
452
- ),
453
- RMMSetup(
454
- initial_pool_size=self.rmm_pool_size,
455
- maximum_pool_size=self.rmm_maximum_pool_size,
456
- managed_memory=self.rmm_managed_memory,
457
- async_alloc=self.rmm_async,
458
- release_threshold=self.rmm_release_threshold,
459
- log_directory=self.rmm_log_directory,
460
- track_allocations=self.rmm_track_allocations,
461
- external_lib_list=self.rmm_allocator_external_lib_list,
446
+ **({"data": self.data(device_index)} if hasattr(self, "data") else {}),
447
+ "plugins": worker_plugins(
448
+ device_index=device_index,
449
+ rmm_initial_pool_size=self.rmm_pool_size,
450
+ rmm_maximum_pool_size=self.rmm_maximum_pool_size,
451
+ rmm_managed_memory=self.rmm_managed_memory,
452
+ rmm_async_alloc=self.rmm_async,
453
+ rmm_release_threshold=self.rmm_release_threshold,
454
+ rmm_log_directory=self.rmm_log_directory,
455
+ rmm_track_allocations=self.rmm_track_allocations,
456
+ rmm_allocator_external_lib_list=(
457
+ self.rmm_allocator_external_lib_list
462
458
  ),
463
- PreImport(self.pre_import),
464
- CUDFSetup(self.enable_cudf_spill, self.cudf_spill_stats),
465
- },
459
+ pre_import=self.pre_import,
460
+ enable_cudf_spill=self.enable_cudf_spill,
461
+ cudf_spill_stats=self.cudf_spill_stats,
462
+ ),
466
463
  }
467
464
  )
468
465
 
dask_cuda/plugins.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import importlib
2
5
  import logging
3
6
  import os
@@ -5,7 +8,7 @@ from typing import Callable, Dict
5
8
 
6
9
  from distributed import WorkerPlugin
7
10
 
8
- from .utils import get_rmm_log_file_name, parse_device_memory_limit
11
+ from .utils import get_rmm_log_file_name, parse_device_bytes
9
12
 
10
13
 
11
14
  class CPUAffinity(WorkerPlugin):
@@ -75,28 +78,26 @@ class RMMSetup(WorkerPlugin):
75
78
  self.external_lib_list = external_lib_list
76
79
 
77
80
  def setup(self, worker=None):
78
- if self.initial_pool_size is not None:
79
- self.initial_pool_size = parse_device_memory_limit(
80
- self.initial_pool_size, alignment_size=256
81
- )
81
+ self.initial_pool_size = parse_device_bytes(
82
+ self.initial_pool_size, alignment_size=256
83
+ )
82
84
 
83
85
  if self.async_alloc:
84
86
  import rmm
85
87
 
86
- if self.release_threshold is not None:
87
- self.release_threshold = parse_device_memory_limit(
88
- self.release_threshold, alignment_size=256
89
- )
88
+ self.release_threshold = parse_device_bytes(
89
+ self.release_threshold, alignment_size=256
90
+ )
90
91
 
91
92
  mr = rmm.mr.CudaAsyncMemoryResource(
92
93
  initial_pool_size=self.initial_pool_size,
93
94
  release_threshold=self.release_threshold,
94
95
  )
95
96
 
97
+ self.maximum_pool_size = parse_device_bytes(
98
+ self.maximum_pool_size, alignment_size=256
99
+ )
96
100
  if self.maximum_pool_size is not None:
97
- self.maximum_pool_size = parse_device_memory_limit(
98
- self.maximum_pool_size, alignment_size=256
99
- )
100
101
  mr = rmm.mr.LimitingResourceAdaptor(
101
102
  mr, allocation_limit=self.maximum_pool_size
102
103
  )
@@ -114,10 +115,9 @@ class RMMSetup(WorkerPlugin):
114
115
  pool_allocator = False if self.initial_pool_size is None else True
115
116
 
116
117
  if self.initial_pool_size is not None:
117
- if self.maximum_pool_size is not None:
118
- self.maximum_pool_size = parse_device_memory_limit(
119
- self.maximum_pool_size, alignment_size=256
120
- )
118
+ self.maximum_pool_size = parse_device_bytes(
119
+ self.maximum_pool_size, alignment_size=256
120
+ )
121
121
 
122
122
  rmm.reinitialize(
123
123
  pool_allocator=pool_allocator,
@@ -129,6 +129,7 @@ class RMMSetup(WorkerPlugin):
129
129
  worker, self.logging, self.log_directory
130
130
  ),
131
131
  )
132
+
132
133
  if self.rmm_track_allocations:
133
134
  import rmm
134
135