dask-cuda 23.10.0a231016__py3-none-any.whl → 23.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. dask_cuda/VERSION +1 -0
  2. dask_cuda/__init__.py +1 -3
  3. dask_cuda/_version.py +20 -0
  4. dask_cuda/benchmarks/local_cudf_groupby.py +1 -1
  5. dask_cuda/benchmarks/local_cudf_merge.py +1 -1
  6. dask_cuda/benchmarks/local_cudf_shuffle.py +1 -1
  7. dask_cuda/benchmarks/local_cupy.py +1 -1
  8. dask_cuda/benchmarks/local_cupy_map_overlap.py +1 -1
  9. dask_cuda/benchmarks/utils.py +1 -1
  10. dask_cuda/cuda_worker.py +1 -3
  11. dask_cuda/device_host_file.py +1 -1
  12. dask_cuda/initialize.py +47 -16
  13. dask_cuda/local_cuda_cluster.py +18 -12
  14. dask_cuda/plugins.py +122 -0
  15. dask_cuda/tests/test_dask_cuda_worker.py +3 -3
  16. dask_cuda/tests/test_dgx.py +45 -17
  17. dask_cuda/tests/test_explicit_comms.py +11 -4
  18. dask_cuda/tests/test_from_array.py +6 -2
  19. dask_cuda/tests/test_initialize.py +69 -21
  20. dask_cuda/tests/test_local_cuda_cluster.py +47 -14
  21. dask_cuda/tests/test_proxify_host_file.py +5 -1
  22. dask_cuda/tests/test_proxy.py +14 -3
  23. dask_cuda/tests/test_spill.py +67 -30
  24. dask_cuda/tests/test_utils.py +20 -6
  25. dask_cuda/utils.py +6 -140
  26. dask_cuda/utils_test.py +45 -0
  27. dask_cuda/worker_spec.py +2 -1
  28. {dask_cuda-23.10.0a231016.dist-info → dask_cuda-23.12.0.dist-info}/METADATA +4 -4
  29. dask_cuda-23.12.0.dist-info/RECORD +53 -0
  30. {dask_cuda-23.10.0a231016.dist-info → dask_cuda-23.12.0.dist-info}/WHEEL +1 -1
  31. dask_cuda/compat.py +0 -118
  32. dask_cuda-23.10.0a231016.dist-info/RECORD +0 -50
  33. {dask_cuda-23.10.0a231016.dist-info → dask_cuda-23.12.0.dist-info}/LICENSE +0 -0
  34. {dask_cuda-23.10.0a231016.dist-info → dask_cuda-23.12.0.dist-info}/entry_points.txt +0 -0
  35. {dask_cuda-23.10.0a231016.dist-info → dask_cuda-23.12.0.dist-info}/top_level.txt +0 -0
dask_cuda/VERSION ADDED
@@ -0,0 +1 @@
1
+ 23.12.00
dask_cuda/__init__.py CHANGED
@@ -11,6 +11,7 @@ import dask.dataframe.shuffle
11
11
  import dask.dataframe.multi
12
12
  import dask.bag.core
13
13
 
14
+ from ._version import __git_commit__, __version__
14
15
  from .cuda_worker import CUDAWorker
15
16
  from .explicit_comms.dataframe.shuffle import (
16
17
  get_rearrange_by_column_wrapper,
@@ -19,9 +20,6 @@ from .explicit_comms.dataframe.shuffle import (
19
20
  from .local_cuda_cluster import LocalCUDACluster
20
21
  from .proxify_device_objects import proxify_decorator, unproxify_decorator
21
22
 
22
- __version__ = "23.10.00"
23
-
24
- from . import compat
25
23
 
26
24
  # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
27
25
  dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper(
dask_cuda/_version.py ADDED
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2023, NVIDIA CORPORATION.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib.resources
16
+
17
+ __version__ = (
18
+ importlib.resources.files("dask_cuda").joinpath("VERSION").read_text().strip()
19
+ )
20
+ __git_commit__ = "e1638aec4fbbc3dd2155c45c3ae3df293efa35c8"
@@ -139,7 +139,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
139
139
  key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
140
140
  )
141
141
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
142
- if args.protocol == "ucx":
142
+ if args.protocol in ["ucx", "ucxx"]:
143
143
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
144
144
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
145
145
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -217,7 +217,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
217
217
  )
218
218
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
219
219
  print_key_value(key="Frac-match", value=f"{args.frac_match}")
220
- if args.protocol == "ucx":
220
+ if args.protocol in ["ucx", "ucxx"]:
221
221
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
222
222
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
223
223
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -146,7 +146,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
146
146
  key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
147
147
  )
