dask-cuda 23.10.0a231015__py3-none-any.whl → 23.12.0a24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dask_cuda/VERSION ADDED
@@ -0,0 +1 @@
1
+ 23.12.00a24
dask_cuda/__init__.py CHANGED
@@ -11,6 +11,7 @@ import dask.dataframe.shuffle
11
11
  import dask.dataframe.multi
12
12
  import dask.bag.core
13
13
 
14
+ from ._version import __git_commit__, __version__
14
15
  from .cuda_worker import CUDAWorker
15
16
  from .explicit_comms.dataframe.shuffle import (
16
17
  get_rearrange_by_column_wrapper,
@@ -19,9 +20,6 @@ from .explicit_comms.dataframe.shuffle import (
19
20
  from .local_cuda_cluster import LocalCUDACluster
20
21
  from .proxify_device_objects import proxify_decorator, unproxify_decorator
21
22
 
22
- __version__ = "23.10.00"
23
-
24
- from . import compat
25
23
 
26
24
  # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
27
25
  dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper(
dask_cuda/_version.py ADDED
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2023, NVIDIA CORPORATION.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib.resources
16
+
17
+ __version__ = (
18
+ importlib.resources.files("dask_cuda").joinpath("VERSION").read_text().strip()
19
+ )
20
+ __git_commit__ = ""
dask_cuda/cuda_worker.py CHANGED
@@ -20,11 +20,9 @@ from distributed.worker_memory import parse_memory_limit
20
20
 
21
21
  from .device_host_file import DeviceHostFile
22
22
  from .initialize import initialize
23
+ from .plugins import CPUAffinity, PreImport, RMMSetup
23
24
  from .proxify_host_file import ProxifyHostFile
24
25
  from .utils import (
25
- CPUAffinity,
26
- PreImport,
27
- RMMSetup,
28
26
  cuda_visible_devices,
29
27
  get_cpu_affinity,
30
28
  get_n_gpus,
@@ -17,7 +17,7 @@ from distributed.protocol import (
17
17
  serialize_bytelist,
18
18
  )
19
19
  from distributed.sizeof import safe_sizeof
20
- from distributed.spill import CustomFile as KeyAsStringFile
20
+ from distributed.spill import AnyKeyFile as KeyAsStringFile
21
21
  from distributed.utils import nbytes
22
22
 
23
23
  from .is_device_object import is_device_object
@@ -2,6 +2,7 @@ import copy
2
2
  import logging
3
3
  import os
4
4
  import warnings
5
+ from functools import partial
5
6
 
6
7
  import dask
7
8
  from distributed import LocalCluster, Nanny, Worker
@@ -9,11 +10,9 @@ from distributed.worker_memory import parse_memory_limit
9
10
 
10
11
  from .device_host_file import DeviceHostFile
11
12
  from .initialize import initialize
13
+ from .plugins import CPUAffinity, PreImport, RMMSetup
12
14
  from .proxify_host_file import ProxifyHostFile
13
15
  from .utils import (
14
- CPUAffinity,
15
- PreImport,
16
- RMMSetup,
17
16
  cuda_visible_devices,
18
17
  get_cpu_affinity,
19
18
  get_ucx_config,
@@ -334,12 +333,16 @@ class LocalCUDACluster(LocalCluster):
334
333
  )
335
334
 
336
335
  if worker_class is not None:
337
- from functools import partial
338
-
339
- worker_class = partial(
340
- LoggedNanny if log_spilling is True else Nanny,
341
- worker_class=worker_class,
342
- )
336
+ if log_spilling is True:
337
+ raise ValueError(
338
+ "Cannot enable `log_spilling` when `worker_class` is specified. If "
339
+ "logging is needed, ensure `worker_class` is a subclass of "
340
+ "`distributed.local_cuda_cluster.LoggedNanny` or a subclass of "
341
+ "`distributed.local_cuda_cluster.LoggedWorker`, and specify "
342
+ "`log_spilling=False`."
343
+ )
344
+ if not issubclass(worker_class, Nanny):
345
+ worker_class = partial(Nanny, worker_class=worker_class)
343
346
 
344
347
  self.pre_import = pre_import
345
348
 
dask_cuda/plugins.py ADDED
@@ -0,0 +1,122 @@
1
+ import importlib
2
+ import os
3
+
4
+ from distributed import WorkerPlugin
5
+
6
+ from .utils import get_rmm_log_file_name, parse_device_memory_limit
7
+
8
+
9
+ class CPUAffinity(WorkerPlugin):
10
+ def __init__(self, cores):
11
+ self.cores = cores
12
+
13
+ def setup(self, worker=None):
14
+ os.sched_setaffinity(0, self.cores)
15
+
16
+
17
+ class RMMSetup(WorkerPlugin):
18
+ def __init__(
19
+ self,
20
+ initial_pool_size,
21
+ maximum_pool_size,
22
+ managed_memory,
23
+ async_alloc,
24
+ release_threshold,
25
+ log_directory,
26
+ track_allocations,
27
+ ):
28
+ if initial_pool_size is None and maximum_pool_size is not None:
29
+ raise ValueError(
30
+ "`rmm_maximum_pool_size` was specified without specifying "
31
+ "`rmm_pool_size`.`rmm_pool_size` must be specified to use RMM pool."
32
+ )
33
+ if async_alloc is True:
34
+ if managed_memory is True:
35
+ raise ValueError(
36
+ "`rmm_managed_memory` is incompatible with the `rmm_async`."
37
+ )
38
+ if async_alloc is False and release_threshold is not None:
39
+ raise ValueError("`rmm_release_threshold` requires `rmm_async`.")
40
+
41
+ self.initial_pool_size = initial_pool_size
42
+ self.maximum_pool_size = maximum_pool_size
43
+ self.managed_memory = managed_memory
44
+ self.async_alloc = async_alloc
45
+ self.release_threshold = release_threshold
46
+ self.logging = log_directory is not None
47
+ self.log_directory = log_directory
48
+ self.rmm_track_allocations = track_allocations
49
+
50
+ def setup(self, worker=None):
51
+ if self.initial_pool_size is not None:
52
+ self.initial_pool_size = parse_device_memory_limit(
53
+ self.initial_pool_size, alignment_size=256
54
+ )
55
+
56
+ if self.async_alloc:
57
+ import rmm
58
+
59
+ if self.release_threshold is not None:
60
+ self.release_threshold = parse_device_memory_limit(
61
+ self.release_threshold, alignment_size=256
62
+ )
63
+
64
+ mr = rmm.mr.CudaAsyncMemoryResource(
65
+ initial_pool_size=self.initial_pool_size,
66
+ release_threshold=self.release_threshold,
67
+ )
68
+
69
+ if self.maximum_pool_size is not None:
70
+ self.maximum_pool_size = parse_device_memory_limit(
71
+ self.maximum_pool_size, alignment_size=256
72
+ )
73
+ mr = rmm.mr.LimitingResourceAdaptor(
74
+ mr, allocation_limit=self.maximum_pool_size
75
+ )
76
+
77
+ rmm.mr.set_current_device_resource(mr)
78
+ if self.logging:
79
+ rmm.enable_logging(
80
+ log_file_name=get_rmm_log_file_name(
81
+ worker, self.logging, self.log_directory
82
+ )
83
+ )
84
+ elif self.initial_pool_size is not None or self.managed_memory:
85
+ import rmm
86
+
87
+ pool_allocator = False if self.initial_pool_size is None else True
88
+
89
+ if self.initial_pool_size is not None:
90
+ if self.maximum_pool_size is not None:
91
+ self.maximum_pool_size = parse_device_memory_limit(
92
+ self.maximum_pool_size, alignment_size=256
93
+ )
94
+
95
+ rmm.reinitialize(
96
+ pool_allocator=pool_allocator,
97
+ managed_memory=self.managed_memory,
98
+ initial_pool_size=self.initial_pool_size,
99
+ maximum_pool_size=self.maximum_pool_size,
100
+ logging=self.logging,
101
+ log_file_name=get_rmm_log_file_name(
102
+ worker, self.logging, self.log_directory
103
+ ),
104
+ )
105
+ if self.rmm_track_allocations:
106
+ import rmm
107
+
108
+ mr = rmm.mr.get_current_device_resource()
109
+ rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))
110
+
111
+
112
+ class PreImport(WorkerPlugin):
113
+ def __init__(self, libraries):
114
+ if libraries is None:
115
+ libraries = []
116
+ elif isinstance(libraries, str):
117
+ libraries = libraries.split(",")
118
+ self.libraries = libraries
119
+
120
+ def setup(self, worker=None):
121
+ for l in self.libraries:
122
+ importlib.import_module(l)
@@ -40,7 +40,7 @@ def test_cuda_visible_devices_and_memory_limit_and_nthreads(loop): # noqa: F811
40
40
  str(nthreads),
41
41
  "--no-dashboard",
42
42
  "--worker-class",
43
- "dask_cuda.utils.MockWorker",
43
+ "dask_cuda.utils_test.MockWorker",
44
44
  ]
45
45
  ):
