dask-cuda 23.12.0a231026__py3-none-any.whl → 24.2.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. dask_cuda/VERSION +1 -0
  2. dask_cuda/__init__.py +1 -3
  3. dask_cuda/_version.py +20 -0
  4. dask_cuda/benchmarks/local_cudf_groupby.py +1 -1
  5. dask_cuda/benchmarks/local_cudf_merge.py +1 -1
  6. dask_cuda/benchmarks/local_cudf_shuffle.py +1 -1
  7. dask_cuda/benchmarks/local_cupy.py +1 -1
  8. dask_cuda/benchmarks/local_cupy_map_overlap.py +1 -1
  9. dask_cuda/benchmarks/utils.py +1 -1
  10. dask_cuda/cuda_worker.py +1 -3
  11. dask_cuda/device_host_file.py +1 -1
  12. dask_cuda/initialize.py +47 -16
  13. dask_cuda/local_cuda_cluster.py +19 -19
  14. dask_cuda/plugins.py +122 -0
  15. dask_cuda/tests/test_dask_cuda_worker.py +3 -3
  16. dask_cuda/tests/test_dgx.py +45 -17
  17. dask_cuda/tests/test_explicit_comms.py +5 -5
  18. dask_cuda/tests/test_from_array.py +6 -2
  19. dask_cuda/tests/test_initialize.py +69 -21
  20. dask_cuda/tests/test_local_cuda_cluster.py +47 -14
  21. dask_cuda/tests/test_proxify_host_file.py +5 -1
  22. dask_cuda/tests/test_proxy.py +13 -3
  23. dask_cuda/tests/test_spill.py +3 -0
  24. dask_cuda/tests/test_utils.py +20 -6
  25. dask_cuda/utils.py +6 -140
  26. dask_cuda/utils_test.py +45 -0
  27. dask_cuda/worker_spec.py +2 -1
  28. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0a3.dist-info}/METADATA +2 -3
  29. dask_cuda-24.2.0a3.dist-info/RECORD +53 -0
  30. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0a3.dist-info}/WHEEL +1 -1
  31. dask_cuda/compat.py +0 -118
  32. dask_cuda-23.12.0a231026.dist-info/RECORD +0 -50
  33. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0a3.dist-info}/LICENSE +0 -0
  34. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0a3.dist-info}/entry_points.txt +0 -0
  35. {dask_cuda-23.12.0a231026.dist-info → dask_cuda-24.2.0a3.dist-info}/top_level.txt +0 -0
@@ -10,9 +10,9 @@ from distributed.deploy.local import LocalCluster
10
10
 
11
11
  from dask_cuda.initialize import initialize
12
12
  from dask_cuda.utils import get_ucx_config
13
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
13
14
 
14
15
  mp = mp.get_context("spawn") # type: ignore
15
- ucp = pytest.importorskip("ucp")
16
16
 
17
17
  # Notice, all of the following tests is executed in a new process such
18
18
  # that UCX options of the different tests doesn't conflict.
@@ -20,15 +20,21 @@ ucp = pytest.importorskip("ucp")
20
20
  # of UCX before retrieving the current config.
21
21
 
22
22
 
23
- def _test_initialize_ucx_tcp():
23
+ def _test_initialize_ucx_tcp(protocol):
24
+ if protocol == "ucx":
25
+ ucp = pytest.importorskip("ucp")
26
+ elif protocol == "ucxx":
27
+ ucp = pytest.importorskip("ucxx")
28
+
24
29
  kwargs = {"enable_tcp_over_ucx": True}
25
- initialize(**kwargs)
30
+ initialize(protocol=protocol, **kwargs)
26
31
  with LocalCluster(
27
- protocol="ucx",
32
+ protocol=protocol,
28
33
  dashboard_address=None,
29
34
  n_workers=1,
30
35
  threads_per_worker=1,
31
36
  processes=True,
37
+ worker_class=IncreasedCloseTimeoutNanny,
32
38
  config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
33
39
  ) as cluster:
34
40
  with Client(cluster) as client:
@@ -48,22 +54,34 @@ def _test_initialize_ucx_tcp():
48
54
  assert all(client.run(check_ucx_options).values())
49
55
 
50
56
 
51
- def test_initialize_ucx_tcp():
52
- p = mp.Process(target=_test_initialize_ucx_tcp)
57
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
58
+ def test_initialize_ucx_tcp(protocol):
59
+ if protocol == "ucx":
60
+ pytest.importorskip("ucp")
61
+ elif protocol == "ucxx":
62
+ pytest.importorskip("ucxx")
63
+
64
+ p = mp.Process(target=_test_initialize_ucx_tcp, args=(protocol,))
53
65
  p.start()
54
66
  p.join()
55
67
  assert not p.exitcode
56
68
 
57
69
 
58
- def _test_initialize_ucx_nvlink():
70
+ def _test_initialize_ucx_nvlink(protocol):
71
+ if protocol == "ucx":
72
+ ucp = pytest.importorskip("ucp")
73
+ elif protocol == "ucxx":
74
+ ucp = pytest.importorskip("ucxx")
75
+
59
76
  kwargs = {"enable_nvlink": True}
60
- initialize(**kwargs)
77
+ initialize(protocol=protocol, **kwargs)
61
78
  with LocalCluster(
62
- protocol="ucx",
79
+ protocol=protocol,
63
80
  dashboard_address=None,
64
81
  n_workers=1,
65
82
  threads_per_worker=1,
66
83
  processes=True,
84
+ worker_class=IncreasedCloseTimeoutNanny,
67
85
  config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
68
86
  ) as cluster:
69
87
  with Client(cluster) as client:
@@ -84,22 +102,34 @@ def _test_initialize_ucx_nvlink():
84
102
  assert all(client.run(check_ucx_options).values())
85
103
 
86
104
 
87
- def test_initialize_ucx_nvlink():
88
- p = mp.Process(target=_test_initialize_ucx_nvlink)
105
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
106
+ def test_initialize_ucx_nvlink(protocol):
107
+ if protocol == "ucx":
108
+ pytest.importorskip("ucp")
109
+ elif protocol == "ucxx":
110
+ pytest.importorskip("ucxx")
111
+
112
+ p = mp.Process(target=_test_initialize_ucx_nvlink, args=(protocol,))
89
113
  p.start()
90
114
  p.join()
91
115
  assert not p.exitcode
92
116
 
93
117
 
94
- def _test_initialize_ucx_infiniband():
118
+ def _test_initialize_ucx_infiniband(protocol):
119
+ if protocol == "ucx":
120
+ ucp = pytest.importorskip("ucp")
121
+ elif protocol == "ucxx":
122
+ ucp = pytest.importorskip("ucxx")
123
+
95
124
  kwargs = {"enable_infiniband": True}
96
- initialize(**kwargs)
125
+ initialize(protocol=protocol, **kwargs)
97
126
  with LocalCluster(
98
- protocol="ucx",
127
+ protocol=protocol,
99
128
  dashboard_address=None,
100
129
  n_workers=1,
101
130
  threads_per_worker=1,
102
131
  processes=True,
132
+ worker_class=IncreasedCloseTimeoutNanny,
103
133
  config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
104
134
  ) as cluster:
105
135
  with Client(cluster) as client:
@@ -123,21 +153,33 @@ def _test_initialize_ucx_infiniband():
123
153
  @pytest.mark.skipif(
124
154
  "ib0" not in psutil.net_if_addrs(), reason="Infiniband interface ib0 not found"
125
155
  )
126
- def test_initialize_ucx_infiniband():
127
- p = mp.Process(target=_test_initialize_ucx_infiniband)
156
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
157
+ def test_initialize_ucx_infiniband(protocol):
158
+ if protocol == "ucx":
159
+ pytest.importorskip("ucp")
160
+ elif protocol == "ucxx":
161
+ pytest.importorskip("ucxx")
162
+
163
+ p = mp.Process(target=_test_initialize_ucx_infiniband, args=(protocol,))
128
164
  p.start()
