dask-cuda 25.4.0__py3-none-any.whl → 25.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/_compat.py +18 -0
  4. dask_cuda/benchmarks/common.py +4 -1
  5. dask_cuda/benchmarks/local_cudf_groupby.py +4 -1
  6. dask_cuda/benchmarks/local_cudf_merge.py +5 -2
  7. dask_cuda/benchmarks/local_cudf_shuffle.py +5 -2
  8. dask_cuda/benchmarks/local_cupy.py +4 -1
  9. dask_cuda/benchmarks/local_cupy_map_overlap.py +4 -1
  10. dask_cuda/benchmarks/utils.py +7 -4
  11. dask_cuda/cli.py +21 -15
  12. dask_cuda/cuda_worker.py +27 -57
  13. dask_cuda/device_host_file.py +31 -15
  14. dask_cuda/disk_io.py +7 -4
  15. dask_cuda/explicit_comms/comms.py +11 -7
  16. dask_cuda/explicit_comms/dataframe/shuffle.py +147 -55
  17. dask_cuda/get_device_memory_objects.py +18 -3
  18. dask_cuda/initialize.py +80 -44
  19. dask_cuda/is_device_object.py +4 -1
  20. dask_cuda/is_spillable_object.py +4 -1
  21. dask_cuda/local_cuda_cluster.py +63 -66
  22. dask_cuda/plugins.py +17 -16
  23. dask_cuda/proxify_device_objects.py +15 -10
  24. dask_cuda/proxify_host_file.py +30 -27
  25. dask_cuda/proxy_object.py +20 -17
  26. dask_cuda/tests/conftest.py +41 -0
  27. dask_cuda/tests/test_dask_cuda_worker.py +114 -27
  28. dask_cuda/tests/test_dgx.py +10 -18
  29. dask_cuda/tests/test_explicit_comms.py +51 -18
  30. dask_cuda/tests/test_from_array.py +7 -5
  31. dask_cuda/tests/test_initialize.py +16 -37
  32. dask_cuda/tests/test_local_cuda_cluster.py +164 -54
  33. dask_cuda/tests/test_proxify_host_file.py +33 -4
  34. dask_cuda/tests/test_proxy.py +18 -16
  35. dask_cuda/tests/test_rdd_ucx.py +160 -0
  36. dask_cuda/tests/test_spill.py +107 -27
  37. dask_cuda/tests/test_utils.py +106 -20
  38. dask_cuda/tests/test_worker_spec.py +5 -2
  39. dask_cuda/utils.py +319 -68
  40. dask_cuda/utils_test.py +23 -7
  41. dask_cuda/worker_common.py +196 -0
  42. dask_cuda/worker_spec.py +12 -5
  43. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/METADATA +5 -4
  44. dask_cuda-25.8.0.dist-info/RECORD +63 -0
  45. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/WHEEL +1 -1
  46. dask_cuda-25.8.0.dist-info/top_level.txt +6 -0
  47. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  48. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  49. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  50. dask_cuda-25.4.0.dist-info/RECORD +0 -56
  51. dask_cuda-25.4.0.dist-info/top_level.txt +0 -5
  52. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/entry_points.txt +0 -0
  53. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/licenses/LICENSE +0 -0
dask_cuda/GIT_COMMIT CHANGED
@@ -1 +1 @@
1
- e9ebd92886e6f518af02faf8a2cdadeb700b25a9
1
+ bde9a4d3ee2c4338f56b3acf919b8e756ecb35b3
dask_cuda/VERSION CHANGED
@@ -1 +1 @@
1
- 25.04.00
1
+ 25.08.00
dask_cuda/_compat.py ADDED
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
2
+
3
+ import functools
4
+ import importlib.metadata
5
+
6
+ import packaging.version
7
+
8
+
9
+ @functools.lru_cache(maxsize=None)
10
+ def get_dask_version() -> packaging.version.Version:
11
+ return packaging.version.parse(importlib.metadata.version("dask"))
12
+
13
+
14
+ @functools.lru_cache(maxsize=None)
15
+ def DASK_2025_4_0():
16
+ # dask 2025.4.0 isn't currently released, so we're relying
17
+ # on strictly greater than here.
18
+ return get_dask_version() > packaging.version.parse("2025.3.0")
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from argparse import Namespace
3
6
  from functools import partial
