dask-cuda 23.12.0a231027__py3-none-any.whl → 24.2.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dask_cuda/VERSION ADDED
@@ -0,0 +1 @@
1
+ 24.02.00a3
dask_cuda/__init__.py CHANGED
@@ -11,6 +11,7 @@ import dask.dataframe.shuffle
11
11
  import dask.dataframe.multi
12
12
  import dask.bag.core
13
13
 
14
+ from ._version import __git_commit__, __version__
14
15
  from .cuda_worker import CUDAWorker
15
16
  from .explicit_comms.dataframe.shuffle import (
16
17
  get_rearrange_by_column_wrapper,
@@ -19,9 +20,6 @@ from .explicit_comms.dataframe.shuffle import (
19
20
  from .local_cuda_cluster import LocalCUDACluster
20
21
  from .proxify_device_objects import proxify_decorator, unproxify_decorator
21
22
 
22
- __version__ = "23.12.00"
23
-
24
- from . import compat
25
23
 
26
24
  # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
27
25
  dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper(
dask_cuda/_version.py ADDED
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2023, NVIDIA CORPORATION.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib.resources
16
+
17
+ __version__ = (
18
+ importlib.resources.files("dask_cuda").joinpath("VERSION").read_text().strip()
19
+ )
20
+ __git_commit__ = ""
@@ -139,7 +139,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
139
139
  key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
140
140
  )
141
141
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
142
- if args.protocol == "ucx":
142
+ if args.protocol in ["ucx", "ucxx"]:
143
143
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
144
144
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
145
145
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -217,7 +217,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
217
217
  )
218
218
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
219
219
  print_key_value(key="Frac-match", value=f"{args.frac_match}")
220
- if args.protocol == "ucx":
220
+ if args.protocol in ["ucx", "ucxx"]:
221
221
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
222
222
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
223
223
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -146,7 +146,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
146
146
  key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
147
147
  )
148
148
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
149
- if args.protocol == "ucx":
149
+ if args.protocol in ["ucx", "ucxx"]:
150
150
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
151
151
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
152
152
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -193,7 +193,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
193
193
  )
194
194
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
195
195
  print_key_value(key="Protocol", value=f"{args.protocol}")
196
- if args.protocol == "ucx":
196
+ if args.protocol in ["ucx", "ucxx"]:
197
197
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
198
198
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
199
199
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -78,7 +78,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
78
78
  )
79
79
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
80
80
  print_key_value(key="Protocol", value=f"{args.protocol}")
81
- if args.protocol == "ucx":
81
+ if args.protocol in ["ucx", "ucxx"]:
82
82
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
83
83
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
84
84
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -73,7 +73,7 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[]
73
73
  cluster_args.add_argument(
74
74
  "-p",
75
75
  "--protocol",
76
- choices=["tcp", "ucx"],
76
+ choices=["tcp", "ucx", "ucxx"],
77
77
  default="tcp",
78
78
  type=str,
79
79
  help="The communication protocol to use.",
@@ -17,7 +17,7 @@ from distributed.protocol import (
17
17
  serialize_bytelist,
18
18
  )
19
19
  from distributed.sizeof import safe_sizeof
20
- from distributed.spill import CustomFile as KeyAsStringFile
20
+ from distributed.spill import AnyKeyFile as KeyAsStringFile
21
21
  from distributed.utils import nbytes
22
22
 
23
23
  from .is_device_object import is_device_object
dask_cuda/initialize.py CHANGED
@@ -5,7 +5,6 @@ import click
5
5
  import numba.cuda
6
6
 
7
7
  import dask
8
- import distributed.comm.ucx
9
8
  from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context
10
9
 
11
10
  from .utils import get_ucx_config
@@ -23,12 +22,21 @@ def _create_cuda_context_handler():
23
22
  numba.cuda.current_context()
24
23
 
25
24
 
26
- def _create_cuda_context():
25
+ def _create_cuda_context(protocol="ucx"):
26
+ if protocol not in ["ucx", "ucxx"]:
27
+ return
27
28
  try:
28
29
  # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
29
30
  # context directly from the UCX module, thus avoiding a similar warning there.
30
31
  try:
31
- distributed.comm.ucx.init_once()
32
+ if protocol == "ucx":
33
+ import distributed.comm.ucx
34
+
35
+ distributed.comm.ucx.init_once()
36
+ elif protocol == "ucxx":
37
+ import distributed_ucxx.ucxx
38
+
39
+ distributed_ucxx.ucxx.init_once()
32
40
  except ModuleNotFoundError:
33
41
  # UCX initialization has to be delegated to Distributed, it will take care
34
42
  # of setting correct environment variables and importing `ucp` after that.
@@ -39,20 +47,35 @@ def _create_cuda_context():
39
47
  os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
40
48
  )
41
49
  ctx = has_cuda_context()
42
- if (
43
- ctx.has_context
44
- and not distributed.comm.ucx.cuda_context_created.has_context
45
- ):
46
- distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
50
+ if protocol == "ucx":
51
+ if (
52
+ ctx.has_context
53
+ and not distributed.comm.ucx.cuda_context_created.has_context
54
+ ):
55
+ distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
56
+ elif protocol == "ucxx":
57
+ if (
58
+ ctx.has_context
59
+ and not distributed_ucxx.ucxx.cuda_context_created.has_context
60
+ ):
61
+ distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
47
62
 
48
63
  _create_cuda_context_handler()
49
64
 
50
- if not distributed.comm.ucx.cuda_context_created.has_context:
51
- ctx = has_cuda_context()
52
- if ctx.has_context and ctx.device_info != cuda_visible_device:
53
- distributed.comm.ucx._warn_cuda_context_wrong_device(
54
- cuda_visible_device, ctx.device_info, os.getpid()
55
- )
65
+ if protocol == "ucx":
66
+ if not distributed.comm.ucx.cuda_context_created.has_context:
67
+ ctx = has_cuda_context()
68
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
69
+ distributed.comm.ucx._warn_cuda_context_wrong_device(
70
+ cuda_visible_device, ctx.device_info, os.getpid()
71
+ )
72
+ elif protocol == "ucxx":
73
+ if not distributed_ucxx.ucxx.cuda_context_created.has_context:
74
+ ctx = has_cuda_context()
75
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
76
+ distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
77
+ cuda_visible_device, ctx.device_info, os.getpid()
78
+ )
56
79
 