129
165
  p.join()
130
166
  assert not p.exitcode
131
167
 
132
168
 
133
- def _test_initialize_ucx_all():
134
- initialize()
169
+ def _test_initialize_ucx_all(protocol):
170
+ if protocol == "ucx":
171
+ ucp = pytest.importorskip("ucp")
172
+ elif protocol == "ucxx":
173
+ ucp = pytest.importorskip("ucxx")
174
+
175
+ initialize(protocol=protocol)
135
176
  with LocalCluster(
136
- protocol="ucx",
177
+ protocol=protocol,
137
178
  dashboard_address=None,
138
179
  n_workers=1,
139
180
  threads_per_worker=1,
140
181
  processes=True,
182
+ worker_class=IncreasedCloseTimeoutNanny,
141
183
  config={"distributed.comm.ucx": get_ucx_config()},
142
184
  ) as cluster:
143
185
  with Client(cluster) as client:
@@ -161,8 +203,14 @@ def _test_initialize_ucx_all():
161
203
  assert all(client.run(check_ucx_options).values())
162
204
 
163
205
 
164
- def test_initialize_ucx_all():
165
- p = mp.Process(target=_test_initialize_ucx_all)
206
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
207
+ def test_initialize_ucx_all(protocol):
208
+ if protocol == "ucx":
209
+ pytest.importorskip("ucp")
210
+ elif protocol == "ucxx":
211
+ pytest.importorskip("ucxx")
212
+
213
+ p = mp.Process(target=_test_initialize_ucx_all, args=(protocol,))
166
214
  p.start()
167
215
  p.join()
168
216
  assert not p.exitcode
@@ -13,13 +13,13 @@ from distributed.utils_test import gen_test, raises_with_cause
13
13
  from dask_cuda import CUDAWorker, LocalCUDACluster, utils
14
14
  from dask_cuda.initialize import initialize
15
15
  from dask_cuda.utils import (
16
- MockWorker,
17
16
  get_cluster_configuration,
18
17
  get_device_total_memory,
19
18
  get_gpu_count_mig,
20
19
  get_gpu_uuid_from_index,
21
20
  print_cluster_config,
22
21
  )
22
+ from dask_cuda.utils_test import MockWorker
23
23
 
24
24
 
25
25
  @gen_test(timeout=20)
@@ -87,23 +87,38 @@ async def test_with_subset_of_cuda_visible_devices():
87
87
  }
88
88
 
89
89
 
90
+ @pytest.mark.parametrize(
91
+ "protocol",
92
+ ["ucx", "ucxx"],
93
+ )
90
94
  @gen_test(timeout=20)
91
- async def test_ucx_protocol():
92
- pytest.importorskip("ucp")
95
+ async def test_ucx_protocol(protocol):
96
+ if protocol == "ucx":
97
+ pytest.importorskip("ucp")
98
+ elif protocol == "ucxx":
99
+ pytest.importorskip("ucxx")
93
100
 
94
101
  async with LocalCUDACluster(
95
- protocol="ucx", asynchronous=True, data=dict
102
+ protocol=protocol, asynchronous=True, data=dict
96
103
  ) as cluster:
97
104
  assert all(
98
- ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values()
105
+ ws.address.startswith(f"{protocol}://")
106
+ for ws in cluster.scheduler.workers.values()
99
107
  )
100
108
 
101
109
 
110
+ @pytest.mark.parametrize(
111
+ "protocol",
112
+ ["ucx", "ucxx"],
113
+ )
102
114
  @gen_test(timeout=20)
