dask-cuda 25.6.0__py3-none-any.whl → 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/benchmarks/common.py +4 -1
  4. dask_cuda/benchmarks/local_cudf_groupby.py +3 -0
  5. dask_cuda/benchmarks/local_cudf_merge.py +4 -1
  6. dask_cuda/benchmarks/local_cudf_shuffle.py +4 -1
  7. dask_cuda/benchmarks/local_cupy.py +3 -0
  8. dask_cuda/benchmarks/local_cupy_map_overlap.py +3 -0
  9. dask_cuda/benchmarks/utils.py +6 -3
  10. dask_cuda/cli.py +21 -15
  11. dask_cuda/cuda_worker.py +28 -58
  12. dask_cuda/device_host_file.py +31 -15
  13. dask_cuda/disk_io.py +7 -4
  14. dask_cuda/explicit_comms/comms.py +11 -7
  15. dask_cuda/explicit_comms/dataframe/shuffle.py +23 -23
  16. dask_cuda/get_device_memory_objects.py +4 -7
  17. dask_cuda/initialize.py +149 -94
  18. dask_cuda/local_cuda_cluster.py +52 -70
  19. dask_cuda/plugins.py +17 -16
  20. dask_cuda/proxify_device_objects.py +12 -10
  21. dask_cuda/proxify_host_file.py +30 -27
  22. dask_cuda/proxy_object.py +20 -17
  23. dask_cuda/tests/conftest.py +41 -0
  24. dask_cuda/tests/test_cudf_builtin_spilling.py +3 -1
  25. dask_cuda/tests/test_dask_cuda_worker.py +109 -25
  26. dask_cuda/tests/test_dask_setup.py +193 -0
  27. dask_cuda/tests/test_dgx.py +20 -44
  28. dask_cuda/tests/test_explicit_comms.py +31 -12
  29. dask_cuda/tests/test_from_array.py +4 -6
  30. dask_cuda/tests/test_initialize.py +233 -65
  31. dask_cuda/tests/test_local_cuda_cluster.py +129 -68
  32. dask_cuda/tests/test_proxify_host_file.py +28 -7
  33. dask_cuda/tests/test_proxy.py +15 -13
  34. dask_cuda/tests/test_spill.py +10 -3
  35. dask_cuda/tests/test_utils.py +100 -29
  36. dask_cuda/tests/test_worker_spec.py +6 -0
  37. dask_cuda/utils.py +211 -42
  38. dask_cuda/utils_test.py +10 -7
  39. dask_cuda/worker_common.py +196 -0
  40. dask_cuda/worker_spec.py +6 -1
  41. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/METADATA +11 -4
  42. dask_cuda-25.10.0.dist-info/RECORD +63 -0
  43. dask_cuda-25.10.0.dist-info/top_level.txt +6 -0
  44. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  45. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  46. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  47. dask_cuda-25.6.0.dist-info/RECORD +0 -57
  48. dask_cuda-25.6.0.dist-info/top_level.txt +0 -4
  49. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/WHEEL +0 -0
  50. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/entry_points.txt +0 -0
  51. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/licenses/LICENSE +0 -0
dask_cuda/GIT_COMMIT CHANGED
@@ -1 +1 @@
1
- 1f834655ecc6286b9e3082f037594f70dcb74062
1
+ 472ca1ce6d1fe836104a5a4f10b284ca9a828ea9
dask_cuda/VERSION CHANGED
@@ -1 +1 @@
1
- 25.06.00
1
+ 25.10.00
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from argparse import Namespace
3
6
  from functools import partial