@@ -124,7 +127,7 @@ def run(client: Client, args: Namespace, config: Config):
124
127
  """
125
128
 
126
129
  wait_for_cluster(client, shutdown_on_failure=True)
127
- assert len(client.scheduler_info()["workers"]) > 0
130
+ assert len(client.scheduler_info(n_workers=-1)["workers"]) > 0
128
131
  setup_memory_pools(
129
132
  client=client,
130
133
  is_gpu=args.type == "gpu",
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from collections import ChainMap
3
6
  from time import perf_counter as clock
@@ -138,7 +141,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
138
141
  key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
139
142
  )
140
143
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
141
- if args.protocol in ["ucx", "ucxx"]:
144
+ if args.protocol in ["ucx", "ucxx", "ucx-old"]:
142
145
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
143
146
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
144
147
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  import math
3
6
  from collections import ChainMap
@@ -166,7 +169,7 @@ def merge(args, ddf1, ddf2):
166
169
 
167
170
  def bench_once(client, args, write_profile=None):
168
171
  # Generate random Dask dataframes
169
- n_workers = len(client.scheduler_info()["workers"])
172
+ n_workers = len(client.scheduler_info(n_workers=-1)["workers"])
170
173
  # Allow the number of chunks to vary between
171
174
  # the "base" and "other" DataFrames
172
175
  args.base_chunks = args.base_chunks or n_workers
@@ -224,7 +227,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
224
227
  )
225
228
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
226
229
  print_key_value(key="Frac-match", value=f"{args.frac_match}")
227
- if args.protocol in ["ucx", "ucxx"]:
230
+ if args.protocol in ["ucx", "ucxx", "ucx-old"]:
228
231
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
229
232
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
230
233
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from collections import ChainMap
3
6
  from time import perf_counter
@@ -70,7 +73,7 @@ def create_data(
70
73
  """
71
74
  chunksize = args.partition_size // np.float64().nbytes
72
75
 
73
- workers = list(client.scheduler_info()["workers"].keys())
76
+ workers = list(client.scheduler_info(n_workers=-1)["workers"].keys())
74
77
  assert len(workers) > 0
75
78
 
76
79
  dist = args.partition_distribution
@@ -149,7 +152,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
149
152
  key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
150
153
  )
151
154
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
152
- if args.protocol in ["ucx", "ucxx"]:
155
+ if args.protocol in ["ucx", "ucxx", "ucx-old"]:
153
156
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
154
157
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
155
158
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from collections import ChainMap
3
6
  from time import perf_counter as clock
@@ -192,7 +195,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
192
195
  )
193
196
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
194
197
  print_key_value(key="Protocol", value=f"{args.protocol}")
195
- if args.protocol in ["ucx", "ucxx"]:
198
+ if args.protocol in ["ucx", "ucxx", "ucx-old"]:
196
199
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
197
200
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
198
201
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import contextlib
2
5
  from collections import ChainMap
3
6
  from time import perf_counter as clock
@@ -77,7 +80,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
77
80
  )
78
81
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
79
82
  print_key_value(key="Protocol", value=f"{args.protocol}")
80
- if args.protocol in ["ucx", "ucxx"]:
83
+ if args.protocol in ["ucx", "ucxx", "ucx-old"]:
81
84
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
82
85
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
83
86
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import argparse
2
5
  import itertools
3
6
  import json