57
80
  except Exception:
58
81
  logger.error("Unable to start CUDA Context", exc_info=True)
@@ -64,6 +87,7 @@ def initialize(
64
87
  enable_infiniband=None,
65
88
  enable_nvlink=None,
66
89
  enable_rdmacm=None,
90
+ protocol="ucx",
67
91
  ):
68
92
  """Create CUDA context and initialize UCX-Py, depending on user parameters.
69
93
 
@@ -118,7 +142,7 @@ def initialize(
118
142
  dask.config.set({"distributed.comm.ucx": ucx_config})
119
143
 
120
144
  if create_cuda_context:
121
- _create_cuda_context()
145
+ _create_cuda_context(protocol=protocol)
122
146
 
123
147
 
124
148
  @click.command()
@@ -127,6 +151,12 @@ def initialize(
127
151
  default=False,
128
152
  help="Create CUDA context",
129
153
  )
154
+ @click.option(
155
+ "--protocol",
156
+ default=None,
157
+ type=str,
158
+ help="Communication protocol, such as: 'tcp', 'tls', 'ucx' or 'ucxx'.",
159
+ )
130
160
  @click.option(
131
161
  "--enable-tcp-over-ucx/--disable-tcp-over-ucx",
132
162
  default=False,
@@ -150,10 +180,11 @@ def initialize(
150
180
  def dask_setup(
151
181
  service,
152
182
  create_cuda_context,
183
+ protocol,
153
184
  enable_tcp_over_ucx,
154
185
  enable_infiniband,
155
186
  enable_nvlink,
156
187
  enable_rdmacm,
157
188
  ):
158
189
  if create_cuda_context:
159
- _create_cuda_context()
190
+ _create_cuda_context(protocol=protocol)
@@ -319,8 +319,11 @@ class LocalCUDACluster(LocalCluster):
319
319
  if enable_tcp_over_ucx or enable_infiniband or enable_nvlink:
320
320
  if protocol is None:
321
321
  protocol = "ucx"
322
- elif protocol != "ucx":
323
- raise TypeError("Enabling InfiniBand or NVLink requires protocol='ucx'")
322
+ elif protocol not in ["ucx", "ucxx"]:
323
+ raise TypeError(
324
+ "Enabling InfiniBand or NVLink requires protocol='ucx' or "
325
+ "protocol='ucxx'"
326
+ )
324
327
 
325
328
  self.host = kwargs.get("host", None)
326
329
 
@@ -371,7 +374,7 @@ class LocalCUDACluster(LocalCluster):
371
374
  ) + ["dask_cuda.initialize"]
372
375
  self.new_spec["options"]["preload_argv"] = self.new_spec["options"].get(
373
376
  "preload_argv", []
374
- ) + ["--create-cuda-context"]
377
+ ) + ["--create-cuda-context", "--protocol", protocol]
375
378
 
376
379
  self.cuda_visible_devices = CUDA_VISIBLE_DEVICES
377
380
  self.scale(n_workers)
@@ -73,10 +73,13 @@ def test_default():
73
73
  assert not p.exitcode
74
74
 
75
75
 
76
- def _test_tcp_over_ucx():
77
- ucp = pytest.importorskip("ucp")
76
+ def _test_tcp_over_ucx(protocol):
77
+ if protocol == "ucx":
78
+ ucp = pytest.importorskip("ucp")
79
+ elif protocol == "ucxx":
80
+ ucp = pytest.importorskip("ucxx")
78
81
 
79
- with LocalCUDACluster(enable_tcp_over_ucx=True) as cluster:
82
+ with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster:
80
83
  with Client(cluster) as client:
81
84
  res = da.from_array(numpy.arange(10000), chunks=(1000,))
82
85
  res = res.sum().compute()
@@ -93,10 +96,17 @@ def _test_tcp_over_ucx():
93
96
  assert all(client.run(check_ucx_options).values())
94
97
 
95
98
 
96
- def test_tcp_over_ucx():
97
- ucp = pytest.importorskip("ucp") # NOQA: F841
99
+ @pytest.mark.parametrize(
100
+ "protocol",
101
+ ["ucx", "ucxx"],
102
+ )
103
+ def test_tcp_over_ucx(protocol):
104
+ if protocol == "ucx":
105
+ pytest.importorskip("ucp")
106
+ elif protocol == "ucxx":
107
+ pytest.importorskip("ucxx")
98
108
 
99
- p = mp.Process(target=_test_tcp_over_ucx)
109
+ p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,))
100
110
  p.start()
101
111
  p.join()
102
112
  assert not p.exitcode
@@ -117,9 +127,22 @@ def test_tcp_only():
117
127
  assert not p.exitcode
118
128
 
119
129
 
120
- def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm):
130
+ def _test_ucx_infiniband_nvlink(
131
+ skip_queue, protocol, enable_infiniband, enable_nvlink, enable_rdmacm
132
+ ):
121
133
  cupy = pytest.importorskip("cupy")
122
- ucp = pytest.importorskip("ucp")
134
+ if protocol == "ucx":
135
+ ucp = pytest.importorskip("ucp")
136
+ elif protocol == "ucxx":
137
+ ucp = pytest.importorskip("ucxx")
138
+
139
+ if enable_infiniband and not any(
140
+ [at.startswith("rc") for at in ucp.get_active_transports()]
141
+ ):
142
+ skip_queue.put("No support available for 'rc' transport in UCX")
143
+ return
144
+ else:
145
+ skip_queue.put("ok")
123
146
 
124
147
  if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
125
148
  enable_tcp_over_ucx = None
@@ -135,6 +158,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
135
158
  cm_tls_priority = ["tcp"]
136
159
 
137
160
  initialize(
161
+ protocol=protocol,
138
162
  enable_tcp_over_ucx=enable_tcp_over_ucx,
139
163
  enable_infiniband=enable_infiniband,
140
164
  enable_nvlink=enable_nvlink,
@@ -142,6 +166,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
142
166
  )
143
167
 
144
168
  with LocalCUDACluster(
169
+ protocol=protocol,
145
170
  interface="ib0",
146
171
  enable_tcp_over_ucx=enable_tcp_over_ucx,
147
172
  enable_infiniband=enable_infiniband,
@@ -171,6 +196,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
171
196
  assert all(client.run(check_ucx_options).values())
172
197
 
173
198
 
199
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
174
200
  @pytest.mark.parametrize(
175
201
  "params",
176
202
  [
@@ -185,16 +211,19 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
185
211
  _get_dgx_version() == DGXVersion.DGX_A100,
186
212
  reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
187
213
  )
188
- def test_ucx_infiniband_nvlink(params):
189
- ucp = pytest.importorskip("ucp") # NOQA: F841
214
+ def test_ucx_infiniband_nvlink(protocol, params):
215
+ if protocol == "ucx":
216
+ pytest.importorskip("ucp")
217
+ elif protocol == "ucxx":
218
+ pytest.importorskip("ucxx")
190
219
 
191
- if params["enable_infiniband"]:
192
- if not any([at.startswith("rc") for at in ucp.get_active_transports()]):
193
- pytest.skip("No support available for 'rc' transport in UCX")
220
+ skip_queue = mp.Queue()
194
221
 
195
222
  p = mp.Process(
196
223
  target=_test_ucx_infiniband_nvlink,
197
224
  args=(
225
+ skip_queue,
226
+ protocol,
198
227
  params["enable_infiniband"],
199
228
  params["enable_nvlink"],
200
229
  params["enable_rdmacm"],
@@ -203,9 +232,8 @@ def test_ucx_infiniband_nvlink(params):
203
232
  p.start()
204
233
  p.join()
205
234
 
206
- # Starting a new cluster on the same pytest process after an rdmacm cluster
207
- # has been used may cause UCX-Py to complain about being already initialized.
208
- if params["enable_rdmacm"] is True:
209
- ucp.reset()
235
+ skip_msg = skip_queue.get()
236
+ if skip_msg != "ok":
237
+ pytest.skip(skip_msg)
210
238
 
211
239
  assert not p.exitcode
@@ -44,7 +44,7 @@ def _test_local_cluster(protocol):
44
44
  assert sum(c.run(my_rank, 0)) == sum(range(4))
45
45
 
46
46
 
47
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
47
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
48
48
  def test_local_cluster(protocol):
49
49
  p = mp.Process(target=_test_local_cluster, args=(protocol,))
50
50
  p.start()
@@ -160,7 +160,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
160
160
 
161
161
  @pytest.mark.parametrize("nworkers", [1, 2, 3])
162
162
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
163
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
163
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
164
164
  @pytest.mark.parametrize("_partitions", [True, False])
165
165
  def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
166
166
  if backend == "cudf":
@@ -256,7 +256,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
256
256
 
257
257
  @pytest.mark.parametrize("nworkers", [1, 2, 4])
258
258
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
259
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
259
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
260
260
  def test_dataframe_shuffle_merge(backend, protocol, nworkers):
261
261
  if backend == "cudf":
262
262
  pytest.importorskip("cudf")
@@ -293,7 +293,7 @@ def _test_jit_unspill(protocol):
293
293
  assert_eq(got, expected)
294
294
 
295
295
 
296
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
296
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
297
297
  def test_jit_unspill(protocol):
298
298
  pytest.importorskip("cudf")
299
299
 
@@ -5,12 +5,16 @@ from distributed import Client
5
5
 
6
6
  from dask_cuda import LocalCUDACluster
7
7
 
8
- pytest.importorskip("ucp")
9
8
  cupy = pytest.importorskip("cupy")
10
9
 
11
10
 
12
- @pytest.mark.parametrize("protocol", ["ucx", "tcp"])
11
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx", "tcp"])
13
12
  def test_ucx_from_array(protocol):
13
+ if protocol == "ucx":
14
+ pytest.importorskip("ucp")
15
+ elif protocol == "ucxx":
16
+ pytest.importorskip("ucxx")
17
+
14
18
  N = 10_000
15
19
  with LocalCUDACluster(protocol=protocol) as cluster:
16
20
  with Client(cluster):
@@ -13,7 +13,6 @@ from dask_cuda.utils import get_ucx_config
13
13
  from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
14
14
 
15
15
  mp = mp.get_context("spawn") # type: ignore
16
- ucp = pytest.importorskip("ucp")
17
16
 
18
17
  # Notice, all of the following tests is executed in a new process such
19
18
  # that UCX options of the different tests doesn't conflict.
@@ -21,11 +20,16 @@ ucp = pytest.importorskip("ucp")
21
20
  # of UCX before retrieving the current config.
22
21
 
23
22
 
24
- def _test_initialize_ucx_tcp():
23
+ def _test_initialize_ucx_tcp(protocol):
24
+ if protocol == "ucx":
25
+ ucp = pytest.importorskip("ucp")
26
+ elif protocol == "ucxx":
27
+ ucp = pytest.importorskip("ucxx")
28
+
25
29
  kwargs = {"enable_tcp_over_ucx": True}
26
- initialize(**kwargs)
30
+ initialize(protocol=protocol, **kwargs)
27
31
  with LocalCluster(
28
- protocol="ucx",
32
+ protocol=protocol,
29
33
  dashboard_address=None,
30
34
  n_workers=1,
31
35
  threads_per_worker=1,
@@ -50,18 +54,29 @@ def _test_initialize_ucx_tcp():
50
54
  assert all(client.run(check_ucx_options).values())
51
55
 
52
56
 
53
- def test_initialize_ucx_tcp():
54
- p = mp.Process(target=_test_initialize_ucx_tcp)
57
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
58
+ def test_initialize_ucx_tcp(protocol):
59
+ if protocol == "ucx":
60
+ pytest.importorskip("ucp")
61
+ elif protocol == "ucxx":
62
+ pytest.importorskip("ucxx")
63
+
64
+ p = mp.Process(target=_test_initialize_ucx_tcp, args=(protocol,))
55
65
  p.start()
56
66
  p.join()
57
67
  assert not p.exitcode
58
68
 
59
69
 
60
- def _test_initialize_ucx_nvlink():
70
+ def _test_initialize_ucx_nvlink(protocol):
71
+ if protocol == "ucx":
72
+ ucp = pytest.importorskip("ucp")
73
+ elif protocol == "ucxx":
74
+ ucp = pytest.importorskip("ucxx")
75
+
61
76
  kwargs = {"enable_nvlink": True}
62
- initialize(**kwargs)
77
+ initialize(protocol=protocol, **kwargs)
63
78
  with LocalCluster(
64
- protocol="ucx",
79
+ protocol=protocol,
65
80
  dashboard_address=None,
66
81
  n_workers=1,
67
82
  threads_per_worker=1,
@@ -87,18 +102,29 @@ def _test_initialize_ucx_nvlink():
87
102
  assert all(client.run(check_ucx_options).values())
88
103
 
89
104
 
90
- def test_initialize_ucx_nvlink():
91
- p = mp.Process(target=_test_initialize_ucx_nvlink)
105
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
106
+ def test_initialize_ucx_nvlink(protocol):
107
+ if protocol == "ucx":
108
+ pytest.importorskip("ucp")
109
+ elif protocol == "ucxx":
110
+ pytest.importorskip("ucxx")
111
+
112
+ p = mp.Process(target=_test_initialize_ucx_nvlink, args=(protocol,))
92
113
  p.start()
93
114
  p.join()
94
115
  assert not p.exitcode
95
116
 
96
117
 
97
- def _test_initialize_ucx_infiniband():
118
+ def _test_initialize_ucx_infiniband(protocol):
119
+ if protocol == "ucx":
120
+ ucp = pytest.importorskip("ucp")
121
+ elif protocol == "ucxx":
122
+ ucp = pytest.importorskip("ucxx")
123
+
98
124
  kwargs = {"enable_infiniband": True}
99
- initialize(**kwargs)
125
+ initialize(protocol=protocol, **kwargs)
100
126
  with LocalCluster(
101
- protocol="ucx",
127
+ protocol=protocol,
102
128
  dashboard_address=None,
103
129
  n_workers=1,
104
130
  threads_per_worker=1,
@@ -127,17 +153,28 @@ def _test_initialize_ucx_infiniband():
127
153
  @pytest.mark.skipif(
128
154
  "ib0" not in psutil.net_if_addrs(), reason="Infiniband interface ib0 not found"
129
155
  )
130
- def test_initialize_ucx_infiniband():
131
- p = mp.Process(target=_test_initialize_ucx_infiniband)
156
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
157
+ def test_initialize_ucx_infiniband(protocol):
158
+ if protocol == "ucx":
159
+ pytest.importorskip("ucp")
160
+ elif protocol == "ucxx":
161
+ pytest.importorskip("ucxx")
162
+
163
+ p = mp.Process(target=_test_initialize_ucx_infiniband, args=(protocol,))
132
164
  p.start()
133
165
  p.join()
134
166
  assert not p.exitcode
135
167
 
136
168
 
137
- def _test_initialize_ucx_all():
138
- initialize()
169
+ def _test_initialize_ucx_all(protocol):
170
+ if protocol == "ucx":
171
+ ucp = pytest.importorskip("ucp")
172
+ elif protocol == "ucxx":
173
+ ucp = pytest.importorskip("ucxx")
174
+
175
+ initialize(protocol=protocol)
139
176
  with LocalCluster(
140
- protocol="ucx",
177
+ protocol=protocol,
141
178
  dashboard_address=None,
142
179
  n_workers=1,
143
180
  threads_per_worker=1,
@@ -166,8 +203,14 @@ def _test_initialize_ucx_all():
166
203
  assert all(client.run(check_ucx_options).values())
167
204
 
168
205
 
169
- def test_initialize_ucx_all():
170
- p = mp.Process(target=_test_initialize_ucx_all)
206
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
207
+ def test_initialize_ucx_all(protocol):
208
+ if protocol == "ucx":
209
+ pytest.importorskip("ucp")
210
+ elif protocol == "ucxx":
211
+ pytest.importorskip("ucxx")
212
+
213
+ p = mp.Process(target=_test_initialize_ucx_all, args=(protocol,))
171
214
  p.start()
172
215
  p.join()
173
216
  assert not p.exitcode
@@ -87,23 +87,38 @@ async def test_with_subset_of_cuda_visible_devices():
87
87
  }
88
88
 
89
89
 
90
+ @pytest.mark.parametrize(
91
+ "protocol",
92
+ ["ucx", "ucxx"],
93
+ )
90
94
  @gen_test(timeout=20)
91
- async def test_ucx_protocol():
92
- pytest.importorskip("ucp")
95
+ async def test_ucx_protocol(protocol):
96
+ if protocol == "ucx":
97
+ pytest.importorskip("ucp")
98
+ elif protocol == "ucxx":
99
+ pytest.importorskip("ucxx")
93
100
 
94
101
  async with LocalCUDACluster(
95
- protocol="ucx", asynchronous=True, data=dict
102
+ protocol=protocol, asynchronous=True, data=dict
96
103
  ) as cluster:
97
104
  assert all(
98
- ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values()
105
+ ws.address.startswith(f"{protocol}://")
106
+ for ws in cluster.scheduler.workers.values()
99
107
  )
100
108
 
101
109
 
110
+ @pytest.mark.parametrize(
111
+ "protocol",
112
+ ["ucx", "ucxx"],
113
+ )
102
114
  @gen_test(timeout=20)
103
- async def test_explicit_ucx_with_protocol_none():
104
- pytest.importorskip("ucp")
115
+ async def test_explicit_ucx_with_protocol_none(protocol):
116
+ if protocol == "ucx":
117
+ pytest.importorskip("ucp")
118
+ elif protocol == "ucxx":
119
+ pytest.importorskip("ucxx")
105
120
 
106
- initialize(enable_tcp_over_ucx=True)
121
+ initialize(protocol=protocol, enable_tcp_over_ucx=True)
107
122
  async with LocalCUDACluster(
108
123
  protocol=None, enable_tcp_over_ucx=True, asynchronous=True, data=dict
109
124
  ) as cluster:
@@ -113,11 +128,18 @@ async def test_explicit_ucx_with_protocol_none():
113
128
 
114
129
 
115
130
  @pytest.mark.filterwarnings("ignore:Exception ignored in")
131
+ @pytest.mark.parametrize(
132
+ "protocol",
133
+ ["ucx", "ucxx"],
134
+ )
116
135
  @gen_test(timeout=20)
117
- async def test_ucx_protocol_type_error():
118
- pytest.importorskip("ucp")
136
+ async def test_ucx_protocol_type_error(protocol):
137
+ if protocol == "ucx":
138
+ pytest.importorskip("ucp")
139
+ elif protocol == "ucxx":
140
+ pytest.importorskip("ucxx")
119
141
 
120
- initialize(enable_tcp_over_ucx=True)
142
+ initialize(protocol=protocol, enable_tcp_over_ucx=True)
121
143
  with pytest.raises(TypeError):
122
144
  async with LocalCUDACluster(
123
145
  protocol="tcp", enable_tcp_over_ucx=True, asynchronous=True, data=dict
@@ -337,6 +359,7 @@ async def test_pre_import():
337
359
 
338
360
 
339
361
  # Intentionally not using @gen_test to skip cleanup checks
362
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
340
363
  def test_pre_import_not_found():
341
364
  async def _test_pre_import_not_found():
342
365
  with raises_with_cause(RuntimeError, None, ImportError, None):
@@ -477,20 +500,30 @@ async def test_worker_fraction_limits():
477
500
  )
478
501
 
479
502
 
480
- def test_print_cluster_config(capsys):
503
+ @pytest.mark.parametrize(
504
+ "protocol",
505
+ ["ucx", "ucxx"],
506
+ )
507
+ def test_print_cluster_config(capsys, protocol):
508
+ if protocol == "ucx":
509
+ pytest.importorskip("ucp")
510
+ elif protocol == "ucxx":
511
+ pytest.importorskip("ucxx")
512
+
481
513
  pytest.importorskip("rich")
482
514
  with LocalCUDACluster(
483
- n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol="ucx"
515
+ n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol=protocol
484
516
  ) as cluster:
485
517
  with Client(cluster) as client:
486
518
  print_cluster_config(client)
487
519
  captured = capsys.readouterr()
488
520
  assert "Dask Cluster Configuration" in captured.out
489
- assert "ucx" in captured.out
521
+ assert protocol in captured.out
490
522
  assert "1 B" in captured.out
491
523
  assert "[plugin]" in captured.out
492
524
 
493
525
 
526
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
494
527
  def test_death_timeout_raises():
495
528
  with pytest.raises(asyncio.exceptions.TimeoutError):
496
529
  with LocalCUDACluster(
@@ -400,10 +400,14 @@ class _PxyObjTest(proxy_object.ProxyObject):
400
400
 
401
401
 
402
402
  @pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)])
403
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
403
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
404
404
  @gen_test(timeout=120)
405
405
  async def test_communicating_proxy_objects(protocol, send_serializers):
406
406
  """Testing serialization of cuDF dataframe when communicating"""
407
+ if protocol == "ucx":
408
+ pytest.importorskip("ucp")
409
+ elif protocol == "ucxx":
410
+ pytest.importorskip("ucxx")
407
411
  cudf = pytest.importorskip("cudf")
408
412
 
409
413
  def task(x):
@@ -412,7 +416,7 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
412
416
  serializers_used = x._pxy_get().serializer
413
417
 
414
418
  # Check that `x` is serialized with the expected serializers
415
- if protocol == "ucx":
419
+ if protocol in ["ucx", "ucxx"]:
416
420
  if send_serializers is None:
417
421
  assert serializers_used == "cuda"
418
422
  else:
@@ -443,11 +447,15 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
443
447
  await client.submit(task, df)
444
448
 
445
449
 
446
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
450
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
447
451
  @pytest.mark.parametrize("shared_fs", [True, False])
448
452
  @gen_test(timeout=20)
449
453
  async def test_communicating_disk_objects(protocol, shared_fs):
450
454
  """Testing disk serialization of cuDF dataframe when communicating"""
455
+ if protocol == "ucx":
456
+ pytest.importorskip("ucp")
457
+ elif protocol == "ucxx":
458
+ pytest.importorskip("ucxx")
451
459
  cudf = pytest.importorskip("cudf")
452
460
  ProxifyHostFile._spill_to_disk.shared_filesystem = shared_fs
453
461
 
@@ -79,11 +79,18 @@ def test_get_device_total_memory():
79
79
  assert total_mem > 0
80
80
 
81
81
 
82
- def test_get_preload_options_default():
83
- pytest.importorskip("ucp")
82
+ @pytest.mark.parametrize(
83
+ "protocol",
84
+ ["ucx", "ucxx"],
85
+ )
86
+ def test_get_preload_options_default(protocol):
87
+ if protocol == "ucx":
88
+ pytest.importorskip("ucp")
89
+ elif protocol == "ucxx":
90
+ pytest.importorskip("ucxx")
84
91
 
85
92
  opts = get_preload_options(
86
- protocol="ucx",
93
+ protocol=protocol,
87
94
  create_cuda_context=True,
88
95
  )
89
96
 
@@ -93,14 +100,21 @@ def test_get_preload_options_default():
93
100
  assert opts["preload_argv"] == ["--create-cuda-context"]
94
101
 
95
102
 
103
+ @pytest.mark.parametrize(
104
+ "protocol",
105
+ ["ucx", "ucxx"],
106
+ )
96
107
  @pytest.mark.parametrize("enable_tcp", [True, False])
97
108
  @pytest.mark.parametrize("enable_infiniband", [True, False])
98
109
  @pytest.mark.parametrize("enable_nvlink", [True, False])
99
- def test_get_preload_options(enable_tcp, enable_infiniband, enable_nvlink):
100
- pytest.importorskip("ucp")
110
+ def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvlink):
111
+ if protocol == "ucx":
112
+ pytest.importorskip("ucp")
113
+ elif protocol == "ucxx":
114
+ pytest.importorskip("ucxx")
101
115
 
102
116
  opts = get_preload_options(
103
- protocol="ucx",
117
+ protocol=protocol,
104
118
  create_cuda_context=True,
105
119
  enable_tcp_over_ucx=enable_tcp,
106
120
  enable_infiniband=enable_infiniband,
dask_cuda/utils.py CHANGED
@@ -287,7 +287,7 @@ def get_preload_options(
287
287
  if create_cuda_context:
288
288
  preload_options["preload_argv"].append("--create-cuda-context")
289
289
 
290
- if protocol == "ucx":
290
+ if protocol in ["ucx", "ucxx"]:
291
291
  initialize_ucx_argv = []
292
292
  if enable_tcp_over_ucx:
293
293
  initialize_ucx_argv.append("--enable-tcp-over-ucx")
@@ -625,6 +625,10 @@ def get_worker_config(dask_worker):
625
625
  import ucp
626
626
 
627
627
  ret["ucx-transports"] = ucp.get_active_transports()
628
+ elif scheme == "ucxx":
629
+ import ucxx
630
+
631
+ ret["ucx-transports"] = ucxx.get_active_transports()
628
632
 
629
633
  # comm timeouts
630
634
  ret["distributed.comm.timeouts"] = dask.config.get("distributed.comm.timeouts")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dask-cuda
3
- Version: 23.12.0a231027
3
+ Version: 24.2.0a3
4
4
  Summary: Utilities for Dask and CUDA interactions
5
5
  Author: NVIDIA Corporation
6
6
  License: Apache-2.0
@@ -17,12 +17,11 @@ Classifier: Programming Language :: Python :: 3.10
17
17
  Requires-Python: >=3.9
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
- Requires-Dist: dask ==2023.9.2
21
- Requires-Dist: distributed ==2023.9.2
22
20
  Requires-Dist: pynvml <11.5,>=11.0.0
23
21
  Requires-Dist: numpy >=1.21
24
22
  Requires-Dist: numba >=0.57
25
23
  Requires-Dist: pandas <1.6.0dev0,>=1.3
24
+ Requires-Dist: rapids-dask-dependency ==23.12.*,>=0.0.0a0
26
25
  Requires-Dist: zict >=2.0.0
27
26
  Provides-Extra: docs
28
27
  Requires-Dist: numpydoc >=1.1.0 ; extra == 'docs'
@@ -1,29 +1,30 @@
1
- dask_cuda/__init__.py,sha256=xtogPs_QSmTTMOWetj9CqLaCwdF-bANfrD75LYpulMc,1452
1
+ dask_cuda/VERSION,sha256=iO2KXnB96wpQpEYq6RxAonc8LlGIKB2Lot9hQmqhXUQ,11
2
+ dask_cuda/__init__.py,sha256=XnMTUi-SvoGn7g1Dj6XW97HnQzGQv0G3EnvSjcZ7vU4,1455
3
+ dask_cuda/_version.py,sha256=FgBzL-H3uFWUDb0QvqJw3AytPr1PG8LbMnHxQEX8Vx4,738
2
4
  dask_cuda/cli.py,sha256=XNRH0bu-6jzRoyWJB5qSWuzePJSh3z_5Ng6rDCnz7lg,15970
3
- dask_cuda/compat.py,sha256=BLXv9IHUtD3h6-T_8MX-uGt-UDMG6EuGuyN-zw3XndU,4084
4
5
  dask_cuda/cuda_worker.py,sha256=bIu-ESeIpJG_WaTYrv0z9z5juJ1qR5i_5Ng3CN1WK8s,8579
5
- dask_cuda/device_host_file.py,sha256=D0rHOFz1TRfvaecoP30x3JRWe1TiHUaq45Dg-v0DfoY,10272
6
+ dask_cuda/device_host_file.py,sha256=yS31LGtt9VFAG78uBBlTDr7HGIng2XymV1OxXIuEMtM,10272
6
7
  dask_cuda/disk_io.py,sha256=urSLKiPvJvYmKCzDPOUDCYuLI3r1RUiyVh3UZGRoF_Y,6626
7
8
  dask_cuda/get_device_memory_objects.py,sha256=zMSqWzm5rflRInbNMz7U2Ewv5nMcE-H8stMJeWHVWyc,3890
8
- dask_cuda/initialize.py,sha256=mzPgKhs8oLgUWpqd4ckvLNKvhLoHjt96RrBPeVneenI,5231
9
+ dask_cuda/initialize.py,sha256=Gjcxs_c8DTafgsHe5-2mw4lJdOmbFJJAZVOnxA8lTjM,6462
9
10
  dask_cuda/is_device_object.py,sha256=CnajvbQiX0FzFzwft0MqK1OPomx3ZGDnDxT56wNjixw,1046
10
11
  dask_cuda/is_spillable_object.py,sha256=CddGmg0tuSpXh2m_TJSY6GRpnl1WRHt1CRcdWgHPzWA,1457
11
- dask_cuda/local_cuda_cluster.py,sha256=w2HXMZtEukwklkB3J6l6DqZstNA5uvGEdFkdzpyUJ6k,17810
12
+ dask_cuda/local_cuda_cluster.py,sha256=hoEiEfJqAQrRS7N632VatSl1245GiWMT5B77Wc-i5C0,17928
12
13
  dask_cuda/plugins.py,sha256=cnHsdrXx7PBPmrzHX6YEkCH5byCsUk8LE2FeTeu8ZLU,4259
13
14
  dask_cuda/proxify_device_objects.py,sha256=99CD7LOE79YiQGJ12sYl_XImVhJXpFR4vG5utdkjTQo,8108
14
15
  dask_cuda/proxify_host_file.py,sha256=Wf5CFCC1JN5zmfvND3ls0M5FL01Y8VhHrk0xV3UQ9kk,30850
15
16
  dask_cuda/proxy_object.py,sha256=bZq92kjgFB-ad_luSAFT_RItV3nssmiEk4OOSp34laU,29812
16
- dask_cuda/utils.py,sha256=wNRItbIXrOpH77AUUrZNrGqgIiGNzpClXYl0QmQqfxs,25002
17
+ dask_cuda/utils.py,sha256=RWlLK2cPHaCuNNhr8bW8etBeGklwREQJOafQbTydStk,25121
17
18
  dask_cuda/utils_test.py,sha256=WNMR0gic2tuP3pgygcR9g52NfyX8iGMOan6juXhpkCE,1694
18
19
  dask_cuda/worker_spec.py,sha256=7-Uq_e5q2SkTlsmctMcYLCa9_3RiiVHZLIN7ctfaFmE,4376
19
20
  dask_cuda/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
21
  dask_cuda/benchmarks/common.py,sha256=sEIFnRZS6wbyKCQyB4fDclYLc2YqC0PolurR5qzuRxw,6393
21
- dask_cuda/benchmarks/local_cudf_groupby.py,sha256=2iHk-a-GvLmAgajwQJNrqmZ-WJeiyMFEyflcxh7SPO8,8894
22
- dask_cuda/benchmarks/local_cudf_merge.py,sha256=vccM5PyzZVW99-a8YaIgftsGAiA5yXnT9NoAusx0PZY,12437
23
- dask_cuda/benchmarks/local_cudf_shuffle.py,sha256=LaNCMKhhfE1lYFUUWMtdYH-efbqV6YTFhKC-Eog9-8Q,8598
24
- dask_cuda/benchmarks/local_cupy.py,sha256=G36CI46ROtNPf6QISK6QjoguB2Qb19ScsylwLFOlMy4,10752
25
- dask_cuda/benchmarks/local_cupy_map_overlap.py,sha256=rQNLGvpX1XpgK-0Wx5fd3kV9Veu2ulBd5eX2sanNlEQ,6432
26
- dask_cuda/benchmarks/utils.py,sha256=mx_JKe4q1xFNwKJX03o8dEwc48iqnqHm-ZTHOcMn17E,26888
22
+ dask_cuda/benchmarks/local_cudf_groupby.py,sha256=T9lA9nb4Wzu46AH--SJEVCeCm3650J7slapdNR_08FU,8904
23
+ dask_cuda/benchmarks/local_cudf_merge.py,sha256=POjxoPx4zY1TjG2S_anElL6rDtC5Jhn3nF4HABlnwZg,12447
24
+ dask_cuda/benchmarks/local_cudf_shuffle.py,sha256=M-Lp3O3q8uyY50imQqMKZYwkAmyR0NApjx2ipGxDkXw,8608
25
+ dask_cuda/benchmarks/local_cupy.py,sha256=aUKIYfeR7c77K4kKk697Rxo8tG8kFabQ9jQEVGr-oTs,10762
26
+ dask_cuda/benchmarks/local_cupy_map_overlap.py,sha256=_texYmam1K_XbzIvURltui5KRsISGFNylXiGUtgRIz0,6442
27
+ dask_cuda/benchmarks/utils.py,sha256=baL5zK6VS6Mw_M4x9zJe8vMLUd2SZd1lS78JrL-h6oo,26896
27
28
  dask_cuda/explicit_comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
29
  dask_cuda/explicit_comms/comms.py,sha256=Su6PuNo68IyS-AwoqU4S9TmqWsLvUdNa0jot2hx8jQQ,10400
29
30
  dask_cuda/explicit_comms/dataframe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -31,22 +32,22 @@ dask_cuda/explicit_comms/dataframe/shuffle.py,sha256=2f2wlPyqXpryIHgMpsZzs3pDE7e
31
32
  dask_cuda/tests/test_cudf_builtin_spilling.py,sha256=u3kW91YRLdHFycvpGfSQKrEucu5khMJ1k4sjmddO490,4910
32
33
  dask_cuda/tests/test_dask_cuda_worker.py,sha256=gViHaMCSfB6ip125OEi9D0nfKC-qBXRoHz6BRodEdb4,17729
33
34
  dask_cuda/tests/test_device_host_file.py,sha256=79ssUISo1YhsW_7HdwqPfsH2LRzS2bi5BjPym1Sdgqw,5882
34
- dask_cuda/tests/test_dgx.py,sha256=bKX-GvkYjWlmcEIK15aGErxmc0qPqIWOG1CeDFGoXFU,6381
35
- dask_cuda/tests/test_explicit_comms.py,sha256=3Q3o9BX4ksCgz11o38o5QhKg3Rv-EtTsGnVG83wwyyo,12283
36
- dask_cuda/tests/test_from_array.py,sha256=i2Vha4mchB0BopTlEdXV7CxY7qyTzFYdgYQTmukZX38,493
35
+ dask_cuda/tests/test_dgx.py,sha256=IDP5vxDgVx6n2Hm-7PhlZQFb_zlgn_3nBW9t7MeJcTM,6986
36
+ dask_cuda/tests/test_explicit_comms.py,sha256=AZlBi16prk3hc0ydzqAGecMcYZeyUvu6Pecb7gY_0yY,12315
37
+ dask_cuda/tests/test_from_array.py,sha256=okT1B6UqHmLxoy0uER0Ylm3UyOmi5BAXwJpTuTAw44I,601
37
38
  dask_cuda/tests/test_gds.py,sha256=6jf0HPTHAIG8Mp_FC4Ai4zpn-U1K7yk0fSXg8He8-r8,1513
38
- dask_cuda/tests/test_initialize.py,sha256=-Vo8SVBrVEKB0V1C6ia8khvbHJt4BC0xEjMNLhNbFxI,5491
39
- dask_cuda/tests/test_local_cuda_cluster.py,sha256=rKusvUq_5FNqev51Vli1wlQQH7qL_5lIsjQN57e0A4k,17445
39
+ dask_cuda/tests/test_initialize.py,sha256=Rba59ZbljEm1yyN94_sWZPEE_f7hWln95aiBVc49pmY,6960
40
+ dask_cuda/tests/test_local_cuda_cluster.py,sha256=G3kR-4o-vCqWWfSuQLFKVEK0F243FaDSgRlDTUll5aU,18376
40
41
  dask_cuda/tests/test_proxify_host_file.py,sha256=cp-U1uNPhesQaHbftKV8ir_dt5fbs0ZXSIsL39oI0fE,18630
41
- dask_cuda/tests/test_proxy.py,sha256=Nu9vLx-dALINcF_wsxuFYUryRE0Jq43w7bAYAchK8RY,23480
42
+ dask_cuda/tests/test_proxy.py,sha256=niRZl4gsBgCqW8mDm4UG_zSro2mVoAMOfhy9n1ps1Wk,23758
42
43
  dask_cuda/tests/test_spill.py,sha256=xN9PbVERBYMuZxvscSO0mAM22loq9WT3ltZVBFxlmM4,10239
43
- dask_cuda/tests/test_utils.py,sha256=wgYPvu7Sk61C64pah9ZbK8cnBXK5RyUCpu3G2ny6OZQ,8832
44
+ dask_cuda/tests/test_utils.py,sha256=JRIwXfemc3lWSzLJX0VcvR1_0wB4yeoOTsw7kB6z6pU,9176
44
45
  dask_cuda/tests/test_worker_spec.py,sha256=Bvu85vkqm6ZDAYPXKMJlI2pm9Uc5tiYKNtO4goXSw-I,2399
45
46
  examples/ucx/client_initialize.py,sha256=YN3AXHF8btcMd6NicKKhKR9SXouAsK1foJhFspbOn70,1262
46
47
  examples/ucx/local_cuda_cluster.py,sha256=7xVY3EhwhkY2L4VZin_BiMCbrjhirDNChoC86KiETNc,1983
47
- dask_cuda-23.12.0a231027.dist-info/LICENSE,sha256=MjI3I-EgxfEvZlgjk82rgiFsZqSDXHFETd2QJ89UwDA,11348
48
- dask_cuda-23.12.0a231027.dist-info/METADATA,sha256=gbaDcC9Wti46hsAM8g0wzySIvt9MVvxyU_zA-PSdl1s,2285
49
- dask_cuda-23.12.0a231027.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
50
- dask_cuda-23.12.0a231027.dist-info/entry_points.txt,sha256=UcRaKVEpywtxc6pF1VnfMB0UK4sJg7a8_NdZF67laPM,136
51
- dask_cuda-23.12.0a231027.dist-info/top_level.txt,sha256=3kKxJxeM108fuYc_lwwlklP7YBU9IEmdmRAouzi397o,33
52
- dask_cuda-23.12.0a231027.dist-info/RECORD,,
48
+ dask_cuda-24.2.0a3.dist-info/LICENSE,sha256=MjI3I-EgxfEvZlgjk82rgiFsZqSDXHFETd2QJ89UwDA,11348
49
+ dask_cuda-24.2.0a3.dist-info/METADATA,sha256=In9gerK6-AOB1i7Z_z85xV4oa7afMcVYQKIcY5jYs_Q,2268
50
+ dask_cuda-24.2.0a3.dist-info/WHEEL,sha256=Xo9-1PvkuimrydujYJAjF7pCkriuXBpUPEjma1nZyJ0,92
51
+ dask_cuda-24.2.0a3.dist-info/entry_points.txt,sha256=UcRaKVEpywtxc6pF1VnfMB0UK4sJg7a8_NdZF67laPM,136
52
+ dask_cuda-24.2.0a3.dist-info/top_level.txt,sha256=3kKxJxeM108fuYc_lwwlklP7YBU9IEmdmRAouzi397o,33
53
+ dask_cuda-24.2.0a3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.2)
2
+ Generator: bdist_wheel (0.41.3)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
dask_cuda/compat.py DELETED
@@ -1,118 +0,0 @@
1
- import pickle
2
-
3
- import msgpack
4
- from packaging.version import Version
5
-
6
- import dask
7
- import distributed
8
- import distributed.comm.utils
9
- import distributed.protocol
10
- from distributed.comm.utils import OFFLOAD_THRESHOLD, nbytes, offload
11
- from distributed.protocol.core import (
12
- Serialized,
13
- decompress,
14
- logger,
15
- merge_and_deserialize,
16
- msgpack_decode_default,
17
- msgpack_opts,
18
- )
19
-
20
- if Version(distributed.__version__) >= Version("2023.8.1"):
21
- # Monkey-patch protocol.core.loads (and its users)
22
- async def from_frames(
23
- frames, deserialize=True, deserializers=None, allow_offload=True
24
- ):
25
- """
26
- Unserialize a list of Distributed protocol frames.
27
- """
28
- size = False
29
-
30
- def _from_frames():
31
- try:
32
- # Patched code
33
- return loads(
34
- frames, deserialize=deserialize, deserializers=deserializers
35
- )
36
- # end patched code
37
- except EOFError:
38
- if size > 1000:
39
- datastr = "[too large to display]"
40
- else:
41
- datastr = frames
42
- # Aid diagnosing
43
- logger.error("truncated data stream (%d bytes): %s", size, datastr)
44
- raise
45
-
46
- if allow_offload and deserialize and OFFLOAD_THRESHOLD:
47
- size = sum(map(nbytes, frames))
48
- if (
49
- allow_offload
50
- and deserialize
51
- and OFFLOAD_THRESHOLD
52
- and size > OFFLOAD_THRESHOLD
53
- ):
54
- res = await offload(_from_frames)
55
- else:
56
- res = _from_frames()
57
-
58
- return res
59
-
60
- def loads(frames, deserialize=True, deserializers=None):
61
- """Transform bytestream back into Python value"""
62
-
63
- allow_pickle = dask.config.get("distributed.scheduler.pickle")
64
-
65
- try:
66
-
67
- def _decode_default(obj):
68
- offset = obj.get("__Serialized__", 0)
69
- if offset > 0:
70
- sub_header = msgpack.loads(
71
- frames[offset],
72
- object_hook=msgpack_decode_default,
73
- use_list=False,
74
- **msgpack_opts,
75
- )
76
- offset += 1
77
- sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
78
- if deserialize:
79
- if "compression" in sub_header:
80
- sub_frames = decompress(sub_header, sub_frames)
81
- return merge_and_deserialize(
82
- sub_header, sub_frames, deserializers=deserializers
83
- )
84
- else:
85
- return Serialized(sub_header, sub_frames)
86
-
87
- offset = obj.get("__Pickled__", 0)
88
- if offset > 0:
89
- sub_header = msgpack.loads(frames[offset])
90
- offset += 1
91
- sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
92
- # Patched code
93
- if "compression" in sub_header:
94
- sub_frames = decompress(sub_header, sub_frames)
95
- # end patched code
96
- if allow_pickle:
97
- return pickle.loads(
98
- sub_header["pickled-obj"], buffers=sub_frames
99
- )
100
- else:
101
- raise ValueError(
102
- "Unpickle on the Scheduler isn't allowed, "
103
- "set `distributed.scheduler.pickle=true`"
104
- )
105
-
106
- return msgpack_decode_default(obj)
107
-
108
- return msgpack.loads(
109
- frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
110
- )
111
-
112
- except Exception:
113
- logger.critical("Failed to deserialize", exc_info=True)
114
- raise
115
-
116
- distributed.protocol.loads = loads
117
- distributed.protocol.core.loads = loads
118
- distributed.comm.utils.from_frames = from_frames