dask-cuda 25.4.0__py3-none-any.whl → 25.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/_compat.py +18 -0
  4. dask_cuda/benchmarks/common.py +4 -1
  5. dask_cuda/benchmarks/local_cudf_groupby.py +4 -1
  6. dask_cuda/benchmarks/local_cudf_merge.py +5 -2
  7. dask_cuda/benchmarks/local_cudf_shuffle.py +5 -2
  8. dask_cuda/benchmarks/local_cupy.py +4 -1
  9. dask_cuda/benchmarks/local_cupy_map_overlap.py +4 -1
  10. dask_cuda/benchmarks/utils.py +7 -4
  11. dask_cuda/cli.py +21 -15
  12. dask_cuda/cuda_worker.py +27 -57
  13. dask_cuda/device_host_file.py +31 -15
  14. dask_cuda/disk_io.py +7 -4
  15. dask_cuda/explicit_comms/comms.py +11 -7
  16. dask_cuda/explicit_comms/dataframe/shuffle.py +147 -55
  17. dask_cuda/get_device_memory_objects.py +18 -3
  18. dask_cuda/initialize.py +80 -44
  19. dask_cuda/is_device_object.py +4 -1
  20. dask_cuda/is_spillable_object.py +4 -1
  21. dask_cuda/local_cuda_cluster.py +63 -66
  22. dask_cuda/plugins.py +17 -16
  23. dask_cuda/proxify_device_objects.py +15 -10
  24. dask_cuda/proxify_host_file.py +30 -27
  25. dask_cuda/proxy_object.py +20 -17
  26. dask_cuda/tests/conftest.py +41 -0
  27. dask_cuda/tests/test_dask_cuda_worker.py +114 -27
  28. dask_cuda/tests/test_dgx.py +10 -18
  29. dask_cuda/tests/test_explicit_comms.py +51 -18
  30. dask_cuda/tests/test_from_array.py +7 -5
  31. dask_cuda/tests/test_initialize.py +16 -37
  32. dask_cuda/tests/test_local_cuda_cluster.py +164 -54
  33. dask_cuda/tests/test_proxify_host_file.py +33 -4
  34. dask_cuda/tests/test_proxy.py +18 -16
  35. dask_cuda/tests/test_rdd_ucx.py +160 -0
  36. dask_cuda/tests/test_spill.py +107 -27
  37. dask_cuda/tests/test_utils.py +106 -20
  38. dask_cuda/tests/test_worker_spec.py +5 -2
  39. dask_cuda/utils.py +319 -68
  40. dask_cuda/utils_test.py +23 -7
  41. dask_cuda/worker_common.py +196 -0
  42. dask_cuda/worker_spec.py +12 -5
  43. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/METADATA +5 -4
  44. dask_cuda-25.8.0.dist-info/RECORD +63 -0
  45. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/WHEEL +1 -1
  46. dask_cuda-25.8.0.dist-info/top_level.txt +6 -0
  47. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  48. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  49. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  50. dask_cuda-25.4.0.dist-info/RECORD +0 -56
  51. dask_cuda-25.4.0.dist-info/top_level.txt +0 -5
  52. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/entry_points.txt +0 -0
  53. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,160 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ import importlib