@@ -77,7 +80,7 @@ def parse_benchmark_args(
77
80
  cluster_args.add_argument(
78
81
  "-p",
79
82
  "--protocol",
80
- choices=["tcp", "ucx", "ucxx"],
83
+ choices=["tcp", "ucx", "ucxx", "ucx-old"],
81
84
  default="tcp",
82
85
  type=str,
83
86
  help="The communication protocol to use.",
@@ -122,7 +125,7 @@ def parse_benchmark_args(
122
125
  "pool size."
123
126
  ""
124
127
  ".. note::"
125
- " When paired with `--enable-rmm-async` the maximum size cannot be "
128
+ " When paired with ``--enable-rmm-async`` the maximum size cannot be "
126
129
  " guaranteed due to fragmentation."
127
130
  ""
128
131
  ".. note::"
@@ -641,11 +644,11 @@ def wait_for_cluster(client, timeout=120, shutdown_on_failure=True):
641
644
  for _ in range(timeout // 5):
642
645
  print(
643
646
  "Waiting for workers to come up, "
644
- f"have {len(client.scheduler_info().get('workers', []))}, "
647
+ f"have {len(client.scheduler_info(n_workers=-1).get('workers', []))}, "
645
648
  f"want {expected}"
646
649
  )
647
650
  time.sleep(5)
648
- nworkers = len(client.scheduler_info().get("workers", []))
651
+ nworkers = len(client.scheduler_info(n_workers=-1).get("workers", []))
649
652
  if nworkers == expected:
650
653
  return
651
654
  else:
dask_cuda/cli.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  from __future__ import absolute_import, division, print_function
2
5
 
3
6
  import logging
@@ -90,16 +93,20 @@ def cuda():
90
93
  help="""Size of the host LRU cache, which is used to determine when the worker
91
94
  starts spilling to disk (not available if JIT-Unspill is enabled). Can be an
92
95
  integer (bytes), float (fraction of total system memory), string (like ``"5GB"``
93
- or ``"5000M"``), or ``"auto"``, 0, or ``None`` for no memory management.""",
96
+ or ``"5000M"``), or ``"auto"`` or ``0`` for no memory management.""",
94
97
  )
95
98
  @click.option(
96
99
  "--device-memory-limit",
97
- default="0.8",
100
+ default="default",
98
101
  show_default=True,
99
102
  help="""Size of the CUDA device LRU cache, which is used to determine when the
100
103
  worker starts spilling to host memory. Can be an integer (bytes), float (fraction of
101
- total device memory), string (like ``"5GB"`` or ``"5000M"``), or ``"auto"`` or 0 to
102
- disable spilling to host (i.e. allow full device memory usage).""",
104
+ total device memory), string (like ``"5GB"`` or ``"5000M"``), ``"auto"`` or ``0``
105
+ to disable spilling to host (i.e. allow full device memory usage). Another special
106
+ value ``"default"`` (which happens to be the default) is also available and uses the
107
+ recommended Dask-CUDA's defaults and means 80% of the total device memory (analogous
108
+ to ``0.8``), and disabled spilling (analogous to ``auto``/``0``) on devices without
109
+ a dedicated memory resource, such as system on a chip (SoC) devices.""",
103
110
  )
104
111
  @click.option(
105
112
  "--enable-cudf-spill/--disable-cudf-spill",
@@ -113,7 +120,7 @@ def cuda():
113
120
  type=int,
114
121
  default=0,
115
122
  help="""Set the cuDF spilling statistics level. This option has no effect if
116
- `--enable-cudf-spill` is not specified.""",
123
+ ``--enable-cudf-spill`` is not specified.""",
117
124
  )
118
125
  @click.option(
119
126
  "--rmm-pool-size",
@@ -135,8 +142,8 @@ def cuda():
135
142
  to set the maximum pool size.
136
143
 
137
144
  .. note::
138
- When paired with `--enable-rmm-async` the maximum size cannot be guaranteed due
139
- to fragmentation.
145
+ When paired with ``--enable-rmm-async`` the maximum size cannot be guaranteed
146
+ due to fragmentation.
140
147
 
141
148
  .. note::
142
149
  This size is a per-worker configuration, and not cluster-wide.""",
@@ -160,9 +167,8 @@ def cuda():
160
167
  allocator. See ``rmm.mr.CudaAsyncMemoryResource`` for more info.
161
168
 
162
169
  .. warning::
163
- The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also
164
- incompatible with RMM pools and managed memory, trying to enable both will
165
- result in failure.""",
170
+ The asynchronous allocator is incompatible with RMM pools and managed memory,
171
+ trying to enable both will result in failure.""",
166
172
  )
167
173
  @click.option(
168
174
  "--set-rmm-allocator-for-libs",
@@ -245,12 +251,12 @@ def cuda():
245
251
  "--shared-filesystem/--no-shared-filesystem",
246
252
  default=None,
247
253
  type=bool,
248
- help="""If `--shared-filesystem` is specified, inform JIT-Unspill that
249
- `local_directory` is a shared filesystem available for all workers, whereas
250
- `--no-shared-filesystem` informs it may not assume it's a shared filesystem.
254
+ help="""If ``--shared-filesystem`` is specified, inform JIT-Unspill that
255
+ ``local_directory`` is a shared filesystem available for all workers, whereas
256
+ ``--no-shared-filesystem`` informs it may not assume it's a shared filesystem.
251
257
  If neither is specified, JIT-Unspill will decide based on the Dask config value
252
- specified by `"jit-unspill-shared-fs"`.
253
- Notice, a shared filesystem must support the `os.link()` operation.""",
258
+ specified by ``"jit-unspill-shared-fs"``.
259
+ Notice, a shared filesystem must support the ``os.link()`` operation.""",
254
260
  )
255
261
  @scheduler_file
256
262
  @click.option(
dask_cuda/cuda_worker.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  from __future__ import absolute_import, division, print_function
2
5
 
3
6
  import asyncio
@@ -18,18 +21,9 @@ from distributed.proctitle import (
18
21
  )
19
22
  from distributed.worker_memory import parse_memory_limit
20
23
 
21
- from .device_host_file import DeviceHostFile
22
24
  from .initialize import initialize
23
- from .plugins import CPUAffinity, CUDFSetup, PreImport, RMMSetup
24
- from .proxify_host_file import ProxifyHostFile
25
- from .utils import (
26
- cuda_visible_devices,
27
- get_cpu_affinity,
28
- get_n_gpus,
29
- get_ucx_config,
30
- nvml_device_index,
31
- parse_device_memory_limit,
32
- )
25
+ from .utils import cuda_visible_devices, get_n_gpus, get_ucx_config, nvml_device_index
26
+ from .worker_common import worker_data_function, worker_plugins
33
27
 
34
28
 
35
29
  class CUDAWorker(Server):
@@ -40,7 +34,7 @@ class CUDAWorker(Server):
40
34
  nthreads=1,
41
35
  name=None,
42
36
  memory_limit="auto",
43
- device_memory_limit="auto",
37
+ device_memory_limit="default",
44
38
  enable_cudf_spill=False,
45
39
  cudf_spill_stats=0,
46
40
  rmm_pool_size=None,
@@ -166,35 +160,14 @@ class CUDAWorker(Server):
166
160
 
167
161
  if jit_unspill is None:
168
162
  jit_unspill = dask.config.get("jit-unspill", default=False)
169
- if device_memory_limit is None and memory_limit is None:
170
- data = lambda _: {}
171
- elif jit_unspill:
172
- if enable_cudf_spill:
173
- warnings.warn(
174
- "Enabling cuDF spilling and JIT-Unspill together is not "
175
- "safe, consider disabling JIT-Unspill."
176
- )
177
163
 
178
- data = lambda i: (
179
- ProxifyHostFile,
180
- {
181
- "device_memory_limit": parse_device_memory_limit(
182
- device_memory_limit, device_index=i
183
- ),
184
- "memory_limit": memory_limit,
185
- "shared_filesystem": shared_filesystem,
186
- },
187
- )
188
- else:
189
- data = lambda i: (
190
- DeviceHostFile,
191
- {
192
- "device_memory_limit": parse_device_memory_limit(
193
- device_memory_limit, device_index=i
194
- ),
195
- "memory_limit": memory_limit,
196
- },
197
- )
164
+ data = worker_data_function(
165
+ device_memory_limit=device_memory_limit,
166
+ memory_limit=memory_limit,
167
+ jit_unspill=jit_unspill,
168
+ enable_cudf_spill=enable_cudf_spill,
169
+ shared_filesystem=shared_filesystem,
170
+ )
198
171
 
199
172
  cudf_spill_warning = dask.config.get("cudf-spill-warning", default=True)
200
173
  if enable_cudf_spill and cudf_spill_warning:
@@ -220,23 +193,20 @@ class CUDAWorker(Server):
220
193
  preload_argv=(list(preload_argv) or []) + ["--create-cuda-context"],
221
194
  security=security,
222
195
  env={"CUDA_VISIBLE_DEVICES": cuda_visible_devices(i)},
223
- plugins={
224
- CPUAffinity(
225
- get_cpu_affinity(nvml_device_index(i, cuda_visible_devices(i)))
226
- ),
227
- RMMSetup(
228
- initial_pool_size=rmm_pool_size,
229
- maximum_pool_size=rmm_maximum_pool_size,
230
- managed_memory=rmm_managed_memory,
231
- async_alloc=rmm_async,
232
- release_threshold=rmm_release_threshold,
233
- log_directory=rmm_log_directory,
234
- track_allocations=rmm_track_allocations,
235
- external_lib_list=rmm_allocator_external_lib_list,
236
- ),
237
- PreImport(pre_import),
238
- CUDFSetup(spill=enable_cudf_spill, spill_stats=cudf_spill_stats),
239
- },
196
+ plugins=worker_plugins(
197
+ device_index=nvml_device_index(i, cuda_visible_devices(i)),
198
+ rmm_initial_pool_size=rmm_pool_size,
199
+ rmm_maximum_pool_size=rmm_maximum_pool_size,
200
+ rmm_managed_memory=rmm_managed_memory,
201
+ rmm_async_alloc=rmm_async,
202
+ rmm_release_threshold=rmm_release_threshold,
203
+ rmm_log_directory=rmm_log_directory,
204
+ rmm_track_allocations=rmm_track_allocations,
205
+ rmm_allocator_external_lib_list=rmm_allocator_external_lib_list,
206
+ pre_import=pre_import,
207
+ enable_cudf_spill=enable_cudf_spill,
208
+ cudf_spill_stats=cudf_spill_stats,
209
+ ),
240
210
  name=name if nprocs == 1 or name is None else str(name) + "-" + str(i),
241
211
  local_directory=local_directory,
242
212
  config={
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import itertools
2
5
  import logging
3
6
  import os
@@ -35,7 +38,7 @@ def _serialize_bytelist(x, **kwargs):
35
38
  class LoggedBuffer(Buffer):
36
39
  """Extends zict.Buffer with logging capabilities
37
40
 
38
- Two arguments `fast_name` and `slow_name` are passed to constructor that
41
+ Two arguments ``fast_name`` and ``slow_name`` are passed to constructor that
39
42
  identify a user-friendly name for logging of where spilling is going from/to.
40
43
  For example, their names can be "Device" and "Host" to identify that spilling
41
44
  is happening from a CUDA device into system memory.
@@ -112,7 +115,7 @@ class DeviceSerialized:
112
115
 
113
116
  This stores a device-side object as
114
117
  1. A msgpack encodable header
115
- 2. A list of `bytes`-like objects (like NumPy arrays)
118
+ 2. A list of ``bytes``-like objects (like NumPy arrays)
116
119
  that are in host memory
117
120
  """
118
121
 
@@ -169,12 +172,13 @@ class DeviceHostFile(ZictBase):
169
172
  ----------
170
173
  worker_local_directory: path
171
174
  Path where to store serialized objects on disk
172
- device_memory_limit: int
175
+ device_memory_limit: int or None
173
176
  Number of bytes of CUDA device memory for device LRU cache,
174
- spills to host cache once filled.
175
- memory_limit: int
177
+ spills to host cache once filled. Setting this ``0`` or ``None``
178
+ means unlimited device memory, implies no spilling to host.
179
+ memory_limit: int or None
176
180
  Number of bytes of host memory for host LRU cache, spills to
177
- disk once filled. Setting this to `0` or `None` means unlimited
181
+ disk once filled. Setting this to ``0`` or ``None`` means unlimited
178
182
  host memory, implies no spilling to disk.
179
183
  log_spilling: bool
180
184
  If True, all spilling operations will be logged directly to
@@ -230,15 +234,22 @@ class DeviceHostFile(ZictBase):
230
234
  self.device_keys = set()
231
235
  self.device_func = dict()
232
236
  self.device_host_func = Func(device_to_host, host_to_device, self.host_buffer)
233
- self.device_buffer = Buffer(
234
- self.device_func,
235
- self.device_host_func,
236
- device_memory_limit,
237
- weight=lambda k, v: safe_sizeof(v),
238
- **device_buffer_kwargs,
239
- )
237
+ if device_memory_limit is None:
238
+ self.device_buffer = self.device_func
239
+ else:
240
+ self.device_buffer = Buffer(
241
+ self.device_func,
242
+ self.device_host_func,
243
+ device_memory_limit,
244
+ weight=lambda k, v: safe_sizeof(v),
245
+ **device_buffer_kwargs,
246
+ )
240
247
 
241
- self.device = self.device_buffer.fast.d
248
+ self.device = (
249
+ self.device_buffer
250
+ if device_memory_limit is None
251
+ else self.device_buffer.fast.d
252
+ )
242
253
  self.host = (
243
254
  self.host_buffer if memory_limit is None else self.host_buffer.fast.d
244
255
  )
@@ -283,7 +294,12 @@ class DeviceHostFile(ZictBase):
283
294
  if key in self.others:
284
295
  del self.others[key]
285
296
  else:
286
- del self.device_buffer[key]
297
+ if isinstance(self.device_buffer, dict) and key not in self.device_buffer:
298
+ # If `self.device_buffer` is a dictionary, host `key`s are inserted
299
+ # directly into `self.host_buffer`.
300
+ del self.host_buffer[key]
301
+ else:
302
+ del self.device_buffer[key]
287
303
 
288
304
  def evict(self):
289
305
  """Evicts least recently used host buffer (aka, CPU or system memory)
dask_cuda/disk_io.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import itertools
2
5
  import os
3
6
  import os.path
@@ -106,7 +109,7 @@ class SpillToDiskProperties:
106
109
  root_dir : os.PathLike
107
110
  Path to the root directory to write serialized data.
108
111
  shared_filesystem: bool or None, default None
109
- Whether the `root_dir` above is shared between all workers or not.
112
+ Whether the ``root_dir`` above is shared between all workers or not.
110
113
  If ``None``, the "jit-unspill-shared-fs" config value are used, which
111
114
  defaults to False.
112
115
  gds: bool
@@ -154,10 +157,10 @@ def disk_write(path: str, frames: Iterable, shared_filesystem: bool, gds=False)
154
157
  The frames to write to disk
155
158
  shared_filesystem: bool
156
159
  Whether the target filesystem is shared between all workers or not.
157
- If True, the filesystem must support the `os.link()` operation.
160
+ If True, the filesystem must support the ``os.link()`` operation.
158
161
  gds: bool
159
162
  Enable the use of GPUDirect Storage. Notice, the consecutive
160
- `disk_read()` must enable GDS as well.
163
+ ``disk_read()`` must enable GDS as well.
161
164
 
162
165
  Returns
163
166
  -------
@@ -196,7 +199,7 @@ def disk_read(header: Mapping, gds=False) -> list:
196
199
  The metadata of the frames to read
197
200
  gds: bool
198
201
  Enable the use of GPUDirect Storage. Notice, this must
199
- match the GDS option set by the prior `disk_write()` call.
202
+ match the GDS option set by the prior ``disk_write()`` call.
200
203
 
201
204
  Returns
202
205
  -------
@@ -33,7 +33,7 @@ def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
33
33
  Returns
34
34
  -------
35
35
  context: context
36
- Either `MultiLock(*args, **kwargs)` or a NULL context
36
+ Either ``MultiLock(*args, **kwargs)`` or a NULL context
37
37
  """
38
38
  if multi_lock_context:
39
39
  from distributed import MultiLock
@@ -52,7 +52,7 @@ def default_comms(client: Optional[Client] = None) -> "CommsContext":
52
52
  Parameters
53
53
  ----------
54
54
  client: Client, optional
55
- If no default comm object exists, create the new comm on `client`
55
+ If no default comm object exists, create the new comm on ``client``
56
56
  are returned.
57
57
 
58
58
  Returns
@@ -77,7 +77,9 @@ def default_comms(client: Optional[Client] = None) -> "CommsContext":
77
77
  # Comms are unique to a {client, [workers]} pair, so we key our
78
78
  # cache by the token of that.
79
79
  client = client or default_client()
80
- token = tokenize(client.id, list(client.scheduler_info()["workers"].keys()))
80
+ token = tokenize(
81
+ client.id, list(client.scheduler_info(n_workers=-1)["workers"].keys())
82
+ )
81
83
  maybe_comms = _comms_cache.get(token)
82
84
  if maybe_comms is None:
83
85
  maybe_comms = CommsContext(client=client)
@@ -206,7 +208,9 @@ class CommsContext:
206
208
  self.sessionId = uuid.uuid4().int
207
209
 
208
210
  # Get address of all workers (not Nanny addresses)
209
- self.worker_addresses = list(self.client.scheduler_info()["workers"].keys())
211
+ self.worker_addresses = list(
212
+ self.client.scheduler_info(n_workers=-1)["workers"].keys()
213
+ )
210
214
 
211
215
  # Make all workers listen and get all listen addresses
212
216
  self.worker_direct_addresses = []
@@ -248,7 +252,7 @@ class CommsContext:
248
252
  Returns
249
253
  -------
250
254
  ret: object or Future
251
- If wait=True, the result of `coroutine`
255
+ If wait=True, the result of ``coroutine``
252
256
  If wait=False, Future that can be waited on later.
253
257
  """
254
258
  ret = self.client.submit(
@@ -305,7 +309,7 @@ class CommsContext:
305
309
  def stage_keys(self, name: str, keys: Iterable[Hashable]) -> Dict[int, set]:
306
310
  """Staging keys on workers under the given name
307
311
 
308
- In an explicit-comms task, use `pop_staging_area(..., name)` to access
312
+ In an explicit-comms task, use ``pop_staging_area(..., name)`` to access
309
313
  the staged keys and the associated data.
310
314
 
311
315
  Notes
@@ -335,7 +339,7 @@ class CommsContext:
335
339
 
336
340
 
337
341
  def pop_staging_area(session_state: dict, name: str) -> Dict[str, Any]:
338
- """Pop the staging area called `name`
342
+ """Pop the staging area called ``name``
339
343
 
340
344
  This function must be called within a running explicit-comms task.
341
345