@@ -124,7 +127,7 @@ def run(client: Client, args: Namespace, config: Config):
124
127
  """
125
128
 
126
129
  wait_for_cluster(client, shutdown_on_failure=True)
127
- assert len(client.scheduler_info()["workers"]) > 0
130
+ assert len(client.scheduler_info(n_workers=-1)["workers"]) > 0
128
131
  setup_memory_pools(
129
132
  client=client,
130
133
  is_gpu=args.type == "gpu",
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from collections import ChainMap
3
6
  from time import perf_counter as clock
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  import math
3
6
  from collections import ChainMap
@@ -166,7 +169,7 @@ def merge(args, ddf1, ddf2):
166
169
 
167
170
  def bench_once(client, args, write_profile=None):
168
171
  # Generate random Dask dataframes
169
- n_workers = len(client.scheduler_info()["workers"])
172
+ n_workers = len(client.scheduler_info(n_workers=-1)["workers"])
170
173
  # Allow the number of chunks to vary between
171
174
  # the "base" and "other" DataFrames
172
175
  args.base_chunks = args.base_chunks or n_workers
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from collections import ChainMap
3
6
  from time import perf_counter
@@ -70,7 +73,7 @@ def create_data(
70
73
  """
71
74
  chunksize = args.partition_size // np.float64().nbytes
72
75
 
73
- workers = list(client.scheduler_info()["workers"].keys())
76
+ workers = list(client.scheduler_info(n_workers=-1)["workers"].keys())
74
77
  assert len(workers) > 0
75
78
 
76
79
  dist = args.partition_distribution
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from collections import ChainMap
3
6
  from time import perf_counter as clock
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from collections import ChainMap
3
6
  from time import perf_counter as clock
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import argparse
2
5
  import itertools
3
6
  import json
@@ -122,7 +125,7 @@ def parse_benchmark_args(
122
125
  "pool size."
123
126
  ""
124
127
  ".. note::"
125
- " When paired with `--enable-rmm-async` the maximum size cannot be "
128
+ " When paired with ``--enable-rmm-async`` the maximum size cannot be "
126
129
  " guaranteed due to fragmentation."
127
130
  ""
128
131
  ".. note::"
@@ -641,11 +644,11 @@ def wait_for_cluster(client, timeout=120, shutdown_on_failure=True):
641
644
  for _ in range(timeout // 5):
642
645
  print(
643
646
  "Waiting for workers to come up, "
644
- f"have {len(client.scheduler_info().get('workers', []))}, "
647
+ f"have {len(client.scheduler_info(n_workers=-1).get('workers', []))}, "
645
648
  f"want {expected}"
646
649
  )
647
650
  time.sleep(5)
648
- nworkers = len(client.scheduler_info().get("workers", []))
651
+ nworkers = len(client.scheduler_info(n_workers=-1).get("workers", []))
649
652
  if nworkers == expected:
650
653
  return
651
654
  else:
dask_cuda/cli.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  from __future__ import absolute_import, division, print_function
2
5
 
3
6
  import logging
@@ -90,16 +93,20 @@ def cuda():
90
93
  help="""Size of the host LRU cache, which is used to determine when the worker
91
94
  starts spilling to disk (not available if JIT-Unspill is enabled). Can be an
92
95
  integer (bytes), float (fraction of total system memory), string (like ``"5GB"``
93
- or ``"5000M"``), or ``"auto"``, 0, or ``None`` for no memory management.""",
96
+ or ``"5000M"``), or ``"auto"`` or ``0`` for no memory management.""",
94
97
  )
95
98
  @click.option(
96
99
  "--device-memory-limit",
97
- default="0.8",
100
+ default="default",
98
101
  show_default=True,
99
102
  help="""Size of the CUDA device LRU cache, which is used to determine when the
100
103
  worker starts spilling to host memory. Can be an integer (bytes), float (fraction of
101
- total device memory), string (like ``"5GB"`` or ``"5000M"``), or ``"auto"`` or 0 to
102
- disable spilling to host (i.e. allow full device memory usage).""",
104
+ total device memory), string (like ``"5GB"`` or ``"5000M"``), ``"auto"`` or ``0``
105
+ to disable spilling to host (i.e. allow full device memory usage). Another special
106
+ value ``"default"`` (which happens to be the default) is also available and uses the
107
+ recommended Dask-CUDA's defaults and means 80% of the total device memory (analogous
108
+ to ``0.8``), and disabled spilling (analogous to ``auto``/``0``) on devices without
109
+ a dedicated memory resource, such as system on a chip (SoC) devices.""",
103
110
  )
