dask-cuda 25.6.0__py3-none-any.whl → 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/benchmarks/common.py +4 -1
  4. dask_cuda/benchmarks/local_cudf_groupby.py +3 -0
  5. dask_cuda/benchmarks/local_cudf_merge.py +4 -1
  6. dask_cuda/benchmarks/local_cudf_shuffle.py +4 -1
  7. dask_cuda/benchmarks/local_cupy.py +3 -0
  8. dask_cuda/benchmarks/local_cupy_map_overlap.py +3 -0
  9. dask_cuda/benchmarks/utils.py +6 -3
  10. dask_cuda/cli.py +21 -15
  11. dask_cuda/cuda_worker.py +28 -58
  12. dask_cuda/device_host_file.py +31 -15
  13. dask_cuda/disk_io.py +7 -4
  14. dask_cuda/explicit_comms/comms.py +11 -7
  15. dask_cuda/explicit_comms/dataframe/shuffle.py +23 -23
  16. dask_cuda/get_device_memory_objects.py +4 -7
  17. dask_cuda/initialize.py +149 -94
  18. dask_cuda/local_cuda_cluster.py +52 -70
  19. dask_cuda/plugins.py +17 -16
  20. dask_cuda/proxify_device_objects.py +12 -10
  21. dask_cuda/proxify_host_file.py +30 -27
  22. dask_cuda/proxy_object.py +20 -17
  23. dask_cuda/tests/conftest.py +41 -0
  24. dask_cuda/tests/test_cudf_builtin_spilling.py +3 -1
  25. dask_cuda/tests/test_dask_cuda_worker.py +109 -25
  26. dask_cuda/tests/test_dask_setup.py +193 -0
  27. dask_cuda/tests/test_dgx.py +20 -44
  28. dask_cuda/tests/test_explicit_comms.py +31 -12
  29. dask_cuda/tests/test_from_array.py +4 -6
  30. dask_cuda/tests/test_initialize.py +233 -65
  31. dask_cuda/tests/test_local_cuda_cluster.py +129 -68
  32. dask_cuda/tests/test_proxify_host_file.py +28 -7
  33. dask_cuda/tests/test_proxy.py +15 -13
  34. dask_cuda/tests/test_spill.py +10 -3
  35. dask_cuda/tests/test_utils.py +100 -29
  36. dask_cuda/tests/test_worker_spec.py +6 -0
  37. dask_cuda/utils.py +211 -42
  38. dask_cuda/utils_test.py +10 -7
  39. dask_cuda/worker_common.py +196 -0
  40. dask_cuda/worker_spec.py +6 -1
  41. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/METADATA +11 -4
  42. dask_cuda-25.10.0.dist-info/RECORD +63 -0
  43. dask_cuda-25.10.0.dist-info/top_level.txt +6 -0
  44. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  45. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  46. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  47. dask_cuda-25.6.0.dist-info/RECORD +0 -57
  48. dask_cuda-25.6.0.dist-info/top_level.txt +0 -4
  49. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/WHEEL +0 -0
  50. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/entry_points.txt +0 -0
  51. {dask_cuda-25.6.0.dist-info → dask_cuda-25.10.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,193 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import json
5
+ import os
6
+ import time
7
+ from contextlib import contextmanager
8
+ from unittest.mock import Mock, patch
9
+
10
+ import pytest
11
+
12
+ from distributed import Client
13
+ from distributed.utils import open_port
14
+ from distributed.utils_test import popen
15
+
16
+ from dask_cuda.initialize import dask_setup
17
+ from dask_cuda.utils import wait_workers
18
+
19
+
20
+ def test_dask_setup_function_with_mock_worker():
21
+ """Test the dask_setup function directly with mock worker."""
22
+ # Create a mock worker object
23
+ mock_worker = Mock()
24
+
25
+ with patch("dask_cuda.initialize._create_cuda_context") as mock_create_context:
26
+ # Test with create_cuda_context=True
27
+ # Call the underlying function directly (the Click decorator wraps the real
28
+ # function)
29
+ dask_setup.callback(
30
+ worker=mock_worker,
31
+ create_cuda_context=True,
32
+ )
33
+
34
+ mock_create_context.assert_called_once_with()
35
+
36
+ mock_create_context.reset_mock()
37
+
38
+ # Test with create_cuda_context=False
39
+ dask_setup.callback(
40
+ worker=mock_worker,
41
+ create_cuda_context=False,
42
+ )
43
+
44
+ mock_create_context.assert_not_called()
45
+
46
+
47
+ @contextmanager
48
+ def start_dask_scheduler(protocol: str, max_attempts: int = 5, timeout: int = 10):
49
+ """Start Dask scheduler in subprocess.
50
+
51
+ Attempts to start a Dask scheduler in subprocess, if the port is not available
52
+ retry on a different port up to a maximum of `max_attempts` attempts. The stdout
53
+ and stderr of the process is read to determine whether the scheduler failed to
54
+ bind to port or succeeded, and ensures no more than `timeout` seconds are awaited
55
+ for between reads.
56
+
57
+ This is primarily useful because UCX does not release TCP ports immediately. A
58
+ workaround without the need for this function is setting `UCX_TCP_CM_REUSEADDR=y`,
59
+ but that requires to be explicitly set when running tests, and that is not very
60
+ friendly.
61
+
62
+ Parameters
63
+ ----------
64
+ protocol: str
65
+ Communication protocol to use.
66
+ max_attempts: int
67
+ Maximum attempts to try to open scheduler.
68
+ timeout: int
69
+ Time to wait while reading stdout/stderr of subprocess.
70
+ """
71
+ port = open_port()
72
+ for _ in range(max_attempts):
73
+ with popen(
74
+ [
75
+ "dask",
76
+ "scheduler",
77
+ "--no-dashboard",
78
+ "--protocol",
79
+ protocol,
80
+ "--port",
81
+ str(port),
82
+ ],
83
+ capture_output=True, # Capture stdout and stderr
84
+ ) as scheduler_process:
85
+ # Check if the scheduler process started successfully by streaming output
86
+ try:
87
+ start_time = time.monotonic()
88
+ while True:
89
+ if time.monotonic() - start_time > timeout:
90
+ raise TimeoutError("Timeout while waiting for scheduler output")
91
+
92
+ line = scheduler_process.stdout.readline()
93
+ if not line:
94
+ break # End of output
95
+ print(
96
+ line.decode(), end=""
97
+ ) # Since capture_output=True, print the line here
98
+ if b"Scheduler at:" in line:
99
+ # Scheduler is now listening
100
+ break
101
+ elif b"UCXXBusyError" in line:
102
+ raise Exception("UCXXBusyError detected in scheduler output")
103
+ except Exception:
104
+ port += 1
105
+ else:
106
+ yield scheduler_process, port
107
+ return
108
+ else:
109
+ pytest.fail(f"Failed to start dask scheduler after {max_attempts} attempts.")
110
+
111
+
112
+ @pytest.mark.timeout(30)
113
+ @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
114
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
115
+ def test_dask_cuda_worker_cli_integration(protocol, tmp_path):
116
+ """Test that dask cuda worker CLI correctly passes arguments to dask_setup.
117
+
118
+ Verifies the end-to-end integration where the CLI tool actually launches and calls
119
+ dask_setup with correct args.
120
+ """
121
+
122
+ # Use pytest's tmp_path for file management
123
+ capture_file_path = tmp_path / "dask_setup_integration_test.json"
124
+ preload_file = tmp_path / "preload_capture.py"
125
+
126
+ # Write the preload script to tmp_path
127
+ preload_file.write_text(
128
+ f'''
129
+ import json
130
+ import os
131
+
132
+ def capture_dask_setup_call(worker, create_cuda_context):
133
+ """Capture dask_setup arguments and write to file."""
134
+ result = {{
135
+ 'worker_protocol': getattr(worker, '_protocol', 'unknown'),
136
+ 'create_cuda_context': create_cuda_context,
137
+ 'test_success': True
138
+ }}
139
+
140
+ # Write immediately to ensure it gets captured
141
+ with open(r"{capture_file_path}", 'w') as f:
142
+ json.dump(result, f)
143
+
144
+ # Patch dask_setup callback
145
+ from dask_cuda.initialize import dask_setup
146
+ dask_setup.callback = capture_dask_setup_call
147
+ '''
148
+ )
149
+
150
+ with start_dask_scheduler(protocol=protocol) as scheduler_process_port:
151
+ scheduler_process, scheduler_port = scheduler_process_port
152
+ sched_addr = f"{protocol}://127.0.0.1:{scheduler_port}"
153
+ print(f"{sched_addr=}", flush=True)
154
+
155
+ # Build dask cuda worker args
156
+ dask_cuda_worker_args = [
157
+ "dask",
158
+ "cuda",
159
+ "worker",
160
+ sched_addr,
161
+ "--host",
162
+ "127.0.0.1",
163
+ "--no-dashboard",
164
+ "--preload",
165
+ str(preload_file),
166
+ "--death-timeout",
167
+ "10",
168
+ ]
169
+
170
+ with popen(dask_cuda_worker_args):
171
+ # Wait and check for worker connection
172
+ with Client(sched_addr) as client:
173
+ assert wait_workers(client, n_gpus=1)
174
+
175
+ # Check if dask_setup was called and captured correctly
176
+ if capture_file_path.exists():
177
+ with open(capture_file_path, "r") as cf:
178
+ captured_args = json.load(cf)
179
+
180
+ # Verify the critical arguments were passed correctly
181
+ assert (
182
+ captured_args["create_cuda_context"] is True
183
+ ), "create_cuda_context should be True"
184
+
185
+ # Verify worker has a protocol set
186
+ assert (
187
+ captured_args["worker_protocol"] == protocol
188
+ ), "Worker should have a protocol"
189
+ else:
190
+ pytest.fail(
191
+ "capture file not found: dask_setup was not called or "
192
+ "failed to write to file"
193
+ )
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import multiprocessing as mp
2
5
  import os
3
6
  from enum import Enum, auto
@@ -15,10 +18,6 @@ mp = mp.get_context("spawn") # type: ignore
15
18
  psutil = pytest.importorskip("psutil")
16
19
 
17
20
 
18
- def _is_ucx_116(ucp):
19
- return ucp.get_ucx_version()[:2] == (1, 16)
20
-
21
-
22
21
  class DGXVersion(Enum):
23
22
  DGX_1 = auto()
24
23
  DGX_2 = auto()
@@ -77,20 +76,17 @@ def test_default():
77
76
  assert not p.exitcode
78
77
 
79
78
 
80
- def _test_tcp_over_ucx(protocol):
81
- if protocol == "ucx":
82
- ucp = pytest.importorskip("ucp")
83
- elif protocol == "ucxx":
84
- ucp = pytest.importorskip("ucxx")
79
+ def _test_tcp_over_ucx():
80
+ ucxx = pytest.importorskip("ucxx")
85
81
 
86
- with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster:
82
+ with LocalCUDACluster(protocol="ucx", enable_tcp_over_ucx=True) as cluster:
87
83
  with Client(cluster) as client:
88
84
  res = da.from_array(numpy.arange(10000), chunks=(1000,))
89
85
  res = res.sum().compute()
90
86
  assert res == 49995000
91
87
 
92
88
  def check_ucx_options():
93
- conf = ucp.get_config()
89
+ conf = ucxx.get_config()
94
90
  assert "TLS" in conf
95
91
  assert "tcp" in conf["TLS"]
96
92
  assert "cuda_copy" in conf["TLS"]
@@ -100,19 +96,10 @@ def _test_tcp_over_ucx(protocol):
100
96
  assert all(client.run(check_ucx_options).values())
101
97
 
102
98
 
103
- @pytest.mark.parametrize(
104
- "protocol",
105
- ["ucx", "ucxx"],
106
- )
107
- def test_tcp_over_ucx(protocol):
108
- if protocol == "ucx":
109
- ucp = pytest.importorskip("ucp")
110
- elif protocol == "ucxx":
111
- ucp = pytest.importorskip("ucxx")
112
- if _is_ucx_116(ucp):
113
- pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
114
-
115
- p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,))
99
+ def test_tcp_over_ucx():
100
+ pytest.importorskip("distributed_ucxx")
101
+
102
+ p = mp.Process(target=_test_tcp_over_ucx)
116
103
  p.start()