148
148
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
149
- if args.protocol == "ucx":
149
+ if args.protocol in ["ucx", "ucxx"]:
150
150
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
151
151
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
152
152
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -193,7 +193,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
193
193
  )
194
194
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
195
195
  print_key_value(key="Protocol", value=f"{args.protocol}")
196
- if args.protocol == "ucx":
196
+ if args.protocol in ["ucx", "ucxx"]:
197
197
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
198
198
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
199
199
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -78,7 +78,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
78
78
  )
79
79
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
80
80
  print_key_value(key="Protocol", value=f"{args.protocol}")
81
- if args.protocol == "ucx":
81
+ if args.protocol in ["ucx", "ucxx"]:
82
82
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
83
83
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
84
84
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -73,7 +73,7 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[]
73
73
  cluster_args.add_argument(
74
74
  "-p",
75
75
  "--protocol",
76
- choices=["tcp", "ucx"],
76
+ choices=["tcp", "ucx", "ucxx"],
77
77
  default="tcp",
78
78
  type=str,
79
79
  help="The communication protocol to use.",
dask_cuda/cuda_worker.py CHANGED
@@ -20,11 +20,9 @@ from distributed.worker_memory import parse_memory_limit
20
20
 
21
21
  from .device_host_file import DeviceHostFile
22
22
  from .initialize import initialize
23
+ from .plugins import CPUAffinity, PreImport, RMMSetup
23
24
  from .proxify_host_file import ProxifyHostFile
24
25
  from .utils import (
25
- CPUAffinity,
26
- PreImport,
27
- RMMSetup,
28
26
  cuda_visible_devices,
29
27
  get_cpu_affinity,
30
28
  get_n_gpus,
@@ -17,7 +17,7 @@ from distributed.protocol import (
17
17
  serialize_bytelist,
18
18
  )
19
19
  from distributed.sizeof import safe_sizeof
20
- from distributed.spill import CustomFile as KeyAsStringFile
20
+ from distributed.spill import AnyKeyFile as KeyAsStringFile
21
21
  from distributed.utils import nbytes
22
22
 
23
23
  from .is_device_object import is_device_object
dask_cuda/initialize.py CHANGED
@@ -5,7 +5,6 @@ import click
5
5
  import numba.cuda
6
6
 
7
7
  import dask
8
- import distributed.comm.ucx
9
8
  from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context
10
9
 
11
10
  from .utils import get_ucx_config
@@ -23,12 +22,21 @@ def _create_cuda_context_handler():
23
22
  numba.cuda.current_context()
24
23
 
25
24
 
26
- def _create_cuda_context():
25
+ def _create_cuda_context(protocol="ucx"):
26
+ if protocol not in ["ucx", "ucxx"]:
27
+ return
27
28
  try:
28
29
  # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
29
30
  # context directly from the UCX module, thus avoiding a similar warning there.
30
31
  try:
31
- distributed.comm.ucx.init_once()
32
+ if protocol == "ucx":
33
+ import distributed.comm.ucx
34
+
35
+ distributed.comm.ucx.init_once()
36
+ elif protocol == "ucxx":
37
+ import distributed_ucxx.ucxx
38
+
39
+ distributed_ucxx.ucxx.init_once()
32
40
  except ModuleNotFoundError:
33
41
  # UCX initialization has to be delegated to Distributed, it will take care
34
42
  # of setting correct environment variables and importing `ucp` after that.
@@ -39,20 +47,35 @@ def _create_cuda_context():
39
47
  os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
40
48
  )
41
49
  ctx = has_cuda_context()
42
- if (
43
- ctx.has_context
44
- and not distributed.comm.ucx.cuda_context_created.has_context
45
- ):
46
- distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
50
+ if protocol == "ucx":
51
+ if (
52
+ ctx.has_context
53
+ and not distributed.comm.ucx.cuda_context_created.has_context
54
+ ):
55
+ distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
56
+ elif protocol == "ucxx":
57
+ if (
58
+ ctx.has_context
59
+ and not distributed_ucxx.ucxx.cuda_context_created.has_context
60
+ ):
61
+ distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
47
62
 
48
63
  _create_cuda_context_handler()
49
64
 
50
- if not distributed.comm.ucx.cuda_context_created.has_context:
51
- ctx = has_cuda_context()
52
- if ctx.has_context and ctx.device_info != cuda_visible_device:
53
- distributed.comm.ucx._warn_cuda_context_wrong_device(
54
- cuda_visible_device, ctx.device_info, os.getpid()
55
- )
65
+ if protocol == "ucx":
66
+ if not distributed.comm.ucx.cuda_context_created.has_context:
67
+ ctx = has_cuda_context()
68
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
69
+ distributed.comm.ucx._warn_cuda_context_wrong_device(
70
+ cuda_visible_device, ctx.device_info, os.getpid()
71
+ )
72
+ elif protocol == "ucxx":
73
+ if not distributed_ucxx.ucxx.cuda_context_created.has_context:
74
+ ctx = has_cuda_context()
75
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
76
+ distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
77
+ cuda_visible_device, ctx.device_info, os.getpid()
78
+ )
56
79
 