46
46
  with Client("127.0.0.1:9359", loop=loop) as client:
@@ -329,7 +329,7 @@ def test_cuda_mig_visible_devices_and_memory_limit_and_nthreads(loop): # noqa:
329
329
  str(nthreads),
330
330
  "--no-dashboard",
331
331
  "--worker-class",
332
- "dask_cuda.utils.MockWorker",
332
+ "dask_cuda.utils_test.MockWorker",
333
333
  ]
334
334
  ):
335
335
  with Client("127.0.0.1:9359", loop=loop) as client:
@@ -364,7 +364,7 @@ def test_cuda_visible_devices_uuid(loop): # noqa: F811
364
364
  "127.0.0.1",
365
365
  "--no-dashboard",
366
366
  "--worker-class",
367
- "dask_cuda.utils.MockWorker",
367
+ "dask_cuda.utils_test.MockWorker",
368
368
  ]
369
369
  ):
370
370
  with Client("127.0.0.1:9359", loop=loop) as client:
@@ -17,6 +17,7 @@ from distributed.deploy.local import LocalCluster
17
17
  import dask_cuda
18
18
  from dask_cuda.explicit_comms import comms
19
19
  from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
20
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
20
21
 
21
22
  mp = mp.get_context("spawn") # type: ignore
22
23
  ucp = pytest.importorskip("ucp")
@@ -35,6 +36,7 @@ def _test_local_cluster(protocol):
35
36
  dashboard_address=None,
36
37
  n_workers=4,
37
38
  threads_per_worker=1,
39
+ worker_class=IncreasedCloseTimeoutNanny,
38
40
  processes=True,
39
41
  ) as cluster:
40
42
  with Client(cluster) as client:
@@ -56,6 +58,7 @@ def _test_dataframe_merge_empty_partitions(nrows, npartitions):
56
58
  dashboard_address=None,
57
59
  n_workers=npartitions,
58
60
  threads_per_worker=1,
61
+ worker_class=IncreasedCloseTimeoutNanny,
59
62
  processes=True,
60
63
  ) as cluster:
61
64
  with Client(cluster):
@@ -102,6 +105,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
102
105
  dashboard_address=None,
103
106
  n_workers=n_workers,
104
107
  threads_per_worker=1,
108
+ worker_class=IncreasedCloseTimeoutNanny,
105
109
  processes=True,
106
110
  ) as cluster:
107
111
  with Client(cluster) as client:
@@ -204,6 +208,7 @@ def test_dask_use_explicit_comms(in_cluster):
204
208
  dashboard_address=None,
205
209
  n_workers=2,
206
210
  threads_per_worker=1,
211
+ worker_class=IncreasedCloseTimeoutNanny,
207
212
  processes=True,
208
213
  ) as cluster:
209
214
  with Client(cluster):
@@ -221,6 +226,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
221
226
  dashboard_address=None,
222
227
  n_workers=n_workers,
223
228
  threads_per_worker=1,
229
+ worker_class=IncreasedCloseTimeoutNanny,
224
230
  processes=True,
225
231
  ) as cluster:
226
232
  with Client(cluster):