103
- async def test_explicit_ucx_with_protocol_none():
104
- pytest.importorskip("ucp")
115
+ async def test_explicit_ucx_with_protocol_none(protocol):
116
+ if protocol == "ucx":
117
+ pytest.importorskip("ucp")
118
+ elif protocol == "ucxx":
119
+ pytest.importorskip("ucxx")
105
120
 
106
- initialize(enable_tcp_over_ucx=True)
121
+ initialize(protocol=protocol, enable_tcp_over_ucx=True)
107
122
  async with LocalCUDACluster(
108
123
  protocol=None, enable_tcp_over_ucx=True, asynchronous=True, data=dict
109
124
  ) as cluster:
@@ -113,11 +128,18 @@ async def test_explicit_ucx_with_protocol_none():
113
128
 
114
129
 
115
130
  @pytest.mark.filterwarnings("ignore:Exception ignored in")
131
+ @pytest.mark.parametrize(
132
+ "protocol",
133
+ ["ucx", "ucxx"],
134
+ )
116
135
  @gen_test(timeout=20)
117
- async def test_ucx_protocol_type_error():
118
- pytest.importorskip("ucp")
136
+ async def test_ucx_protocol_type_error(protocol):
137
+ if protocol == "ucx":
138
+ pytest.importorskip("ucp")
139
+ elif protocol == "ucxx":
140
+ pytest.importorskip("ucxx")
119
141
 
120
- initialize(enable_tcp_over_ucx=True)
142
+ initialize(protocol=protocol, enable_tcp_over_ucx=True)
121
143
  with pytest.raises(TypeError):
122
144
  async with LocalCUDACluster(
123
145
  protocol="tcp", enable_tcp_over_ucx=True, asynchronous=True, data=dict
@@ -337,6 +359,7 @@ async def test_pre_import():
337
359
 
338
360
 
339
361
  # Intentionally not using @gen_test to skip cleanup checks
362
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
340
363
  def test_pre_import_not_found():
341
364
  async def _test_pre_import_not_found():
342
365
  with raises_with_cause(RuntimeError, None, ImportError, None):
@@ -477,20 +500,30 @@ async def test_worker_fraction_limits():
477
500
  )
478
501
 
479
502
 
480
- def test_print_cluster_config(capsys):
503
+ @pytest.mark.parametrize(
504
+ "protocol",
505
+ ["ucx", "ucxx"],
506
+ )
507
+ def test_print_cluster_config(capsys, protocol):
508
+ if protocol == "ucx":
509
+ pytest.importorskip("ucp")
510
+ elif protocol == "ucxx":
511
+ pytest.importorskip("ucxx")
512
+
481
513
  pytest.importorskip("rich")
482
514
  with LocalCUDACluster(
483
- n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol="ucx"
515
+ n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol=protocol
484
516
  ) as cluster:
485
517
  with Client(cluster) as client:
486
518
  print_cluster_config(client)
487
519
  captured = capsys.readouterr()
488
520
  assert "Dask Cluster Configuration" in captured.out
489
- assert "ucx" in captured.out
521
+ assert protocol in captured.out
490
522
  assert "1 B" in captured.out
491
523
  assert "[plugin]" in captured.out
492
524
 
493
525
 
526
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
494
527
  def test_death_timeout_raises():
495
528
  with pytest.raises(asyncio.exceptions.TimeoutError):
