dask-cuda 23.12.0a231027__py3-none-any.whl → 24.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dask_cuda/VERSION ADDED
@@ -0,0 +1 @@
1
+ 24.02.00
dask_cuda/__init__.py CHANGED
@@ -11,6 +11,7 @@ import dask.dataframe.shuffle
11
11
  import dask.dataframe.multi
12
12
  import dask.bag.core
13
13
 
14
+ from ._version import __git_commit__, __version__
14
15
  from .cuda_worker import CUDAWorker
15
16
  from .explicit_comms.dataframe.shuffle import (
16
17
  get_rearrange_by_column_wrapper,
@@ -19,9 +20,6 @@ from .explicit_comms.dataframe.shuffle import (
19
20
  from .local_cuda_cluster import LocalCUDACluster
20
21
  from .proxify_device_objects import proxify_decorator, unproxify_decorator
21
22
 
22
- __version__ = "23.12.00"
23
-
24
- from . import compat
25
23
 
26
24
  # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
27
25
  dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper(
dask_cuda/_version.py ADDED
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2023, NVIDIA CORPORATION.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib.resources
16
+
17
+ __version__ = (
18
+ importlib.resources.files("dask_cuda").joinpath("VERSION").read_text().strip()
19
+ )
20
+ __git_commit__ = "96bedbc2931170ef44ed9929e00bcf5bb2a9de56"
@@ -139,7 +139,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
139
139
  key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
140
140
  )
141
141
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
142
- if args.protocol == "ucx":
142
+ if args.protocol in ["ucx", "ucxx"]:
143
143
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
144
144
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
145
145
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -217,7 +217,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
217
217
  )
218
218
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
219
219
  print_key_value(key="Frac-match", value=f"{args.frac_match}")
220
- if args.protocol == "ucx":
220
+ if args.protocol in ["ucx", "ucxx"]:
221
221
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
222
222
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
223
223
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -146,7 +146,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
146
146
  key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
147
147
  )
148
148
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
149
- if args.protocol == "ucx":
149
+ if args.protocol in ["ucx", "ucxx"]:
150
150
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
151
151
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
152
152
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -193,7 +193,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
193
193
  )
194
194
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
195
195
  print_key_value(key="Protocol", value=f"{args.protocol}")
196
- if args.protocol == "ucx":
196
+ if args.protocol in ["ucx", "ucxx"]:
197
197
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
198
198
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
199
199
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -78,7 +78,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
78
78
  )
79
79
  print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
80
80
  print_key_value(key="Protocol", value=f"{args.protocol}")
81
- if args.protocol == "ucx":
81
+ if args.protocol in ["ucx", "ucxx"]:
82
82
  print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
83
83
  print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
84
84
  print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
@@ -73,7 +73,7 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[]
73
73
  cluster_args.add_argument(
74
74
  "-p",
75
75
  "--protocol",
76
- choices=["tcp", "ucx"],
76
+ choices=["tcp", "ucx", "ucxx"],
77
77
  default="tcp",
78
78
  type=str,
79
79
  help="The communication protocol to use.",
@@ -17,7 +17,7 @@ from distributed.protocol import (
17
17
  serialize_bytelist,
18
18
  )
19
19
  from distributed.sizeof import safe_sizeof
20
- from distributed.spill import CustomFile as KeyAsStringFile
20
+ from distributed.spill import AnyKeyFile as KeyAsStringFile
21
21
  from distributed.utils import nbytes
22
22
 
23
23
  from .is_device_object import is_device_object
@@ -577,7 +577,7 @@ def get_rearrange_by_column_wrapper(func):
577
577
  kw = kw.arguments
578
578
  # Notice, we only overwrite the default and the "tasks" shuffle
579
579
  # algorithm. The "disk" and "p2p" algorithm, we don't touch.
580
- if kw["shuffle"] in ("tasks", None):
580
+ if kw["shuffle_method"] in ("tasks", None):
581
581
  col = kw["col"]
582
582
  if isinstance(col, str):
583
583
  col = [col]
@@ -124,6 +124,10 @@ def get_device_memory_objects_register_cudf():
124
124
  def get_device_memory_objects_cudf_multiindex(obj):
125
125
  return dispatch(obj._columns)
126
126
 
127
+ @dispatch.register(cudf.core.column.ColumnBase)
128
+ def get_device_memory_objects_cudf_column(obj):
129
+ return dispatch(obj.data) + dispatch(obj.children) + dispatch(obj.mask)
130
+
127
131
 
128
132
  @sizeof.register_lazy("cupy")
129
133
  def register_cupy(): # NB: this overwrites dask.sizeof.register_cupy()
dask_cuda/initialize.py CHANGED
@@ -5,7 +5,6 @@ import click
5
5
  import numba.cuda
6
6
 
7
7
  import dask
8
- import distributed.comm.ucx
9
8
  from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context
10
9
 
11
10
  from .utils import get_ucx_config
@@ -23,12 +22,21 @@ def _create_cuda_context_handler():
23
22
  numba.cuda.current_context()
24
23
 
25
24
 
26
- def _create_cuda_context():
25
+ def _create_cuda_context(protocol="ucx"):
26
+ if protocol not in ["ucx", "ucxx"]:
27
+ return
27
28
  try:
28
29
  # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
29
30
  # context directly from the UCX module, thus avoiding a similar warning there.
30
31
  try:
31
- distributed.comm.ucx.init_once()
32
+ if protocol == "ucx":
33
+ import distributed.comm.ucx
34
+
35
+ distributed.comm.ucx.init_once()
36
+ elif protocol == "ucxx":
37
+ import distributed_ucxx.ucxx
38
+
39
+ distributed_ucxx.ucxx.init_once()
32
40
  except ModuleNotFoundError:
33
41
  # UCX initialization has to be delegated to Distributed, it will take care
34
42
  # of setting correct environment variables and importing `ucp` after that.
@@ -39,20 +47,35 @@ def _create_cuda_context():
39
47
  os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
40
48
  )
41
49
  ctx = has_cuda_context()
42
- if (
43
- ctx.has_context
44
- and not distributed.comm.ucx.cuda_context_created.has_context
45
- ):
46
- distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
50
+ if protocol == "ucx":
51
+ if (
52
+ ctx.has_context
53
+ and not distributed.comm.ucx.cuda_context_created.has_context
54
+ ):
55
+ distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
56
+ elif protocol == "ucxx":
57
+ if (
58
+ ctx.has_context
59
+ and not distributed_ucxx.ucxx.cuda_context_created.has_context
60
+ ):
61
+ distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
47
62
 
48
63
  _create_cuda_context_handler()
49
64
 
50
- if not distributed.comm.ucx.cuda_context_created.has_context:
51
- ctx = has_cuda_context()
52
- if ctx.has_context and ctx.device_info != cuda_visible_device:
53
- distributed.comm.ucx._warn_cuda_context_wrong_device(
54
- cuda_visible_device, ctx.device_info, os.getpid()
55
- )
65
+ if protocol == "ucx":
66
+ if not distributed.comm.ucx.cuda_context_created.has_context:
67
+ ctx = has_cuda_context()
68
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
69
+ distributed.comm.ucx._warn_cuda_context_wrong_device(
70
+ cuda_visible_device, ctx.device_info, os.getpid()
71
+ )
72
+ elif protocol == "ucxx":
73
+ if not distributed_ucxx.ucxx.cuda_context_created.has_context:
74
+ ctx = has_cuda_context()
75
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
76
+ distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
77
+ cuda_visible_device, ctx.device_info, os.getpid()
78
+ )
56
79
 
57
80
  except Exception:
58
81
  logger.error("Unable to start CUDA Context", exc_info=True)
@@ -64,6 +87,7 @@ def initialize(
64
87
  enable_infiniband=None,
65
88
  enable_nvlink=None,
66
89
  enable_rdmacm=None,
90
+ protocol="ucx",
67
91
  ):
68
92
  """Create CUDA context and initialize UCX-Py, depending on user parameters.
69
93
 
@@ -118,7 +142,7 @@ def initialize(
118
142
  dask.config.set({"distributed.comm.ucx": ucx_config})
119
143
 
120
144
  if create_cuda_context:
121
- _create_cuda_context()
145
+ _create_cuda_context(protocol=protocol)
122
146
 
123
147
 
124
148
  @click.command()
@@ -127,6 +151,12 @@ def initialize(
127
151
  default=False,
128
152
  help="Create CUDA context",
129
153
  )
154
+ @click.option(
155
+ "--protocol",
156
+ default=None,
157
+ type=str,
158
+ help="Communication protocol, such as: 'tcp', 'tls', 'ucx' or 'ucxx'.",
159
+ )
130
160
  @click.option(
131
161
  "--enable-tcp-over-ucx/--disable-tcp-over-ucx",
132
162
  default=False,
@@ -150,10 +180,11 @@ def initialize(
150
180
  def dask_setup(
151
181
  service,
152
182
  create_cuda_context,
183
+ protocol,
153
184
  enable_tcp_over_ucx,
154
185
  enable_infiniband,
155
186
  enable_nvlink,
156
187
  enable_rdmacm,
157
188
  ):
158
189
  if create_cuda_context:
159
- _create_cuda_context()
190
+ _create_cuda_context(protocol=protocol)
@@ -319,8 +319,11 @@ class LocalCUDACluster(LocalCluster):
319
319
  if enable_tcp_over_ucx or enable_infiniband or enable_nvlink:
320
320
  if protocol is None:
321
321
  protocol = "ucx"
322
- elif protocol != "ucx":
323
- raise TypeError("Enabling InfiniBand or NVLink requires protocol='ucx'")
322
+ elif protocol not in ["ucx", "ucxx"]:
323
+ raise TypeError(
324
+ "Enabling InfiniBand or NVLink requires protocol='ucx' or "
325
+ "protocol='ucxx'"
326
+ )
324
327
 
325
328
  self.host = kwargs.get("host", None)
326
329
 
@@ -371,7 +374,7 @@ class LocalCUDACluster(LocalCluster):
371
374
  ) + ["dask_cuda.initialize"]
372
375
  self.new_spec["options"]["preload_argv"] = self.new_spec["options"].get(
373
376
  "preload_argv", []
374
- ) + ["--create-cuda-context"]
377
+ ) + ["--create-cuda-context", "--protocol", protocol]
375
378
 
376
379
  self.cuda_visible_devices = CUDA_VISIBLE_DEVICES
377
380
  self.scale(n_workers)
@@ -73,10 +73,13 @@ def test_default():
73
73
  assert not p.exitcode
74
74
 
75
75
 
76
- def _test_tcp_over_ucx():
77
- ucp = pytest.importorskip("ucp")
76
+ def _test_tcp_over_ucx(protocol):
77
+ if protocol == "ucx":
78
+ ucp = pytest.importorskip("ucp")
79
+ elif protocol == "ucxx":
80
+ ucp = pytest.importorskip("ucxx")
78
81
 
79
- with LocalCUDACluster(enable_tcp_over_ucx=True) as cluster:
82
+ with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster:
80
83
  with Client(cluster) as client:
81
84
  res = da.from_array(numpy.arange(10000), chunks=(1000,))
82
85
  res = res.sum().compute()
@@ -93,10 +96,17 @@ def _test_tcp_over_ucx():
93
96
  assert all(client.run(check_ucx_options).values())
94
97
 
95
98
 
96
- def test_tcp_over_ucx():
97
- ucp = pytest.importorskip("ucp") # NOQA: F841
99
+ @pytest.mark.parametrize(
100
+ "protocol",
101
+ ["ucx", "ucxx"],
102
+ )
103
+ def test_tcp_over_ucx(protocol):
104
+ if protocol == "ucx":
105
+ pytest.importorskip("ucp")
106
+ elif protocol == "ucxx":
107
+ pytest.importorskip("ucxx")
98
108
 
99
- p = mp.Process(target=_test_tcp_over_ucx)
109
+ p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,))
100
110
  p.start()
101
111
  p.join()
102
112
  assert not p.exitcode
@@ -117,9 +127,26 @@ def test_tcp_only():
117
127
  assert not p.exitcode
118
128
 
119
129
 
120
- def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm):
130
+ def _test_ucx_infiniband_nvlink(
131
+ skip_queue, protocol, enable_infiniband, enable_nvlink, enable_rdmacm
132
+ ):
121
133
  cupy = pytest.importorskip("cupy")
122
- ucp = pytest.importorskip("ucp")
134
+ if protocol == "ucx":
135
+ ucp = pytest.importorskip("ucp")
136
+ elif protocol == "ucxx":
137
+ ucp = pytest.importorskip("ucxx")
138
+
139
+ if enable_infiniband and not any(
140
+ [at.startswith("rc") for at in ucp.get_active_transports()]
141
+ ):
142
+ skip_queue.put("No support available for 'rc' transport in UCX")
143
+ return
144
+ else:
145
+ skip_queue.put("ok")
146
+
147
+ # `ucp.get_active_transports()` call above initializes UCX, we must reset it
148
+ # so that Dask doesn't try to initialize it again and raise an exception.
149
+ ucp.reset()
123
150
 
124
151
  if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
125
152
  enable_tcp_over_ucx = None
@@ -135,6 +162,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
135
162
  cm_tls_priority = ["tcp"]
136
163
 
137
164
  initialize(
165
+ protocol=protocol,
138
166
  enable_tcp_over_ucx=enable_tcp_over_ucx,
139
167
  enable_infiniband=enable_infiniband,
140
168
  enable_nvlink=enable_nvlink,
@@ -142,6 +170,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
142
170
  )
143
171
 
144
172
  with LocalCUDACluster(
173
+ protocol=protocol,
145
174
  interface="ib0",
146
175
  enable_tcp_over_ucx=enable_tcp_over_ucx,
147
176
  enable_infiniband=enable_infiniband,
@@ -171,6 +200,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
171
200
  assert all(client.run(check_ucx_options).values())
172
201
 
173
202
 
203
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
174
204
  @pytest.mark.parametrize(
175
205
  "params",
176
206
  [
@@ -185,16 +215,19 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
185
215
  _get_dgx_version() == DGXVersion.DGX_A100,
186
216
  reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
187
217
  )
188
- def test_ucx_infiniband_nvlink(params):
189
- ucp = pytest.importorskip("ucp") # NOQA: F841
218
+ def test_ucx_infiniband_nvlink(protocol, params):
219
+ if protocol == "ucx":
220
+ pytest.importorskip("ucp")
221
+ elif protocol == "ucxx":
222
+ pytest.importorskip("ucxx")
190
223
 
191
- if params["enable_infiniband"]:
192
- if not any([at.startswith("rc") for at in ucp.get_active_transports()]):
193
- pytest.skip("No support available for 'rc' transport in UCX")
224
+ skip_queue = mp.Queue()
194
225
 
195
226
  p = mp.Process(
196
227
  target=_test_ucx_infiniband_nvlink,
197
228
  args=(
229
+ skip_queue,
230
+ protocol,
198
231
  params["enable_infiniband"],
199
232
  params["enable_nvlink"],
200
233
  params["enable_rdmacm"],
@@ -203,9 +236,8 @@ def test_ucx_infiniband_nvlink(params):
203
236
  p.start()
204
237
  p.join()
205
238
 
206
- # Starting a new cluster on the same pytest process after an rdmacm cluster
207
- # has been used may cause UCX-Py to complain about being already initialized.
208
- if params["enable_rdmacm"] is True:
209
- ucp.reset()
239
+ skip_msg = skip_queue.get()
240
+ if skip_msg != "ok":
241
+ pytest.skip(skip_msg)
210
242
 
211
243
  assert not p.exitcode
@@ -1,6 +1,9 @@
1
1
  import asyncio
2
2
  import multiprocessing as mp
3
3
  import os
4
+ import signal
5
+ import time
6
+ from functools import partial
4
7
  from unittest.mock import patch
5
8
 
6
9
  import numpy as np
@@ -44,7 +47,7 @@ def _test_local_cluster(protocol):
44
47
  assert sum(c.run(my_rank, 0)) == sum(range(4))
45
48
 
46
49
 
47
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
50
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
48
51
  def test_local_cluster(protocol):
49
52
  p = mp.Process(target=_test_local_cluster, args=(protocol,))
50
53
  p.start()
@@ -160,7 +163,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
160
163
 
161
164
  @pytest.mark.parametrize("nworkers", [1, 2, 3])
162
165
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
163
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
166
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
164
167
  @pytest.mark.parametrize("_partitions", [True, False])
165
168
  def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
166
169
  if backend == "cudf":
@@ -175,7 +178,7 @@ def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
175
178
 
176
179
 
177
180
  @pytest.mark.parametrize("in_cluster", [True, False])
178
- def test_dask_use_explicit_comms(in_cluster):
181
+ def _test_dask_use_explicit_comms(in_cluster):
179
182
  def check_shuffle():
180
183
  """Check if shuffle use explicit-comms by search for keys named