104
111
  @click.option(
105
112
  "--enable-cudf-spill/--disable-cudf-spill",
@@ -113,7 +120,7 @@ def cuda():
113
120
  type=int,
114
121
  default=0,
115
122
  help="""Set the cuDF spilling statistics level. This option has no effect if
116
- `--enable-cudf-spill` is not specified.""",
123
+ ``--enable-cudf-spill`` is not specified.""",
117
124
  )
118
125
  @click.option(
119
126
  "--rmm-pool-size",
@@ -135,8 +142,8 @@ def cuda():
135
142
  to set the maximum pool size.
136
143
 
137
144
  .. note::
138
- When paired with `--enable-rmm-async` the maximum size cannot be guaranteed due
139
- to fragmentation.
145
+ When paired with ``--enable-rmm-async`` the maximum size cannot be guaranteed
146
+ due to fragmentation.
140
147
 
141
148
  .. note::
142
149
  This size is a per-worker configuration, and not cluster-wide.""",
@@ -160,9 +167,8 @@ def cuda():
160
167
  allocator. See ``rmm.mr.CudaAsyncMemoryResource`` for more info.
161
168
 
162
169
  .. warning::
163
- The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also
164
- incompatible with RMM pools and managed memory, trying to enable both will
165
- result in failure.""",
170
+ The asynchronous allocator is incompatible with RMM pools and managed memory,
171
+ trying to enable both will result in failure.""",
166
172
  )
167
173
  @click.option(
168
174
  "--set-rmm-allocator-for-libs",
@@ -245,12 +251,12 @@ def cuda():
245
251
  "--shared-filesystem/--no-shared-filesystem",
246
252
  default=None,
247
253
  type=bool,
248
- help="""If `--shared-filesystem` is specified, inform JIT-Unspill that
249
- `local_directory` is a shared filesystem available for all workers, whereas
250
- `--no-shared-filesystem` informs it may not assume it's a shared filesystem.
254
+ help="""If ``--shared-filesystem`` is specified, inform JIT-Unspill that
255
+ ``local_directory`` is a shared filesystem available for all workers, whereas
256
+ ``--no-shared-filesystem`` informs it may not assume it's a shared filesystem.
251
257
  If neither is specified, JIT-Unspill will decide based on the Dask config value
252
- specified by `"jit-unspill-shared-fs"`.
253
- Notice, a shared filesystem must support the `os.link()` operation.""",
258
+ specified by ``"jit-unspill-shared-fs"``.
259
+ Notice, a shared filesystem must support the ``os.link()`` operation.""",
254
260
  )
255
261
  @scheduler_file