@@ -327,6 +333,7 @@ def test_lock_workers():
327
333
  dashboard_address=None,
328
334
  n_workers=4,
329
335
  threads_per_worker=5,
336
+ worker_class=IncreasedCloseTimeoutNanny,
330
337
  processes=True,
331
338
  ) as cluster:
332
339
  ps = []
@@ -10,6 +10,7 @@ from distributed.deploy.local import LocalCluster
10
10
 
11
11
  from dask_cuda.initialize import initialize
12
12
  from dask_cuda.utils import get_ucx_config
13
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
13
14
 
14
15
  mp = mp.get_context("spawn") # type: ignore
15
16
  ucp = pytest.importorskip("ucp")
@@ -29,6 +30,7 @@ def _test_initialize_ucx_tcp():
29
30
  n_workers=1,
30
31
  threads_per_worker=1,
31
32
  processes=True,
33
+ worker_class=IncreasedCloseTimeoutNanny,
32
34
  config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
33
35
  ) as cluster:
34
36
  with Client(cluster) as client:
@@ -64,6 +66,7 @@ def _test_initialize_ucx_nvlink():
64
66
  n_workers=1,
65
67
  threads_per_worker=1,
66
68
  processes=True,
69
+ worker_class=IncreasedCloseTimeoutNanny,
67
70
  config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
68
71
  ) as cluster:
69
72
  with Client(cluster) as client:
@@ -100,6 +103,7 @@ def _test_initialize_ucx_infiniband():
100
103
  n_workers=1,
101
104
  threads_per_worker=1,
102
105
  processes=True,
106
+ worker_class=IncreasedCloseTimeoutNanny,
103
107
  config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
104
108
  ) as cluster:
105
109
  with Client(cluster) as client:
@@ -138,6 +142,7 @@ def _test_initialize_ucx_all():
138
142
  n_workers=1,
139
143
  threads_per_worker=1,
140
144
  processes=True,
145
+ worker_class=IncreasedCloseTimeoutNanny,
141
146
  config={"distributed.comm.ucx": get_ucx_config()},
142
147
  ) as cluster:
143
148
  with Client(cluster) as client:
@@ -13,13 +13,13 @@ from distributed.utils_test import gen_test, raises_with_cause
13
13
  from dask_cuda import CUDAWorker, LocalCUDACluster, utils
14
14
  from dask_cuda.initialize import initialize
15
15
  from dask_cuda.utils import (
16
- MockWorker,
17
16
  get_cluster_configuration,
18
17
  get_device_total_memory,
19
18
  get_gpu_count_mig,
20
19
  get_gpu_uuid_from_index,
21
20
  print_cluster_config,
22
21
  )
22
+ from dask_cuda.utils_test import MockWorker
23
23
 
24
24
 
25
25
  @gen_test(timeout=20)
@@ -337,6 +337,7 @@ async def test_pre_import():
337
337
 
338
338
 
339
339
  # Intentionally not using @gen_test to skip cleanup checks
340
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
340
341
  def test_pre_import_not_found():
341
342
  async def _test_pre_import_not_found():
342
343
  with raises_with_cause(RuntimeError, None, ImportError, None):
@@ -491,6 +492,7 @@ def test_print_cluster_config(capsys):
491
492
  assert "[plugin]" in captured.out
492
493
 
493
494
 
495
+ @pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/1265")
494
496
  def test_death_timeout_raises():
495
497
  with pytest.raises(asyncio.exceptions.TimeoutError):