181
184
  'explicit-comms-shuffle'
@@ -217,6 +220,31 @@ def test_dask_use_explicit_comms(in_cluster):
217
220
  check_shuffle()
218
221
 
219
222
 
223
+ @pytest.mark.parametrize("in_cluster", [True, False])
224
+ def test_dask_use_explicit_comms(in_cluster):
225
+ def _timeout(process, function, timeout):
226
+ if process.is_alive():
227
+ function()
228
+ timeout = time.time() + timeout
229
+ while process.is_alive() and time.time() < timeout:
230
+ time.sleep(0.1)
231
+
232
+ p = mp.Process(target=_test_dask_use_explicit_comms, args=(in_cluster,))
233
+ p.start()
234
+
235
+ # Timeout before killing process
236
+ _timeout(p, lambda: None, 60.0)
237
+
238
+ # Send SIGINT (i.e., KeyboardInterrupt) hoping we get a stack trace.
239
+ _timeout(p, partial(p._popen._send_signal, signal.SIGINT), 3.0)
240
+
241
+ # SIGINT didn't work, kill process.
242
+ _timeout(p, p.kill, 3.0)
243
+
244
+ assert not p.is_alive()
245
+ assert p.exitcode == 0
246
+
247
+
220
248
  def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
221
249
  if backend == "cudf":
222
250
  cudf = pytest.importorskip("cudf")
@@ -256,7 +284,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
256
284
 
257
285
  @pytest.mark.parametrize("nworkers", [1, 2, 4])
258
286
  @pytest.mark.parametrize("backend", ["pandas", "cudf"])
259
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
287
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
260
288
  def test_dataframe_shuffle_merge(backend, protocol, nworkers):
261
289
  if backend == "cudf":
262
290
  pytest.importorskip("cudf")
@@ -293,7 +321,7 @@ def _test_jit_unspill(protocol):
293
321
  assert_eq(got, expected)
294
322
 
295
323
 
296
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
324
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
297
325
  def test_jit_unspill(protocol):
298
326
  pytest.importorskip("cudf")
299
327
 
@@ -5,12 +5,16 @@ from distributed import Client
5
5
 
6
6
  from dask_cuda import LocalCUDACluster
7
7
 
8
- pytest.importorskip("ucp")
9
8
  cupy = pytest.importorskip("cupy")
10
9
 
11
10
 
12
- @pytest.mark.parametrize("protocol", ["ucx", "tcp"])
11
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx", "tcp"])
13
12
  def test_ucx_from_array(protocol):
13
+ if protocol == "ucx":
14
+ pytest.importorskip("ucp")
15
+ elif protocol == "ucxx":
16
+ pytest.importorskip("ucxx")
17
+
14
18
  N = 10_000
15
19
  with LocalCUDACluster(protocol=protocol) as cluster:
16
20
  with Client(cluster):
@@ -13,7 +13,6 @@ from dask_cuda.utils import get_ucx_config
13
13
  from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
14
14
 
15
15
  mp = mp.get_context("spawn") # type: ignore
16
- ucp = pytest.importorskip("ucp")
17
16
 
18
17
  # Notice, all of the following tests is executed in a new process such
19
18
  # that UCX options of the different tests doesn't conflict.
@@ -21,11 +20,16 @@ ucp = pytest.importorskip("ucp")
21
20
  # of UCX before retrieving the current config.
22
21
 
23
22
 
24
- def _test_initialize_ucx_tcp():
23
+ def _test_initialize_ucx_tcp(protocol):
24
+ if protocol == "ucx":
25
+ ucp = pytest.importorskip("ucp")
26
+ elif protocol == "ucxx":
27
+ ucp = pytest.importorskip("ucxx")
28
+
25
29
  kwargs = {"enable_tcp_over_ucx": True}