6
+ import io
7
+ import multiprocessing as mp
8
+ import sys
9
+
10
+ import pytest
11
+
12
+ from dask_cuda import LocalCUDACluster
13
+
14
+ mp = mp.get_context("spawn") # type: ignore
15
+
16
+
17
+ def _has_distributed_ucxx() -> bool:
18
+ return bool(importlib.util.find_spec("distributed_ucxx"))
19
+
20
+
21
+ def _test_protocol_ucx():
22
+ with LocalCUDACluster(protocol="ucx") as cluster:
23
+ assert cluster.scheduler_comm.address.startswith("ucx://")
24
+
25
+ if _has_distributed_ucxx():
26
+ import distributed_ucxx
27
+
28
+ assert all(
29
+ isinstance(batched_send.comm, distributed_ucxx.ucxx.UCXX)
30
+ for batched_send in cluster.scheduler.stream_comms.values()
31
+ )
32
+ else:
33
+ import rapids_dask_dependency
34
+
35
+ assert all(
36
+ isinstance(
37
+ batched_send.comm,
38
+ rapids_dask_dependency.patches.distributed.comm.__rdd_patch_ucx.UCX,
39
+ )
40
+ for batched_send in cluster.scheduler.stream_comms.values()
41
+ )
42
+
43
+
44
+ def _test_protocol_ucxx():
45
+ if _has_distributed_ucxx():
46
+ with LocalCUDACluster(protocol="ucxx") as cluster:
47
+ assert cluster.scheduler_comm.address.startswith("ucxx://")
48
+ import distributed_ucxx
49
+
50
+ assert all(
51
+ isinstance(batched_send.comm, distributed_ucxx.ucxx.UCXX)
52
+ for batched_send in cluster.scheduler.stream_comms.values()
53
+ )
54
+ else:
55
+ with pytest.raises(RuntimeError, match="Cluster failed to start"):
56
+ LocalCUDACluster(protocol="ucxx")
57
+
58
+
59
+ def _test_protocol_ucx_old():
60
+ with LocalCUDACluster(protocol="ucx-old") as cluster:
61
+ assert cluster.scheduler_comm.address.startswith("ucx-old://")
62
+
63
+ import rapids_dask_dependency
64
+
65
+ assert all(
66
+ isinstance(
67
+ batched_send.comm,
68
+ rapids_dask_dependency.patches.distributed.comm.__rdd_patch_ucx.UCX,
69
+ )
70
+ for batched_send in cluster.scheduler.stream_comms.values()
71
+ )
72
+
73
+
74
+ def _run_test_with_output_capture(test_func_name, conn):
75
+ """Run a test function in a subprocess and capture stdout/stderr."""
76
+ # Redirect stdout and stderr to capture output
77
+ old_stdout = sys.stdout
78
+ old_stderr = sys.stderr
79
+ captured_output = io.StringIO()
80
+ sys.stdout = sys.stderr = captured_output
81
+
82
+ try:
83
+ # Import and run the test function
84
+ if test_func_name == "_test_protocol_ucx":
85
+ _test_protocol_ucx()
86
+ elif test_func_name == "_test_protocol_ucxx":
87
+ _test_protocol_ucxx()
88
+ elif test_func_name == "_test_protocol_ucx_old":
89
+ _test_protocol_ucx_old()
90
+ else:
91
+ raise ValueError(f"Unknown test function: {test_func_name}")
92
+
93
+ output = captured_output.getvalue()
94
+ conn.send((True, output)) # True = success
95
+ except Exception as e:
96
+ output = captured_output.getvalue()
97
+ output += f"\nException: {e}"
98
+ import traceback
99
+
100
+ output += f"\nTraceback:\n{traceback.format_exc()}"
101
+ conn.send((False, output)) # False = failure
102
+ finally:
103
+ # Restore original stdout/stderr
104
+ sys.stdout = old_stdout
105
+ sys.stderr = old_stderr
106
+ conn.close()
107
+
108
+
109
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx", "ucx-old"])
110
+ def test_rdd_protocol(protocol):
111
+ """Test rapids-dask-dependency protocol selection"""
112
+ if protocol == "ucx":
113
+ test_func_name = "_test_protocol_ucx"
114
+ elif protocol == "ucxx":
115
+ test_func_name = "_test_protocol_ucxx"
116
+ else:
117
+ test_func_name = "_test_protocol_ucx_old"
118
+
119
+ # Create a pipe for communication between parent and child processes
120
+ parent_conn, child_conn = mp.Pipe()
121
+ p = mp.Process(
122
+ target=_run_test_with_output_capture, args=(test_func_name, child_conn)
123
+ )
124
+
125
+ p.start()
126
+ p.join(timeout=60)
127
+
128
+ if p.is_alive():
129
+ p.kill()
130
+ p.close()
131
+ raise TimeoutError("Test process timed out")
132
+
133
+ # Get the result from the child process
134
+ success, output = parent_conn.recv()
135
+
136
+ # Check that the test passed
137
+ assert success, f"Test failed in subprocess. Output:\n{output}"
138
+
139
+ # For the ucx protocol, check if warnings are printed when distributed_ucxx is not
140
+ # available
141
+ if protocol == "ucx" and not _has_distributed_ucxx():
142
+ # Check if the warning about protocol='ucx' is printed
143
+ print(f"Output for {protocol} protocol:\n{output}")
144
+ assert (
145
+ "you have requested protocol='ucx'" in output
146
+ ), f"Expected warning not found in output: {output}"
147
+ assert (
148
+ "'distributed-ucxx' is not installed" in output
149
+ ), f"Expected warning about distributed-ucxx not found in output: {output}"
150
+ elif protocol == "ucx" and _has_distributed_ucxx():
151
+ # When distributed_ucxx is available, the warning should NOT be printed
152
+ assert "you have requested protocol='ucx'" not in output, (
153
+ "Warning should not be printed when distributed_ucxx is available: "
154
+ f"{output}"
155
+ )
156
+ elif protocol == "ucx-old":
157
+ # The ucx-old protocol should not generate warnings
158
+ assert (
159
+ "you have requested protocol='ucx'" not in output
160
+ ), f"Warning should not be printed for ucx-old protocol: {output}"
@@ -1,14 +1,18 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
+
1
3
  import gc