57
80
  except Exception:
58
81
  logger.error("Unable to start CUDA Context", exc_info=True)
@@ -64,6 +87,7 @@ def initialize(
64
87
  enable_infiniband=None,
65
88
  enable_nvlink=None,
66
89
  enable_rdmacm=None,
90
+ protocol="ucx",
67
91
  ):
68
92
  """Create CUDA context and initialize UCX-Py, depending on user parameters.
69
93
 
@@ -118,7 +142,7 @@ def initialize(
118
142
  dask.config.set({"distributed.comm.ucx": ucx_config})
119
143
 
120
144
  if create_cuda_context:
121
- _create_cuda_context()
145
+ _create_cuda_context(protocol=protocol)
122
146
 
123
147
 
124
148
  @click.command()
@@ -127,6 +151,12 @@ def initialize(
127
151
  default=False,
128
152
  help="Create CUDA context",
129
153
  )
154
+ @click.option(
155
+ "--protocol",
156
+ default=None,
157
+ type=str,
158
+ help="Communication protocol, such as: 'tcp', 'tls', 'ucx' or 'ucxx'.",
159
+ )
130
160
  @click.option(
131
161
  "--enable-tcp-over-ucx/--disable-tcp-over-ucx",
132
162
  default=False,
@@ -150,10 +180,11 @@ def initialize(
150
180
  def dask_setup(
151
181
  service,
152
182
  create_cuda_context,
183
+ protocol,
153
184
  enable_tcp_over_ucx,
154
185
  enable_infiniband,
155
186
  enable_nvlink,
156
187
  enable_rdmacm,
157
188
  ):
158
189
  if create_cuda_context:
159
- _create_cuda_context()
190
+ _create_cuda_context(protocol=protocol)
@@ -2,6 +2,7 @@ import copy
2
2
  import logging
3
3
  import os
4
4
  import warnings
5
+ from functools import partial
5
6
 
6
7
  import dask
7
8
  from distributed import LocalCluster, Nanny, Worker
@@ -9,11 +10,9 @@ from distributed.worker_memory import parse_memory_limit
9
10
 
10
11
  from .device_host_file import DeviceHostFile
11
12
  from .initialize import initialize
13
+ from .plugins import CPUAffinity, PreImport, RMMSetup
12
14
  from .proxify_host_file import ProxifyHostFile
13
15
  from .utils import (
14
- CPUAffinity,
15
- PreImport,
16
- RMMSetup,
17
16
  cuda_visible_devices,
18
17
  get_cpu_affinity,
19
18
  get_ucx_config,
@@ -320,8 +319,11 @@ class LocalCUDACluster(LocalCluster):
320
319
  if enable_tcp_over_ucx or enable_infiniband or enable_nvlink:
321
320
  if protocol is None:
322
321
  protocol = "ucx"
323
- elif protocol != "ucx":
324
- raise TypeError("Enabling InfiniBand or NVLink requires protocol='ucx'")
322
+ elif protocol not in ["ucx", "ucxx"]:
323
+ raise TypeError(
324
+ "Enabling InfiniBand or NVLink requires protocol='ucx' or "
325
+ "protocol='ucxx'"
326
+ )
325
327
 
326
328
  self.host = kwargs.get("host", None)
327
329
 
@@ -334,12 +336,16 @@ class LocalCUDACluster(LocalCluster):
334
336
  )
335
337
 
336
338
  if worker_class is not None:
337
- from functools import partial
338
-
339
- worker_class = partial(
340
- LoggedNanny if log_spilling is True else Nanny,
341
- worker_class=worker_class,
342
- )
339
+ if log_spilling is True:
340
+ raise ValueError(
341
+ "Cannot enable `log_spilling` when `worker_class` is specified. If "
342
+ "logging is needed, ensure `worker_class` is a subclass of "
343
+ "`distributed.local_cuda_cluster.LoggedNanny` or a subclass of "
344
+ "`distributed.local_cuda_cluster.LoggedWorker`, and specify "
345
+ "`log_spilling=False`."
346
+ )
347
+ if not issubclass(worker_class, Nanny):
348
+ worker_class = partial(Nanny, worker_class=worker_class)
343
349
 
344
350
  self.pre_import = pre_import
345
351
 
@@ -368,7 +374,7 @@ class LocalCUDACluster(LocalCluster):
368
374
  ) + ["dask_cuda.initialize"]
369
375
  self.new_spec["options"]["preload_argv"] = self.new_spec["options"].get(
370
376
  "preload_argv", []
371
- ) + ["--create-cuda-context"]
377
+ ) + ["--create-cuda-context", "--protocol", protocol]
372
378
 
373
379
  self.cuda_visible_devices = CUDA_VISIBLE_DEVICES
374
380
  self.scale(n_workers)
dask_cuda/plugins.py ADDED
@@ -0,0 +1,122 @@
1
+ import importlib
2
+ import os
3
+
4
+ from distributed import WorkerPlugin
5
+
6
+ from .utils import get_rmm_log_file_name, parse_device_memory_limit
7
+
8
+
9
+ class CPUAffinity(WorkerPlugin):
10
+ def __init__(self, cores):
11
+ self.cores = cores
12
+
13
+ def setup(self, worker=None):
14
+ os.sched_setaffinity(0, self.cores)
15
+
16
+
17
+ class RMMSetup(WorkerPlugin):
18
+ def __init__(
19
+ self,
20
+ initial_pool_size,
21
+ maximum_pool_size,
22
+ managed_memory,
23
+ async_alloc,
24
+ release_threshold,
25
+ log_directory,
26
+ track_allocations,
27
+ ):
28
+ if initial_pool_size is None and maximum_pool_size is not None:
29
+ raise ValueError(
30
+ "`rmm_maximum_pool_size` was specified without specifying "
31
+ "`rmm_pool_size`.`rmm_pool_size` must be specified to use RMM pool."
32
+ )
33
+ if async_alloc is True:
34
+ if managed_memory is True:
35
+ raise ValueError(
36
+ "`rmm_managed_memory` is incompatible with the `rmm_async`."
37
+ )
38
+ if async_alloc is False and release_threshold is not None:
39
+ raise ValueError("`rmm_release_threshold` requires `rmm_async`.")
40
+
41
+ self.initial_pool_size = initial_pool_size
42
+ self.maximum_pool_size = maximum_pool_size
43
+ self.managed_memory = managed_memory
44
+ self.async_alloc = async_alloc
45
+ self.release_threshold = release_threshold
46
+ self.logging = log_directory is not None
47
+ self.log_directory = log_directory
48
+ self.rmm_track_allocations = track_allocations
49
+
50
+ def setup(self, worker=None):
51
+ if self.initial_pool_size is not None:
52
+ self.initial_pool_size = parse_device_memory_limit(
53
+ self.initial_pool_size, alignment_size=256
54
+ )
55
+
56
+ if self.async_alloc:
57
+ import rmm
58
+
59
+ if self.release_threshold is not None:
60
+ self.release_threshold = parse_device_memory_limit(
61
+ self.release_threshold, alignment_size=256
62
+ )
63
+
64
+ mr = rmm.mr.CudaAsyncMemoryResource(
65
+ initial_pool_size=self.initial_pool_size,
66
+ release_threshold=self.release_threshold,
67
+ )
68
+
69
+ if self.maximum_pool_size is not None:
70
+ self.maximum_pool_size = parse_device_memory_limit(
71
+ self.maximum_pool_size, alignment_size=256
72
+ )
73
+ mr = rmm.mr.LimitingResourceAdaptor(
74
+ mr, allocation_limit=self.maximum_pool_size
75
+ )
76
+
77
+ rmm.mr.set_current_device_resource(mr)
78
+ if self.logging:
79
+ rmm.enable_logging(
80
+ log_file_name=get_rmm_log_file_name(
81
+ worker, self.logging, self.log_directory
82
+ )
83
+ )
84
+ elif self.initial_pool_size is not None or self.managed_memory:
85
+ import rmm
86
+
87
+ pool_allocator = False if self.initial_pool_size is None else True
88
+
89
+ if self.initial_pool_size is not None:
90
+ if self.maximum_pool_size is not None:
91
+ self.maximum_pool_size = parse_device_memory_limit(
92
+ self.maximum_pool_size, alignment_size=256
93
+ )
94
+
95
+ rmm.reinitialize(
96
+ pool_allocator=pool_allocator,
97
+ managed_memory=self.managed_memory,
98
+ initial_pool_size=self.initial_pool_size,
99
+ maximum_pool_size=self.maximum_pool_size,
100
+ logging=self.logging,
101
+ log_file_name=get_rmm_log_file_name(
102
+ worker, self.logging, self.log_directory
103
+ ),
104
+ )
105
+ if self.rmm_track_allocations:
106
+ import rmm
107
+
108
+ mr = rmm.mr.get_current_device_resource()
109
+ rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))
110
+
111
+
112
+ class PreImport(WorkerPlugin):
113
+ def __init__(self, libraries):
114
+ if libraries is None:
115
+ libraries = []
116
+ elif isinstance(libraries, str):
117
+ libraries = libraries.split(",")
118
+ self.libraries = libraries
119
+
120
+ def setup(self, worker=None):
121
+ for l in self.libraries:
122
+ importlib.import_module(l)
@@ -40,7 +40,7 @@ def test_cuda_visible_devices_and_memory_limit_and_nthreads(loop): # noqa: F811
40
40
  str(nthreads),
41
41
  "--no-dashboard",
42
42
  "--worker-class",
43
- "dask_cuda.utils.MockWorker",
43
+ "dask_cuda.utils_test.MockWorker",
44
44
  ]
45
45
  ):
46
46
  with Client("127.0.0.1:9359", loop=loop) as client:
@@ -329,7 +329,7 @@ def test_cuda_mig_visible_devices_and_memory_limit_and_nthreads(loop): # noqa:
329
329
  str(nthreads),
330
330
  "--no-dashboard",
331
331
  "--worker-class",
332
- "dask_cuda.utils.MockWorker",
332
+ "dask_cuda.utils_test.MockWorker",
333
333
  ]
334
334
  ):
335
335
  with Client("127.0.0.1:9359", loop=loop) as client:
@@ -364,7 +364,7 @@ def test_cuda_visible_devices_uuid(loop): # noqa: F811
364
364
  "127.0.0.1",
365
365
  "--no-dashboard",
366
366
  "--worker-class",
367
- "dask_cuda.utils.MockWorker",
367
+ "dask_cuda.utils_test.MockWorker",
368
368
  ]
369
369
  ):
370
370
  with Client("127.0.0.1:9359", loop=loop) as client:
@@ -73,10 +73,13 @@ def test_default():
73
73
  assert not p.exitcode
74
74
 
75
75
 
76
- def _test_tcp_over_ucx():
77
- ucp = pytest.importorskip("ucp")
76
+ def _test_tcp_over_ucx(protocol):
77
+ if protocol == "ucx":
78
+ ucp = pytest.importorskip("ucp")
79
+ elif protocol == "ucxx":
80
+ ucp = pytest.importorskip("ucxx")
78
81
 
79
- with LocalCUDACluster(enable_tcp_over_ucx=True) as cluster:
82
+ with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster:
80
83
  with Client(cluster) as client:
81
84
  res = da.from_array(numpy.arange(10000), chunks=(1000,))
82
85
  res = res.sum().compute()
@@ -93,10 +96,17 @@ def _test_tcp_over_ucx():
93
96
  assert all(client.run(check_ucx_options).values())
94
97
 
95
98
 
96
- def test_tcp_over_ucx():
97
- ucp = pytest.importorskip("ucp") # NOQA: F841
99
+ @pytest.mark.parametrize(
100
+ "protocol",
101
+ ["ucx", "ucxx"],
102
+ )
103
+ def test_tcp_over_ucx(protocol):
104
+ if protocol == "ucx":
105
+ pytest.importorskip("ucp")
106
+ elif protocol == "ucxx":
107
+ pytest.importorskip("ucxx")
98
108
 
99
- p = mp.Process(target=_test_tcp_over_ucx)
109
+ p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,))
100
110
  p.start()
101
111
  p.join()
102
112
  assert not p.exitcode
@@ -117,9 +127,22 @@ def test_tcp_only():
117
127
  assert not p.exitcode
118
128
 
119
129
 
120
- def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm):
130
+ def _test_ucx_infiniband_nvlink(
131
+ skip_queue, protocol, enable_infiniband, enable_nvlink, enable_rdmacm
132
+ ):
121
133
  cupy = pytest.importorskip("cupy")
122
- ucp = pytest.importorskip("ucp")
134
+ if protocol == "ucx":
135
+ ucp = pytest.importorskip("ucp")
136
+ elif protocol == "ucxx":
137
+ ucp = pytest.importorskip("ucxx")
138
+
139
+ if enable_infiniband and not any(
140
+ [at.startswith("rc") for at in ucp.get_active_transports()]
141
+ ):
142
+ skip_queue.put("No support available for 'rc' transport in UCX")
143
+ return
144
+ else:
145
+ skip_queue.put("ok")
123
146
 
124
147
  if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
125
148
  enable_tcp_over_ucx = None
@@ -135,6 +158,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
135
158
  cm_tls_priority = ["tcp"]
136
159
 
137
160
  initialize(
161
+ protocol=protocol,
138
162
  enable_tcp_over_ucx=enable_tcp_over_ucx,
139
163
  enable_infiniband=enable_infiniband,
140
164
  enable_nvlink=enable_nvlink,
@@ -142,6 +166,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
142
166
  )
143
167
 
144
168
  with LocalCUDACluster(
169
+ protocol=protocol,
145
170
  interface="ib0",
146
171
  enable_tcp_over_ucx=enable_tcp_over_ucx,
147
172
  enable_infiniband=enable_infiniband,
@@ -171,6 +196,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
171
196
  assert all(client.run(check_ucx_options).values())
172
197
 
173
198
 
199
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
174
200
  @pytest.mark.parametrize(
175
201
  "params",
176
202
  [
@@ -185,16 +211,19 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
185
211
  _get_dgx_version() == DGXVersion.DGX_A100,
186
212
  reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
187
213
  )
188
- def test_ucx_infiniband_nvlink(params):
189
- ucp = pytest.importorskip("ucp") # NOQA: F841
214
+ def test_ucx_infiniband_nvlink(protocol, params):
215
+ if protocol == "ucx":
216
+ pytest.importorskip("ucp")
217
+ elif protocol == "ucxx":
218
+ pytest.importorskip("ucxx")
190
219
 
191
- if params["enable_infiniband"]:
192
- if not any([at.startswith("rc") for at in ucp.get_active_transports()]):
193
- pytest.skip("No support available for 'rc' transport in UCX")
220
+ skip_queue = mp.Queue()
194
221
 
195
222
  p = mp.Process(
196
223
  target=_test_ucx_infiniband_nvlink,
197
224
  args=(
225
+ skip_queue,
226
+ protocol,
198
227
  params["enable_infiniband"],
199
228
  params["enable_nvlink"],
200
229
  params["enable_rdmacm"],
@@ -203,9 +232,8 @@ def test_ucx_infiniband_nvlink(params):
203
232
  p.start()
204
233
  p.join()
205
234
 
206
- # Starting a new cluster on the same pytest process after an rdmacm cluster
207
- # has been used may cause UCX-Py to complain about being already initialized.
208
- if params["enable_rdmacm"] is True:
209
- ucp.reset()
235
+ skip_msg = skip_queue.get()
236
+ if skip_msg != "ok":
237
+ pytest.skip(skip_msg)
210
238
 
211
239
  assert not p.exitcode