496
529
  with LocalCUDACluster(
@@ -19,6 +19,7 @@ from dask_cuda.get_device_memory_objects import get_device_memory_ids
19
19
  from dask_cuda.proxify_host_file import ProxifyHostFile
20
20
  from dask_cuda.proxy_object import ProxyObject, asproxy, unproxy
21
21
  from dask_cuda.utils import get_device_total_memory
22
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
22
23
 
23
24
  cupy = pytest.importorskip("cupy")
24
25
  cupy.cuda.set_allocator(None)
@@ -393,7 +394,10 @@ async def test_compatibility_mode_dataframe_shuffle(compatibility_mode, npartiti
393
394
 
394
395
  with dask.config.set(jit_unspill_compatibility_mode=compatibility_mode):
395
396
  async with dask_cuda.LocalCUDACluster(
396
- n_workers=1, jit_unspill=True, asynchronous=True
397
+ n_workers=1,
398
+ jit_unspill=True,
399
+ worker_class=IncreasedCloseTimeoutNanny,
400
+ asynchronous=True,
397
401
  ) as cluster:
398
402
  async with Client(cluster, asynchronous=True) as client:
399
403
  ddf = dask.dataframe.from_pandas(
@@ -23,6 +23,7 @@ from dask_cuda import LocalCUDACluster, proxy_object
23
23
  from dask_cuda.disk_io import SpillToDiskFile
24
24
  from dask_cuda.proxify_device_objects import proxify_device_objects
25
25
  from dask_cuda.proxify_host_file import ProxifyHostFile
26
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
26
27
 
27
28
  # Make the "disk" serializer available and use a directory that are
28
29
  # remove on exit.
@@ -399,10 +400,14 @@ class _PxyObjTest(proxy_object.ProxyObject):
399
400
 
400
401
 
401
402
  @pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)])
402
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
403
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
403
404
  @gen_test(timeout=120)
404
405
  async def test_communicating_proxy_objects(protocol, send_serializers):
405
406
  """Testing serialization of cuDF dataframe when communicating"""
407
+ if protocol == "ucx":
408
+ pytest.importorskip("ucp")
409
+ elif protocol == "ucxx":
410
+ pytest.importorskip("ucxx")
406
411
  cudf = pytest.importorskip("cudf")
407
412
 
408
413
  def task(x):
@@ -411,7 +416,7 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
411
416
  serializers_used = x._pxy_get().serializer
412
417
 
413
418
  # Check that `x` is serialized with the expected serializers
414
- if protocol == "ucx":
419
+ if protocol in ["ucx", "ucxx"]:
415
420
  if send_serializers is None:
416
421
  assert serializers_used == "cuda"
417
422
  else:
@@ -422,6 +427,7 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
422
427
  async with dask_cuda.LocalCUDACluster(
423
428
  n_workers=1,
424
429
  protocol=protocol,
430
+ worker_class=IncreasedCloseTimeoutNanny,
425
431
  asynchronous=True,
426
432
  ) as cluster:
427
433
  async with Client(cluster, asynchronous=True) as client:
@@ -441,11 +447,15 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
441
447
  await client.submit(task, df)
442
448
 
443
449
 
444
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
450
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
445
451
  @pytest.mark.parametrize("shared_fs", [True, False])
446
452
  @gen_test(timeout=20)
447
453
  async def test_communicating_disk_objects(protocol, shared_fs):
448
454
  """Testing disk serialization of cuDF dataframe when communicating"""
455
+ if protocol == "ucx":
456
+ pytest.importorskip("ucp")
457
+ elif protocol == "ucxx":
458
+ pytest.importorskip("ucxx")
449
459
  cudf = pytest.importorskip("cudf")
450
460
  ProxifyHostFile._spill_to_disk.shared_filesystem = shared_fs
451
461
 
@@ -12,6 +12,7 @@ from distributed.sizeof import sizeof
12
12
  from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401
13
13
 
14
14
  from dask_cuda import LocalCUDACluster, utils
15
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
15
16
 
16
17
  if utils.get_device_total_memory() < 1e10:
17
18
  pytest.skip("Not enough GPU memory", allow_module_level=True)
@@ -160,6 +161,7 @@ async def test_cupy_cluster_device_spill(params):
160
161
  asynchronous=True,
161
162
  device_memory_limit=params["device_memory_limit"],
162
163
  memory_limit=params["memory_limit"],
164
+ worker_class=IncreasedCloseTimeoutNanny,
163
165
  ) as cluster:
164
166
  async with Client(cluster, asynchronous=True) as client:
165
167
 
@@ -263,6 +265,7 @@ async def test_cudf_cluster_device_spill(params):
263
265
  asynchronous=True,