2
4
  import os
3
5
  from time import sleep
6
+ from typing import TypedDict
4
7
 
5
8
  import pytest
6
9
 
7
10
  import dask
8
11
  from dask import array as da
9
- from distributed import Client, wait
12
+ from distributed import Client, Worker, wait
10
13
  from distributed.metrics import time
11
14
  from distributed.sizeof import sizeof
15
+ from distributed.utils import Deadline
12
16
  from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401
13
17
 
14
18
  import dask_cudf
@@ -16,6 +20,13 @@ import dask_cudf
16
20
  from dask_cuda import LocalCUDACluster, utils
17
21
  from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
18
22
 
23
+ if not utils.has_device_memory_resource():
24
+ pytest.skip(
25
+ "No spilling tests supported for devices without memory resources. "
26
+ "See https://github.com/rapidsai/dask-cuda/issues/1510",
27
+ allow_module_level=True,
28
+ )
29
+
19
30
  if utils.get_device_total_memory() < 1e10:
20
31
  pytest.skip("Not enough GPU memory", allow_module_level=True)
21
32
 
@@ -72,24 +83,66 @@ def cudf_spill(request):
72
83
 
73
84
 
74
85
  def device_host_file_size_matches(
75
- dhf, total_bytes, device_chunk_overhead=0, serialized_chunk_overhead=1024
86
+ dask_worker: Worker,
87
+ total_bytes,
88
+ device_chunk_overhead=0,
89
+ serialized_chunk_overhead=1024,
76
90
  ):
77
- byte_sum = dhf.device_buffer.fast.total_weight
91
+ worker_data_sizes = collect_device_host_file_size(
92
+ dask_worker,
93
+ device_chunk_overhead=device_chunk_overhead,
94
+ serialized_chunk_overhead=serialized_chunk_overhead,
95
+ )
96
+ byte_sum = (
97
+ worker_data_sizes["device_fast"]
98
+ + worker_data_sizes["host_fast"]
99
+ + worker_data_sizes["host_buffer"]
100
+ + worker_data_sizes["disk"]
101
+ )
102
+ return (
103
+ byte_sum >= total_bytes
104
+ and byte_sum
105
+ <= total_bytes
106
+ + worker_data_sizes["device_overhead"]
107
+ + worker_data_sizes["host_overhead"]
108
+ + worker_data_sizes["disk_overhead"]
109
+ )
110
+
111
+
112
+ class WorkerDataSizes(TypedDict):
113
+ device_fast: int
114
+ host_fast: int
115
+ host_buffer: int
116
+ disk: int
117
+ device_overhead: int
118
+ host_overhead: int
119
+ disk_overhead: int
120
+
78
121
 
