dask-cuda 25.4.0__py3-none-any.whl → 25.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_cuda/GIT_COMMIT +1 -1
- dask_cuda/VERSION +1 -1
- dask_cuda/_compat.py +18 -0
- dask_cuda/benchmarks/common.py +4 -1
- dask_cuda/benchmarks/local_cudf_groupby.py +4 -1
- dask_cuda/benchmarks/local_cudf_merge.py +5 -2
- dask_cuda/benchmarks/local_cudf_shuffle.py +5 -2
- dask_cuda/benchmarks/local_cupy.py +4 -1
- dask_cuda/benchmarks/local_cupy_map_overlap.py +4 -1
- dask_cuda/benchmarks/utils.py +7 -4
- dask_cuda/cli.py +21 -15
- dask_cuda/cuda_worker.py +27 -57
- dask_cuda/device_host_file.py +31 -15
- dask_cuda/disk_io.py +7 -4
- dask_cuda/explicit_comms/comms.py +11 -7
- dask_cuda/explicit_comms/dataframe/shuffle.py +147 -55
- dask_cuda/get_device_memory_objects.py +18 -3
- dask_cuda/initialize.py +80 -44
- dask_cuda/is_device_object.py +4 -1
- dask_cuda/is_spillable_object.py +4 -1
- dask_cuda/local_cuda_cluster.py +63 -66
- dask_cuda/plugins.py +17 -16
- dask_cuda/proxify_device_objects.py +15 -10
- dask_cuda/proxify_host_file.py +30 -27
- dask_cuda/proxy_object.py +20 -17
- dask_cuda/tests/conftest.py +41 -0
- dask_cuda/tests/test_dask_cuda_worker.py +114 -27
- dask_cuda/tests/test_dgx.py +10 -18
- dask_cuda/tests/test_explicit_comms.py +51 -18
- dask_cuda/tests/test_from_array.py +7 -5
- dask_cuda/tests/test_initialize.py +16 -37
- dask_cuda/tests/test_local_cuda_cluster.py +164 -54
- dask_cuda/tests/test_proxify_host_file.py +33 -4
- dask_cuda/tests/test_proxy.py +18 -16
- dask_cuda/tests/test_rdd_ucx.py +160 -0
- dask_cuda/tests/test_spill.py +107 -27
- dask_cuda/tests/test_utils.py +106 -20
- dask_cuda/tests/test_worker_spec.py +5 -2
- dask_cuda/utils.py +319 -68
- dask_cuda/utils_test.py +23 -7
- dask_cuda/worker_common.py +196 -0
- dask_cuda/worker_spec.py +12 -5
- {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/METADATA +5 -4
- dask_cuda-25.8.0.dist-info/RECORD +63 -0
- {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/WHEEL +1 -1
- dask_cuda-25.8.0.dist-info/top_level.txt +6 -0
- shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
- shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
- shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
- dask_cuda-25.4.0.dist-info/RECORD +0 -56
- dask_cuda-25.4.0.dist-info/top_level.txt +0 -5
- {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/entry_points.txt +0 -0
- {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
import io
|
|
7
|
+
import multiprocessing as mp
|
|
8
|
+
import sys
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
|
|
12
|
+
from dask_cuda import LocalCUDACluster
|
|
13
|
+
|
|
14
|
+
mp = mp.get_context("spawn") # type: ignore
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _has_distributed_ucxx() -> bool:
|
|
18
|
+
return bool(importlib.util.find_spec("distributed_ucxx"))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _test_protocol_ucx():
|
|
22
|
+
with LocalCUDACluster(protocol="ucx") as cluster:
|
|
23
|
+
assert cluster.scheduler_comm.address.startswith("ucx://")
|
|
24
|
+
|
|
25
|
+
if _has_distributed_ucxx():
|
|
26
|
+
import distributed_ucxx
|
|
27
|
+
|
|
28
|
+
assert all(
|
|
29
|
+
isinstance(batched_send.comm, distributed_ucxx.ucxx.UCXX)
|
|
30
|
+
for batched_send in cluster.scheduler.stream_comms.values()
|
|
31
|
+
)
|
|
32
|
+
else:
|
|
33
|
+
import rapids_dask_dependency
|
|
34
|
+
|
|
35
|
+
assert all(
|
|
36
|
+
isinstance(
|
|
37
|
+
batched_send.comm,
|
|
38
|
+
rapids_dask_dependency.patches.distributed.comm.__rdd_patch_ucx.UCX,
|
|
39
|
+
)
|
|
40
|
+
for batched_send in cluster.scheduler.stream_comms.values()
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _test_protocol_ucxx():
|
|
45
|
+
if _has_distributed_ucxx():
|
|
46
|
+
with LocalCUDACluster(protocol="ucxx") as cluster:
|
|
47
|
+
assert cluster.scheduler_comm.address.startswith("ucxx://")
|
|
48
|
+
import distributed_ucxx
|
|
49
|
+
|
|
50
|
+
assert all(
|
|
51
|
+
isinstance(batched_send.comm, distributed_ucxx.ucxx.UCXX)
|
|
52
|
+
for batched_send in cluster.scheduler.stream_comms.values()
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
with pytest.raises(RuntimeError, match="Cluster failed to start"):
|
|
56
|
+
LocalCUDACluster(protocol="ucxx")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _test_protocol_ucx_old():
|
|
60
|
+
with LocalCUDACluster(protocol="ucx-old") as cluster:
|
|
61
|
+
assert cluster.scheduler_comm.address.startswith("ucx-old://")
|
|
62
|
+
|
|
63
|
+
import rapids_dask_dependency
|
|
64
|
+
|
|
65
|
+
assert all(
|
|
66
|
+
isinstance(
|
|
67
|
+
batched_send.comm,
|
|
68
|
+
rapids_dask_dependency.patches.distributed.comm.__rdd_patch_ucx.UCX,
|
|
69
|
+
)
|
|
70
|
+
for batched_send in cluster.scheduler.stream_comms.values()
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _run_test_with_output_capture(test_func_name, conn):
|
|
75
|
+
"""Run a test function in a subprocess and capture stdout/stderr."""
|
|
76
|
+
# Redirect stdout and stderr to capture output
|
|
77
|
+
old_stdout = sys.stdout
|
|
78
|
+
old_stderr = sys.stderr
|
|
79
|
+
captured_output = io.StringIO()
|
|
80
|
+
sys.stdout = sys.stderr = captured_output
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
# Import and run the test function
|
|
84
|
+
if test_func_name == "_test_protocol_ucx":
|
|
85
|
+
_test_protocol_ucx()
|
|
86
|
+
elif test_func_name == "_test_protocol_ucxx":
|
|
87
|
+
_test_protocol_ucxx()
|
|
88
|
+
elif test_func_name == "_test_protocol_ucx_old":
|
|
89
|
+
_test_protocol_ucx_old()
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(f"Unknown test function: {test_func_name}")
|
|
92
|
+
|
|
93
|
+
output = captured_output.getvalue()
|
|
94
|
+
conn.send((True, output)) # True = success
|
|
95
|
+
except Exception as e:
|
|
96
|
+
output = captured_output.getvalue()
|
|
97
|
+
output += f"\nException: {e}"
|
|
98
|
+
import traceback
|
|
99
|
+
|
|
100
|
+
output += f"\nTraceback:\n{traceback.format_exc()}"
|
|
101
|
+
conn.send((False, output)) # False = failure
|
|
102
|
+
finally:
|
|
103
|
+
# Restore original stdout/stderr
|
|
104
|
+
sys.stdout = old_stdout
|
|
105
|
+
sys.stderr = old_stderr
|
|
106
|
+
conn.close()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.mark.parametrize("protocol", ["ucx", "ucxx", "ucx-old"])
|
|
110
|
+
def test_rdd_protocol(protocol):
|
|
111
|
+
"""Test rapids-dask-dependency protocol selection"""
|
|
112
|
+
if protocol == "ucx":
|
|
113
|
+
test_func_name = "_test_protocol_ucx"
|
|
114
|
+
elif protocol == "ucxx":
|
|
115
|
+
test_func_name = "_test_protocol_ucxx"
|
|
116
|
+
else:
|
|
117
|
+
test_func_name = "_test_protocol_ucx_old"
|
|
118
|
+
|
|
119
|
+
# Create a pipe for communication between parent and child processes
|
|
120
|
+
parent_conn, child_conn = mp.Pipe()
|
|
121
|
+
p = mp.Process(
|
|
122
|
+
target=_run_test_with_output_capture, args=(test_func_name, child_conn)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
p.start()
|
|
126
|
+
p.join(timeout=60)
|
|
127
|
+
|
|
128
|
+
if p.is_alive():
|
|
129
|
+
p.kill()
|
|
130
|
+
p.close()
|
|
131
|
+
raise TimeoutError("Test process timed out")
|
|
132
|
+
|
|
133
|
+
# Get the result from the child process
|
|
134
|
+
success, output = parent_conn.recv()
|
|
135
|
+
|
|
136
|
+
# Check that the test passed
|
|
137
|
+
assert success, f"Test failed in subprocess. Output:\n{output}"
|
|
138
|
+
|
|
139
|
+
# For the ucx protocol, check if warnings are printed when distributed_ucxx is not
|
|
140
|
+
# available
|
|
141
|
+
if protocol == "ucx" and not _has_distributed_ucxx():
|
|
142
|
+
# Check if the warning about protocol='ucx' is printed
|
|
143
|
+
print(f"Output for {protocol} protocol:\n{output}")
|
|
144
|
+
assert (
|
|
145
|
+
"you have requested protocol='ucx'" in output
|
|
146
|
+
), f"Expected warning not found in output: {output}"
|
|
147
|
+
assert (
|
|
148
|
+
"'distributed-ucxx' is not installed" in output
|
|
149
|
+
), f"Expected warning about distributed-ucxx not found in output: {output}"
|
|
150
|
+
elif protocol == "ucx" and _has_distributed_ucxx():
|
|
151
|
+
# When distributed_ucxx is available, the warning should NOT be printed
|
|
152
|
+
assert "you have requested protocol='ucx'" not in output, (
|
|
153
|
+
"Warning should not be printed when distributed_ucxx is available: "
|
|
154
|
+
f"{output}"
|
|
155
|
+
)
|
|
156
|
+
elif protocol == "ucx-old":
|
|
157
|
+
# The ucx-old protocol should not generate warnings
|
|
158
|
+
assert (
|
|
159
|
+
"you have requested protocol='ucx'" not in output
|
|
160
|
+
), f"Warning should not be printed for ucx-old protocol: {output}"
|
dask_cuda/tests/test_spill.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2025, NVIDIA CORPORATION.
|
|
2
|
+
|
|
1
3
|
import gc
|
|
2
4
|
import os
|
|
3
5
|
from time import sleep
|
|
6
|
+
from typing import TypedDict
|
|
4
7
|
|
|
5
8
|
import pytest
|
|
6
9
|
|
|
7
10
|
import dask
|
|
8
11
|
from dask import array as da
|
|
9
|
-
from distributed import Client, wait
|
|
12
|
+
from distributed import Client, Worker, wait
|
|
10
13
|
from distributed.metrics import time
|
|
11
14
|
from distributed.sizeof import sizeof
|
|
15
|
+
from distributed.utils import Deadline
|
|
12
16
|
from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401
|
|
13
17
|
|
|
14
18
|
import dask_cudf
|
|
@@ -16,6 +20,13 @@ import dask_cudf
|
|
|
16
20
|
from dask_cuda import LocalCUDACluster, utils
|
|
17
21
|
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
|
|
18
22
|
|
|
23
|
+
if not utils.has_device_memory_resource():
|
|
24
|
+
pytest.skip(
|
|
25
|
+
"No spilling tests supported for devices without memory resources. "
|
|
26
|
+
"See https://github.com/rapidsai/dask-cuda/issues/1510",
|
|
27
|
+
allow_module_level=True,
|
|
28
|
+
)
|
|
29
|
+
|
|
19
30
|
if utils.get_device_total_memory() < 1e10:
|
|
20
31
|
pytest.skip("Not enough GPU memory", allow_module_level=True)
|
|
21
32
|
|
|
@@ -72,24 +83,66 @@ def cudf_spill(request):
|
|
|
72
83
|
|
|
73
84
|
|
|
74
85
|
def device_host_file_size_matches(
|
|
75
|
-
|
|
86
|
+
dask_worker: Worker,
|
|
87
|
+
total_bytes,
|
|
88
|
+
device_chunk_overhead=0,
|
|
89
|
+
serialized_chunk_overhead=1024,
|
|
76
90
|
):
|
|
77
|
-
|
|
91
|
+
worker_data_sizes = collect_device_host_file_size(
|
|
92
|
+
dask_worker,
|
|
93
|
+
device_chunk_overhead=device_chunk_overhead,
|
|
94
|
+
serialized_chunk_overhead=serialized_chunk_overhead,
|
|
95
|
+
)
|
|
96
|
+
byte_sum = (
|
|
97
|
+
worker_data_sizes["device_fast"]
|
|
98
|
+
+ worker_data_sizes["host_fast"]
|
|
99
|
+
+ worker_data_sizes["host_buffer"]
|
|
100
|
+
+ worker_data_sizes["disk"]
|
|
101
|
+
)
|
|
102
|
+
return (
|
|
103
|
+
byte_sum >= total_bytes
|
|
104
|
+
and byte_sum
|
|
105
|
+
<= total_bytes
|
|
106
|
+
+ worker_data_sizes["device_overhead"]
|
|
107
|
+
+ worker_data_sizes["host_overhead"]
|
|
108
|
+
+ worker_data_sizes["disk_overhead"]
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class WorkerDataSizes(TypedDict):
|
|
113
|
+
device_fast: int
|
|
114
|
+
host_fast: int
|
|
115
|
+
host_buffer: int
|
|
116
|
+
disk: int
|
|
117
|
+
device_overhead: int
|
|
118
|
+
host_overhead: int
|
|
119
|
+
disk_overhead: int
|
|
120
|
+
|
|
78
121
|
|
|
79
|
-
|
|
122
|
+
def collect_device_host_file_size(
|
|
123
|
+
dask_worker: Worker,
|
|
124
|
+
device_chunk_overhead: int,
|
|
125
|
+
serialized_chunk_overhead: int,
|
|
126
|
+
) -> WorkerDataSizes:
|
|
127
|
+
dhf = dask_worker.data
|
|
128
|
+
|
|
129
|
+
device_fast = dhf.device_buffer.fast.total_weight or 0
|
|
80
130
|
if hasattr(dhf.host_buffer, "fast"):
|
|
81
|
-
|
|
131
|
+
host_fast = dhf.host_buffer.fast.total_weight or 0
|
|
132
|
+
host_buffer = 0
|
|
82
133
|
else:
|
|
83
|
-
|
|
134
|
+
host_buffer = sum([sizeof(b) for b in dhf.host_buffer.values()])
|
|
135
|
+
host_fast = 0
|
|
84
136
|
|
|
85
|
-
# `dhf.disk` is only available when Worker's `memory_limit != 0`
|
|
86
137
|
if dhf.disk is not None:
|
|
87
138
|
file_path = [
|
|
88
139
|
os.path.join(dhf.disk.directory, fname)
|
|
89
140
|
for fname in dhf.disk.filenames.values()
|
|
90
141
|
]
|
|
91
142
|
file_size = [os.path.getsize(f) for f in file_path]
|
|
92
|
-
|
|
143
|
+
disk = sum(file_size)
|
|
144
|
+
else:
|
|
145
|
+
disk = 0
|
|
93
146
|
|
|
94
147
|
# Allow up to chunk_overhead bytes overhead per chunk
|
|
95
148
|
device_overhead = len(dhf.device) * device_chunk_overhead
|
|
@@ -98,17 +151,25 @@ def device_host_file_size_matches(
|
|
|
98
151
|
len(dhf.disk) * serialized_chunk_overhead if dhf.disk is not None else 0
|
|
99
152
|
)
|
|
100
153
|
|
|
101
|
-
return (
|
|
102
|
-
|
|
103
|
-
|
|
154
|
+
return WorkerDataSizes(
|
|
155
|
+
device_fast=device_fast,
|
|
156
|
+
host_fast=host_fast,
|
|
157
|
+
host_buffer=host_buffer,
|
|
158
|
+
disk=disk,
|
|
159
|
+
device_overhead=device_overhead,
|
|
160
|
+
host_overhead=host_overhead,
|
|
161
|
+
disk_overhead=disk_overhead,
|
|
104
162
|
)
|
|
105
163
|
|
|
106
164
|
|
|
107
165
|
def assert_device_host_file_size(
|
|
108
|
-
|
|
166
|
+
dask_worker: Worker,
|
|
167
|
+
total_bytes,
|
|
168
|
+
device_chunk_overhead=0,
|
|
169
|
+
serialized_chunk_overhead=1024,
|
|
109
170
|
):
|
|
110
171
|
assert device_host_file_size_matches(
|
|
111
|
-
|
|
172
|
+
dask_worker, total_bytes, device_chunk_overhead, serialized_chunk_overhead
|
|
112
173
|
)
|
|
113
174
|
|
|
114
175
|
|
|
@@ -119,7 +180,7 @@ def worker_assert(
|
|
|
119
180
|
dask_worker=None,
|
|
120
181
|
):
|
|
121
182
|
assert_device_host_file_size(
|
|
122
|
-
dask_worker
|
|
183
|
+
dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
|
|
123
184
|
)
|
|
124
185
|
|
|
125
186
|
|
|
@@ -131,12 +192,12 @@ def delayed_worker_assert(
|
|
|
131
192
|
):
|
|
132
193
|
start = time()
|
|
133
194
|
while not device_host_file_size_matches(
|
|
134
|
-
dask_worker
|
|
195
|
+
dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
|
|
135
196
|
):
|
|
136
197
|
sleep(0.01)
|
|
137
198
|
if time() < start + 3:
|
|
138
199
|
assert_device_host_file_size(
|
|
139
|
-
dask_worker
|
|
200
|
+
dask_worker,
|
|
140
201
|
total_size,
|
|
141
202
|
device_chunk_overhead,
|
|
142
203
|
serialized_chunk_overhead,
|
|
@@ -224,8 +285,8 @@ async def test_cupy_cluster_device_spill(params):
|
|
|
224
285
|
x = rs.random(int(50e6), chunks=2e6)
|
|
225
286
|
await wait(x)
|
|
226
287
|
|
|
227
|
-
xx =
|
|
228
|
-
await
|
|
288
|
+
[xx] = client.persist([x])
|
|
289
|
+
await xx
|
|
229
290
|
|
|
230
291
|
# Allow up to 1024 bytes overhead per chunk serialized
|
|
231
292
|
await client.run(
|
|
@@ -344,19 +405,38 @@ async def test_cudf_cluster_device_spill(params, cudf_spill):
|
|
|
344
405
|
sizes = sizes.to_arrow().to_pylist()
|
|
345
406
|
nbytes = sum(sizes)
|
|
346
407
|
|
|
347
|
-
cdf2 =
|
|
348
|
-
await
|
|
408
|
+
[cdf2] = client.persist([cdf])
|
|
409
|
+
await cdf2
|
|
349
410
|
|
|
350
411
|
del cdf
|
|
351
412
|
gc.collect()
|
|
352
413
|
|
|
353
414
|
if enable_cudf_spill:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
0,
|
|
357
|
-
0,
|
|
358
|
-
0,
|
|
415
|
+
expected_data = WorkerDataSizes(
|
|
416
|
+
device_fast=0,
|
|
417
|
+
host_fast=0,
|
|
418
|
+
host_buffer=0,
|
|
419
|
+
disk=0,
|
|
420
|
+
device_overhead=0,
|
|
421
|
+
host_overhead=0,
|
|
422
|
+
disk_overhead=0,
|
|
359
423
|
)
|
|
424
|
+
|
|
425
|
+
deadline = Deadline.after(duration=3)
|
|
426
|
+
while not deadline.expired:
|
|
427
|
+
data = await client.run(
|
|
428
|
+
collect_device_host_file_size,
|
|
429
|
+
device_chunk_overhead=0,
|
|
430
|
+
serialized_chunk_overhead=0,
|
|
431
|
+
)
|
|
432
|
+
expected = {k: expected_data for k in data}
|
|
433
|
+
if data == expected:
|
|
434
|
+
break
|
|
435
|
+
sleep(0.01)
|
|
436
|
+
|
|
437
|
+
# final assertion for pytest to reraise with a nice traceback
|
|
438
|
+
assert data == expected
|
|
439
|
+
|
|
360
440
|
else:
|
|
361
441
|
await client.run(
|
|
362
442
|
assert_host_chunks,
|
|
@@ -419,8 +499,8 @@ async def test_cudf_spill_cluster(cudf_spill):
|
|
|
419
499
|
}
|
|
420
500
|
)
|
|
421
501
|
|
|
422
|
-
ddf = dask_cudf.from_cudf(cdf, npartitions=2).sum()
|
|
423
|
-
await
|
|
502
|
+
[ddf] = client.persist([dask_cudf.from_cudf(cdf, npartitions=2).sum()])
|
|
503
|
+
await ddf
|
|
424
504
|
|
|
425
505
|
await client.run(_assert_cudf_spill_stats, enable_cudf_spill)
|
|
426
506
|
_assert_cudf_spill_stats(enable_cudf_spill)
|
dask_cuda/tests/test_utils.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
1
4
|
import os
|
|
2
5
|
from unittest.mock import patch
|
|
3
6
|
|
|
@@ -15,11 +18,13 @@ from dask_cuda.utils import (
|
|
|
15
18
|
get_n_gpus,
|
|
16
19
|
get_preload_options,
|
|
17
20
|
get_ucx_config,
|
|
21
|
+
has_device_memory_resource,
|
|
18
22
|
nvml_device_index,
|
|
19
23
|
parse_cuda_visible_device,
|
|
20
24
|
parse_device_memory_limit,
|
|
21
25
|
unpack_bitmask,
|
|
22
26
|
)
|
|
27
|
+
from dask_cuda.utils_test import get_ucx_implementation
|
|
23
28
|
|
|
24
29
|
|
|
25
30
|
@patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1,2"})
|
|
@@ -76,19 +81,19 @@ def test_get_device_total_memory():
|
|
|
76
81
|
for i in range(get_n_gpus()):
|
|
77
82
|
with cuda.gpus[i]:
|
|
78
83
|
total_mem = get_device_total_memory(i)
|
|
79
|
-
|
|
80
|
-
|
|
84
|
+
if has_device_memory_resource():
|
|
85
|
+
assert type(total_mem) is int
|
|
86
|
+
assert total_mem > 0
|
|
87
|
+
else:
|
|
88
|
+
assert total_mem is None
|
|
81
89
|
|
|
82
90
|
|
|
83
91
|
@pytest.mark.parametrize(
|
|
84
92
|
"protocol",
|
|
85
|
-
["ucx", "
|
|
93
|
+
["ucx", "ucx-old"],
|
|
86
94
|
)
|
|
87
95
|
def test_get_preload_options_default(protocol):
|
|
88
|
-
|
|
89
|
-
pytest.importorskip("ucp")
|
|
90
|
-
elif protocol == "ucxx":
|
|
91
|
-
pytest.importorskip("ucxx")
|
|
96
|
+
get_ucx_implementation(protocol)
|
|
92
97
|
|
|
93
98
|
opts = get_preload_options(
|
|
94
99
|
protocol=protocol,
|
|
@@ -103,16 +108,13 @@ def test_get_preload_options_default(protocol):
|
|
|
103
108
|
|
|
104
109
|
@pytest.mark.parametrize(
|
|
105
110
|
"protocol",
|
|
106
|
-
["ucx", "
|
|
111
|
+
["ucx", "ucx-old"],
|
|
107
112
|
)
|
|
108
113
|
@pytest.mark.parametrize("enable_tcp", [True, False])
|
|
109
114
|
@pytest.mark.parametrize("enable_infiniband", [True, False])
|
|
110
115
|
@pytest.mark.parametrize("enable_nvlink", [True, False])
|
|
111
116
|
def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvlink):
|
|
112
|
-
|
|
113
|
-
pytest.importorskip("ucp")
|
|
114
|
-
elif protocol == "ucxx":
|
|
115
|
-
pytest.importorskip("ucxx")
|
|
117
|
+
get_ucx_implementation(protocol)
|
|
116
118
|
|
|
117
119
|
opts = get_preload_options(
|
|
118
120
|
protocol=protocol,
|
|
@@ -135,11 +137,17 @@ def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvl
|
|
|
135
137
|
assert "--enable-nvlink" in opts["preload_argv"]
|
|
136
138
|
|
|
137
139
|
|
|
140
|
+
@pytest.mark.parametrize(
|
|
141
|
+
"protocol",
|
|
142
|
+
["ucx", "ucx-old"],
|
|
143
|
+
)
|
|
138
144
|
@pytest.mark.parametrize("enable_tcp_over_ucx", [True, False, None])
|
|
139
145
|
@pytest.mark.parametrize("enable_nvlink", [True, False, None])
|
|
140
146
|
@pytest.mark.parametrize("enable_infiniband", [True, False, None])
|
|
141
|
-
def test_get_ucx_config(
|
|
142
|
-
|
|
147
|
+
def test_get_ucx_config(
|
|
148
|
+
protocol, enable_tcp_over_ucx, enable_infiniband, enable_nvlink
|
|
149
|
+
):
|
|
150
|
+
get_ucx_implementation(protocol)
|
|
143
151
|
|
|
144
152
|
kwargs = {
|
|
145
153
|
"enable_tcp_over_ucx": enable_tcp_over_ucx,
|
|
@@ -234,20 +242,98 @@ def test_parse_visible_devices():
|
|
|
234
242
|
parse_cuda_visible_device([])
|
|
235
243
|
|
|
236
244
|
|
|
245
|
+
def test_parse_device_bytes():
|
|
246
|
+
total = get_device_total_memory(0)
|
|
247
|
+
|
|
248
|
+
assert parse_device_memory_limit(None) is None
|
|
249
|
+
assert parse_device_memory_limit(0) is None
|
|
250
|
+
assert parse_device_memory_limit("0") is None
|
|
251
|
+
assert parse_device_memory_limit("0.0") is None
|
|
252
|
+
assert parse_device_memory_limit("0 GiB") is None
|
|
253
|
+
|
|
254
|
+
assert parse_device_memory_limit(1) == 1
|
|
255
|
+
assert parse_device_memory_limit("1") == 1
|
|
256
|
+
|
|
257
|
+
assert parse_device_memory_limit(1000000000) == 1000000000
|
|
258
|
+
assert parse_device_memory_limit("1GB") == 1000000000
|
|
259
|
+
|
|
260
|
+
if has_device_memory_resource(0):
|
|
261
|
+
assert parse_device_memory_limit(1.0) == total
|
|
262
|
+
assert parse_device_memory_limit("1.0") == total
|
|
263
|
+
|
|
264
|
+
assert parse_device_memory_limit(0.8) == int(total * 0.8)
|
|
265
|
+
assert parse_device_memory_limit(0.8, alignment_size=256) == int(
|
|
266
|
+
total * 0.8 // 256 * 256
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
assert parse_device_memory_limit("default") == parse_device_memory_limit(0.8)
|
|
270
|
+
else:
|
|
271
|
+
assert parse_device_memory_limit("default") is None
|
|
272
|
+
|
|
273
|
+
with pytest.raises(ValueError):
|
|
274
|
+
assert parse_device_memory_limit(1.0) == total
|
|
275
|
+
with pytest.raises(ValueError):
|
|
276
|
+
assert parse_device_memory_limit("1.0") == total
|
|
277
|
+
with pytest.raises(ValueError):
|
|
278
|
+
assert parse_device_memory_limit(0.8) == int(total * 0.8)
|
|
279
|
+
with pytest.raises(ValueError):
|
|
280
|
+
assert parse_device_memory_limit(0.8, alignment_size=256) == int(
|
|
281
|
+
total * 0.8 // 256 * 256
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
|
|
237
285
|
def test_parse_device_memory_limit():
|
|
238
286
|
total = get_device_total_memory(0)
|
|
239
287
|
|
|
240
|
-
assert parse_device_memory_limit(None)
|
|
241
|
-
assert parse_device_memory_limit(0)
|
|
288
|
+
assert parse_device_memory_limit(None) is None
|
|
289
|
+
assert parse_device_memory_limit(0) is None
|
|
290
|
+
assert parse_device_memory_limit("0") is None
|
|
291
|
+
assert parse_device_memory_limit(0.0) is None
|
|
292
|
+
assert parse_device_memory_limit("0 GiB") is None
|
|
293
|
+
|
|
294
|
+
assert parse_device_memory_limit(1) == 1
|
|
295
|
+
assert parse_device_memory_limit("1") == 1
|
|
296
|
+
|
|
242
297
|
assert parse_device_memory_limit("auto") == total
|
|
243
298
|
|
|
244
|
-
assert parse_device_memory_limit(0.8) == int(total * 0.8)
|
|
245
|
-
assert parse_device_memory_limit(0.8, alignment_size=256) == int(
|
|
246
|
-
total * 0.8 // 256 * 256
|
|
247
|
-
)
|
|
248
299
|
assert parse_device_memory_limit(1000000000) == 1000000000
|
|
249
300
|
assert parse_device_memory_limit("1GB") == 1000000000
|
|
250
301
|
|
|
302
|
+
if has_device_memory_resource(0):
|
|
303
|
+
assert parse_device_memory_limit(1.0) == total
|
|
304
|
+
assert parse_device_memory_limit("1.0") == total
|
|
305
|
+
|
|
306
|
+
assert parse_device_memory_limit(0.8) == int(total * 0.8)
|
|
307
|
+
assert parse_device_memory_limit(0.8, alignment_size=256) == int(
|
|
308
|
+
total * 0.8 // 256 * 256
|
|
309
|
+
)
|
|
310
|
+
assert parse_device_memory_limit("default") == parse_device_memory_limit(0.8)
|
|
311
|
+
else:
|
|
312
|
+
assert parse_device_memory_limit("default") is None
|
|
313
|
+
|
|
314
|
+
with pytest.raises(ValueError):
|
|
315
|
+
assert parse_device_memory_limit(1.0) == total
|
|
316
|
+
with pytest.raises(ValueError):
|
|
317
|
+
assert parse_device_memory_limit("1.0") == total
|
|
318
|
+
with pytest.raises(ValueError):
|
|
319
|
+
assert parse_device_memory_limit(0.8) == int(total * 0.8)
|
|
320
|
+
with pytest.raises(ValueError):
|
|
321
|
+
assert parse_device_memory_limit(0.8, alignment_size=256) == int(
|
|
322
|
+
total * 0.8 // 256 * 256
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def test_has_device_memory_resoure():
|
|
327
|
+
has_memory_resource = has_device_memory_resource()
|
|
328
|
+
total = get_device_total_memory(0)
|
|
329
|
+
|
|
330
|
+
if has_memory_resource:
|
|
331
|
+
# Tested only in devices with a memory resource
|
|
332
|
+
assert total == parse_device_memory_limit("auto")
|
|
333
|
+
else:
|
|
334
|
+
# Tested only in devices without a memory resource
|
|
335
|
+
assert total is None
|
|
336
|
+
|
|
251
337
|
|
|
252
338
|
def test_parse_visible_mig_devices():
|
|
253
339
|
pynvml.nvmlInit()
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
1
4
|
import pytest
|
|
2
5
|
|
|
3
6
|
from distributed import Nanny
|
|
@@ -28,7 +31,7 @@ def _check_env_value(spec, k, v):
|
|
|
28
31
|
@pytest.mark.parametrize("num_devices", [1, 4])
|
|
29
32
|
@pytest.mark.parametrize("cls", [Nanny])
|
|
30
33
|
@pytest.mark.parametrize("interface", [None, "eth0", "enp1s0f0"])
|
|
31
|
-
@pytest.mark.parametrize("protocol", [None, "tcp", "ucx"])
|
|
34
|
+
@pytest.mark.parametrize("protocol", [None, "tcp", "ucx", "ucx-old"])
|
|
32
35
|
@pytest.mark.parametrize("dashboard_address", [None, ":0", ":8787"])
|
|
33
36
|
@pytest.mark.parametrize("threads_per_worker", [1, 8])
|
|
34
37
|
@pytest.mark.parametrize("silence_logs", [False, True])
|
|
@@ -58,7 +61,7 @@ def test_worker_spec(
|
|
|
58
61
|
enable_nvlink=enable_nvlink,
|
|
59
62
|
)
|
|
60
63
|
|
|
61
|
-
if (enable_infiniband or enable_nvlink) and protocol
|
|
64
|
+
if (enable_infiniband or enable_nvlink) and protocol not in ("ucx", "ucx-old"):
|
|
62
65
|
with pytest.raises(
|
|
63
66
|
TypeError, match="Enabling InfiniBand or NVLink requires protocol='ucx'"
|
|
64
67
|
):
|