264
266
  device_memory_limit=params["device_memory_limit"],
265
267
  memory_limit=params["memory_limit"],
268
+ worker_class=IncreasedCloseTimeoutNanny,
266
269
  ) as cluster:
267
270
  async with Client(cluster, asynchronous=True) as client:
268
271
 
@@ -79,11 +79,18 @@ def test_get_device_total_memory():
79
79
  assert total_mem > 0
80
80
 
81
81
 
82
- def test_get_preload_options_default():
83
- pytest.importorskip("ucp")
82
+ @pytest.mark.parametrize(
83
+ "protocol",
84
+ ["ucx", "ucxx"],
85
+ )
86
+ def test_get_preload_options_default(protocol):
87
+ if protocol == "ucx":
88
+ pytest.importorskip("ucp")
89
+ elif protocol == "ucxx":
90
+ pytest.importorskip("ucxx")
84
91
 
85
92
  opts = get_preload_options(
86
- protocol="ucx",
93
+ protocol=protocol,
87
94
  create_cuda_context=True,
88
95
  )
89
96
 
@@ -93,14 +100,21 @@ def test_get_preload_options_default():
93
100
  assert opts["preload_argv"] == ["--create-cuda-context"]
94
101
 
95
102
 
103
+ @pytest.mark.parametrize(
104
+ "protocol",
105
+ ["ucx", "ucxx"],
106
+ )
96
107
  @pytest.mark.parametrize("enable_tcp", [True, False])
97
108
  @pytest.mark.parametrize("enable_infiniband", [True, False])
98
109
  @pytest.mark.parametrize("enable_nvlink", [True, False])
99
- def test_get_preload_options(enable_tcp, enable_infiniband, enable_nvlink):
100
- pytest.importorskip("ucp")
110
+ def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvlink):
111
+ if protocol == "ucx":
112
+ pytest.importorskip("ucp")
113
+ elif protocol == "ucxx":
114
+ pytest.importorskip("ucxx")
101
115
 