117
104
  p.join()
118
105
  assert not p.exitcode
@@ -134,25 +121,22 @@ def test_tcp_only():
134
121
 
135
122
 
136
123
  def _test_ucx_infiniband_nvlink(
137
- skip_queue, protocol, enable_infiniband, enable_nvlink, enable_rdmacm
124
+ skip_queue, enable_infiniband, enable_nvlink, enable_rdmacm
138
125
  ):
126
+ ucxx = pytest.importorskip("ucxx")
139
127
  cupy = pytest.importorskip("cupy")
140
- if protocol == "ucx":
141
- ucp = pytest.importorskip("ucp")
142
- elif protocol == "ucxx":
143
- ucp = pytest.importorskip("ucxx")
144
128
 
145
129
  if enable_infiniband and not any(
146
- [at.startswith("rc") for at in ucp.get_active_transports()]
130
+ [at.startswith("rc") for at in ucxx.get_active_transports()]
147
131
  ):
148
132
  skip_queue.put("No support available for 'rc' transport in UCX")
149
133
  return
150
134
  else:
151
135
  skip_queue.put("ok")
152
136
 
153
- # `ucp.get_active_transports()` call above initializes UCX, we must reset it
137
+ # `ucxx.get_active_transports()` call above initializes UCX, we must reset it
154
138
  # so that Dask doesn't try to initialize it again and raise an exception.