79
- # `dhf.host_buffer.fast` is only available when Worker's `memory_limit != 0`
122
+ def collect_device_host_file_size(
123
+ dask_worker: Worker,
124
+ device_chunk_overhead: int,
125
+ serialized_chunk_overhead: int,
126
+ ) -> WorkerDataSizes:
127
+ dhf = dask_worker.data
128
+
129
+ device_fast = dhf.device_buffer.fast.total_weight or 0
80
130
  if hasattr(dhf.host_buffer, "fast"):
81
- byte_sum += dhf.host_buffer.fast.total_weight
131
+ host_fast = dhf.host_buffer.fast.total_weight or 0
132
+ host_buffer = 0
82
133
  else:
83
- byte_sum += sum([sizeof(b) for b in dhf.host_buffer.values()])
134
+ host_buffer = sum([sizeof(b) for b in dhf.host_buffer.values()])
135
+ host_fast = 0
84
136
 
85
- # `dhf.disk` is only available when Worker's `memory_limit != 0`
86
137
  if dhf.disk is not None:
87
138
  file_path = [
88
139
  os.path.join(dhf.disk.directory, fname)
89
140
  for fname in dhf.disk.filenames.values()
90
141
  ]
91
142
  file_size = [os.path.getsize(f) for f in file_path]
92
- byte_sum += sum(file_size)
143
+ disk = sum(file_size)
144
+ else:
145
+ disk = 0
93
146
 
94
147
  # Allow up to chunk_overhead bytes overhead per chunk
95
148
  device_overhead = len(dhf.device) * device_chunk_overhead
@@ -98,17 +151,25 @@ def device_host_file_size_matches(
98
151
  len(dhf.disk) * serialized_chunk_overhead if dhf.disk is not None else 0
99
152
  )
100
153
 
