dask-cuda 23.12.0a231026__py3-none-any.whl → 24.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. dask_cuda/VERSION +1 -0
  2. dask_cuda/__init__.py +1 -3
  3. dask_cuda/_version.py +20 -0
  4. dask_cuda/benchmarks/local_cudf_groupby.py +1 -1
  5. dask_cuda/benchmarks/local_cudf_merge.py +1 -1
  6. dask_cuda/benchmarks/local_cudf_shuffle.py +1 -1
  7. dask_cuda/benchmarks/local_cupy.py +1 -1
  8. dask_cuda/benchmarks/local_cupy_map_overlap.py +1 -1
  9. dask_cuda/benchmarks/utils.py +1 -1
  10. dask_cuda/cuda_worker.py +1 -3
  11. dask_cuda/device_host_file.py +1 -1
  12. dask_cuda/explicit_comms/dataframe/shuffle.py +1 -1
  13. dask_cuda/get_device_memory_objects.py +4 -0
  14. dask_cuda/initialize.py +47 -16
  15. dask_cuda/local_cuda_cluster.py +19 -19
  16. dask_cuda/plugins.py +122 -0
  17. dask_cuda/tests/test_dask_cuda_worker.py +3 -3
  18. dask_cuda/tests/test_dgx.py +49 -17
  19. dask_cuda/tests/test_explicit_comms.py +34 -6
  20. dask_cuda/tests/test_from_array.py +6 -2
  21. dask_cuda/tests/test_initialize.py +69 -21
  22. dask_cuda/tests/test_local_cuda_cluster.py +47 -14
  23. dask_cuda/tests/test_proxify_host_file.py +19 -4
  24. dask_cuda/tests/test_proxy.py +14 -3
  25. dask_cuda/tests/test_spill.py +3 -0
  26. dask_cuda/tests/test_utils.py +20 -6
  27. dask_cuda/utils.py +6 -140
  28. dask_cuda/utils_test.py +45 -0
  29. dask_cuda/worker_spec.py +2 -1
  30. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0.dist-info}/METADATA +11 -6
  31. dask_cuda-24.2.0.dist-info/RECORD +53 -0
  32. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0.dist-info}/WHEEL +1 -1
  33. dask_cuda/compat.py +0 -118
  34. dask_cuda-23.12.0a231026.dist-info/RECORD +0 -50
  35. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0.dist-info}/LICENSE +0 -0
  36. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0.dist-info}/entry_points.txt +0 -0
  37. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,9 @@
1
1
  import asyncio
2
2
  import multiprocessing as mp
3
3
  import os
4
+ import signal
5
+ import time
6
+ from functools import partial
4
7
  from unittest.mock import patch
5
8
 
6
9
  import numpy as np
@@ -17,7 +20,7 @@ from distributed.deploy.local import LocalCluster
17
20
  import dask_cuda
18
21
  from dask_cuda.explicit_comms import comms
19
22
  from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
20
- from dask_cuda.local_cuda_cluster import IncreasedCloseTimeoutNanny
23
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
21
24
 
22
25
  mp = mp.get_context("spawn") # type: ignore
23
26
  ucp = pytest.importorskip("ucp")
@@ -44,7 +47,7 @@ def _test_local_cluster(protocol):
44
47
  assert sum(c.run(my_rank, 0)) == sum(range(4))
45
48
 
46
49
 
47
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
50
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
48
51
  def test_local_cluster(protocol):
49
52
  p = mp.Process(target=_test_local_cluster, args=(protocol,))
50
53
  p.start()
@@ -160,7 +163,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
160
163
 
161
164
  @pytest.mark.parametrize("nworkers", [1, 2, 3])
162
165
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
163
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
166
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
164
167
  @pytest.mark.parametrize("_partitions", [True, False])
165
168
  def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
166
169
  if backend == "cudf":
@@ -175,7 +178,7 @@ def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
175
178
 
176
179
 
177
180
  @pytest.mark.parametrize("in_cluster", [True, False])
178
- def test_dask_use_explicit_comms(in_cluster):
181
+ def _test_dask_use_explicit_comms(in_cluster):
179
182
  def check_shuffle():