26
- initialize(**kwargs)
30
+ initialize(protocol=protocol, **kwargs)
27
31
  with LocalCluster(
28
- protocol="ucx",
32
+ protocol=protocol,
29
33
  dashboard_address=None,
30
34
  n_workers=1,
31
35
  threads_per_worker=1,
@@ -50,18 +54,29 @@ def _test_initialize_ucx_tcp():
50
54
  assert all(client.run(check_ucx_options).values())
51
55
 
52
56
 
53
- def test_initialize_ucx_tcp():
54
- p = mp.Process(target=_test_initialize_ucx_tcp)
57
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
58
+ def test_initialize_ucx_tcp(protocol):
59
+ if protocol == "ucx":
60
+ pytest.importorskip("ucp")
61
+ elif protocol == "ucxx":
62
+ pytest.importorskip("ucxx")
63
+
64
+ p = mp.Process(target=_test_initialize_ucx_tcp, args=(protocol,))
55
65
  p.start()
56
66
  p.join()
57
67
  assert not p.exitcode
58
68
 
59
69
 
60
- def _test_initialize_ucx_nvlink():
70
+ def _test_initialize_ucx_nvlink(protocol):
71
+ if protocol == "ucx":
72
+ ucp = pytest.importorskip("ucp")
73
+ elif protocol == "ucxx":
74
+ ucp = pytest.importorskip("ucxx")
75
+
61
76
  kwargs = {"enable_nvlink": True}
62
- initialize(**kwargs)
77
+ initialize(protocol=protocol, **kwargs)
63
78
  with LocalCluster(
64
- protocol="ucx",
79
+ protocol=protocol,
65
80
  dashboard_address=None,
66
81
  n_workers=1,
67
82
  threads_per_worker=1,
@@ -87,18 +102,29 @@ def _test_initialize_ucx_nvlink():
87
102
  assert all(client.run(check_ucx_options).values())
88
103
 
89
104
 
90
- def test_initialize_ucx_nvlink():
91
- p = mp.Process(target=_test_initialize_ucx_nvlink)
105
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
106
+ def test_initialize_ucx_nvlink(protocol):
107
+ if protocol == "ucx":
108
+ pytest.importorskip("ucp")
109
+ elif protocol == "ucxx":
110
+ pytest.importorskip("ucxx")
111
+
112
+ p = mp.Process(target=_test_initialize_ucx_nvlink, args=(protocol,))
92
113
  p.start()
93
114
  p.join()
94
115
  assert not p.exitcode
95
116
 
96
117
 
97
- def _test_initialize_ucx_infiniband():
118
+ def _test_initialize_ucx_infiniband(protocol):
119
+ if protocol == "ucx":
120
+ ucp = pytest.importorskip("ucp")
121
+ elif protocol == "ucxx":
122
+ ucp = pytest.importorskip("ucxx")
123
+
98
124
  kwargs = {"enable_infiniband": True}
99
- initialize(**kwargs)
125
+ initialize(protocol=protocol, **kwargs)
100
126
  with LocalCluster(
101
- protocol="ucx",
127
+ protocol=protocol,
102
128
  dashboard_address=None,
103
129
  n_workers=1,
104
130
  threads_per_worker=1,
@@ -127,17 +153,28 @@ def _test_initialize_ucx_infiniband():
127
153
  @pytest.mark.skipif(
128
154
  "ib0" not in psutil.net_if_addrs(), reason="Infiniband interface ib0 not found"
129
155
  )
130
- def test_initialize_ucx_infiniband():
131
- p = mp.Process(target=_test_initialize_ucx_infiniband)
156
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
157
+ def test_initialize_ucx_infiniband(protocol):
158
+ if protocol == "ucx":
159
+ pytest.importorskip("ucp")
160
+ elif protocol == "ucxx":
161
+ pytest.importorskip("ucxx")
162
+
163
+ p = mp.Process(target=_test_initialize_ucx_infiniband, args=(protocol,))
132
164
  p.start()
133
165
  p.join()
134
166
  assert not p.exitcode
135
167
 
136
168
 
137
- def _test_initialize_ucx_all():
138
- initialize()
169
+ def _test_initialize_ucx_all(protocol):
170
+ if protocol == "ucx":
171
+ ucp = pytest.importorskip("ucp")
172
+ elif protocol == "ucxx":
173
+ ucp = pytest.importorskip("ucxx")
174
+
175
+ initialize(protocol=protocol)
139
176
  with LocalCluster(
140
- protocol="ucx",
177
+ protocol=protocol,
141
178
  dashboard_address=None,
142
179
  n_workers=1,
143
180
  threads_per_worker=1,
@@ -166,8 +203,14 @@ def _test_initialize_ucx_all():
166
203
  assert all(client.run(check_ucx_options).values())
167
204
 
168
205
 
169
- def test_initialize_ucx_all():
170
- p = mp.Process(target=_test_initialize_ucx_all)
206
+ @pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
207
+ def test_initialize_ucx_all(protocol):
208
+ if protocol == "ucx":
209
+ pytest.importorskip("ucp")
210
+ elif protocol == "ucxx":
211
+ pytest.importorskip("ucxx")
212
+
213
+ p = mp.Process(target=_test_initialize_ucx_all, args=(protocol,))
171
214
  p.start()
172
215
  p.join()
173
216
  assert not p.exitcode
@@ -87,23 +87,38 @@ async def test_with_subset_of_cuda_visible_devices():
87
87
  }
88
88
 
89
89
 
90
+ @pytest.mark.parametrize(
91
+ "protocol",
92
+ ["ucx", "ucxx"],
93
+ )
90
94
  @gen_test(timeout=20)
91
- async def test_ucx_protocol():
92
- pytest.importorskip("ucp")
95
+ async def test_ucx_protocol(protocol):
96
+ if protocol == "ucx":
97
+ pytest.importorskip("ucp")
98
+ elif protocol == "ucxx":
99
+ pytest.importorskip("ucxx")
93
100
 
94
101
  async with LocalCUDACluster(
95
- protocol="ucx", asynchronous=True, data=dict
102
+ protocol=protocol, asynchronous=True, data=dict
96
103
  ) as cluster:
97
104
  assert all(
98
- ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values()
105
+ ws.address.startswith(f"{protocol}://")
106
+ for ws in cluster.scheduler.workers.values()
99
107
  )
100
108
 
101
109
 
110
+ @pytest.mark.parametrize(
111
+ "protocol",
112
+ ["ucx", "ucxx"],
113
+ )
102
114
  @gen_test(timeout=20)
103
- async def test_explicit_ucx_with_protocol_none():
104
- pytest.importorskip("ucp")
115
+ async def test_explicit_ucx_with_protocol_none(protocol):
116
+ if protocol == "ucx":
117
+ pytest.importorskip("ucp")
118
+ elif protocol == "ucxx":
119
+ pytest.importorskip("ucxx")
105
120
 
106
- initialize(enable_tcp_over_ucx=True)
121
+ initialize(protocol=protocol, enable_tcp_over_ucx=True)
107
122
  async with LocalCUDACluster(
108
123
  protocol=None, enable_tcp_over_ucx=True, asynchronous=True, data=dict
109
124
  ) as cluster:
@@ -113,11 +128,18 @@ async def test_explicit_ucx_with_protocol_none():
113
128
 
114
129
 
115
130
  @pytest.mark.filterwarnings("ignore:Exception ignored in")
131
+ @pytest.mark.parametrize(
132
+ "protocol",
133
+ ["ucx", "ucxx"],
134
+ )
116
135
  @gen_test(timeout=20)
117
- async def test_ucx_protocol_type_error():
118
- pytest.importorskip("ucp")
136
+ async def test_ucx_protocol_type_error(protocol):
137
+ if protocol == "ucx":
138
+ pytest.importorskip("ucp")
139
+ elif protocol == "ucxx":
140
+ pytest.importorskip("ucxx")
119
141
 
120
- initialize(enable_tcp_over_ucx=True)
142
+ initialize(protocol=protocol, enable_tcp_over_ucx=True)
121
143
  with pytest.raises(TypeError):
122
144
  async with LocalCUDACluster(
123
145
  protocol="tcp", enable_tcp_over_ucx=True, asynchronous=True, data=dict
@@ -337,6 +359,7 @@ async def test_pre_import():
337
359
 
338
360
 
339
361
  # Intentionally not using @gen_test to skip cleanup checks
362
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
340
363
  def test_pre_import_not_found():
341
364
  async def _test_pre_import_not_found():
342
365
  with raises_with_cause(RuntimeError, None, ImportError, None):
@@ -477,20 +500,30 @@ async def test_worker_fraction_limits():
477
500
  )
478
501
 
479
502
 
480
- def test_print_cluster_config(capsys):
503
+ @pytest.mark.parametrize(
504
+ "protocol",
505
+ ["ucx", "ucxx"],
506
+ )
507
+ def test_print_cluster_config(capsys, protocol):
508
+ if protocol == "ucx":
509
+ pytest.importorskip("ucp")
510
+ elif protocol == "ucxx":
511
+ pytest.importorskip("ucxx")
512
+
481
513
  pytest.importorskip("rich")
482
514
  with LocalCUDACluster(
483
- n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol="ucx"
515
+ n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol=protocol
484
516
  ) as cluster:
485
517
  with Client(cluster) as client:
486
518
  print_cluster_config(client)
487
519
  captured = capsys.readouterr()
488
520
  assert "Dask Cluster Configuration" in captured.out
489
- assert "ucx" in captured.out
521
+ assert protocol in captured.out
490
522
  assert "1 B" in captured.out
491
523
  assert "[plugin]" in captured.out
492
524
 
493
525
 
526
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
494
527
  def test_death_timeout_raises():
495
528
  with pytest.raises(asyncio.exceptions.TimeoutError):
496
529
  with LocalCUDACluster(
@@ -302,13 +302,24 @@ def test_dataframes_share_dev_mem(root_dir):
302
302
  def test_cudf_get_device_memory_objects():
303
303
  cudf = pytest.importorskip("cudf")
304
304
  objects = [
305
- cudf.DataFrame({"a": range(10), "b": range(10)}, index=reversed(range(10))),
305
+ cudf.DataFrame(
306
+ {"a": [0, 1, 2, 3, None, 5, 6, 7, 8, 9], "b": range(10)},
307
+ index=reversed(range(10)),
308
+ ),
306
309
  cudf.MultiIndex(
307
310
  levels=[[1, 2], ["blue", "red"]], codes=[[0, 0, 1, 1], [1, 0, 1, 0]]
308
311
  ),
309
312
  ]
310
313
  res = get_device_memory_ids(objects)
311
- assert len(res) == 4, "We expect four buffer objects"
314
+ # Buffers are:
315
+ # 1. int data for objects[0].a
316
+ # 2. mask data for objects[0].a
317
+ # 3. int data for objects[0].b
318
+ # 4. int data for objects[0].index
319
+ # 5. int data for objects[1].levels[0]
320
+ # 6. char data for objects[1].levels[1]
321
+ # 7. offset data for objects[1].levels[1]
322
+ assert len(res) == 7, "We expect seven buffer objects"
312
323
 
313
324
 
314
325
  def test_externals(root_dir):
@@ -403,7 +414,7 @@ async def test_compatibility_mode_dataframe_shuffle(compatibility_mode, npartiti
403
414
  ddf = dask.dataframe.from_pandas(
404
415
  cudf.DataFrame({"key": np.arange(10)}), npartitions=npartitions
405
416
  )
406
- res = ddf.shuffle(on="key", shuffle="tasks").persist()
417
+ res = ddf.shuffle(on="key", shuffle_method="tasks").persist()
407
418
 
408
419
  # With compatibility mode on, we shouldn't encounter any proxy objects
409
420
  if compatibility_mode:
@@ -306,6 +306,7 @@ async def test_spilling_local_cuda_cluster(jit_unspill):
306
306
  n_workers=1,
307
307
  device_memory_limit="1B",
308
308
  jit_unspill=jit_unspill,
309
+ worker_class=IncreasedCloseTimeoutNanny,
309
310
  asynchronous=True,
310
311
  ) as cluster:
311
312
  async with Client(cluster, asynchronous=True) as client:
@@ -400,10 +401,14 @@ class _PxyObjTest(proxy_object.ProxyObject):
400
401
 
401
402
 
402
403
  @pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)])
403
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
404
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
404
405
  @gen_test(timeout=120)
405
406
  async def test_communicating_proxy_objects(protocol, send_serializers):
406
407
  """Testing serialization of cuDF dataframe when communicating"""
408
+ if protocol == "ucx":
409
+ pytest.importorskip("ucp")
410
+ elif protocol == "ucxx":
411
+ pytest.importorskip("ucxx")
407
412
  cudf = pytest.importorskip("cudf")
408
413
 
409
414
  def task(x):
@@ -412,7 +417,7 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
412
417
  serializers_used = x._pxy_get().serializer
413
418
 
414
419
  # Check that `x` is serialized with the expected serializers
415
- if protocol == "ucx":
420
+ if protocol in ["ucx", "ucxx"]:
416
421
  if send_serializers is None:
417
422
  assert serializers_used == "cuda"
418
423
  else:
@@ -443,11 +448,15 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
443
448
  await client.submit(task, df)
444
449
 
445
450
 
446
- @pytest.mark.parametrize("protocol", ["tcp", "ucx"])
451
+ @pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
447
452
  @pytest.mark.parametrize("shared_fs", [True, False])
448
453
  @gen_test(timeout=20)
449
454
  async def test_communicating_disk_objects(protocol, shared_fs):
450
455
  """Testing disk serialization of cuDF dataframe when communicating"""
456
+ if protocol == "ucx":
457
+ pytest.importorskip("ucp")
458
+ elif protocol == "ucxx":
459
+ pytest.importorskip("ucxx")
451
460
  cudf = pytest.importorskip("cudf")
452
461
  ProxifyHostFile._spill_to_disk.shared_filesystem = shared_fs
453
462
 
@@ -79,11 +79,18 @@ def test_get_device_total_memory():
79
79
  assert total_mem > 0
80
80
 
81
81
 
82
- def test_get_preload_options_default():
83
- pytest.importorskip("ucp")
82
+ @pytest.mark.parametrize(
83
+ "protocol",
84
+ ["ucx", "ucxx"],
85
+ )
86
+ def test_get_preload_options_default(protocol):
87
+ if protocol == "ucx":
88
+ pytest.importorskip("ucp")
89
+ elif protocol == "ucxx":
90
+ pytest.importorskip("ucxx")
84
91
 
85
92
  opts = get_preload_options(
86
- protocol="ucx",
93
+ protocol=protocol,
87
94
  create_cuda_context=True,
88
95
  )
89
96
 
@@ -93,14 +100,21 @@ def test_get_preload_options_default():
93
100
  assert opts["preload_argv"] == ["--create-cuda-context"]
94
101
 
95
102
 
103
+ @pytest.mark.parametrize(
104
+ "protocol",
105
+ ["ucx", "ucxx"],
106
+ )
96
107
  @pytest.mark.parametrize("enable_tcp", [True, False])
97
108
  @pytest.mark.parametrize("enable_infiniband", [True, False])
98
109
  @pytest.mark.parametrize("enable_nvlink", [True, False])
99
- def test_get_preload_options(enable_tcp, enable_infiniband, enable_nvlink):
100
- pytest.importorskip("ucp")
110
+ def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvlink):
111
+ if protocol == "ucx":
112
+ pytest.importorskip("ucp")
113
+ elif protocol == "ucxx":
114
+ pytest.importorskip("ucxx")
101
115
 