496
498
  with LocalCUDACluster(
@@ -19,6 +19,7 @@ from dask_cuda.get_device_memory_objects import get_device_memory_ids
19
19
  from dask_cuda.proxify_host_file import ProxifyHostFile
20
20
  from dask_cuda.proxy_object import ProxyObject, asproxy, unproxy
21
21
  from dask_cuda.utils import get_device_total_memory
22
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
22
23
 
23
24
  cupy = pytest.importorskip("cupy")
24
25
  cupy.cuda.set_allocator(None)
@@ -393,7 +394,10 @@ async def test_compatibility_mode_dataframe_shuffle(compatibility_mode, npartiti
393
394
 
394
395
  with dask.config.set(jit_unspill_compatibility_mode=compatibility_mode):
395
396
  async with dask_cuda.LocalCUDACluster(
396
- n_workers=1, jit_unspill=True, asynchronous=True
397
+ n_workers=1,
398
+ jit_unspill=True,
399
+ worker_class=IncreasedCloseTimeoutNanny,
400
+ asynchronous=True,
397
401
  ) as cluster:
398
402
  async with Client(cluster, asynchronous=True) as client:
399
403
  ddf = dask.dataframe.from_pandas(
@@ -23,6 +23,7 @@ from dask_cuda import LocalCUDACluster, proxy_object
23
23
  from dask_cuda.disk_io import SpillToDiskFile
24
24
  from dask_cuda.proxify_device_objects import proxify_device_objects
25
25
  from dask_cuda.proxify_host_file import ProxifyHostFile
26
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
26
27
 
27
28
  # Make the "disk" serializer available and use a directory that are
28
29
  # remove on exit.
@@ -422,6 +423,7 @@ async def test_communicating_proxy_objects(protocol, send_serializers):
422
423
  async with dask_cuda.LocalCUDACluster(
423
424
  n_workers=1,
424
425
  protocol=protocol,
426
+ worker_class=IncreasedCloseTimeoutNanny,
425
427
  asynchronous=True,
426
428
  ) as cluster:
427
429
  async with Client(cluster, asynchronous=True) as client:
@@ -1,3 +1,4 @@
1
+ import gc
1
2
  import os
2
3
  from time import sleep
3
4
 
@@ -11,6 +12,7 @@ from distributed.sizeof import sizeof
11
12
  from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401
12
13
 
13
14
  from dask_cuda import LocalCUDACluster, utils
15
+ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
14
16
 
15
17
  if utils.get_device_total_memory() < 1e10:
16
18
  pytest.skip("Not enough GPU memory", allow_module_level=True)
@@ -58,7 +60,10 @@ def assert_device_host_file_size(
58
60
 
59
61
 
60
62
  def worker_assert(
61
- dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
63
+ total_size,
64
+ device_chunk_overhead,
65
+ serialized_chunk_overhead,
66
+ dask_worker=None,
62
67
  ):
63
68
  assert_device_host_file_size(
64
69
  dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
@@ -66,7 +71,10 @@ def worker_assert(
66
71
 
67
72
 
68
73
  def delayed_worker_assert(
69
- dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
74
+ total_size,
75
+ device_chunk_overhead,
76
+ serialized_chunk_overhead,
77
+ dask_worker=None,
70
78
  ):
71
79
  start = time()
72
80
  while not device_host_file_size_matches(
@@ -82,6 +90,18 @@ def delayed_worker_assert(
82
90
  )
83
91
 
84
92
 
93
+ def assert_host_chunks(spills_to_disk, dask_worker=None):
94
+ if spills_to_disk is False:
95
+ assert len(dask_worker.data.host)
96
+
97
+
98
+ def assert_disk_chunks(spills_to_disk, dask_worker=None):
99
+ if spills_to_disk is True:
100
+ assert len(dask_worker.data.disk or list()) > 0
101
+ else:
102
+ assert len(dask_worker.data.disk or list()) == 0
103
+
104
+
85
105
  @pytest.mark.parametrize(
86
106
  "params",
87
107
  [
@@ -122,7 +142,7 @@ def delayed_worker_assert(
122
142
  },
123
143
  ],
124
144
  )
125
- @gen_test(timeout=120)
145
+ @gen_test(timeout=30)
126
146
  async def test_cupy_cluster_device_spill(params):
127
147
  cupy = pytest.importorskip("cupy")
128
148
  with dask.config.set(
@@ -141,9 +161,12 @@ async def test_cupy_cluster_device_spill(params):
141
161
  asynchronous=True,
142
162
  device_memory_limit=params["device_memory_limit"],
143
163
  memory_limit=params["memory_limit"],
164
+ worker_class=IncreasedCloseTimeoutNanny,
144
165
  ) as cluster:
145
166
  async with Client(cluster, asynchronous=True) as client:
146
167
 
168
+ await client.wait_for_workers(1)
169
+
147
170
  rs = da.random.RandomState(RandomState=cupy.random.RandomState)
148
171
  x = rs.random(int(50e6), chunks=2e6)
149
172
  await wait(x)
@@ -153,7 +176,10 @@ async def test_cupy_cluster_device_spill(params):
153
176
 
154
177
  # Allow up to 1024 bytes overhead per chunk serialized
155
178
  await client.run(
156
- lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024)
179
+ worker_assert,
180
+ x.nbytes,
181
+ 1024,
182
+ 1024,
157
183
  )
158
184
 
159
185
  y = client.compute(x.sum())
@@ -162,20 +188,19 @@ async def test_cupy_cluster_device_spill(params):
162
188
  assert (abs(res / x.size) - 0.5) < 1e-3
163
189
 
164
190
  await client.run(
165
- lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024)
191
+ worker_assert,
192
+ x.nbytes,
193
+ 1024,
194
+ 1024,
166
195
  )
167
- host_chunks = await client.run(
168
- lambda dask_worker: len(dask_worker.data.host)
196
+ await client.run(
197
+ assert_host_chunks,
198
+ params["spills_to_disk"],
169
199
  )
170
- disk_chunks = await client.run(
171
- lambda dask_worker: len(dask_worker.data.disk or list())
200
+ await client.run(
201
+ assert_disk_chunks,
202
+ params["spills_to_disk"],
172
203
  )
173
- for hc, dc in zip(host_chunks.values(), disk_chunks.values()):
174
- if params["spills_to_disk"]:
175
- assert dc > 0
176
- else:
177
- assert hc > 0
178
- assert dc == 0
179
204
 
180
205
 
181
206
  @pytest.mark.parametrize(
@@ -218,7 +243,7 @@ async def test_cupy_cluster_device_spill(params):
218
243
  },
219
244
  ],
220
245
  )
221
- @gen_test(timeout=120)
246
+ @gen_test(timeout=30)
222
247
  async def test_cudf_cluster_device_spill(params):
223
248
  cudf = pytest.importorskip("cudf")
224
249
 
@@ -240,9 +265,12 @@ async def test_cudf_cluster_device_spill(params):
240
265
  asynchronous=True,
241
266
  device_memory_limit=params["device_memory_limit"],
242
267
  memory_limit=params["memory_limit"],
268
+ worker_class=IncreasedCloseTimeoutNanny,
243
269
  ) as cluster:
244
270
  async with Client(cluster, asynchronous=True) as client:
245
271
 
272
+ await client.wait_for_workers(1)
273
+
246
274
  # There's a known issue with datetime64:
247
275
  # https://github.com/numpy/numpy/issues/4983#issuecomment-441332940
248
276
  # The same error above happens when spilling datetime64 to disk
@@ -264,26 +292,35 @@ async def test_cudf_cluster_device_spill(params):
264
292
  await wait(cdf2)
265
293
 
266
294
  del cdf
295
+ gc.collect()
267
296
 
268
- host_chunks = await client.run(
269
- lambda dask_worker: len(dask_worker.data.host)
297
+ await client.run(
298
+ assert_host_chunks,
299
+ params["spills_to_disk"],
270
300
  )
271
- disk_chunks = await client.run(
272
- lambda dask_worker: len(dask_worker.data.disk or list())
301
+ await client.run(
302
+ assert_disk_chunks,
303
+ params["spills_to_disk"],
273
304
  )
274
- for hc, dc in zip(host_chunks.values(), disk_chunks.values()):
275
- if params["spills_to_disk"]:
276
- assert dc > 0
277
- else:
278
- assert hc > 0
279
- assert dc == 0
280
305
 
281
306
  await client.run(
282
- lambda dask_worker: worker_assert(dask_worker, nbytes, 32, 2048)
307
+ worker_assert,
308
+ nbytes,
309
+ 32,
310
+ 2048,
283
311
  )
284
312
 
285
313
  del cdf2
286
314
 
287
- await client.run(
288
- lambda dask_worker: delayed_worker_assert(dask_worker, 0, 0, 0)
289
- )
315
+ while True:
316
+ try:
317
+ await client.run(
318
+ delayed_worker_assert,
319
+ 0,
320
+ 0,
321
+ 0,
322
+ )
323
+ except AssertionError:
324
+ gc.collect()
325
+ else:
326
+ break
dask_cuda/utils.py CHANGED
@@ -1,4 +1,3 @@
1
- import importlib
2
1
  import math
3
2
  import operator
4
3
  import os
@@ -18,7 +17,7 @@ import dask
18
17
  import distributed # noqa: required for dask.config.get("distributed.comm.ucx")
19
18
  from dask.config import canonical_name
20
19
  from dask.utils import format_bytes, parse_bytes
21
- from distributed import Worker, wait
20
+ from distributed import wait
22
21
  from distributed.comm import parse_address
23
22
 
24
23
  try:
@@ -32,122 +31,6 @@ except ImportError:
32
31
  yield
33
32
 
34
33
 
35
- class CPUAffinity:
36
- def __init__(self, cores):
37
- self.cores = cores
38
-
39
- def setup(self, worker=None):
40
- os.sched_setaffinity(0, self.cores)
41
-
42
-
43
- class RMMSetup:
44
- def __init__(
45
- self,
46
- initial_pool_size,
47
- maximum_pool_size,
48
- managed_memory,
49
- async_alloc,
50
- release_threshold,
51
- log_directory,
52
- track_allocations,
53
- ):
54
- if initial_pool_size is None and maximum_pool_size is not None:
55
- raise ValueError(
56
- "`rmm_maximum_pool_size` was specified without specifying "
57
- "`rmm_pool_size`.`rmm_pool_size` must be specified to use RMM pool."
58
- )
59
- if async_alloc is True:
60
- if managed_memory is True:
61
- raise ValueError(
62
- "`rmm_managed_memory` is incompatible with the `rmm_async`."
63
- )
64
- if async_alloc is False and release_threshold is not None:
65
- raise ValueError("`rmm_release_threshold` requires `rmm_async`.")
66
-
67
- self.initial_pool_size = initial_pool_size
68
- self.maximum_pool_size = maximum_pool_size
69
- self.managed_memory = managed_memory
70
- self.async_alloc = async_alloc
71
- self.release_threshold = release_threshold
72
- self.logging = log_directory is not None
73
- self.log_directory = log_directory
74
- self.rmm_track_allocations = track_allocations
75
-
76
- def setup(self, worker=None):
77
- if self.initial_pool_size is not None:
78
- self.initial_pool_size = parse_device_memory_limit(
79
- self.initial_pool_size, alignment_size=256
80
- )
81
-
82
- if self.async_alloc:
83
- import rmm
84
-
85
- if self.release_threshold is not None:
86
- self.release_threshold = parse_device_memory_limit(
87
- self.release_threshold, alignment_size=256
88
- )
89
-
90
- mr = rmm.mr.CudaAsyncMemoryResource(
91
- initial_pool_size=self.initial_pool_size,
92
- release_threshold=self.release_threshold,
93
- )
94
-
95
- if self.maximum_pool_size is not None:
96
- self.maximum_pool_size = parse_device_memory_limit(
97
- self.maximum_pool_size, alignment_size=256
98
- )
99
- mr = rmm.mr.LimitingResourceAdaptor(
100
- mr, allocation_limit=self.maximum_pool_size
101
- )
102
-
103
- rmm.mr.set_current_device_resource(mr)
104
- if self.logging:
105
- rmm.enable_logging(
106
- log_file_name=get_rmm_log_file_name(
107
- worker, self.logging, self.log_directory
108
- )
109
- )
110
- elif self.initial_pool_size is not None or self.managed_memory:
111
- import rmm
112
-
113
- pool_allocator = False if self.initial_pool_size is None else True
114
-
115
- if self.initial_pool_size is not None:
116
- if self.maximum_pool_size is not None:
117
- self.maximum_pool_size = parse_device_memory_limit(
118
- self.maximum_pool_size, alignment_size=256
119
- )
120
-
121
- rmm.reinitialize(
122
- pool_allocator=pool_allocator,
123
- managed_memory=self.managed_memory,
124
- initial_pool_size=self.initial_pool_size,
125
- maximum_pool_size=self.maximum_pool_size,
126
- logging=self.logging,
127
- log_file_name=get_rmm_log_file_name(
128
- worker, self.logging, self.log_directory
129
- ),
130
- )
131
- if self.rmm_track_allocations:
132
- import rmm
133
-
134
- mr = rmm.mr.get_current_device_resource()
135
- rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))
136
-
137
-
138
- class PreImport:
139
- def __init__(self, libraries):
140
- if libraries is None:
141
- libraries = []
142
- elif isinstance(libraries, str):
143
- libraries = libraries.split(",")
144
- self.libraries = libraries
145
-
146
- def setup(self, worker=None):
147
- for l in self.libraries:
148
- importlib.import_module(l)
149
-
150
-
151
34
  def unpack_bitmask(x, mask_bits=64):
152
35
  """Unpack a list of integers containing bitmasks.
153
36
 
@@ -669,27 +552,6 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
669
552
  return _align(int(device_memory_limit), alignment_size)
670
553
 
671
554
 
672
- class MockWorker(Worker):
673
- """Mock Worker class preventing NVML from getting used by SystemMonitor.
674
-
675
- By preventing the Worker from initializing NVML in the SystemMonitor, we can
676
- mock test multiple devices in `CUDA_VISIBLE_DEVICES` behavior with single-GPU
677
- machines.
678
- """
679
-
680
- def __init__(self, *args, **kwargs):
681
- distributed.diagnostics.nvml.device_get_count = MockWorker.device_get_count
682
- self._device_get_count = distributed.diagnostics.nvml.device_get_count
683
- super().__init__(*args, **kwargs)
684
-
685
- def __del__(self):
686
- distributed.diagnostics.nvml.device_get_count = self._device_get_count
687
-
688
- @staticmethod
689
- def device_get_count():
690
- return 0
691
-
692
-
693
555
  def get_gpu_uuid_from_index(device_index=0):
694
556
  """Get GPU UUID from CUDA device index.
695
557
 
@@ -0,0 +1,45 @@
1
+ from typing import Literal
2
+
3
+ import distributed
4
+ from distributed import Nanny, Worker
5
+
6
+
7
+ class MockWorker(Worker):
8
+ """Mock Worker class preventing NVML from getting used by SystemMonitor.
9
+
10
+ By preventing the Worker from initializing NVML in the SystemMonitor, we can
11
+ mock test multiple devices in `CUDA_VISIBLE_DEVICES` behavior with single-GPU
12
+ machines.
13
+ """
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ distributed.diagnostics.nvml.device_get_count = MockWorker.device_get_count
17
+ self._device_get_count = distributed.diagnostics.nvml.device_get_count
18
+ super().__init__(*args, **kwargs)
19
+
20
+ def __del__(self):
21
+ distributed.diagnostics.nvml.device_get_count = self._device_get_count
22
+
23
+ @staticmethod
24
+ def device_get_count():
25
+ return 0
26
+
27
+
28
+ class IncreasedCloseTimeoutNanny(Nanny):
29
+ """Increase `Nanny`'s close timeout.
30
+
31
+ The internal close timeout mechanism of `Nanny` recomputes the time left to kill
32
+ the `Worker` process based on elapsed time of the close task, which may leave
33
+ very little time for the subprocess to shutdown cleanly, which may cause tests
34
+ to fail when the system is under higher load. This class increases the default
35
+ close timeout of 5.0 seconds that `Nanny` sets by default, which can be overriden
36
+ via Distributed's public API.
37
+
38
+ This class can be used with the `worker_class` argument of `LocalCluster` or
39
+ `LocalCUDACluster` to provide a much higher default of 30.0 seconds.
40
+ """
41
+
42
+ async def close( # type:ignore[override]
43
+ self, timeout: float = 30.0, reason: str = "nanny-close"
44
+ ) -> Literal["OK"]:
45
+ return await super().close(timeout=timeout, reason=reason)
dask_cuda/worker_spec.py CHANGED
@@ -5,7 +5,8 @@ from distributed.system import MEMORY_LIMIT
5
5
 
6
6
  from .initialize import initialize
7
7
  from .local_cuda_cluster import cuda_visible_devices
8
- from .utils import CPUAffinity, get_cpu_affinity, get_gpu_count
8
+ from .plugins import CPUAffinity
9
+ from .utils import get_cpu_affinity, get_gpu_count
9
10
 
10
11
 
11
12
  def worker_spec(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dask-cuda
3
- Version: 23.10.0a231015
3
+ Version: 23.12.0a24
4
4
  Summary: Utilities for Dask and CUDA interactions
5
5
  Author: NVIDIA Corporation
6
6
  License: Apache-2.0
@@ -17,8 +17,8 @@ Classifier: Programming Language :: Python :: 3.10
17
17
  Requires-Python: >=3.9
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
- Requires-Dist: dask ==2023.9.2
21
- Requires-Dist: distributed ==2023.9.2
20
+ Requires-Dist: dask >=2023.9.2
21
+ Requires-Dist: distributed >=2023.9.2
22
22
  Requires-Dist: pynvml <11.5,>=11.0.0
23
23
  Requires-Dist: numpy >=1.21
24
24
  Requires-Dist: numba >=0.57
@@ -1,19 +1,22 @@
1
- dask_cuda/__init__.py,sha256=2oMXKPmTjhzvy2sCRD8O88sV0cqVzAnlxLn9-3j3_os,1452
1
+ dask_cuda/VERSION,sha256=B3lEoXnOJhj1jdnSXB9iPziZLRAQGvXofHbxd6cA664,12
2
+ dask_cuda/__init__.py,sha256=XnMTUi-SvoGn7g1Dj6XW97HnQzGQv0G3EnvSjcZ7vU4,1455
3
+ dask_cuda/_version.py,sha256=FgBzL-H3uFWUDb0QvqJw3AytPr1PG8LbMnHxQEX8Vx4,738
2
4
  dask_cuda/cli.py,sha256=XNRH0bu-6jzRoyWJB5qSWuzePJSh3z_5Ng6rDCnz7lg,15970
3
- dask_cuda/compat.py,sha256=BLXv9IHUtD3h6-T_8MX-uGt-UDMG6EuGuyN-zw3XndU,4084
4
- dask_cuda/cuda_worker.py,sha256=hUJ3dCdeF1GxL0Oio-d-clQ5tLxQ9xjwU6Bse5JW54g,8571
5
- dask_cuda/device_host_file.py,sha256=D0rHOFz1TRfvaecoP30x3JRWe1TiHUaq45Dg-v0DfoY,10272
5
+ dask_cuda/cuda_worker.py,sha256=bIu-ESeIpJG_WaTYrv0z9z5juJ1qR5i_5Ng3CN1WK8s,8579
6
+ dask_cuda/device_host_file.py,sha256=yS31LGtt9VFAG78uBBlTDr7HGIng2XymV1OxXIuEMtM,10272
6
7
  dask_cuda/disk_io.py,sha256=urSLKiPvJvYmKCzDPOUDCYuLI3r1RUiyVh3UZGRoF_Y,6626
7
8
  dask_cuda/get_device_memory_objects.py,sha256=zMSqWzm5rflRInbNMz7U2Ewv5nMcE-H8stMJeWHVWyc,3890
8
9
  dask_cuda/initialize.py,sha256=mzPgKhs8oLgUWpqd4ckvLNKvhLoHjt96RrBPeVneenI,5231
9
10
  dask_cuda/is_device_object.py,sha256=CnajvbQiX0FzFzwft0MqK1OPomx3ZGDnDxT56wNjixw,1046
10
11
  dask_cuda/is_spillable_object.py,sha256=CddGmg0tuSpXh2m_TJSY6GRpnl1WRHt1CRcdWgHPzWA,1457
11
- dask_cuda/local_cuda_cluster.py,sha256=hjjgqFkGyuEqYMIYbxBV4xW2b7M6UPw9TnYM1Tf5r_4,17377
12
+ dask_cuda/local_cuda_cluster.py,sha256=w2HXMZtEukwklkB3J6l6DqZstNA5uvGEdFkdzpyUJ6k,17810
13
+ dask_cuda/plugins.py,sha256=cnHsdrXx7PBPmrzHX6YEkCH5byCsUk8LE2FeTeu8ZLU,4259
12
14
  dask_cuda/proxify_device_objects.py,sha256=99CD7LOE79YiQGJ12sYl_XImVhJXpFR4vG5utdkjTQo,8108
13
15
  dask_cuda/proxify_host_file.py,sha256=Wf5CFCC1JN5zmfvND3ls0M5FL01Y8VhHrk0xV3UQ9kk,30850
14
16
  dask_cuda/proxy_object.py,sha256=bZq92kjgFB-ad_luSAFT_RItV3nssmiEk4OOSp34laU,29812
15
- dask_cuda/utils.py,sha256=IGlr6SZAhULIo4WJrhJxyHAy7l6mp9vN_U1QZjqUYJY,29815
16
- dask_cuda/worker_spec.py,sha256=EQffH_fuBBaghmO8o9kxJ7EAQiB4gaW-uPRYesPknSs,4356
17
+ dask_cuda/utils.py,sha256=wNRItbIXrOpH77AUUrZNrGqgIiGNzpClXYl0QmQqfxs,25002
18
+ dask_cuda/utils_test.py,sha256=WNMR0gic2tuP3pgygcR9g52NfyX8iGMOan6juXhpkCE,1694
19
+ dask_cuda/worker_spec.py,sha256=7-Uq_e5q2SkTlsmctMcYLCa9_3RiiVHZLIN7ctfaFmE,4376
17
20
  dask_cuda/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
21
  dask_cuda/benchmarks/common.py,sha256=sEIFnRZS6wbyKCQyB4fDclYLc2YqC0PolurR5qzuRxw,6393
19
22
  dask_cuda/benchmarks/local_cudf_groupby.py,sha256=2iHk-a-GvLmAgajwQJNrqmZ-WJeiyMFEyflcxh7SPO8,8894
@@ -27,24 +30,24 @@ dask_cuda/explicit_comms/comms.py,sha256=Su6PuNo68IyS-AwoqU4S9TmqWsLvUdNa0jot2hx
27
30
  dask_cuda/explicit_comms/dataframe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
31
  dask_cuda/explicit_comms/dataframe/shuffle.py,sha256=2f2wlPyqXpryIHgMpsZzs3pDE7eyslYam-jQh3ujszQ,20124
29
32
  dask_cuda/tests/test_cudf_builtin_spilling.py,sha256=u3kW91YRLdHFycvpGfSQKrEucu5khMJ1k4sjmddO490,4910
30
- dask_cuda/tests/test_dask_cuda_worker.py,sha256=VgybyylO7eaSk9yVBj1snp3vM7ZTG-VPEcE8agTmaWI,17714
33
+ dask_cuda/tests/test_dask_cuda_worker.py,sha256=gViHaMCSfB6ip125OEi9D0nfKC-qBXRoHz6BRodEdb4,17729
31
34
  dask_cuda/tests/test_device_host_file.py,sha256=79ssUISo1YhsW_7HdwqPfsH2LRzS2bi5BjPym1Sdgqw,5882
32
35
  dask_cuda/tests/test_dgx.py,sha256=bKX-GvkYjWlmcEIK15aGErxmc0qPqIWOG1CeDFGoXFU,6381
33
- dask_cuda/tests/test_explicit_comms.py,sha256=WJDQQkqaYT9tRiz0zgC9_udzRG3DuhPlbH7X-wshC7w,11925
36
+ dask_cuda/tests/test_explicit_comms.py,sha256=3Q3o9BX4ksCgz11o38o5QhKg3Rv-EtTsGnVG83wwyyo,12283
34
37
  dask_cuda/tests/test_from_array.py,sha256=i2Vha4mchB0BopTlEdXV7CxY7qyTzFYdgYQTmukZX38,493
35
38
  dask_cuda/tests/test_gds.py,sha256=6jf0HPTHAIG8Mp_FC4Ai4zpn-U1K7yk0fSXg8He8-r8,1513
36
- dask_cuda/tests/test_initialize.py,sha256=EV3FTqBRX_kxHJ0ZEij34JpLyOJvGIYB_hQc-0afoG8,5235
37
- dask_cuda/tests/test_local_cuda_cluster.py,sha256=5-55CSMDJqBXqQzFQibmbWwvVOFC5iq7F1KtvtUx0kE,17417
38
- dask_cuda/tests/test_proxify_host_file.py,sha256=vnmUuU9w9hO4Et-qwnvY5VMkoohRt62cKhyP-wi7zKM,18492
39
- dask_cuda/tests/test_proxy.py,sha256=eJuXU0KRQC36R8g0WN9gyIeZ3tbKFlqMxybEzmaT1LA,23371
40
- dask_cuda/tests/test_spill.py,sha256=RfgIDWUkTbe7XqdDVJNnRuB_2U-IUvV_rwtZhY8OofE,9741
39
+ dask_cuda/tests/test_initialize.py,sha256=-Vo8SVBrVEKB0V1C6ia8khvbHJt4BC0xEjMNLhNbFxI,5491
40
+ dask_cuda/tests/test_local_cuda_cluster.py,sha256=1zlbRLn8ukopl5u8wBEfyyEhWpUblHYnwcPiPJO5bAU,17603
41
+ dask_cuda/tests/test_proxify_host_file.py,sha256=cp-U1uNPhesQaHbftKV8ir_dt5fbs0ZXSIsL39oI0fE,18630
42
+ dask_cuda/tests/test_proxy.py,sha256=Nu9vLx-dALINcF_wsxuFYUryRE0Jq43w7bAYAchK8RY,23480
43
+ dask_cuda/tests/test_spill.py,sha256=xN9PbVERBYMuZxvscSO0mAM22loq9WT3ltZVBFxlmM4,10239
41
44
  dask_cuda/tests/test_utils.py,sha256=wgYPvu7Sk61C64pah9ZbK8cnBXK5RyUCpu3G2ny6OZQ,8832
42
45
  dask_cuda/tests/test_worker_spec.py,sha256=Bvu85vkqm6ZDAYPXKMJlI2pm9Uc5tiYKNtO4goXSw-I,2399
43
46
  examples/ucx/client_initialize.py,sha256=YN3AXHF8btcMd6NicKKhKR9SXouAsK1foJhFspbOn70,1262
44
47
  examples/ucx/local_cuda_cluster.py,sha256=7xVY3EhwhkY2L4VZin_BiMCbrjhirDNChoC86KiETNc,1983
45
- dask_cuda-23.10.0a231015.dist-info/LICENSE,sha256=MjI3I-EgxfEvZlgjk82rgiFsZqSDXHFETd2QJ89UwDA,11348
46
- dask_cuda-23.10.0a231015.dist-info/METADATA,sha256=CxnTtTdisHQZtHiF3hS16hC-X_fhRJ21oIVEga974JM,2285
47
- dask_cuda-23.10.0a231015.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
48
- dask_cuda-23.10.0a231015.dist-info/entry_points.txt,sha256=UcRaKVEpywtxc6pF1VnfMB0UK4sJg7a8_NdZF67laPM,136
49
- dask_cuda-23.10.0a231015.dist-info/top_level.txt,sha256=3kKxJxeM108fuYc_lwwlklP7YBU9IEmdmRAouzi397o,33
50
- dask_cuda-23.10.0a231015.dist-info/RECORD,,
48
+ dask_cuda-23.12.0a24.dist-info/LICENSE,sha256=MjI3I-EgxfEvZlgjk82rgiFsZqSDXHFETd2QJ89UwDA,11348
49
+ dask_cuda-23.12.0a24.dist-info/METADATA,sha256=1MrxlpZ1ah-mzK-LqOVqkoJD8M5pRC_WH-j7nwdOya8,2281
50
+ dask_cuda-23.12.0a24.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
51
+ dask_cuda-23.12.0a24.dist-info/entry_points.txt,sha256=UcRaKVEpywtxc6pF1VnfMB0UK4sJg7a8_NdZF67laPM,136
52
+ dask_cuda-23.12.0a24.dist-info/top_level.txt,sha256=3kKxJxeM108fuYc_lwwlklP7YBU9IEmdmRAouzi397o,33
53
+ dask_cuda-23.12.0a24.dist-info/RECORD,,
dask_cuda/compat.py DELETED
@@ -1,118 +0,0 @@
1
- import pickle
2
-
3
- import msgpack
4
- from packaging.version import Version
5
-
6
- import dask
7
- import distributed
8
- import distributed.comm.utils
9
- import distributed.protocol
10
- from distributed.comm.utils import OFFLOAD_THRESHOLD, nbytes, offload
11
- from distributed.protocol.core import (
12
- Serialized,
13
- decompress,
14
- logger,
15
- merge_and_deserialize,
16
- msgpack_decode_default,
17
- msgpack_opts,
18
- )
19
-
20
- if Version(distributed.__version__) >= Version("2023.8.1"):
21
- # Monkey-patch protocol.core.loads (and its users)
22
- async def from_frames(
23
- frames, deserialize=True, deserializers=None, allow_offload=True
24
- ):
25
- """
26
- Unserialize a list of Distributed protocol frames.
27
- """
28
- size = False
29
-
30
- def _from_frames():
31
- try:
32
- # Patched code
33
- return loads(
34
- frames, deserialize=deserialize, deserializers=deserializers
35
- )
36
- # end patched code
37
- except EOFError:
38
- if size > 1000:
39
- datastr = "[too large to display]"
40
- else:
41
- datastr = frames
42
- # Aid diagnosing
43
- logger.error("truncated data stream (%d bytes): %s", size, datastr)
44
- raise
45
-
46
- if allow_offload and deserialize and OFFLOAD_THRESHOLD:
47
- size = sum(map(nbytes, frames))
48
- if (
49
- allow_offload
50
- and deserialize
51
- and OFFLOAD_THRESHOLD
52
- and size > OFFLOAD_THRESHOLD
53
- ):
54
- res = await offload(_from_frames)
55
- else:
56
- res = _from_frames()
57
-
58
- return res
59
-
60
- def loads(frames, deserialize=True, deserializers=None):
61
- """Transform bytestream back into Python value"""
62
-
63
- allow_pickle = dask.config.get("distributed.scheduler.pickle")
64
-
65
- try:
66
-
67
- def _decode_default(obj):
68
- offset = obj.get("__Serialized__", 0)
69
- if offset > 0:
70
- sub_header = msgpack.loads(
71
- frames[offset],
72
- object_hook=msgpack_decode_default,
73
- use_list=False,
74
- **msgpack_opts,
75
- )
76
- offset += 1
77
- sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
78
- if deserialize:
79
- if "compression" in sub_header:
80
- sub_frames = decompress(sub_header, sub_frames)
81
- return merge_and_deserialize(
82
- sub_header, sub_frames, deserializers=deserializers
83
- )
84
- else:
85
- return Serialized(sub_header, sub_frames)
86
-
87
- offset = obj.get("__Pickled__", 0)
88
- if offset > 0:
89
- sub_header = msgpack.loads(frames[offset])
90
- offset += 1
91
- sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
92
- # Patched code
93
- if "compression" in sub_header:
94
- sub_frames = decompress(sub_header, sub_frames)
95
- # end patched code
96
- if allow_pickle:
97
- return pickle.loads(
98
- sub_header["pickled-obj"], buffers=sub_frames
99
- )
100
- else:
101
- raise ValueError(
102
- "Unpickle on the Scheduler isn't allowed, "
103
- "set `distributed.scheduler.pickle=true`"
104
- )
105
-
106
- return msgpack_decode_default(obj)
107
-
108
- return msgpack.loads(
109
- frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
110
- )
111
-
112
- except Exception:
113
- logger.critical("Failed to deserialize", exc_info=True)
114
- raise
115
-
116
- distributed.protocol.loads = loads
117
- distributed.protocol.core.loads = loads
118
- distributed.comm.utils.from_frames = from_frames