101
- return (
102
- byte_sum >= total_bytes
103
- and byte_sum <= total_bytes + device_overhead + host_overhead + disk_overhead
154
+ return WorkerDataSizes(
155
+ device_fast=device_fast,
156
+ host_fast=host_fast,
157
+ host_buffer=host_buffer,
158
+ disk=disk,
159
+ device_overhead=device_overhead,
160
+ host_overhead=host_overhead,
161
+ disk_overhead=disk_overhead,
104
162
  )
105
163
 
106
164
 
107
165
  def assert_device_host_file_size(
108
- dhf, total_bytes, device_chunk_overhead=0, serialized_chunk_overhead=1024
166
+ dask_worker: Worker,
167
+ total_bytes,
168
+ device_chunk_overhead=0,
169
+ serialized_chunk_overhead=1024,
109
170
  ):
110
171
  assert device_host_file_size_matches(
111
- dhf, total_bytes, device_chunk_overhead, serialized_chunk_overhead
172
+ dask_worker, total_bytes, device_chunk_overhead, serialized_chunk_overhead
112
173
  )
113
174
 
114
175
 
@@ -119,7 +180,7 @@ def worker_assert(
119
180
  dask_worker=None,
120
181
  ):
121
182
  assert_device_host_file_size(
122
- dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
183
+ dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
123
184
  )
124
185
 
125
186
 
@@ -131,12 +192,12 @@ def delayed_worker_assert(
131
192
  ):
132
193
  start = time()
133
194
  while not device_host_file_size_matches(
134
- dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
195
+ dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
135
196
  ):
136
197
  sleep(0.01)
137
198
  if time() < start + 3:
138
199
  assert_device_host_file_size(
139
- dask_worker.data,
200
+ dask_worker,
140
201
  total_size,
141
202
  device_chunk_overhead,
142
203
  serialized_chunk_overhead,
@@ -224,8 +285,8 @@ async def test_cupy_cluster_device_spill(params):
224
285
  x = rs.random(int(50e6), chunks=2e6)
225
286
  await wait(x)
226
287
 
227
- xx = x.persist()
228
- await wait(xx)
288
+ [xx] = client.persist([x])
289
+ await xx
229
290
 
230
291
  # Allow up to 1024 bytes overhead per chunk serialized
231
292
  await client.run(
@@ -344,19 +405,38 @@ async def test_cudf_cluster_device_spill(params, cudf_spill):
344
405
  sizes = sizes.to_arrow().to_pylist()
345
406
  nbytes = sum(sizes)
346
407
 
347
- cdf2 = cdf.persist()
348
- await wait(cdf2)
408
+ [cdf2] = client.persist([cdf])
409
+ await cdf2
349
410
 
350
411
  del cdf
351
412
  gc.collect()
352
413
 
353
414
  if enable_cudf_spill:
354
- await client.run(
355
- worker_assert,
356
- 0,
357
- 0,
358
- 0,
415
+ expected_data = WorkerDataSizes(
416
+ device_fast=0,
417
+ host_fast=0,
418
+ host_buffer=0,
419
+ disk=0,
420
+ device_overhead=0,
421
+ host_overhead=0,
422
+ disk_overhead=0,
359
423
  )
424
+
425
+ deadline = Deadline.after(duration=3)
426
+ while not deadline.expired:
427
+ data = await client.run(
428
+ collect_device_host_file_size,
429
+ device_chunk_overhead=0,
430
+ serialized_chunk_overhead=0,
431
+ )
432
+ expected = {k: expected_data for k in data}
433
+ if data == expected:
434
+ break
435
+ sleep(0.01)
436
+
437
+ # final assertion for pytest to reraise with a nice traceback
438
+ assert data == expected
439
+
360
440
  else:
361
441
  await client.run(
362
442
  assert_host_chunks,
@@ -419,8 +499,8 @@ async def test_cudf_spill_cluster(cudf_spill):
419
499
  }
420
500
  )
421
501
 
422
- ddf = dask_cudf.from_cudf(cdf, npartitions=2).sum().persist()
423
- await wait(ddf)
502
+ [ddf] = client.persist([dask_cudf.from_cudf(cdf, npartitions=2).sum()])
503
+ await ddf
424
504
 
425
505
  await client.run(_assert_cudf_spill_stats, enable_cudf_spill)
426
506
  _assert_cudf_spill_stats(enable_cudf_spill)
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import os
2
5
  from unittest.mock import patch
3
6
 
@@ -15,11 +18,13 @@ from dask_cuda.utils import (
15
18
  get_n_gpus,
16
19
  get_preload_options,
17
20
  get_ucx_config,
21
+ has_device_memory_resource,
18
22
  nvml_device_index,
19
23
  parse_cuda_visible_device,
20
24
  parse_device_memory_limit,
21
25
  unpack_bitmask,
22
26
  )
27
+ from dask_cuda.utils_test import get_ucx_implementation
23
28
 
24
29
 
25
30
  @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1,2"})
@@ -76,19 +81,19 @@ def test_get_device_total_memory():
76
81
  for i in range(get_n_gpus()):
77
82
  with cuda.gpus[i]:
78
83
  total_mem = get_device_total_memory(i)
79
- assert type(total_mem) is int
80
- assert total_mem > 0
84
+ if has_device_memory_resource():
85
+ assert type(total_mem) is int
86
+ assert total_mem > 0
87
+ else:
88
+ assert total_mem is None
81
89
 
82
90
 
83
91
  @pytest.mark.parametrize(
84
92
  "protocol",
85
- ["ucx", "ucxx"],
93
+ ["ucx", "ucx-old"],
86
94
  )
87
95
  def test_get_preload_options_default(protocol):
88
- if protocol == "ucx":
89
- pytest.importorskip("ucp")
90
- elif protocol == "ucxx":
91
- pytest.importorskip("ucxx")
96
+ get_ucx_implementation(protocol)
92
97
 
93
98
  opts = get_preload_options(
94
99
  protocol=protocol,
@@ -103,16 +108,13 @@ def test_get_preload_options_default(protocol):
103
108
 
104
109
  @pytest.mark.parametrize(
105
110
  "protocol",
106
- ["ucx", "ucxx"],
111
+ ["ucx", "ucx-old"],
107
112
  )
108
113
  @pytest.mark.parametrize("enable_tcp", [True, False])
109
114
  @pytest.mark.parametrize("enable_infiniband", [True, False])
110
115
  @pytest.mark.parametrize("enable_nvlink", [True, False])
111
116
  def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvlink):
112
- if protocol == "ucx":
113
- pytest.importorskip("ucp")
114
- elif protocol == "ucxx":
115
- pytest.importorskip("ucxx")
117
+ get_ucx_implementation(protocol)
116
118
 
117
119
  opts = get_preload_options(
118
120
  protocol=protocol,
@@ -135,11 +137,17 @@ def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvl
135
137
  assert "--enable-nvlink" in opts["preload_argv"]
136
138
 
137
139
 
140
+ @pytest.mark.parametrize(
141
+ "protocol",
142
+ ["ucx", "ucx-old"],
143
+ )
138
144
  @pytest.mark.parametrize("enable_tcp_over_ucx", [True, False, None])
139
145
  @pytest.mark.parametrize("enable_nvlink", [True, False, None])
140
146
  @pytest.mark.parametrize("enable_infiniband", [True, False, None])
141
- def test_get_ucx_config(enable_tcp_over_ucx, enable_infiniband, enable_nvlink):
142
- pytest.importorskip("ucp")
147
+ def test_get_ucx_config(
148
+ protocol, enable_tcp_over_ucx, enable_infiniband, enable_nvlink
149
+ ):
150
+ get_ucx_implementation(protocol)
143
151
 
144
152
  kwargs = {
145
153
  "enable_tcp_over_ucx": enable_tcp_over_ucx,
@@ -234,20 +242,98 @@ def test_parse_visible_devices():
234
242
  parse_cuda_visible_device([])
235
243
 
236
244
 
245
+ def test_parse_device_bytes():
246
+ total = get_device_total_memory(0)
247
+
248
+ assert parse_device_memory_limit(None) is None
249
+ assert parse_device_memory_limit(0) is None
250
+ assert parse_device_memory_limit("0") is None
251
+ assert parse_device_memory_limit("0.0") is None
252
+ assert parse_device_memory_limit("0 GiB") is None
253
+
254
+ assert parse_device_memory_limit(1) == 1
255
+ assert parse_device_memory_limit("1") == 1
256
+
257
+ assert parse_device_memory_limit(1000000000) == 1000000000
258
+ assert parse_device_memory_limit("1GB") == 1000000000
259
+
260
+ if has_device_memory_resource(0):
261
+ assert parse_device_memory_limit(1.0) == total
262
+ assert parse_device_memory_limit("1.0") == total
263
+
264
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
265
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
266
+ total * 0.8 // 256 * 256
267
+ )
268
+
269
+ assert parse_device_memory_limit("default") == parse_device_memory_limit(0.8)
270
+ else:
271
+ assert parse_device_memory_limit("default") is None
272
+
273
+ with pytest.raises(ValueError):
274
+ assert parse_device_memory_limit(1.0) == total
275
+ with pytest.raises(ValueError):
276
+ assert parse_device_memory_limit("1.0") == total
277
+ with pytest.raises(ValueError):
278
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
279
+ with pytest.raises(ValueError):
280
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
281
+ total * 0.8 // 256 * 256
282
+ )
283
+
284
+
237
285
  def test_parse_device_memory_limit():
238
286
  total = get_device_total_memory(0)
239
287
 
240
- assert parse_device_memory_limit(None) == total
241
- assert parse_device_memory_limit(0) == total
288
+ assert parse_device_memory_limit(None) is None
289
+ assert parse_device_memory_limit(0) is None
290
+ assert parse_device_memory_limit("0") is None
291
+ assert parse_device_memory_limit(0.0) is None
292
+ assert parse_device_memory_limit("0 GiB") is None
293
+
294
+ assert parse_device_memory_limit(1) == 1
295
+ assert parse_device_memory_limit("1") == 1
296
+
242
297
  assert parse_device_memory_limit("auto") == total
243
298
 
244
- assert parse_device_memory_limit(0.8) == int(total * 0.8)
245
- assert parse_device_memory_limit(0.8, alignment_size=256) == int(
246
- total * 0.8 // 256 * 256
247
- )
248
299
  assert parse_device_memory_limit(1000000000) == 1000000000
249
300
  assert parse_device_memory_limit("1GB") == 1000000000
250
301
 
302
+ if has_device_memory_resource(0):
303
+ assert parse_device_memory_limit(1.0) == total
304
+ assert parse_device_memory_limit("1.0") == total
305
+
306
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
307
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
308
+ total * 0.8 // 256 * 256
309
+ )
310
+ assert parse_device_memory_limit("default") == parse_device_memory_limit(0.8)
311
+ else:
312
+ assert parse_device_memory_limit("default") is None
313
+
314
+ with pytest.raises(ValueError):
315
+ assert parse_device_memory_limit(1.0) == total
316
+ with pytest.raises(ValueError):
317
+ assert parse_device_memory_limit("1.0") == total
318
+ with pytest.raises(ValueError):
319
+ assert parse_device_memory_limit(0.8) == int(total * 0.8)
320
+ with pytest.raises(ValueError):
321
+ assert parse_device_memory_limit(0.8, alignment_size=256) == int(
322
+ total * 0.8 // 256 * 256
323
+ )
324
+
325
+
326
+ def test_has_device_memory_resoure():
327
+ has_memory_resource = has_device_memory_resource()
328
+ total = get_device_total_memory(0)
329
+
330
+ if has_memory_resource:
331
+ # Tested only in devices with a memory resource
332
+ assert total == parse_device_memory_limit("auto")
333
+ else:
334
+ # Tested only in devices without a memory resource
335
+ assert total is None
336
+
251
337
 
252
338
  def test_parse_visible_mig_devices():
253
339
  pynvml.nvmlInit()
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import pytest
2
5
 
3
6
  from distributed import Nanny
@@ -28,7 +31,7 @@ def _check_env_value(spec, k, v):
28
31
  @pytest.mark.parametrize("num_devices", [1, 4])
29
32
  @pytest.mark.parametrize("cls", [Nanny])
30
33
  @pytest.mark.parametrize("interface", [None, "eth0", "enp1s0f0"])
31
- @pytest.mark.parametrize("protocol", [None, "tcp", "ucx"])
34
+ @pytest.mark.parametrize("protocol", [None, "tcp", "ucx", "ucx-old"])
32
35
  @pytest.mark.parametrize("dashboard_address", [None, ":0", ":8787"])
33
36
  @pytest.mark.parametrize("threads_per_worker", [1, 8])
34
37
  @pytest.mark.parametrize("silence_logs", [False, True])
@@ -58,7 +61,7 @@ def test_worker_spec(
58
61
  enable_nvlink=enable_nvlink,
59
62
  )
60
63
 
61
- if (enable_infiniband or enable_nvlink) and protocol != "ucx":
64
+ if (enable_infiniband or enable_nvlink) and protocol not in ("ucx", "ucx-old"):
62
65
  with pytest.raises(
63
66
  TypeError, match="Enabling InfiniBand or NVLink requires protocol='ucx'"
64
67
  ):