256
262
  @click.option(
dask_cuda/cuda_worker.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  from __future__ import absolute_import, division, print_function
2
5
 
3
6
  import asyncio
@@ -18,18 +21,9 @@ from distributed.proctitle import (
18
21
  )
19
22
  from distributed.worker_memory import parse_memory_limit
20
23
 
21
- from .device_host_file import DeviceHostFile
22
24
  from .initialize import initialize
23
- from .plugins import CPUAffinity, CUDFSetup, PreImport, RMMSetup
24
- from .proxify_host_file import ProxifyHostFile
25
- from .utils import (
26
- cuda_visible_devices,
27
- get_cpu_affinity,
28
- get_n_gpus,
29
- get_ucx_config,
30
- nvml_device_index,
31
- parse_device_memory_limit,
32
- )
25
+ from .utils import cuda_visible_devices, get_n_gpus, get_ucx_config, nvml_device_index
26
+ from .worker_common import worker_data_function, worker_plugins
33
27
 
34
28
 
35
29
  class CUDAWorker(Server):
@@ -40,7 +34,7 @@ class CUDAWorker(Server):
40
34
  nthreads=1,
41
35
  name=None,
42
36
  memory_limit="auto",
43
- device_memory_limit="auto",
37
+ device_memory_limit="default",
44
38
  enable_cudf_spill=False,
45
39
  cudf_spill_stats=0,
46
40
  rmm_pool_size=None,
@@ -166,35 +160,14 @@ class CUDAWorker(Server):
166
160
 
167
161
  if jit_unspill is None:
168
162
  jit_unspill = dask.config.get("jit-unspill", default=False)
169
- if device_memory_limit is None and memory_limit is None:
170
- data = lambda _: {}
171
- elif jit_unspill:
172
- if enable_cudf_spill:
173
- warnings.warn(
174
- "Enabling cuDF spilling and JIT-Unspill together is not "
175
- "safe, consider disabling JIT-Unspill."
176
- )
177
163
 
178
- data = lambda i: (
179
- ProxifyHostFile,
180
- {
181
- "device_memory_limit": parse_device_memory_limit(
182
- device_memory_limit, device_index=i
183
- ),
184
- "memory_limit": memory_limit,
185
- "shared_filesystem": shared_filesystem,
186
- },
187
- )
188
- else:
189
- data = lambda i: (
190
- DeviceHostFile,
191
- {
192
- "device_memory_limit": parse_device_memory_limit(
193
- device_memory_limit, device_index=i
194
- ),
195
- "memory_limit": memory_limit,
196
- },
197
- )
164
+ data = worker_data_function(
165
+ device_memory_limit=device_memory_limit,
166
+ memory_limit=memory_limit,
167
+ jit_unspill=jit_unspill,
168
+ enable_cudf_spill=enable_cudf_spill,
169
+ shared_filesystem=shared_filesystem,
170
+ )
198
171
 
199
172
  cudf_spill_warning = dask.config.get("cudf-spill-warning", default=True)
200
173
  if enable_cudf_spill and cudf_spill_warning:
@@ -220,27 +193,24 @@ class CUDAWorker(Server):
220
193
  preload_argv=(list(preload_argv) or []) + ["--create-cuda-context"],
221
194
  security=security,
222
195
  env={"CUDA_VISIBLE_DEVICES": cuda_visible_devices(i)},
223
- plugins={
224
- CPUAffinity(
225
- get_cpu_affinity(nvml_device_index(i, cuda_visible_devices(i)))
226
- ),
227
- RMMSetup(
228
- initial_pool_size=rmm_pool_size,
229
- maximum_pool_size=rmm_maximum_pool_size,
230
- managed_memory=rmm_managed_memory,
231
- async_alloc=rmm_async,
232
- release_threshold=rmm_release_threshold,
233
- log_directory=rmm_log_directory,
234
- track_allocations=rmm_track_allocations,
235
- external_lib_list=rmm_allocator_external_lib_list,
236
- ),
237
- PreImport(pre_import),
238
- CUDFSetup(spill=enable_cudf_spill, spill_stats=cudf_spill_stats),
239
- },
196
+ plugins=worker_plugins(
197
+ device_index=nvml_device_index(i, cuda_visible_devices(i)),
198
+ rmm_initial_pool_size=rmm_pool_size,
199
+ rmm_maximum_pool_size=rmm_maximum_pool_size,
200
+ rmm_managed_memory=rmm_managed_memory,
201
+ rmm_async_alloc=rmm_async,
202
+ rmm_release_threshold=rmm_release_threshold,
203
+ rmm_log_directory=rmm_log_directory,
204
+ rmm_track_allocations=rmm_track_allocations,
205
+ rmm_allocator_external_lib_list=rmm_allocator_external_lib_list,
206
+ pre_import=pre_import,
207
+ enable_cudf_spill=enable_cudf_spill,
208
+ cudf_spill_stats=cudf_spill_stats,
209
+ ),
240
210
  name=name if nprocs == 1 or name is None else str(name) + "-" + str(i),
241
211
  local_directory=local_directory,
242
212
  config={
243
- "distributed.comm.ucx": get_ucx_config(
213
+ "distributed-ucxx": get_ucx_config(
244
214
  enable_tcp_over_ucx=enable_tcp_over_ucx,
245
215
  enable_infiniband=enable_infiniband,
246
216
  enable_nvlink=enable_nvlink,
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import itertools
2
5
  import logging
3
6
  import os
@@ -35,7 +38,7 @@ def _serialize_bytelist(x, **kwargs):
35
38
  class LoggedBuffer(Buffer):
36
39
  """Extends zict.Buffer with logging capabilities
37
40
 
38
- Two arguments `fast_name` and `slow_name` are passed to constructor that
41
+ Two arguments ``fast_name`` and ``slow_name`` are passed to constructor that
39
42
  identify a user-friendly name for logging of where spilling is going from/to.
40
43
  For example, their names can be "Device" and "Host" to identify that spilling
41
44
  is happening from a CUDA device into system memory.
@@ -112,7 +115,7 @@ class DeviceSerialized:
112
115
 
113
116
  This stores a device-side object as
114
117
  1. A msgpack encodable header
115
- 2. A list of `bytes`-like objects (like NumPy arrays)
118
+ 2. A list of ``bytes``-like objects (like NumPy arrays)
116
119
  that are in host memory
117
120
  """
118
121
 
@@ -169,12 +172,13 @@ class DeviceHostFile(ZictBase):
169
172
  ----------
170
173
  worker_local_directory: path
171
174
  Path where to store serialized objects on disk
172
- device_memory_limit: int
175
+ device_memory_limit: int or None
173
176
  Number of bytes of CUDA device memory for device LRU cache,
174
- spills to host cache once filled.
175
- memory_limit: int
177
+ spills to host cache once filled. Setting this ``0`` or ``None``
178
+ means unlimited device memory, implies no spilling to host.
179
+ memory_limit: int or None
176
180
  Number of bytes of host memory for host LRU cache, spills to
177
- disk once filled. Setting this to `0` or `None` means unlimited
181
+ disk once filled. Setting this to ``0`` or ``None`` means unlimited
178
182
  host memory, implies no spilling to disk.
179
183
  log_spilling: bool
180
184
  If True, all spilling operations will be logged directly to
@@ -230,15 +234,22 @@ class DeviceHostFile(ZictBase):
230
234
  self.device_keys = set()
231
235
  self.device_func = dict()
232
236
  self.device_host_func = Func(device_to_host, host_to_device, self.host_buffer)
233
- self.device_buffer = Buffer(
234
- self.device_func,
235
- self.device_host_func,
236
- device_memory_limit,
237
- weight=lambda k, v: safe_sizeof(v),
238
- **device_buffer_kwargs,
239
- )
237
+ if device_memory_limit is None:
238
+ self.device_buffer = self.device_func
239
+ else:
240
+ self.device_buffer = Buffer(
241
+ self.device_func,
242
+ self.device_host_func,
243
+ device_memory_limit,
244
+ weight=lambda k, v: safe_sizeof(v),
245
+ **device_buffer_kwargs,
246
+ )
240
247
 
241
- self.device = self.device_buffer.fast.d
248
+ self.device = (
249
+ self.device_buffer
250
+ if device_memory_limit is None
251
+ else self.device_buffer.fast.d
252
+ )
242
253
  self.host = (
243
254
  self.host_buffer if memory_limit is None else self.host_buffer.fast.d
244
255
  )
@@ -283,7 +294,12 @@ class DeviceHostFile(ZictBase):
283
294
  if key in self.others:
284
295
  del self.others[key]
285
296
  else:
286
- del self.device_buffer[key]
297
+ if isinstance(self.device_buffer, dict) and key not in self.device_buffer:
298
+ # If `self.device_buffer` is a dictionary, host `key`s are inserted
299
+ # directly into `self.host_buffer`.
300
+ del self.host_buffer[key]
301
+ else:
302
+ del self.device_buffer[key]
287
303
 
288
304
  def evict(self):
289
305
  """Evicts least recently used host buffer (aka, CPU or system memory)
dask_cuda/disk_io.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import itertools
2
5
  import os
3
6
  import os.path
@@ -106,7 +109,7 @@ class SpillToDiskProperties:
106
109
  root_dir : os.PathLike
107
110
  Path to the root directory to write serialized data.
108
111
  shared_filesystem: bool or None, default None
109
- Whether the `root_dir` above is shared between all workers or not.
112
+ Whether the ``root_dir`` above is shared between all workers or not.
110
113
  If ``None``, the "jit-unspill-shared-fs" config value are used, which
111
114
  defaults to False.
112
115
  gds: bool
@@ -154,10 +157,10 @@ def disk_write(path: str, frames: Iterable, shared_filesystem: bool, gds=False)
154
157
  The frames to write to disk
155
158
  shared_filesystem: bool
156
159
  Whether the target filesystem is shared between all workers or not.
157
- If True, the filesystem must support the `os.link()` operation.
160
+ If True, the filesystem must support the ``os.link()`` operation.
158
161
  gds: bool
159
162
  Enable the use of GPUDirect Storage. Notice, the consecutive
160
- `disk_read()` must enable GDS as well.
163
+ ``disk_read()`` must enable GDS as well.
161
164
 
162
165
  Returns
163
166
  -------
@@ -196,7 +199,7 @@ def disk_read(header: Mapping, gds=False) -> list:
196
199
  The metadata of the frames to read
197
200
  gds: bool
198
201
  Enable the use of GPUDirect Storage. Notice, this must
199
- match the GDS option set by the prior `disk_write()` call.
202
+ match the GDS option set by the prior ``disk_write()`` call.
200
203
 
201
204
  Returns
202
205
  -------
@@ -33,7 +33,7 @@ def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
33
33
  Returns
34
34
  -------
35
35
  context: context
36
- Either `MultiLock(*args, **kwargs)` or a NULL context
36
+ Either ``MultiLock(*args, **kwargs)`` or a NULL context
37
37
  """
38
38
  if multi_lock_context:
39
39
  from distributed import MultiLock
@@ -52,7 +52,7 @@ def default_comms(client: Optional[Client] = None) -> "CommsContext":
52
52
  Parameters
53
53
  ----------
54
54
  client: Client, optional
55
- If no default comm object exists, create the new comm on `client`
55
+ If no default comm object exists, create the new comm on ``client``
56
56
  are returned.
57
57
 
58
58
  Returns
@@ -77,7 +77,9 @@ def default_comms(client: Optional[Client] = None) -> "CommsContext":
77
77
  # Comms are unique to a {client, [workers]} pair, so we key our
78
78
  # cache by the token of that.
79
79
  client = client or default_client()
80
- token = tokenize(client.id, list(client.scheduler_info()["workers"].keys()))
80
+ token = tokenize(
81
+ client.id, list(client.scheduler_info(n_workers=-1)["workers"].keys())
82
+ )
81
83
  maybe_comms = _comms_cache.get(token)
82
84
  if maybe_comms is None:
83
85
  maybe_comms = CommsContext(client=client)
@@ -206,7 +208,9 @@ class CommsContext:
206
208
  self.sessionId = uuid.uuid4().int
207
209
 
208
210
  # Get address of all workers (not Nanny addresses)
209
- self.worker_addresses = list(self.client.scheduler_info()["workers"].keys())
211
+ self.worker_addresses = list(
212
+ self.client.scheduler_info(n_workers=-1)["workers"].keys()
213
+ )
210
214
 
211
215
  # Make all workers listen and get all listen addresses
212
216
  self.worker_direct_addresses = []
@@ -248,7 +252,7 @@ class CommsContext:
248
252
  Returns
249
253
  -------
250
254
  ret: object or Future
251
- If wait=True, the result of `coroutine`
255
+ If wait=True, the result of ``coroutine``
252
256
  If wait=False, Future that can be waited on later.
253
257
  """
254
258
  ret = self.client.submit(
@@ -305,7 +309,7 @@ class CommsContext:
305
309
  def stage_keys(self, name: str, keys: Iterable[Hashable]) -> Dict[int, set]:
306
310
  """Staging keys on workers under the given name
307
311
 
308
- In an explicit-comms task, use `pop_staging_area(..., name)` to access
312
+ In an explicit-comms task, use ``pop_staging_area(..., name)`` to access
309
313
  the staged keys and the associated data.
310
314
 
311
315
  Notes
@@ -335,7 +339,7 @@ class CommsContext:
335
339
 
336
340
 
337
341
  def pop_staging_area(session_state: dict, name: str) -> Dict[str, Any]:
338
- """Pop the staging area called `name`
342
+ """Pop the staging area called ``name``
339
343
 
340
344
  This function must be called within a running explicit-comms task.
341
345
 
@@ -65,13 +65,13 @@ def get_no_comm_postprocess(
65
65
  ) -> Callable[[DataFrame], DataFrame]:
66
66
  """Get function for post-processing partitions not communicated
67
67
 
68
- In cuDF, the `group_split_dispatch` uses `scatter_by_map` to create
68
+ In cuDF, the ``group_split_dispatch`` uses ``scatter_by_map`` to create
69
69
  the partitions, which is implemented by splitting a single base dataframe
70
70
  into multiple partitions. This means that memory are not freed until
71
71
  ALL partitions are deleted.
72
72
 
73
73
  In order to free memory ASAP, we can deep copy partitions NOT being
74
- communicated. We do this when `num_rounds != batchsize`.
74
+ communicated. We do this when ``num_rounds != batchsize``.
75
75
 
76
76
  Parameters
77
77
  ----------
@@ -116,7 +116,7 @@ async def send(
116
116
  rank_to_out_part_ids: Dict[int, Set[int]],
117
117
  out_part_id_to_dataframe: Dict[int, DataFrame],
118
118
  ) -> None:
119
- """Notice, items sent are removed from `out_part_id_to_dataframe`"""
119
+ """Notice, items sent are removed from ``out_part_id_to_dataframe``"""
120
120
  futures = []
121
121
  for rank, out_part_ids in rank_to_out_part_ids.items():
122
122
  if rank != myrank:
@@ -135,7 +135,7 @@ async def recv(
135
135
  out_part_id_to_dataframe_list: Dict[int, List[DataFrame]],
136
136
  proxify: Proxify,
137
137
  ) -> None:
138
- """Notice, received items are appended to `out_parts_list`"""
138
+ """Notice, received items are appended to ``out_parts_list``"""
139
139
 
140
140
  async def read_msg(rank: int) -> None:
141
141
  msg: Dict[int, DataFrame] = nested_deserialize(await eps[rank].read())
@@ -150,11 +150,11 @@ async def recv(
150
150
  def compute_map_index(
151
151
  df: DataFrame, column_names: List[str], npartitions: int
152
152
  ) -> Series:
153
- """Return a Series that maps each row `df` to a partition ID
153
+ """Return a Series that maps each row ``df`` to a partition ID
154
154
 
155
155
  The partitions are determined by hashing the columns given by column_names
156
- unless if `column_names[0] == "_partitions"`, in which case the values of
157
- `column_names[0]` are used as index.
156
+ unless if ``column_names[0] == "_partitions"``, in which case the values of
157
+ ``column_names[0]`` are used as index.
158
158
 
159
159
  Parameters
160
160
  ----------
@@ -168,7 +168,7 @@ def compute_map_index(
168
168
  Returns
169
169
  -------
170
170
  Series
171
- Series that maps each row `df` to a partition ID
171
+ Series that maps each row ``df`` to a partition ID
172
172
  """
173
173
 
174
174
  if column_names[0] == "_partitions":
@@ -193,8 +193,8 @@ def partition_dataframe(
193
193
  """Partition dataframe to a dict of dataframes
194
194
 
195
195
  The partitions are determined by hashing the columns given by column_names
196
- unless `column_names[0] == "_partitions"`, in which case the values of
197
- `column_names[0]` are used as index.
196
+ unless ``column_names[0] == "_partitions"``, in which case the values of
197
+ ``column_names[0]`` are used as index.
198
198
 
199
199
  Parameters
200
200
  ----------
@@ -301,13 +301,13 @@ async def send_recv_partitions(
301
301
  rank_to_out_part_ids
302
302
  dict that for each worker rank specifies a set of output partition IDs.
303
303
  If the worker shouldn't return any partitions, it is excluded from the
304
- dict. Partition IDs are global integers `0..npartitions` and corresponds
305
- to the dict keys returned by `group_split_dispatch`.
304
+ dict. Partition IDs are global integers ``0..npartitions`` and corresponds
305
+ to the dict keys returned by ``group_split_dispatch``.
306
306
  out_part_id_to_dataframe
307
307
  Mapping from partition ID to dataframe. This dict is cleared on return.
308
308
  no_comm_postprocess
309
309
  Function to post-process partitions not communicated.
310
- See `get_no_comm_postprocess`
310
+ See ``get_no_comm_postprocess``
311
311
  proxify
312
312
  Function to proxify object.
313
313
  out_part_id_to_dataframe_list
@@ -365,8 +365,8 @@ async def shuffle_task(
365
365
  rank_to_out_part_ids: dict
366
366
  dict that for each worker rank specifies a set of output partition IDs.
367
367
  If the worker shouldn't return any partitions, it is excluded from the
368
- dict. Partition IDs are global integers `0..npartitions` and corresponds
369
- to the dict keys returned by `group_split_dispatch`.
368
+ dict. Partition IDs are global integers ``0..npartitions`` and corresponds
369
+ to the dict keys returned by ``group_split_dispatch``.
370
370
  column_names: list of strings
371
371
  List of column names on which we want to split.
372
372
  npartitions: int
@@ -449,7 +449,7 @@ def shuffle(
449
449
  List of column names on which we want to split.
450
450
  npartitions: int or None
451
451
  The desired number of output partitions. If None, the number of output
452
- partitions equals `df.npartitions`
452
+ partitions equals ``df.npartitions``
453
453
  ignore_index: bool
454
454
  Ignore index during shuffle. If True, performance may improve,
455
455
  but index values will not be preserved.
@@ -460,7 +460,7 @@ def shuffle(
460
460
  If -1, each worker will handle all its partitions in a single round and
461
461
  all techniques to reduce memory usage are disabled, which might be faster
462
462
  when memory pressure isn't an issue.
463
- If None, the value of `DASK_EXPLICIT_COMMS_BATCHSIZE` is used or 1 if not
463
+ If None, the value of ``DASK_EXPLICIT_COMMS_BATCHSIZE`` is used or 1 if not
464
464
  set thus by default, we prioritize robustness over performance.
465
465
 
466
466
  Returns
@@ -471,12 +471,12 @@ def shuffle(
471
471
  Developer Notes
472
472
  ---------------
473
473
  The implementation consist of three steps:
474
- (a) Stage the partitions of `df` on all workers and then cancel them
474
+ (a) Stage the partitions of ``df`` on all workers and then cancel them
475
475
  thus at this point the Dask Scheduler doesn't know about any of the
476
476
  the partitions.
477
477
  (b) Submit a task on each worker that shuffle (all-to-all communicate)
478
478
  the staged partitions and return a list of dataframe-partitions.
479
- (c) Submit a dask graph that extract (using `getitem()`) individual
479
+ (c) Submit a dask graph that extract (using ``getitem()``) individual
480
480
  dataframe-partitions from (b).
481
481
  """
482
482
  c = comms.default_comms()
@@ -594,7 +594,7 @@ def _contains_shuffle_expr(*args) -> bool:
594
594
  """
595
595
  Check whether any of the arguments is a Shuffle expression.
596
596
 
597
- This is called by `compute`, which is given a sequence of Dask Collections
597
+ This is called by ``compute``, which is given a sequence of Dask Collections
598
598
  to process. For each of those, we'll check whether the expresion contains a
599
599
  Shuffle operation.
600
600
  """
@@ -712,9 +712,9 @@ def patch_shuffle_expression() -> None:
712
712
  """Patch Dasks Shuffle expression.
713
713
 
714
714
  Notice, this is monkey patched into Dask at dask_cuda
715
- import, and it changes `Shuffle._layer` to lower into
716
- an `ECShuffle` expression when the 'explicit-comms'
717
- config is set to `True`.
715
+ import, and it changes ``Shuffle._layer`` to lower into
716
+ an ``ECShuffle`` expression when the 'explicit-comms'
717
+ config is set to ``True``.
718
718
  """
719
719
  dask.base.compute = _patched_compute
720
720