dask-cuda 25.6.0__py3-none-any.whl → 25.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_cuda/GIT_COMMIT +1 -1
- dask_cuda/VERSION +1 -1
- dask_cuda/benchmarks/common.py +4 -1
- dask_cuda/benchmarks/local_cudf_groupby.py +3 -0
- dask_cuda/benchmarks/local_cudf_merge.py +4 -1
- dask_cuda/benchmarks/local_cudf_shuffle.py +4 -1
- dask_cuda/benchmarks/local_cupy.py +3 -0
- dask_cuda/benchmarks/local_cupy_map_overlap.py +3 -0
- dask_cuda/benchmarks/utils.py +6 -3
- dask_cuda/cli.py +21 -15
- dask_cuda/cuda_worker.py +28 -58
- dask_cuda/device_host_file.py +31 -15
- dask_cuda/disk_io.py +7 -4
- dask_cuda/explicit_comms/comms.py +11 -7
- dask_cuda/explicit_comms/dataframe/shuffle.py +23 -23
- dask_cuda/get_device_memory_objects.py +4 -7
- dask_cuda/initialize.py +149 -94
- dask_cuda/local_cuda_cluster.py +52 -70
- dask_cuda/plugins.py +17 -16
- dask_cuda/proxify_device_objects.py +12 -10
- dask_cuda/proxify_host_file.py +30 -27
- dask_cuda/proxy_object.py +20 -17
- dask_cuda/tests/conftest.py +41 -0
- dask_cuda/tests/test_cudf_builtin_spilling.py +3 -1
- dask_cuda/tests/test_dask_cuda_worker.py +109 -25
- dask_cuda/tests/test_dask_setup.py +193 -0
- dask_cuda/tests/test_dgx.py +20 -44
- dask_cuda/tests/test_explicit_comms.py +31 -12
- dask_cuda/tests/test_from_array.py +4 -6
- dask_cuda/tests/test_initialize.py +233 -65
- dask_cuda/tests/test_local_cuda_cluster.py +129 -68
- dask_cuda/tests/test_proxify_host_file.py +28 -7
- dask_cuda/tests/test_proxy.py +15 -13
- dask_cuda/tests/test_spill.py +10 -3
- dask_cuda/tests/test_utils.py +100 -29
- dask_cuda/tests/test_worker_spec.py +6 -0
- dask_cuda/utils.py +211 -42
- dask_cuda/utils_test.py +10 -7
- dask_cuda/worker_common.py +196 -0
- dask_cuda/worker_spec.py +6 -1
- {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/METADATA +11 -4
- dask_cuda-25.10.0.dist-info/RECORD +63 -0
- dask_cuda-25.10.0.dist-info/top_level.txt +6 -0
- shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
- shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
- shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
- dask_cuda-25.6.0.dist-info/RECORD +0 -57
- dask_cuda-25.6.0.dist-info/top_level.txt +0 -4
- {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/WHEEL +0 -0
- {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/entry_points.txt +0 -0
- {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,15 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
1
4
|
import multiprocessing as mp
|
|
5
|
+
import os
|
|
6
|
+
import shutil
|
|
7
|
+
import subprocess
|
|
2
8
|
import sys
|
|
9
|
+
import tempfile
|
|
10
|
+
import textwrap
|
|
3
11
|
|
|
12
|
+
import cuda.core.experimental
|
|
4
13
|
import numpy
|
|
5
14
|
import psutil
|
|
6
15
|
import pytest
|
|
@@ -21,22 +30,19 @@ mp = mp.get_context("spawn") # type: ignore
|
|
|
21
30
|
# of UCX before retrieving the current config.
|
|
22
31
|
|
|
23
32
|
|
|
24
|
-
def _test_initialize_ucx_tcp(
|
|
25
|
-
|
|
26
|
-
ucp = pytest.importorskip("ucp")
|
|
27
|
-
elif protocol == "ucxx":
|
|
28
|
-
ucp = pytest.importorskip("ucxx")
|
|
33
|
+
def _test_initialize_ucx_tcp():
|
|
34
|
+
ucxx = pytest.importorskip("ucxx")
|
|
29
35
|
|
|
30
36
|
kwargs = {"enable_tcp_over_ucx": True}
|
|
31
|
-
initialize(
|
|
37
|
+
initialize(**kwargs)
|
|
32
38
|
with LocalCluster(
|
|
33
|
-
protocol=
|
|
39
|
+
protocol="ucx",
|
|
34
40
|
dashboard_address=None,
|
|
35
41
|
n_workers=1,
|
|
36
42
|
threads_per_worker=1,
|
|
37
43
|
processes=True,
|
|
38
44
|
worker_class=IncreasedCloseTimeoutNanny,
|
|
39
|
-
config={"distributed
|
|
45
|
+
config={"distributed-ucxx": get_ucx_config(**kwargs)},
|
|
40
46
|
) as cluster:
|
|
41
47
|
with Client(cluster) as client:
|
|
42
48
|
res = da.from_array(numpy.arange(10000), chunks=(1000,))
|
|
@@ -44,7 +50,7 @@ def _test_initialize_ucx_tcp(protocol):
|
|
|
44
50
|
assert res == 49995000
|
|
45
51
|
|
|
46
52
|
def check_ucx_options():
|
|
47
|
-
conf =
|
|
53
|
+
conf = ucxx.get_config()
|
|
48
54
|
assert "TLS" in conf
|
|
49
55
|
assert "tcp" in conf["TLS"]
|
|
50
56
|
assert "cuda_copy" in conf["TLS"]
|
|
@@ -55,35 +61,28 @@ def _test_initialize_ucx_tcp(protocol):
|
|
|
55
61
|
assert all(client.run(check_ucx_options).values())
|
|
56
62
|
|
|
57
63
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
if protocol == "ucx":
|
|
61
|
-
pytest.importorskip("ucp")
|
|
62
|
-
elif protocol == "ucxx":
|
|
63
|
-
pytest.importorskip("ucxx")
|
|
64
|
+
def test_initialize_ucx_tcp():
|
|
65
|
+
pytest.importorskip("distributed_ucxx")
|
|
64
66
|
|
|
65
|
-
p = mp.Process(target=_test_initialize_ucx_tcp
|
|
67
|
+
p = mp.Process(target=_test_initialize_ucx_tcp)
|
|
66
68
|
p.start()
|
|
67
69
|
p.join()
|
|
68
70
|
assert not p.exitcode
|
|
69
71
|
|
|
70
72
|
|
|
71
|
-
def _test_initialize_ucx_nvlink(
|
|
72
|
-
|
|
73
|
-
ucp = pytest.importorskip("ucp")
|
|
74
|
-
elif protocol == "ucxx":
|
|
75
|
-
ucp = pytest.importorskip("ucxx")
|
|
73
|
+
def _test_initialize_ucx_nvlink():
|
|
74
|
+
ucxx = pytest.importorskip("ucxx")
|
|
76
75
|
|
|
77
76
|
kwargs = {"enable_nvlink": True}
|
|
78
|
-
initialize(
|
|
77
|
+
initialize(**kwargs)
|
|
79
78
|
with LocalCluster(
|
|
80
|
-
protocol=
|
|
79
|
+
protocol="ucx",
|
|
81
80
|
dashboard_address=None,
|
|
82
81
|
n_workers=1,
|
|
83
82
|
threads_per_worker=1,
|
|
84
83
|
processes=True,
|
|
85
84
|
worker_class=IncreasedCloseTimeoutNanny,
|
|
86
|
-
config={"distributed
|
|
85
|
+
config={"distributed-ucxx": get_ucx_config(**kwargs)},
|
|
87
86
|
) as cluster:
|
|
88
87
|
with Client(cluster) as client:
|
|
89
88
|
res = da.from_array(numpy.arange(10000), chunks=(1000,))
|
|
@@ -91,7 +90,7 @@ def _test_initialize_ucx_nvlink(protocol):
|
|
|
91
90
|
assert res == 49995000
|
|
92
91
|
|
|
93
92
|
def check_ucx_options():
|
|
94
|
-
conf =
|
|
93
|
+
conf = ucxx.get_config()
|
|
95
94
|
assert "TLS" in conf
|
|
96
95
|
assert "cuda_ipc" in conf["TLS"]
|
|
97
96
|
assert "tcp" in conf["TLS"]
|
|
@@ -103,35 +102,28 @@ def _test_initialize_ucx_nvlink(protocol):
|
|
|
103
102
|
assert all(client.run(check_ucx_options).values())
|
|
104
103
|
|
|
105
104
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
if protocol == "ucx":
|
|
109
|
-
pytest.importorskip("ucp")
|
|
110
|
-
elif protocol == "ucxx":
|
|
111
|
-
pytest.importorskip("ucxx")
|
|
105
|
+
def test_initialize_ucx_nvlink():
|
|
106
|
+
pytest.importorskip("distributed_ucxx")
|
|
112
107
|
|
|
113
|
-
p = mp.Process(target=_test_initialize_ucx_nvlink
|
|
108
|
+
p = mp.Process(target=_test_initialize_ucx_nvlink)
|
|
114
109
|
p.start()
|
|
115
110
|
p.join()
|
|
116
111
|
assert not p.exitcode
|
|
117
112
|
|
|
118
113
|
|
|
119
|
-
def _test_initialize_ucx_infiniband(
|
|
120
|
-
|
|
121
|
-
ucp = pytest.importorskip("ucp")
|
|
122
|
-
elif protocol == "ucxx":
|
|
123
|
-
ucp = pytest.importorskip("ucxx")
|
|
114
|
+
def _test_initialize_ucx_infiniband():
|
|
115
|
+
ucxx = pytest.importorskip("ucxx")
|
|
124
116
|
|
|
125
117
|
kwargs = {"enable_infiniband": True}
|
|
126
|
-
initialize(
|
|
118
|
+
initialize(**kwargs)
|
|
127
119
|
with LocalCluster(
|
|
128
|
-
protocol=
|
|
120
|
+
protocol="ucx",
|
|
129
121
|
dashboard_address=None,
|
|
130
122
|
n_workers=1,
|
|
131
123
|
threads_per_worker=1,
|
|
132
124
|
processes=True,
|
|
133
125
|
worker_class=IncreasedCloseTimeoutNanny,
|
|
134
|
-
config={"distributed
|
|
126
|
+
config={"distributed-ucxx": get_ucx_config(**kwargs)},
|
|
135
127
|
) as cluster:
|
|
136
128
|
with Client(cluster) as client:
|
|
137
129
|
res = da.from_array(numpy.arange(10000), chunks=(1000,))
|
|
@@ -139,7 +131,7 @@ def _test_initialize_ucx_infiniband(protocol):
|
|
|
139
131
|
assert res == 49995000
|
|
140
132
|
|
|
141
133
|
def check_ucx_options():
|
|
142
|
-
conf =
|
|
134
|
+
conf = ucxx.get_config()
|
|
143
135
|
assert "TLS" in conf
|
|
144
136
|
assert "rc" in conf["TLS"]
|
|
145
137
|
assert "tcp" in conf["TLS"]
|
|
@@ -154,34 +146,27 @@ def _test_initialize_ucx_infiniband(protocol):
|
|
|
154
146
|
@pytest.mark.skipif(
|
|
155
147
|
"ib0" not in psutil.net_if_addrs(), reason="Infiniband interface ib0 not found"
|
|
156
148
|
)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
elif protocol == "ucxx":
|
|
162
|
-
pytest.importorskip("ucxx")
|
|
163
|
-
|
|
164
|
-
p = mp.Process(target=_test_initialize_ucx_infiniband, args=(protocol,))
|
|
149
|
+
def test_initialize_ucx_infiniband():
|
|
150
|
+
pytest.importorskip("distributed_ucxx")
|
|
151
|
+
|
|
152
|
+
p = mp.Process(target=_test_initialize_ucx_infiniband)
|
|
165
153
|
p.start()
|
|
166
154
|
p.join()
|
|
167
155
|
assert not p.exitcode
|
|
168
156
|
|
|
169
157
|
|
|
170
|
-
def _test_initialize_ucx_all(
|
|
171
|
-
|
|
172
|
-
ucp = pytest.importorskip("ucp")
|
|
173
|
-
elif protocol == "ucxx":
|
|
174
|
-
ucp = pytest.importorskip("ucxx")
|
|
158
|
+
def _test_initialize_ucx_all():
|
|
159
|
+
ucxx = pytest.importorskip("ucxx")
|
|
175
160
|
|
|
176
|
-
initialize(
|
|
161
|
+
initialize()
|
|
177
162
|
with LocalCluster(
|
|
178
|
-
protocol=
|
|
163
|
+
protocol="ucx",
|
|
179
164
|
dashboard_address=None,
|
|
180
165
|
n_workers=1,
|
|
181
166
|
threads_per_worker=1,
|
|
182
167
|
processes=True,
|
|
183
168
|
worker_class=IncreasedCloseTimeoutNanny,
|
|
184
|
-
config={"distributed
|
|
169
|
+
config={"distributed-ucxx": get_ucx_config()},
|
|
185
170
|
) as cluster:
|
|
186
171
|
with Client(cluster) as client:
|
|
187
172
|
res = da.from_array(numpy.arange(10000), chunks=(1000,))
|
|
@@ -189,7 +174,7 @@ def _test_initialize_ucx_all(protocol):
|
|
|
189
174
|
assert res == 49995000
|
|
190
175
|
|
|
191
176
|
def check_ucx_options():
|
|
192
|
-
conf =
|
|
177
|
+
conf = ucxx.get_config()
|
|
193
178
|
assert "TLS" in conf
|
|
194
179
|
assert conf["TLS"] == "all"
|
|
195
180
|
assert all(
|
|
@@ -204,14 +189,10 @@ def _test_initialize_ucx_all(protocol):
|
|
|
204
189
|
assert all(client.run(check_ucx_options).values())
|
|
205
190
|
|
|
206
191
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
if protocol == "ucx":
|
|
210
|
-
pytest.importorskip("ucp")
|
|
211
|
-
elif protocol == "ucxx":
|
|
212
|
-
pytest.importorskip("ucxx")
|
|
192
|
+
def test_initialize_ucx_all():
|
|
193
|
+
pytest.importorskip("distributed_ucxx")
|
|
213
194
|
|
|
214
|
-
p = mp.Process(target=_test_initialize_ucx_all
|
|
195
|
+
p = mp.Process(target=_test_initialize_ucx_all)
|
|
215
196
|
p.start()
|
|
216
197
|
p.join()
|
|
217
198
|
assert not p.exitcode
|
|
@@ -250,3 +231,190 @@ def test_dask_cuda_import():
|
|
|
250
231
|
p.start()
|
|
251
232
|
p.join()
|
|
252
233
|
assert not p.exitcode
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _test_cuda_context_warning_with_subprocess_warnings(protocol):
|
|
237
|
+
"""Test CUDA context warnings from both parent and worker subprocesses.
|
|
238
|
+
|
|
239
|
+
This test creates a standalone script that imports a problematic library
|
|
240
|
+
and creates LocalCUDACluster with processes=True. This should generate
|
|
241
|
+
warnings from both the parent process and each worker subprocess, since
|
|
242
|
+
they all inherit the CUDA context created at import time.
|
|
243
|
+
"""
|
|
244
|
+
# Create temporary directory for our test files
|
|
245
|
+
temp_dir = tempfile.mkdtemp()
|
|
246
|
+
|
|
247
|
+
# Create the problematic library that creates CUDA context at import
|
|
248
|
+
problematic_library_code = textwrap.dedent(
|
|
249
|
+
"""
|
|
250
|
+
# Problematic library that creates CUDA context at import time
|
|
251
|
+
import os
|
|
252
|
+
|
|
253
|
+
import cuda.core.experimental
|
|
254
|
+
|
|
255
|
+
try:
|
|
256
|
+
# Create CUDA context at import time, this will be inherited by subprocesses
|
|
257
|
+
cuda.core.experimental.Device().set_current()
|
|
258
|
+
print("Problematic library: Created CUDA context at import time")
|
|
259
|
+
os.environ['SUBPROCESS_CUDA_CONTEXT_CREATED'] = '1'
|
|
260
|
+
except Exception as e:
|
|
261
|
+
raise RuntimeError(
|
|
262
|
+
f"Problematic library: Failed to create CUDA context({e})"
|
|
263
|
+
)
|
|
264
|
+
os.environ['SUBPROCESS_CUDA_CONTEXT_CREATED'] = '0'
|
|
265
|
+
"""
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
problematic_lib_path = os.path.join(temp_dir, "problematic_cuda_library.py")
|
|
269
|
+
with open(problematic_lib_path, "w") as f:
|
|
270
|
+
f.write(problematic_library_code)
|
|
271
|
+
|
|
272
|
+
# Create the main test script that imports the problematic library
|
|
273
|
+
# and creates LocalCUDACluster - this will run in a subprocess
|
|
274
|
+
main_script_code = textwrap.dedent(
|
|
275
|
+
f"""
|
|
276
|
+
# Main script that demonstrates the real-world problematic scenario
|
|
277
|
+
import os
|
|
278
|
+
import sys
|
|
279
|
+
import logging
|
|
280
|
+
|
|
281
|
+
# Add the temp directory to path so we can import our problematic library
|
|
282
|
+
sys.path.insert(0, '{temp_dir}')
|
|
283
|
+
|
|
284
|
+
print("=== Starting subprocess warnings test ===")
|
|
285
|
+
|
|
286
|
+
# This is the key part: import the problematic library BEFORE creating
|
|
287
|
+
# LocalCUDACluster. This creates a CUDA context that will be inherited
|
|
288
|
+
# by all worker subprocesses
|
|
289
|
+
print("Importing problematic library...")
|
|
290
|
+
import problematic_cuda_library
|
|
291
|
+
|
|
292
|
+
context_mode = os.environ.get('SUBPROCESS_CUDA_CONTEXT_CREATED', None)
|
|
293
|
+
if context_mode == "1":
|
|
294
|
+
print(f"Context creation successful")
|
|
295
|
+
else:
|
|
296
|
+
raise RuntimeError("Context creation failed")
|
|
297
|
+
|
|
298
|
+
if __name__ == "__main__":
|
|
299
|
+
try:
|
|
300
|
+
from dask_cuda import LocalCUDACluster
|
|
301
|
+
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
|
|
302
|
+
|
|
303
|
+
cluster = LocalCUDACluster(
|
|
304
|
+
dashboard_address=None,
|
|
305
|
+
worker_class=IncreasedCloseTimeoutNanny,
|
|
306
|
+
protocol=f"{protocol}",
|
|
307
|
+
)
|
|
308
|
+
print("LocalCUDACluster created successfully!")
|
|
309
|
+
|
|
310
|
+
cluster.close()
|
|
311
|
+
print("Cluster closed successfully")
|
|
312
|
+
|
|
313
|
+
except Exception as e:
|
|
314
|
+
raise RuntimeError(f"Cluster setup error: {{e}}")
|
|
315
|
+
|
|
316
|
+
print("=== Subprocess warnings test completed ===")
|
|
317
|
+
"""
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
main_script_path = os.path.join(temp_dir, "test_subprocess_warnings.py")
|
|
321
|
+
with open(main_script_path, "w") as f:
|
|
322
|
+
f.write(main_script_code)
|
|
323
|
+
|
|
324
|
+
try:
|
|
325
|
+
# Run the main script in a subprocess
|
|
326
|
+
result = subprocess.run(
|
|
327
|
+
[sys.executable, main_script_path],
|
|
328
|
+
capture_output=True,
|
|
329
|
+
text=True,
|
|
330
|
+
timeout=30, # Reduced timeout for simpler test
|
|
331
|
+
cwd=os.getcwd(),
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Check for successful test execution regardless of warnings
|
|
335
|
+
assert (
|
|
336
|
+
"Context creation successful" in result.stdout
|
|
337
|
+
), "Test did not create a CUDA context"
|
|
338
|
+
assert (
|
|
339
|
+
"Creating LocalCUDACluster" in result.stdout
|
|
340
|
+
or "LocalCUDACluster created successfully" in result.stdout
|
|
341
|
+
), "LocalCUDACluster was not created"
|
|
342
|
+
|
|
343
|
+
# Check the log file for warnings from multiple processes
|
|
344
|
+
warnings_found = []
|
|
345
|
+
warnings_assigned_device_found = []
|
|
346
|
+
|
|
347
|
+
# Look for CUDA context warnings from different processes
|
|
348
|
+
lines = result.stderr.split("\n")
|
|
349
|
+
for line in lines:
|
|
350
|
+
if "A CUDA context for device" in line and "already exists" in line:
|
|
351
|
+
warnings_found.append(line)
|
|
352
|
+
if (
|
|
353
|
+
"should have a CUDA context assigned to device" in line
|
|
354
|
+
and "but instead the CUDA context is on device" in line
|
|
355
|
+
):
|
|
356
|
+
warnings_assigned_device_found.append(line)
|
|
357
|
+
|
|
358
|
+
num_devices = cuda.core.experimental.system.num_devices
|
|
359
|
+
|
|
360
|
+
# Every worker raises the warning once. With protocol="ucx" the warning is
|
|
361
|
+
# raised once more by the parent process.
|
|
362
|
+
expected_warnings = num_devices if protocol == "tcp" else num_devices + 1
|
|
363
|
+
assert len(warnings_found) == expected_warnings, (
|
|
364
|
+
f"Expected {expected_warnings} CUDA context warnings, "
|
|
365
|
+
f"but found {len(warnings_assigned_device_found)}"
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# Can only be tested in multi-GPU test environment, device 0 can never raise
|
|
369
|
+
# this warning (because it's where all CUDA contexts are created), thus one
|
|
370
|
+
# warning is raised by every device except 0.
|
|
371
|
+
expected_assigned_device_warnings = num_devices - 1
|
|
372
|
+
assert (
|
|
373
|
+
len(warnings_assigned_device_found) == expected_assigned_device_warnings
|
|
374
|
+
), (
|
|
375
|
+
f"Expected {expected_assigned_device_warnings} warnings assigned to "
|
|
376
|
+
f"device, but found {len(warnings_assigned_device_found)}"
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Verify warnings contents
|
|
380
|
+
for warning in warnings_found:
|
|
381
|
+
assert (
|
|
382
|
+
"This is often the result of a CUDA-enabled library calling a "
|
|
383
|
+
"CUDA runtime function before Dask-CUDA" in warning
|
|
384
|
+
), f"Warning missing explanatory text: {warning}"
|
|
385
|
+
for warning in warnings_assigned_device_found:
|
|
386
|
+
assert (
|
|
387
|
+
"This is often the result of a CUDA-enabled library calling a "
|
|
388
|
+
"CUDA runtime function before Dask-CUDA" in warning
|
|
389
|
+
), f"Warning missing explanatory text: {warning}"
|
|
390
|
+
finally:
|
|
391
|
+
# Clean up temporary files
|
|
392
|
+
try:
|
|
393
|
+
if os.path.exists(temp_dir):
|
|
394
|
+
shutil.rmtree(temp_dir)
|
|
395
|
+
except Exception as e:
|
|
396
|
+
print(f"Cleanup error: {e}")
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
|
|
400
|
+
def test_cuda_context_warning_with_subprocess_warnings(protocol):
|
|
401
|
+
"""Test CUDA context warnings from parent and worker subprocesses.
|
|
402
|
+
|
|
403
|
+
This test creates a standalone script that imports a problematic library at the top
|
|
404
|
+
level and then creates LocalCUDACluster with processes=True. This replicates the
|
|
405
|
+
exact real-world scenario where:
|
|
406
|
+
|
|
407
|
+
1. User imports a problematic library that creates CUDA context at import time
|
|
408
|
+
2. User creates LocalCUDACluster with multiple workers
|
|
409
|
+
3. Each worker subprocess inherits the CUDA context and emits warnings
|
|
410
|
+
4. Multiple warnings are generated (parent process + each worker subprocess)
|
|
411
|
+
|
|
412
|
+
This is the ultimate test as it demonstrates the distributed warning scenario
|
|
413
|
+
that users actually encounter in production.
|
|
414
|
+
"""
|
|
415
|
+
p = mp.Process(
|
|
416
|
+
target=_test_cuda_context_warning_with_subprocess_warnings, args=(protocol,)
|
|
417
|
+
)
|
|
418
|
+
p.start()
|
|
419
|
+
p.join()
|
|
420
|
+
assert not p.exitcode
|