180
183
  """Check if shuffle use explicit-comms by search for keys named
181
184
  'explicit-comms-shuffle'
@@ -217,6 +220,31 @@ def test_dask_use_explicit_comms(in_cluster):
217
220
  check_shuffle()
218
221
 
219
222
 
223
+ @pytest.mark.parametrize("in_cluster", [True, False])
224
+ def test_dask_use_explicit_comms(in_cluster):
225
+ def _timeout(process, function, timeout):
226
+ if process.is_alive():
227
+ function()
228
+ timeout = time.time() + timeout
229
+ while process.is_alive() and time.time() < timeout:
230
+ time.sleep(0.1)
231
+
232
+ p = mp.Process(target=_test_dask_use_explicit_comms, args=(in_cluster,))
233
+ p.start()
234
+
235
+ # Timeout before killing process
236
+ _timeout(p, lambda: None, 60.0)
237
+
238
+ # Send SIGINT (i.e., KeyboardInterrupt) hoping we get a stack trace.
239
+ _timeout(p, partial(p._popen._send_signal, signal.SIGINT), 3.0)
240
+
241
+ # SIGINT didn't work, kill process.
242
+ _timeout(p, p.kill, 3.0)
243
+
244
+ assert not p.is_alive()
245
+ assert p.exitcode == 0
246
+
247
+
220
248
  def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
221
249
  if backend == "cudf":
222
250
  cudf = pytest.importorskip("cudf")
@@ -256,7 +284,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
256
284
 
257
285
  @pytest.mark.parametrize("nworkers", [1, 2, 4])
258
286
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
259
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
287
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
260
288
  def test_dataframe_shuffle_merge(backend, protocol, nworkers):
261
289
  if backend == "cudf":
262
290
  pytest.importorskip("cudf")
@@ -293,7 +321,7 @@ def _test_jit_unspill(protocol):
293
321
  assert_eq(got, expected)
294
322
 
295
323
 
296
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
324
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
297
325
  def test_jit_unspill(protocol):
298
326
  pytest.importorskip("cudf")
299
327
 
@@ -5,12 +5,16 @@ from distributed import Client
5
5
 
6
6
  from dask_cuda import LocalCUDACluster
7
7
 
8
- pytest.importorskip("ucp")
9
8
  cupy = pytest.importorskip("cupy")
10
9
 
11
10
 
12
- @pytest.mark.parametrize("protocol", ["ucx", "tcp"])
11
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx", "tcp"])
13
12
  def test_ucx_from_array(protocol):
13
+ if protocol == "ucx":
14
+ pytest.importorskip("ucp")
15
+ elif protocol == "ucxx":
16
+ pytest.importorskip("ucxx")
17
+
14
18
  N = 10_000
15
19
  with LocalCUDACluster(protocol=protocol) as cluster:
16
20
  with Client(cluster):
@@ -10,9 +10,9 @@ from distributed.deploy.local import LocalCluster
10
10
 
11
11
  from dask_cuda.initialize import initialize
12
12
  from dask_cuda.utils import get_ucx_config
13
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
13
14
 
14
15
  mp = mp.get_context("spawn") # type: ignore
15
- ucp = pytest.importorskip("ucp")
16
16
 
17
17
  # Notice, all of the following tests is executed in a new process such
18
18
  # that UCX options of the different tests doesn't conflict.
@@ -20,15 +20,21 @@ ucp = pytest.importorskip("ucp")
20
20
  # of UCX before retrieving the current config.
21
21
 
22
22
 
23
- def _test_initialize_ucx_tcp():
23
+ def _test_initialize_ucx_tcp(protocol):
24
+ if protocol == "ucx":
25
+ ucp = pytest.importorskip("ucp")
26
+ elif protocol == "ucxx":
27
+ ucp = pytest.importorskip("ucxx")
28
+
24
29
  kwargs = {"enable_tcp_over_ucx": True}
25
- initialize(**kwargs)
30
+ initialize(protocol=protocol, **kwargs)
26
31
  with LocalCluster(
27
- protocol="ucx",
32
+ protocol=protocol,
28
33
  dashboard_address=None,
29
34
  n_workers=1,
30
35
  threads_per_worker=1,
31
36
  processes=True,
37
+ worker_class=IncreasedCloseTimeoutNanny,
32
38
  config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
33
39
  ) as cluster:
34
40
  with Client(cluster) as client:
@@ -48,22 +54,34 @@ def _test_initialize_ucx_tcp():
48
54
  assert all(client.run(check_ucx_options).values())
49
55
 
50
56
 
51
- def test_initialize_ucx_tcp():
52
- p = mp.Process(target=_test_initialize_ucx_tcp)
57
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
58
+ def test_initialize_ucx_tcp(protocol):
59
+ if protocol == "ucx":
60
+ pytest.importorskip("ucp")
61
+ elif protocol == "ucxx":
62
+ pytest.importorskip("ucxx")
63
+
64
+ p = mp.Process(target=_test_initialize_ucx_tcp, args=(protocol,))
53
65
  p.start()
54
66
  p.join()
55
67
  assert not p.exitcode
56
68
 
57
69
 
58
- def _test_initialize_ucx_nvlink():
70
+ def _test_initialize_ucx_nvlink(protocol):
71
+ if protocol == "ucx":
72
+ ucp = pytest.importorskip("ucp")
73
+ elif protocol == "ucxx":
74
+ ucp = pytest.importorskip("ucxx")
75
+
59
76
  kwargs = {"enable_nvlink": True}
60
- initialize(**kwargs)
77
+ initialize(protocol=protocol, **kwargs)
61
78
  with LocalCluster(
62
- protocol="ucx",
79
+ protocol=protocol,
63
80
  dashboard_address=None,
64
81
  n_workers=1,
65
82
  threads_per_worker=1,
66
83
  processes=True,
84
+ worker_class=IncreasedCloseTimeoutNanny,
67
85
  config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
68
86
  ) as cluster:
69
87
  with Client(cluster) as client:
@@ -84,22 +102,34 @@ def _test_initialize_ucx_nvlink():
84
102
  assert all(client.run(check_ucx_options).values())
85
103
 
86
104
 
87
- def test_initialize_ucx_nvlink():
88
- p = mp.Process(target=_test_initialize_ucx_nvlink)
105
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
106
+ def test_initialize_ucx_nvlink(protocol):
107
+ if protocol == "ucx":
108
+ pytest.importorskip("ucp")
109
+ elif protocol == "ucxx":
110
+ pytest.importorskip("ucxx")
111
+
112
+ p = mp.Process(target=_test_initialize_ucx_nvlink, args=(protocol,))
89
113
  p.start()
90
114
  p.join()
91
115
  assert not p.exitcode
92
116
 
93
117
 
94
- def _test_initialize_ucx_infiniband():
118
+ def _test_initialize_ucx_infiniband(protocol):
119
+ if protocol == "ucx":
120
+ ucp = pytest.importorskip("ucp")
121
+ elif protocol == "ucxx":
122
+ ucp = pytest.importorskip("ucxx")
123
+
95
124
  kwargs = {"enable_infiniband": True}
96
- initialize(**kwargs)
125
+ initialize(protocol=protocol, **kwargs)
97
126
  with LocalCluster(
98
- protocol="ucx",
127
+ protocol=protocol,
99
128
  dashboard_address=None,
100
129
  n_workers=1,
101
130
  threads_per_worker=1,
102
131
  processes=True,
132
+ worker_class=IncreasedCloseTimeoutNanny,
103
133
  config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
104
134
  ) as cluster:
105
135
  with Client(cluster) as client:
@@ -123,21 +153,33 @@ def _test_initialize_ucx_infiniband():
123
153
  @pytest.mark.skipif(
124
154
  "ib0" not in psutil.net_if_addrs(), reason="Infiniband interface ib0 not found"
125
155
  )
126
- def test_initialize_ucx_infiniband():
127
- p = mp.Process(target=_test_initialize_ucx_infiniband)
156
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
157
+ def test_initialize_ucx_infiniband(protocol):
158
+ if protocol == "ucx":
159
+ pytest.importorskip("ucp")
160
+ elif protocol == "ucxx":
161
+ pytest.importorskip("ucxx")
162
+
163
+ p = mp.Process(target=_test_initialize_ucx_infiniband, args=(protocol,))
128
164
  p.start()
129
165
  p.join()
130
166
  assert not p.exitcode
131
167
 
132
168
 
133
- def _test_initialize_ucx_all():
134
- initialize()
169
+ def _test_initialize_ucx_all(protocol):
170
+ if protocol == "ucx":
171
+ ucp = pytest.importorskip("ucp")
172
+ elif protocol == "ucxx":
173
+ ucp = pytest.importorskip("ucxx")
174
+
175
+ initialize(protocol=protocol)
135
176
  with LocalCluster(
136
- protocol="ucx",
177
+ protocol=protocol,
137
178
  dashboard_address=None,
138
179
  n_workers=1,
139
180
  threads_per_worker=1,
140
181
  processes=True,
182
+ worker_class=IncreasedCloseTimeoutNanny,
141
183
  config={"distributed.comm.ucx": get_ucx_config()},
142
184
  ) as cluster:
143
185
  with Client(cluster) as client:
@@ -161,8 +203,14 @@ def _test_initialize_ucx_all():
161
203
  assert all(client.run(check_ucx_options).values())
162
204
 
163
205
 
164
- def test_initialize_ucx_all():
165
- p = mp.Process(target=_test_initialize_ucx_all)
206
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
207
+ def test_initialize_ucx_all(protocol):
208
+ if protocol == "ucx":
209
+ pytest.importorskip("ucp")
210
+ elif protocol == "ucxx":
211
+ pytest.importorskip("ucxx")
212
+
213
+ p = mp.Process(target=_test_initialize_ucx_all, args=(protocol,))
166
214
  p.start()
167
215
  p.join()
168
216
  assert not p.exitcode
@@ -13,13 +13,13 @@ from distributed.utils_test import gen_test, raises_with_cause
13
13
  from dask_cuda import CUDAWorker, LocalCUDACluster, utils
14
14
  from dask_cuda.initialize import initialize
15
15
  from dask_cuda.utils import (
16
- MockWorker,
17
16
  get_cluster_configuration,
18
17
  get_device_total_memory,
19
18
  get_gpu_count_mig,
20
19
  get_gpu_uuid_from_index,
21
20
  print_cluster_config,
22
21
  )
22
+ from dask_cuda.utils_test import MockWorker
23
23
 
24
24
 
25
25
  @gen_test(timeout=20)
@@ -87,23 +87,38 @@ async def test_with_subset_of_cuda_visible_devices():
87
87
  }
88
88
 
89
89
 
90
+ @pytest.mark.parametrize(
91
+ "protocol",
92
+ ["ucx", "ucxx"],
93
+ )
90
94
  @gen_test(timeout=20)
91
- async def test_ucx_protocol():
92
- pytest.importorskip("ucp")
95
+ async def test_ucx_protocol(protocol):
96
+ if protocol == "ucx":
97
+ pytest.importorskip("ucp")
98
+ elif protocol == "ucxx":
99
+ pytest.importorskip("ucxx")
93
100
 
94
101
  async with LocalCUDACluster(
95
- protocol="ucx", asynchronous=True, data=dict
102
+ protocol=protocol, asynchronous=True, data=dict
96
103
  ) as cluster:
97
104
  assert all(
98
- ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values()
105
+ ws.address.startswith(f"{protocol}://")
106
+ for ws in cluster.scheduler.workers.values()
99
107
  )
100
108
 
101
109
 
110
+ @pytest.mark.parametrize(
111
+ "protocol",
112
+ ["ucx", "ucxx"],
113
+ )
102
114
  @gen_test(timeout=20)
103
- async def test_explicit_ucx_with_protocol_none():
104
- pytest.importorskip("ucp")
115
+ async def test_explicit_ucx_with_protocol_none(protocol):
116
+ if protocol == "ucx":
117
+ pytest.importorskip("ucp")
118
+ elif protocol == "ucxx":
119
+ pytest.importorskip("ucxx")
105
120
 
106
- initialize(enable_tcp_over_ucx=True)
121
+ initialize(protocol=protocol, enable_tcp_over_ucx=True)
107
122
  async with LocalCUDACluster(
108
123
  protocol=None, enable_tcp_over_ucx=True, asynchronous=True, data=dict
109
124
  ) as cluster:
@@ -113,11 +128,18 @@ async def test_explicit_ucx_with_protocol_none():
113
128
 
114
129
 
115
130
  @pytest.mark.filterwarnings("ignore:Exception ignored in")
131
+ @pytest.mark.parametrize(
132
+ "protocol",
133
+ ["ucx", "ucxx"],
134
+ )
116
135
  @gen_test(timeout=20)
117
- async def test_ucx_protocol_type_error():
118
- pytest.importorskip("ucp")
136
+ async def test_ucx_protocol_type_error(protocol):
137
+ if protocol == "ucx":
138
+ pytest.importorskip("ucp")
139
+ elif protocol == "ucxx":
140
+ pytest.importorskip("ucxx")
119
141
 
120
- initialize(enable_tcp_over_ucx=True)
142
+ initialize(protocol=protocol, enable_tcp_over_ucx=True)
121
143
  with pytest.raises(TypeError):
122
144
  async with LocalCUDACluster(
123
145
  protocol="tcp", enable_tcp_over_ucx=True, asynchronous=True, data=dict
@@ -337,6 +359,7 @@ async def test_pre_import():
337
359
 
338
360
 
339
361
  # Intentionally not using @gen_test to skip cleanup checks
362
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
340
363
  def test_pre_import_not_found():
341
364
  async def _test_pre_import_not_found():
342
365
  with raises_with_cause(RuntimeError, None, ImportError, None):
@@ -477,20 +500,30 @@ async def test_worker_fraction_limits():
477
500
  )
478
501
 
479
502
 
480
- def test_print_cluster_config(capsys):
503
+ @pytest.mark.parametrize(
504
+ "protocol",
505
+ ["ucx", "ucxx"],
506
+ )
507
+ def test_print_cluster_config(capsys, protocol):
508
+ if protocol == "ucx":
509
+ pytest.importorskip("ucp")
510
+ elif protocol == "ucxx":
511
+ pytest.importorskip("ucxx")
512
+
481
513
  pytest.importorskip("rich")
482
514
  with LocalCUDACluster(
483
- n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol="ucx"
515
+ n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol=protocol
484
516
  ) as cluster:
485
517
  with Client(cluster) as client:
486
518
  print_cluster_config(client)
487
519
  captured = capsys.readouterr()
488
520
  assert "Dask Cluster Configuration" in captured.out
489
- assert "ucx" in captured.out
521
+ assert protocol in captured.out
490
522
  assert "1 B" in captured.out
491
523
  assert "[plugin]" in captured.out
492
524
 
493
525
 
526
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
494
527
  def test_death_timeout_raises():
495
528
  with pytest.raises(asyncio.exceptions.TimeoutError):
496
529
  with LocalCUDACluster(
@@ -19,6 +19,7 @@ from dask_cuda.get_device_memory_objects import get_device_memory_ids
19
19
  from dask_cuda.proxify_host_file import ProxifyHostFile
20
20
  from dask_cuda.proxy_object import ProxyObject, asproxy, unproxy
21
21
  from dask_cuda.utils import get_device_total_memory
22
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
22
23
 
23
24
  cupy = pytest.importorskip("cupy")
24
25
  cupy.cuda.set_allocator(None)
@@ -301,13 +302,24 @@ def test_dataframes_share_dev_mem(root_dir):
301
302
  def test_cudf_get_device_memory_objects():
302
303
  cudf = pytest.importorskip("cudf")
303
304
  objects = [
304
- cudf.DataFrame({"a": range(10), "b": range(10)}, index=reversed(range(10))),
305
+ cudf.DataFrame(
306
+ {"a": [0, 1, 2, 3, None, 5, 6, 7, 8, 9], "b": range(10)},
307
+ index=reversed(range(10)),
308
+ ),
305
309
  cudf.MultiIndex(
306
310
  levels=[[1, 2], ["blue", "red"]], codes=[[0, 0, 1, 1], [1, 0, 1, 0]]
307
311
  ),
308
312
  ]
309
313
  res = get_device_memory_ids(objects)
310
- assert len(res) == 4, "We expect four buffer objects"
314
+ # Buffers are:
315
+ # 1. int data for objects[0].a
316
+ # 2. mask data for objects[0].a
317
+ # 3. int data for objects[0].b
318
+ # 4. int data for objects[0].index
319
+ # 5. int data for objects[1].levels[0]
320
+ # 6. char data for objects[1].levels[1]
321
+ # 7. offset data for objects[1].levels[1]
322
+ assert len(res) == 7, "We expect seven buffer objects"
311
323
 
312
324
 
313
325
  def test_externals(root_dir):
@@ -393,13 +405,16 @@ async def test_compatibility_mode_dataframe_shuffle(compatibility_mode, npartiti
393
405
 
394
406
  with dask.config.set(jit_unspill_compatibility_mode=compatibility_mode):
395
407
  async with dask_cuda.LocalCUDACluster(
396
- n_workers=1, jit_unspill=True, asynchronous=True
408
+ n_workers=1,
409
+ jit_unspill=True,
410
+ worker_class=IncreasedCloseTimeoutNanny,
411
+ asynchronous=True,
397
412
  ) as cluster:
398
413
  async with Client(cluster, asynchronous=True) as client:
399
414
  ddf = dask.dataframe.from_pandas(
400
415
  cudf.DataFrame({"key": np.arange(10)}), npartitions=npartitions
401
416
  )
402
- res = ddf.shuffle(on="key", shuffle="tasks").persist()
417
+ res = ddf.shuffle(on="key", shuffle_method="tasks").persist()
403
418
 
404
419
  # With compatibility mode on, we shouldn't encounter any proxy objects
405
420
  if compatibility_mode:
@@ -23,6 +23,7 @@ from dask_cuda import LocalCUDACluster, proxy_object
23
23
  from dask_cuda.disk_io import SpillToDiskFile
24
24
  from dask_cuda.proxify_device_objects import proxify_device_objects
25
25
  from dask_cuda.proxify_host_file import ProxifyHostFile
26
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
26
27
 
27
28
  # Make the "disk" serializer available and use a directory that are
28
29
  # remove on exit.
@@ -305,6 +306,7 @@ async def test_spilling_local_cuda_cluster(jit_unspill):
305
306
  n_workers=1,
306
307
  device_memory_limit="1B",
307
308
  jit_unspill=jit_unspill,
309
+ worker_class=IncreasedCloseTimeoutNanny,
308
310
  asynchronous=True,
309
311
  ) as cluster:
310
312
  async with Client(cluster, asynchronous=True) as client:
@@ -399,10 +401,14 @@ class _PxyObjTest(proxy_object.ProxyObject):
399
401
 
400
402
 
401
403
  @pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)])
402
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
404
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
403
405
  @gen_test(timeout=120)
404
406
  async def test_communicating_proxy_objects(protocol, send_serializers):
405
407
  """Testing serialization of cuDF dataframe when communicating"""
408
+ if protocol == "ucx":
409
+ pytest.importorskip("ucp")
410
+ elif protocol == "ucxx":
411
+ pytest.importorskip("ucxx")
406
412
  cudf = pytest.importorskip("cudf")
407
413
 
408
414
  def task(x):
@@ -411,7 +417,7 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
411
417
  serializers_used = x._pxy_get().serializer
412
418
 
413
419
  # Check that `x` is serialized with the expected serializers
414
- if protocol == "ucx":
420
+ if protocol in ["ucx", "ucxx"]:
415
421
  if send_serializers is None:
416
422
  assert serializers_used == "cuda"
417
423
  else:
@@ -422,6 +428,7 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
422
428
  async with dask_cuda.LocalCUDACluster(
423
429
  n_workers=1,
424
430
  protocol=protocol,
431
+ worker_class=IncreasedCloseTimeoutNanny,
425
432
  asynchronous=True,
426
433
  ) as cluster:
427
434
  async with Client(cluster, asynchronous=True) as client:
@@ -441,11 +448,15 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
441
448
  await client.submit(task, df)
442
449
 
443
450
 
444
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
451
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
445
452
  @pytest.mark.parametrize("shared_fs", [True, False])
446
453
  @gen_test(timeout=20)
447
454
  async def test_communicating_disk_objects(protocol, shared_fs):
448
455
  """Testing disk serialization of cuDF dataframe when communicating"""
456
+ if protocol == "ucx":
457
+ pytest.importorskip("ucp")
458
+ elif protocol == "ucxx":
459
+ pytest.importorskip("ucxx")
449
460
  cudf = pytest.importorskip("cudf")
450
461
  ProxifyHostFile._spill_to_disk.shared_filesystem = shared_fs
451
462
 
@@ -12,6 +12,7 @@ from distributed.sizeof import sizeof
12
12
  from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401
13
13
 
14
14
  from dask_cuda import LocalCUDACluster, utils
15
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
15
16
 
16
17
  if utils.get_device_total_memory() < 1e10:
17
18
  pytest.skip("Not enough GPU memory", allow_module_level=True)
@@ -160,6 +161,7 @@ async def test_cupy_cluster_device_spill(params):
160
161
  asynchronous=True,
161
162
  device_memory_limit=params["device_memory_limit"],
162
163
  memory_limit=params["memory_limit"],
164
+ worker_class=IncreasedCloseTimeoutNanny,
163
165
  ) as cluster:
164
166
  async with Client(cluster, asynchronous=True) as client:
165
167
 
@@ -263,6 +265,7 @@ async def test_cudf_cluster_device_spill(params):
263
265
  asynchronous=True,
264
266
  device_memory_limit=params["device_memory_limit"],
265
267
  memory_limit=params["memory_limit"],
268
+ worker_class=IncreasedCloseTimeoutNanny,
266
269
  ) as cluster:
267
270
  async with Client(cluster, asynchronous=True) as client:
268
271
 
@@ -79,11 +79,18 @@ def test_get_device_total_memory():
79
79
  assert total_mem > 0
80
80
 
81
81
 
82
- def test_get_preload_options_default():
83
- pytest.importorskip("ucp")
82
+ @pytest.mark.parametrize(
83
+ "protocol",
84
+ ["ucx", "ucxx"],
85
+ )
86
+ def test_get_preload_options_default(protocol):
87
+ if protocol == "ucx":
88
+ pytest.importorskip("ucp")
89
+ elif protocol == "ucxx":
90
+ pytest.importorskip("ucxx")
84
91
 
85
92
  opts = get_preload_options(
86
- protocol="ucx",
93
+ protocol=protocol,
87
94
  create_cuda_context=True,
88
95
  )
89
96
 
@@ -93,14 +100,21 @@ def test_get_preload_options_default():
93
100
  assert opts["preload_argv"] == ["--create-cuda-context"]
94
101
 
95
102
 
103
+ @pytest.mark.parametrize(
104
+ "protocol",
105
+ ["ucx", "ucxx"],
106
+ )
96
107
  @pytest.mark.parametrize("enable_tcp", [True, False])
97
108
  @pytest.mark.parametrize("enable_infiniband", [True, False])
98
109
  @pytest.mark.parametrize("enable_nvlink", [True, False])
99
- def test_get_preload_options(enable_tcp, enable_infiniband, enable_nvlink):
100
- pytest.importorskip("ucp")
110
+ def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvlink):
111
+ if protocol == "ucx":
112
+ pytest.importorskip("ucp")
113
+ elif protocol == "ucxx":
114
+ pytest.importorskip("ucxx")
101
115
 
102
116
  opts = get_preload_options(
103
- protocol="ucx",
117
+ protocol=protocol,
104
118
  create_cuda_context=True,
105
119
  enable_tcp_over_ucx=enable_tcp,
106
120
  enable_infiniband=enable_infiniband,