102
116
  opts = get_preload_options(
103
- protocol="ucx",
117
+ protocol=protocol,
104
118
  create_cuda_context=True,
105
119
  enable_tcp_over_ucx=enable_tcp,
106
120
  enable_infiniband=enable_infiniband,
dask_cuda/utils.py CHANGED
@@ -1,4 +1,3 @@
1
- import importlib
2
1
  import math
3
2
  import operator
4
3
  import os
@@ -18,7 +17,7 @@ import dask
18
17
  import distributed # noqa: required for dask.config.get("distributed.comm.ucx")
19
18
  from dask.config import canonical_name
20
19
  from dask.utils import format_bytes, parse_bytes
21
- from distributed import Worker, WorkerPlugin, wait
20
+ from distributed import wait
22
21
  from distributed.comm import parse_address
23
22
 
24
23
  try:
@@ -32,122 +31,6 @@ except ImportError:
32
31
  yield
33
32
 
34
33
 
35
- class CPUAffinity(WorkerPlugin):
36
- def __init__(self, cores):
37
- self.cores = cores
38
-
39
- def setup(self, worker=None):
40
- os.sched_setaffinity(0, self.cores)
41
-
42
-
43
- class RMMSetup(WorkerPlugin):
44
- def __init__(
45
- self,
46
- initial_pool_size,
47
- maximum_pool_size,
48
- managed_memory,
49
- async_alloc,
50
- release_threshold,
51
- log_directory,
52
- track_allocations,
53
- ):
54
- if initial_pool_size is None and maximum_pool_size is not None:
55
- raise ValueError(
56
- "`rmm_maximum_pool_size` was specified without specifying "
57
- "`rmm_pool_size`.`rmm_pool_size` must be specified to use RMM pool."
58
- )
59
- if async_alloc is True:
60
- if managed_memory is True:
61
- raise ValueError(
62
- "`rmm_managed_memory` is incompatible with the `rmm_async`."
63
- )
64
- if async_alloc is False and release_threshold is not None:
65
- raise ValueError("`rmm_release_threshold` requires `rmm_async`.")
66
-
67
- self.initial_pool_size = initial_pool_size
68
- self.maximum_pool_size = maximum_pool_size
69
- self.managed_memory = managed_memory
70
- self.async_alloc = async_alloc
71
- self.release_threshold = release_threshold
72
- self.logging = log_directory is not None
73
- self.log_directory = log_directory
74
- self.rmm_track_allocations = track_allocations
75
-
76
- def setup(self, worker=None):
77
- if self.initial_pool_size is not None:
78
- self.initial_pool_size = parse_device_memory_limit(
79
- self.initial_pool_size, alignment_size=256
80
- )
81
-
82
- if self.async_alloc:
83
- import rmm
84
-
85
- if self.release_threshold is not None:
86
- self.release_threshold = parse_device_memory_limit(
87
- self.release_threshold, alignment_size=256
88
- )
89
-
90
- mr = rmm.mr.CudaAsyncMemoryResource(
91
- initial_pool_size=self.initial_pool_size,
92
- release_threshold=self.release_threshold,
93
- )
94
-
95
- if self.maximum_pool_size is not None:
96
- self.maximum_pool_size = parse_device_memory_limit(
97
- self.maximum_pool_size, alignment_size=256
98
- )
99
- mr = rmm.mr.LimitingResourceAdaptor(
100
- mr, allocation_limit=self.maximum_pool_size
101
- )
102
-
103
- rmm.mr.set_current_device_resource(mr)
104
- if self.logging:
105
- rmm.enable_logging(
106
- log_file_name=get_rmm_log_file_name(
107
- worker, self.logging, self.log_directory
108
- )
109
- )
110
- elif self.initial_pool_size is not None or self.managed_memory:
111
- import rmm
112
-
113
- pool_allocator = False if self.initial_pool_size is None else True
114
-
115
- if self.initial_pool_size is not None:
116
- if self.maximum_pool_size is not None:
117
- self.maximum_pool_size = parse_device_memory_limit(
118
- self.maximum_pool_size, alignment_size=256
119
- )
120
-
121
- rmm.reinitialize(
122
- pool_allocator=pool_allocator,
123
- managed_memory=self.managed_memory,
124
- initial_pool_size=self.initial_pool_size,
125
- maximum_pool_size=self.maximum_pool_size,
126
- logging=self.logging,
127
- log_file_name=get_rmm_log_file_name(
128
- worker, self.logging, self.log_directory
129
- ),
130
- )
131
- if self.rmm_track_allocations:
132
- import rmm
133
-
134
- mr = rmm.mr.get_current_device_resource()
135
- rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))
136
-
137
-
138
- class PreImport(WorkerPlugin):
139
- def __init__(self, libraries):
140
- if libraries is None:
141
- libraries = []
142
- elif isinstance(libraries, str):
143
- libraries = libraries.split(",")
144
- self.libraries = libraries
145
-
146
- def setup(self, worker=None):
147
- for l in self.libraries:
148
- importlib.import_module(l)
149
-
150
-
151
34
  def unpack_bitmask(x, mask_bits=64):