155
- ucp.reset()
139
+ ucxx.reset()
156
140
 
157
141
  if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
158
142
  enable_tcp_over_ucx = None
@@ -168,7 +152,6 @@ def _test_ucx_infiniband_nvlink(
168
152
  cm_tls_priority = ["tcp"]
169
153
 
170
154
  initialize(
171
- protocol=protocol,
172
155
  enable_tcp_over_ucx=enable_tcp_over_ucx,
173
156
  enable_infiniband=enable_infiniband,
174
157
  enable_nvlink=enable_nvlink,
@@ -176,7 +159,7 @@ def _test_ucx_infiniband_nvlink(
176
159
  )
177
160
 
178
161
  with LocalCUDACluster(
179
- protocol=protocol,
162
+ protocol="ucx",
180
163
  interface="ib0",
181
164
  enable_tcp_over_ucx=enable_tcp_over_ucx,
182
165
  enable_infiniband=enable_infiniband,
@@ -190,7 +173,7 @@ def _test_ucx_infiniband_nvlink(
190
173
  assert res == 49995000
191
174
 
192
175
  def check_ucx_options():
193
- conf = ucp.get_config()
176
+ conf = ucxx.get_config()
194
177
  assert "TLS" in conf
195
178
  assert all(t in conf["TLS"] for t in cm_tls)
196
179
  assert all(p in conf["SOCKADDR_TLS_PRIORITY"] for p in cm_tls_priority)
@@ -206,7 +189,6 @@ def _test_ucx_infiniband_nvlink(
206
189
  assert all(client.run(check_ucx_options).values())
207
190
 
208
191
 
209
- @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
210
192
  @pytest.mark.parametrize(
211
193
  "params",
212
194
  [
@@ -221,13 +203,8 @@ def _test_ucx_infiniband_nvlink(
221
203
  _get_dgx_version() == DGXVersion.DGX_A100,
222
204
  reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
223
205
  )
224
- def test_ucx_infiniband_nvlink(protocol, params):
225
- if protocol == "ucx":
226
- ucp = pytest.importorskip("ucp")
227
- elif protocol == "ucxx":
228
- ucp = pytest.importorskip("ucxx")
229
- if _is_ucx_116(ucp) and params["enable_infiniband"] is False:
230
- pytest.skip("https://github.com/rapidsai/ucx-py/issues/1037")
206
+ def test_ucx_infiniband_nvlink(params):
207
+ pytest.importorskip("distributed_ucxx")
231
208
 
232
209
  skip_queue = mp.Queue()
233
210
 
@@ -235,7 +212,6 @@ def test_ucx_infiniband_nvlink(protocol, params):
235
212
  target=_test_ucx_infiniband_nvlink,
236
213
  args=(
237
214
  skip_queue,
238
- protocol,
239
215
  params["enable_infiniband"],
240
216
  params["enable_nvlink"],
241
217
  params["enable_rdmacm"],
@@ -1,4 +1,5 @@
1
- # Copyright (c) 2021-2025 NVIDIA CORPORATION.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
2
3
 
3
4
  import asyncio
4
5
  import multiprocessing as mp
@@ -28,7 +29,6 @@ from dask_cuda.explicit_comms.dataframe.shuffle import (
28
29
  from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
29
30
 
30
31
  mp = mp.get_context("spawn") # type: ignore
31
- ucp = pytest.importorskip("ucp")
32
32
 
33
33
 
34
34
  # Notice, all of the following tests is executed in a new process such
@@ -53,8 +53,11 @@ def _test_local_cluster(protocol):
53
53
  assert sum(c.run(my_rank, 0)) == sum(range(4))
54
54
 
55
55
 
56
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
56
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
57
57
  def test_local_cluster(protocol):
58
+ if protocol.startswith("ucx"):
59
+ pytest.importorskip("distributed_ucxx")
60
+
58
61
  p = mp.Process(target=_test_local_cluster, args=(protocol,))
59
62
  p.start()
60
63
  p.join()
@@ -97,7 +100,7 @@ def test_dataframe_merge_empty_partitions():
97
100
 
98
101
 
99
102
  def check_partitions(df, npartitions):
100
- """Check that all values in `df` hashes to the same"""
103
+ """Check that all values in ``df`` hashes to the same"""
101
104
  dtypes = {}
102
105
  for col, dtype in df.dtypes.items():
103
106
  if pd.api.types.is_numeric_dtype(dtype):
@@ -199,11 +202,13 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
199
202
 
200
203
  @pytest.mark.parametrize("nworkers", [1, 2, 3])
201
204
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
202
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
205
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
203
206
  @pytest.mark.parametrize("_partitions", [True, False])
204
207
  def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
205
208
  if backend == "cudf":
206
209
  pytest.importorskip("cudf")
210
+ if protocol.startswith("ucx"):
211
+ pytest.importorskip("distributed_ucxx")
207
212
 
208
213
  p = mp.Process(
209
214
  target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
@@ -320,10 +325,13 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
320
325
 
321
326
  @pytest.mark.parametrize("nworkers", [1, 2, 4])
322
327
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
323
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
328
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
324
329
  def test_dataframe_shuffle_merge(backend, protocol, nworkers):
325
330
  if backend == "cudf":
326
331
  pytest.importorskip("cudf")
332
+ if protocol.startswith("ucx"):
333
+ pytest.importorskip("distributed_ucxx")
334
+
327
335
  p = mp.Process(
328
336
  target=_test_dataframe_shuffle_merge, args=(backend, protocol, nworkers)
329
337
  )
@@ -357,9 +365,14 @@ def _test_jit_unspill(protocol):
357
365
  assert_eq(got, expected)
358
366
 
359
367
 
360
- @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
368
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
369
+ @pytest.mark.skip_if_no_device_memory(
370
+ "JIT-Unspill not supported in devices without dedicated memory resource"
371
+ )
361
372
  def test_jit_unspill(protocol):
362
373
  pytest.importorskip("cudf")
374
+ if protocol.startswith("ucx"):
375
+ pytest.importorskip("distributed_ucxx")
363
376
 
364
377
  p = mp.Process(target=_test_jit_unspill, args=(protocol,))
365
378
  p.start()
@@ -384,7 +397,7 @@ def _test_lock_workers(scheduler_address, ranks):
384
397
 
385
398
  def test_lock_workers():
386
399
  """
387
- Testing `run(...,lock_workers=True)` by spawning 30 runs with overlapping
400
+ Testing ``run(...,lock_workers=True)`` by spawning 30 runs with overlapping
388
401
  and non-overlapping worker sets.
389
402
  """
390
403
  try:
@@ -423,7 +436,9 @@ def test_create_destroy_create():
423
436
  with LocalCluster(n_workers=1) as cluster:
424
437
  with Client(cluster) as client:
425
438
  context = comms.default_comms()
426
- scheduler_addresses_old = list(client.scheduler_info()["workers"].keys())
439
+ scheduler_addresses_old = list(
440
+ client.scheduler_info(n_workers=-1)["workers"].keys()
441
+ )
427
442
  comms_addresses_old = list(comms.default_comms().worker_addresses)
428
443
  assert comms.default_comms() is context
429
444
  assert len(comms._comms_cache) == 1
@@ -444,7 +459,9 @@ def test_create_destroy_create():
444
459
  # because we referenced the old cluster's addresses.
445
460
  with LocalCluster(n_workers=1) as cluster:
446
461
  with Client(cluster) as client:
447
- scheduler_addresses_new = list(client.scheduler_info()["workers"].keys())
462
+ scheduler_addresses_new = list(
463
+ client.scheduler_info(n_workers=-1)["workers"].keys()
464
+ )
448
465
  comms_addresses_new = list(comms.default_comms().worker_addresses)
449
466
 
450
467
  assert scheduler_addresses_new == comms_addresses_new
@@ -485,7 +502,8 @@ def test_scaled_cluster_gets_new_comms_context():
485
502
  "n_workers": 2,
486
503
  }
487
504
  expected_1 = {
488
- k: expected_values for k in client.scheduler_info()["workers"]
505
+ k: expected_values
506
+ for k in client.scheduler_info(n_workers=-1)["workers"]
489
507
  }
490
508
  assert result_1 == expected_1
491
509
 
@@ -513,7 +531,8 @@ def test_scaled_cluster_gets_new_comms_context():
513
531
  "n_workers": 3,
514
532
  }
515
533
  expected_2 = {
516
- k: expected_values for k in client.scheduler_info()["workers"]
534
+ k: expected_values
535
+ for k in client.scheduler_info(n_workers=-1)["workers"]
517
536
  }
518
537
  assert result_2 == expected_2
519
538
 
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import pytest
2
5
 
3
6
  import dask.array as da
@@ -8,13 +11,8 @@ from dask_cuda import LocalCUDACluster
8
11
  cupy = pytest.importorskip("cupy")
9
12
 
10
13
 
11
- @pytest.mark.parametrize("protocol", ["ucx", "ucxx", "tcp"])
14
+ @pytest.mark.parametrize("protocol", ["ucx", "tcp"])
12
15
  def test_ucx_from_array(protocol):
13
- if protocol == "ucx":
14
- pytest.importorskip("ucp")
15
- elif protocol == "ucxx":
16
- pytest.importorskip("ucxx")
17
-
18
16
  N = 10_000
19
17
  with LocalCUDACluster(protocol=protocol) as cluster:
20
18
  with Client(cluster):