102
116
  opts = get_preload_options(
103
- protocol="ucx",
117
+ protocol=protocol,
104
118
  create_cuda_context=True,
105
119
  enable_tcp_over_ucx=enable_tcp,
106
120
  enable_infiniband=enable_infiniband,
dask_cuda/utils.py CHANGED
@@ -287,7 +287,7 @@ def get_preload_options(
287
287
  if create_cuda_context:
288
288
  preload_options["preload_argv"].append("--create-cuda-context")
289
289
 
290
- if protocol == "ucx":
290
+ if protocol in ["ucx", "ucxx"]:
291
291
  initialize_ucx_argv = []
292
292
  if enable_tcp_over_ucx:
293
293
  initialize_ucx_argv.append("--enable-tcp-over-ucx")
@@ -625,6 +625,10 @@ def get_worker_config(dask_worker):
625
625
  import ucp
626
626
 
627
627
  ret["ucx-transports"] = ucp.get_active_transports()
628
+ elif scheme == "ucxx":
629
+ import ucxx
630
+
631
+ ret["ucx-transports"] = ucxx.get_active_transports()
628
632
 
629
633
  # comm timeouts
630
634
  ret["distributed.comm.timeouts"] = dask.config.get("distributed.comm.timeouts")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dask-cuda
3
- Version: 23.12.0a231027
3
+ Version: 24.2.0
4
4
  Summary: Utilities for Dask and CUDA interactions
5
5
  Author: NVIDIA Corporation
6
6
  License: Apache-2.0
@@ -17,12 +17,12 @@ Classifier: Programming Language :: Python :: 3.10
17
17
  Requires-Python: >=3.9
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
- Requires-Dist: dask ==2023.9.2
21
- Requires-Dist: distributed ==2023.9.2
22
- Requires-Dist: pynvml <11.5,>=11.0.0
23
- Requires-Dist: numpy >=1.21
20
+ Requires-Dist: click >=8.1
24
21
  Requires-Dist: numba >=0.57
25
- Requires-Dist: pandas <1.6.0dev0,>=1.3
22
+ Requires-Dist: numpy >=1.21
23
+ Requires-Dist: pandas <1.6.0.dev0,>=1.3
24
+ Requires-Dist: pynvml <11.5,>=11.0.0
25
+ Requires-Dist: rapids-dask-dependency ==24.2.*
26
26
  Requires-Dist: zict >=2.0.0
27
27
  Provides-Extra: docs
28
28
  Requires-Dist: numpydoc >=1.1.0 ; extra == 'docs'
@@ -30,7 +30,12 @@ Requires-Dist: sphinx ; extra == 'docs'
30
30
  Requires-Dist: sphinx-click >=2.7.1 ; extra == 'docs'
31
31
  Requires-Dist: sphinx-rtd-theme >=0.5.1 ; extra == 'docs'
32
32
  Provides-Extra: test
33
+ Requires-Dist: cudf ==24.2.* ; extra == 'test'
34
+ Requires-Dist: dask-cudf ==24.2.* ; extra == 'test'
35
+ Requires-Dist: kvikio ==24.2.* ; extra == 'test'
33
36
  Requires-Dist: pytest ; extra == 'test'
37
+ Requires-Dist: pytest-cov ; extra == 'test'
38
+ Requires-Dist: ucx-py ==0.36.* ; extra == 'test'
34
39
 
35
40
  Dask CUDA
36
41
  =========
@@ -0,0 +1,53 @@
1
+ dask_cuda/VERSION,sha256=LOsdRePwGiMfhM2DrcvIm5wG4HpS3B0cMVJnTcjfKmM,9
2
+ dask_cuda/__init__.py,sha256=XnMTUi-SvoGn7g1Dj6XW97HnQzGQv0G3EnvSjcZ7vU4,1455
3
+ dask_cuda/_version.py,sha256=iR6Kt93dZiHB4aBed4vCsW9knxQxMl0--nC5HoIaxyE,778
4
+ dask_cuda/cli.py,sha256=XNRH0bu-6jzRoyWJB5qSWuzePJSh3z_5Ng6rDCnz7lg,15970
5
+ dask_cuda/cuda_worker.py,sha256=bIu-ESeIpJG_WaTYrv0z9z5juJ1qR5i_5Ng3CN1WK8s,8579
6
+ dask_cuda/device_host_file.py,sha256=yS31LGtt9VFAG78uBBlTDr7HGIng2XymV1OxXIuEMtM,10272
7
+ dask_cuda/disk_io.py,sha256=urSLKiPvJvYmKCzDPOUDCYuLI3r1RUiyVh3UZGRoF_Y,6626
8
+ dask_cuda/get_device_memory_objects.py,sha256=R3U2cq4fJZPgtsUKyIguy9161p3Q99oxmcCmTcg6BtQ,4075
9
+ dask_cuda/initialize.py,sha256=Gjcxs_c8DTafgsHe5-2mw4lJdOmbFJJAZVOnxA8lTjM,6462
10
+ dask_cuda/is_device_object.py,sha256=CnajvbQiX0FzFzwft0MqK1OPomx3ZGDnDxT56wNjixw,1046
11
+ dask_cuda/is_spillable_object.py,sha256=CddGmg0tuSpXh2m_TJSY6GRpnl1WRHt1CRcdWgHPzWA,1457
12
+ dask_cuda/local_cuda_cluster.py,sha256=hoEiEfJqAQrRS7N632VatSl1245GiWMT5B77Wc-i5C0,17928
13
+ dask_cuda/plugins.py,sha256=cnHsdrXx7PBPmrzHX6YEkCH5byCsUk8LE2FeTeu8ZLU,4259
14
+ dask_cuda/proxify_device_objects.py,sha256=99CD7LOE79YiQGJ12sYl_XImVhJXpFR4vG5utdkjTQo,8108
15
+ dask_cuda/proxify_host_file.py,sha256=Wf5CFCC1JN5zmfvND3ls0M5FL01Y8VhHrk0xV3UQ9kk,30850
16
+ dask_cuda/proxy_object.py,sha256=bZq92kjgFB-ad_luSAFT_RItV3nssmiEk4OOSp34laU,29812
17
+ dask_cuda/utils.py,sha256=RWlLK2cPHaCuNNhr8bW8etBeGklwREQJOafQbTydStk,25121
18
+ dask_cuda/utils_test.py,sha256=WNMR0gic2tuP3pgygcR9g52NfyX8iGMOan6juXhpkCE,1694
19
+ dask_cuda/worker_spec.py,sha256=7-Uq_e5q2SkTlsmctMcYLCa9_3RiiVHZLIN7ctfaFmE,4376
20
+ dask_cuda/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ dask_cuda/benchmarks/common.py,sha256=sEIFnRZS6wbyKCQyB4fDclYLc2YqC0PolurR5qzuRxw,6393
22
+ dask_cuda/benchmarks/local_cudf_groupby.py,sha256=T9lA9nb4Wzu46AH--SJEVCeCm3650J7slapdNR_08FU,8904
23
+ dask_cuda/benchmarks/local_cudf_merge.py,sha256=POjxoPx4zY1TjG2S_anElL6rDtC5Jhn3nF4HABlnwZg,12447
24
+ dask_cuda/benchmarks/local_cudf_shuffle.py,sha256=M-Lp3O3q8uyY50imQqMKZYwkAmyR0NApjx2ipGxDkXw,8608
25
+ dask_cuda/benchmarks/local_cupy.py,sha256=aUKIYfeR7c77K4kKk697Rxo8tG8kFabQ9jQEVGr-oTs,10762
26
+ dask_cuda/benchmarks/local_cupy_map_overlap.py,sha256=_texYmam1K_XbzIvURltui5KRsISGFNylXiGUtgRIz0,6442
27
+ dask_cuda/benchmarks/utils.py,sha256=baL5zK6VS6Mw_M4x9zJe8vMLUd2SZd1lS78JrL-h6oo,26896
28
+ dask_cuda/explicit_comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
+ dask_cuda/explicit_comms/comms.py,sha256=Su6PuNo68IyS-AwoqU4S9TmqWsLvUdNa0jot2hx8jQQ,10400
30
+ dask_cuda/explicit_comms/dataframe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
+ dask_cuda/explicit_comms/dataframe/shuffle.py,sha256=YferHNWKsMea8tele-ynPVr_6RAZNZIR-VzK_uFuEQU,20131
32
+ dask_cuda/tests/test_cudf_builtin_spilling.py,sha256=u3kW91YRLdHFycvpGfSQKrEucu5khMJ1k4sjmddO490,4910
33
+ dask_cuda/tests/test_dask_cuda_worker.py,sha256=gViHaMCSfB6ip125OEi9D0nfKC-qBXRoHz6BRodEdb4,17729
34
+ dask_cuda/tests/test_device_host_file.py,sha256=79ssUISo1YhsW_7HdwqPfsH2LRzS2bi5BjPym1Sdgqw,5882
35
+ dask_cuda/tests/test_dgx.py,sha256=Oh2vwL_CdUzSVQQoiIu6SPwXGRtmXwaW_Hh3ipXPUOc,7162
36
+ dask_cuda/tests/test_explicit_comms.py,sha256=I4lSW-NQ0E08baEoG7cY4Ix3blGb1Auz88q2BNd1cPA,13136
37
+ dask_cuda/tests/test_from_array.py,sha256=okT1B6UqHmLxoy0uER0Ylm3UyOmi5BAXwJpTuTAw44I,601
38
+ dask_cuda/tests/test_gds.py,sha256=6jf0HPTHAIG8Mp_FC4Ai4zpn-U1K7yk0fSXg8He8-r8,1513
39
+ dask_cuda/tests/test_initialize.py,sha256=Rba59ZbljEm1yyN94_sWZPEE_f7hWln95aiBVc49pmY,6960
40
+ dask_cuda/tests/test_local_cuda_cluster.py,sha256=G3kR-4o-vCqWWfSuQLFKVEK0F243FaDSgRlDTUll5aU,18376
41
+ dask_cuda/tests/test_proxify_host_file.py,sha256=Yiv0sDcUoWw0d2oiPeHGoHqqSSM4lfQ4rChCiaxb6EU,18994
42
+ dask_cuda/tests/test_proxy.py,sha256=6iicSYYT2BGo1iKUQ7jM00mCjC4gtfwwxFXfGwH3QHc,23807
43
+ dask_cuda/tests/test_spill.py,sha256=xN9PbVERBYMuZxvscSO0mAM22loq9WT3ltZVBFxlmM4,10239
44
+ dask_cuda/tests/test_utils.py,sha256=JRIwXfemc3lWSzLJX0VcvR1_0wB4yeoOTsw7kB6z6pU,9176
45
+ dask_cuda/tests/test_worker_spec.py,sha256=Bvu85vkqm6ZDAYPXKMJlI2pm9Uc5tiYKNtO4goXSw-I,2399
46
+ examples/ucx/client_initialize.py,sha256=YN3AXHF8btcMd6NicKKhKR9SXouAsK1foJhFspbOn70,1262
47
+ examples/ucx/local_cuda_cluster.py,sha256=7xVY3EhwhkY2L4VZin_BiMCbrjhirDNChoC86KiETNc,1983
48
+ dask_cuda-24.2.0.dist-info/LICENSE,sha256=MjI3I-EgxfEvZlgjk82rgiFsZqSDXHFETd2QJ89UwDA,11348
49
+ dask_cuda-24.2.0.dist-info/METADATA,sha256=WDvD-un12aVVPatfnue3HLTts2j9cUz9lxSUrzh05vE,2524
50
+ dask_cuda-24.2.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
51
+ dask_cuda-24.2.0.dist-info/entry_points.txt,sha256=UcRaKVEpywtxc6pF1VnfMB0UK4sJg7a8_NdZF67laPM,136
52
+ dask_cuda-24.2.0.dist-info/top_level.txt,sha256=3kKxJxeM108fuYc_lwwlklP7YBU9IEmdmRAouzi397o,33
53
+ dask_cuda-24.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.2)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
dask_cuda/compat.py DELETED
@@ -1,118 +0,0 @@
1
- import pickle
2
-
3
- import msgpack
4
- from packaging.version import Version
5
-
6
- import dask
7
- import distributed
8
- import distributed.comm.utils
9
- import distributed.protocol
10
- from distributed.comm.utils import OFFLOAD_THRESHOLD, nbytes, offload
11
- from distributed.protocol.core import (
12
- Serialized,
13
- decompress,
14
- logger,
15
- merge_and_deserialize,
16
- msgpack_decode_default,
17
- msgpack_opts,
18
- )
19
-
20
- if Version(distributed.__version__) >= Version("2023.8.1"):
21
- # Monkey-patch protocol.core.loads (and its users)
22
- async def from_frames(
23
- frames, deserialize=True, deserializers=None, allow_offload=True
24
- ):
25
- """
26
- Unserialize a list of Distributed protocol frames.
27
- """
28
- size = False
29
-
30
- def _from_frames():
31
- try:
32
- # Patched code
33
- return loads(
34
- frames, deserialize=deserialize, deserializers=deserializers
35
- )
36
- # end patched code
37
- except EOFError:
38
- if size > 1000:
39
- datastr = "[too large to display]"
40
- else:
41
- datastr = frames
42
- # Aid diagnosing
43
- logger.error("truncated data stream (%d bytes): %s", size, datastr)
44
- raise
45
-
46
- if allow_offload and deserialize and OFFLOAD_THRESHOLD:
47
- size = sum(map(nbytes, frames))
48
- if (
49
- allow_offload
50
- and deserialize
51
- and OFFLOAD_THRESHOLD
52
- and size > OFFLOAD_THRESHOLD
53
- ):
54
- res = await offload(_from_frames)
55
- else:
56
- res = _from_frames()
57
-
58
- return res
59
-
60
- def loads(frames, deserialize=True, deserializers=None):
61
- """Transform bytestream back into Python value"""
62
-
63
- allow_pickle = dask.config.get("distributed.scheduler.pickle")
64
-
65
- try:
66
-
67
- def _decode_default(obj):
68
- offset = obj.get("__Serialized__", 0)
69
- if offset > 0:
70
- sub_header = msgpack.loads(
71
- frames[offset],
72
- object_hook=msgpack_decode_default,
73
- use_list=False,
74
- **msgpack_opts,
75
- )
76
- offset += 1
77
- sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
78
- if deserialize:
79
- if "compression" in sub_header:
80
- sub_frames = decompress(sub_header, sub_frames)
81
- return merge_and_deserialize(
82
- sub_header, sub_frames, deserializers=deserializers
83
- )
84
- else:
85
- return Serialized(sub_header, sub_frames)
86
-
87
- offset = obj.get("__Pickled__", 0)
88
- if offset > 0:
89
- sub_header = msgpack.loads(frames[offset])
90
- offset += 1
91
- sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
92
- # Patched code
93
- if "compression" in sub_header:
94
- sub_frames = decompress(sub_header, sub_frames)
95
- # end patched code
96
- if allow_pickle:
97
- return pickle.loads(
98
- sub_header["pickled-obj"], buffers=sub_frames
99
- )
100
- else:
101
- raise ValueError(
102
- "Unpickle on the Scheduler isn't allowed, "
103
- "set `distributed.scheduler.pickle=true`"
104
- )
105
-
106
- return msgpack_decode_default(obj)
107
-
108
- return msgpack.loads(
109
- frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
110
- )
111
-
112
- except Exception:
113
- logger.critical("Failed to deserialize", exc_info=True)
114
- raise
115
-
116
- distributed.protocol.loads = loads
117
- distributed.protocol.core.loads = loads
118
- distributed.comm.utils.from_frames = from_frames
@@ -1,52 +0,0 @@
1
- dask_cuda/__init__.py,sha256=xtogPs_QSmTTMOWetj9CqLaCwdF-bANfrD75LYpulMc,1452
2
- dask_cuda/cli.py,sha256=XNRH0bu-6jzRoyWJB5qSWuzePJSh3z_5Ng6rDCnz7lg,15970
3
- dask_cuda/compat.py,sha256=BLXv9IHUtD3h6-T_8MX-uGt-UDMG6EuGuyN-zw3XndU,4084
4
- dask_cuda/cuda_worker.py,sha256=bIu-ESeIpJG_WaTYrv0z9z5juJ1qR5i_5Ng3CN1WK8s,8579
5
- dask_cuda/device_host_file.py,sha256=D0rHOFz1TRfvaecoP30x3JRWe1TiHUaq45Dg-v0DfoY,10272
6
- dask_cuda/disk_io.py,sha256=urSLKiPvJvYmKCzDPOUDCYuLI3r1RUiyVh3UZGRoF_Y,6626
7
- dask_cuda/get_device_memory_objects.py,sha256=zMSqWzm5rflRInbNMz7U2Ewv5nMcE-H8stMJeWHVWyc,3890
8
- dask_cuda/initialize.py,sha256=mzPgKhs8oLgUWpqd4ckvLNKvhLoHjt96RrBPeVneenI,5231
9
- dask_cuda/is_device_object.py,sha256=CnajvbQiX0FzFzwft0MqK1OPomx3ZGDnDxT56wNjixw,1046
10
- dask_cuda/is_spillable_object.py,sha256=CddGmg0tuSpXh2m_TJSY6GRpnl1WRHt1CRcdWgHPzWA,1457
11
- dask_cuda/local_cuda_cluster.py,sha256=w2HXMZtEukwklkB3J6l6DqZstNA5uvGEdFkdzpyUJ6k,17810
12
- dask_cuda/plugins.py,sha256=cnHsdrXx7PBPmrzHX6YEkCH5byCsUk8LE2FeTeu8ZLU,4259
13
- dask_cuda/proxify_device_objects.py,sha256=99CD7LOE79YiQGJ12sYl_XImVhJXpFR4vG5utdkjTQo,8108
14
- dask_cuda/proxify_host_file.py,sha256=Wf5CFCC1JN5zmfvND3ls0M5FL01Y8VhHrk0xV3UQ9kk,30850
15
- dask_cuda/proxy_object.py,sha256=bZq92kjgFB-ad_luSAFT_RItV3nssmiEk4OOSp34laU,29812
16
- dask_cuda/utils.py,sha256=wNRItbIXrOpH77AUUrZNrGqgIiGNzpClXYl0QmQqfxs,25002
17
- dask_cuda/utils_test.py,sha256=WNMR0gic2tuP3pgygcR9g52NfyX8iGMOan6juXhpkCE,1694
18
- dask_cuda/worker_spec.py,sha256=7-Uq_e5q2SkTlsmctMcYLCa9_3RiiVHZLIN7ctfaFmE,4376
19
- dask_cuda/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- dask_cuda/benchmarks/common.py,sha256=sEIFnRZS6wbyKCQyB4fDclYLc2YqC0PolurR5qzuRxw,6393
21
- dask_cuda/benchmarks/local_cudf_groupby.py,sha256=2iHk-a-GvLmAgajwQJNrqmZ-WJeiyMFEyflcxh7SPO8,8894
22
- dask_cuda/benchmarks/local_cudf_merge.py,sha256=vccM5PyzZVW99-a8YaIgftsGAiA5yXnT9NoAusx0PZY,12437
23
- dask_cuda/benchmarks/local_cudf_shuffle.py,sha256=LaNCMKhhfE1lYFUUWMtdYH-efbqV6YTFhKC-Eog9-8Q,8598
24
- dask_cuda/benchmarks/local_cupy.py,sha256=G36CI46ROtNPf6QISK6QjoguB2Qb19ScsylwLFOlMy4,10752
25
- dask_cuda/benchmarks/local_cupy_map_overlap.py,sha256=rQNLGvpX1XpgK-0Wx5fd3kV9Veu2ulBd5eX2sanNlEQ,6432
26
- dask_cuda/benchmarks/utils.py,sha256=mx_JKe4q1xFNwKJX03o8dEwc48iqnqHm-ZTHOcMn17E,26888
27
- dask_cuda/explicit_comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
- dask_cuda/explicit_comms/comms.py,sha256=Su6PuNo68IyS-AwoqU4S9TmqWsLvUdNa0jot2hx8jQQ,10400
29
- dask_cuda/explicit_comms/dataframe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- dask_cuda/explicit_comms/dataframe/shuffle.py,sha256=2f2wlPyqXpryIHgMpsZzs3pDE7eyslYam-jQh3ujszQ,20124
31
- dask_cuda/tests/test_cudf_builtin_spilling.py,sha256=u3kW91YRLdHFycvpGfSQKrEucu5khMJ1k4sjmddO490,4910
32
- dask_cuda/tests/test_dask_cuda_worker.py,sha256=gViHaMCSfB6ip125OEi9D0nfKC-qBXRoHz6BRodEdb4,17729
33
- dask_cuda/tests/test_device_host_file.py,sha256=79ssUISo1YhsW_7HdwqPfsH2LRzS2bi5BjPym1Sdgqw,5882
34
- dask_cuda/tests/test_dgx.py,sha256=bKX-GvkYjWlmcEIK15aGErxmc0qPqIWOG1CeDFGoXFU,6381
35
- dask_cuda/tests/test_explicit_comms.py,sha256=3Q3o9BX4ksCgz11o38o5QhKg3Rv-EtTsGnVG83wwyyo,12283
36
- dask_cuda/tests/test_from_array.py,sha256=i2Vha4mchB0BopTlEdXV7CxY7qyTzFYdgYQTmukZX38,493
37
- dask_cuda/tests/test_gds.py,sha256=6jf0HPTHAIG8Mp_FC4Ai4zpn-U1K7yk0fSXg8He8-r8,1513
38
- dask_cuda/tests/test_initialize.py,sha256=-Vo8SVBrVEKB0V1C6ia8khvbHJt4BC0xEjMNLhNbFxI,5491
39
- dask_cuda/tests/test_local_cuda_cluster.py,sha256=rKusvUq_5FNqev51Vli1wlQQH7qL_5lIsjQN57e0A4k,17445
40
- dask_cuda/tests/test_proxify_host_file.py,sha256=cp-U1uNPhesQaHbftKV8ir_dt5fbs0ZXSIsL39oI0fE,18630
41
- dask_cuda/tests/test_proxy.py,sha256=Nu9vLx-dALINcF_wsxuFYUryRE0Jq43w7bAYAchK8RY,23480
42
- dask_cuda/tests/test_spill.py,sha256=xN9PbVERBYMuZxvscSO0mAM22loq9WT3ltZVBFxlmM4,10239
43
- dask_cuda/tests/test_utils.py,sha256=wgYPvu7Sk61C64pah9ZbK8cnBXK5RyUCpu3G2ny6OZQ,8832
44
- dask_cuda/tests/test_worker_spec.py,sha256=Bvu85vkqm6ZDAYPXKMJlI2pm9Uc5tiYKNtO4goXSw-I,2399
45
- examples/ucx/client_initialize.py,sha256=YN3AXHF8btcMd6NicKKhKR9SXouAsK1foJhFspbOn70,1262
46
- examples/ucx/local_cuda_cluster.py,sha256=7xVY3EhwhkY2L4VZin_BiMCbrjhirDNChoC86KiETNc,1983
47
- dask_cuda-23.12.0a231027.dist-info/LICENSE,sha256=MjI3I-EgxfEvZlgjk82rgiFsZqSDXHFETd2QJ89UwDA,11348
48
- dask_cuda-23.12.0a231027.dist-info/METADATA,sha256=gbaDcC9Wti46hsAM8g0wzySIvt9MVvxyU_zA-PSdl1s,2285
49
- dask_cuda-23.12.0a231027.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
50
- dask_cuda-23.12.0a231027.dist-info/entry_points.txt,sha256=UcRaKVEpywtxc6pF1VnfMB0UK4sJg7a8_NdZF67laPM,136
51
- dask_cuda-23.12.0a231027.dist-info/top_level.txt,sha256=3kKxJxeM108fuYc_lwwlklP7YBU9IEmdmRAouzi397o,33
52
- dask_cuda-23.12.0a231027.dist-info/RECORD,,