152
35
  """Unpack a list of integers containing bitmasks.
153
36
 
@@ -404,7 +287,7 @@ def get_preload_options(
404
287
  if create_cuda_context:
405
288
  preload_options["preload_argv"].append("--create-cuda-context")
406
289
 
407
- if protocol == "ucx":
290
+ if protocol in ["ucx", "ucxx"]:
408
291
  initialize_ucx_argv = []
409
292
  if enable_tcp_over_ucx:
410
293
  initialize_ucx_argv.append("--enable-tcp-over-ucx")
@@ -669,27 +552,6 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
669
552
  return _align(int(device_memory_limit), alignment_size)
670
553
 
671
554
 
672
- class MockWorker(Worker):
673
- """Mock Worker class preventing NVML from getting used by SystemMonitor.
674
-
675
- By preventing the Worker from initializing NVML in the SystemMonitor, we can
676
- mock test multiple devices in `CUDA_VISIBLE_DEVICES` behavior with single-GPU
677
- machines.
678
- """
679
-
680
- def __init__(self, *args, **kwargs):
681
- distributed.diagnostics.nvml.device_get_count = MockWorker.device_get_count
682
- self._device_get_count = distributed.diagnostics.nvml.device_get_count
683
- super().__init__(*args, **kwargs)
684
-
685
- def __del__(self):
686
- distributed.diagnostics.nvml.device_get_count = self._device_get_count
687
-
688
- @staticmethod
689
- def device_get_count():
690
- return 0
691
-
692
-
693
555
  def get_gpu_uuid_from_index(device_index=0):
694
556
  """Get GPU UUID from CUDA device index.
695
557
 
@@ -763,6 +625,10 @@ def get_worker_config(dask_worker):
763
625
  import ucp
764
626
 
765
627
  ret["ucx-transports"] = ucp.get_active_transports()
628
+ elif scheme == "ucxx":
629
+ import ucxx
630
+
631
+ ret["ucx-transports"] = ucxx.get_active_transports()
766
632
 
767
633
  # comm timeouts
768
634
  ret["distributed.comm.timeouts"] = dask.config.get("distributed.comm.timeouts")
@@ -0,0 +1,45 @@
1
+ from typing import Literal
2
+
3
+ import distributed
4
+ from distributed import Nanny, Worker
5
+
6
+
7
+ class MockWorker(Worker):
8
+ """Mock Worker class preventing NVML from getting used by SystemMonitor.
9
+
10
+ By preventing the Worker from initializing NVML in the SystemMonitor, we can
11
+ mock test multiple devices in `CUDA_VISIBLE_DEVICES` behavior with single-GPU
12
+ machines.
13
+ """
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ distributed.diagnostics.nvml.device_get_count = MockWorker.device_get_count
17
+ self._device_get_count = distributed.diagnostics.nvml.device_get_count
18
+ super().__init__(*args, **kwargs)
19
+
20
+ def __del__(self):
21
+ distributed.diagnostics.nvml.device_get_count = self._device_get_count
22
+
23
+ @staticmethod
24
+ def device_get_count():
25
+ return 0
26
+
27
+
28
+ class IncreasedCloseTimeoutNanny(Nanny):
29
+ """Increase `Nanny`'s close timeout.
30
+
31
+ The internal close timeout mechanism of `Nanny` recomputes the time left to kill
32
+ the `Worker` process based on elapsed time of the close task, which may leave
33
+ very little time for the subprocess to shutdown cleanly, which may cause tests
34
+ to fail when the system is under higher load. This class increases the default
35
+ close timeout of 5.0 seconds that `Nanny` sets by default, which can be overriden
36
+ via Distributed's public API.
37
+
38
+ This class can be used with the `worker_class` argument of `LocalCluster` or
39
+ `LocalCUDACluster` to provide a much higher default of 30.0 seconds.
40
+ """
41
+
42
+ async def close( # type:ignore[override]
43
+ self, timeout: float = 30.0, reason: str = "nanny-close"
44
+ ) -> Literal["OK"]:
45
+ return await super().close(timeout=timeout, reason=reason)
dask_cuda/worker_spec.py CHANGED
@@ -5,7 +5,8 @@ from distributed.system import MEMORY_LIMIT
5
5
 
6
6
  from .initialize import initialize
7
7
  from .local_cuda_cluster import cuda_visible_devices
8
- from .utils import CPUAffinity, get_cpu_affinity, get_gpu_count
8
+ from .plugins import CPUAffinity
9
+ from .utils import get_cpu_affinity, get_gpu_count
9
10
 
10
11
 
11
12
  def worker_spec(