dask-cuda 25.2.0__py3-none-any.whl → 25.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_cuda/GIT_COMMIT +1 -0
- dask_cuda/VERSION +1 -1
- dask_cuda/__init__.py +38 -27
- dask_cuda/_compat.py +18 -0
- dask_cuda/explicit_comms/comms.py +34 -7
- dask_cuda/explicit_comms/dataframe/shuffle.py +127 -35
- dask_cuda/get_device_memory_objects.py +15 -0
- dask_cuda/is_device_object.py +4 -1
- dask_cuda/is_spillable_object.py +4 -1
- dask_cuda/proxify_device_objects.py +4 -1
- dask_cuda/proxy_object.py +55 -35
- dask_cuda/tests/test_dask_cuda_worker.py +5 -2
- dask_cuda/tests/test_explicit_comms.py +136 -6
- dask_cuda/tests/test_initialize.py +36 -0
- dask_cuda/tests/test_local_cuda_cluster.py +5 -2
- dask_cuda/tests/test_proxify_host_file.py +15 -2
- dask_cuda/tests/test_spill.py +100 -27
- dask_cuda/utils.py +61 -33
- {dask_cuda-25.2.0.dist-info → dask_cuda-25.6.0.dist-info}/METADATA +7 -5
- {dask_cuda-25.2.0.dist-info → dask_cuda-25.6.0.dist-info}/RECORD +24 -22
- {dask_cuda-25.2.0.dist-info → dask_cuda-25.6.0.dist-info}/WHEEL +1 -1
- {dask_cuda-25.2.0.dist-info → dask_cuda-25.6.0.dist-info}/top_level.txt +0 -1
- {dask_cuda-25.2.0.dist-info → dask_cuda-25.6.0.dist-info}/entry_points.txt +0 -0
- {dask_cuda-25.2.0.dist-info → dask_cuda-25.6.0.dist-info/licenses}/LICENSE +0 -0
dask_cuda/GIT_COMMIT
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
1f834655ecc6286b9e3082f037594f70dcb74062
|
dask_cuda/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
25.
|
|
1
|
+
25.06.00
|
dask_cuda/__init__.py
CHANGED
|
@@ -5,8 +5,6 @@ if sys.platform != "linux":
|
|
|
5
5
|
|
|
6
6
|
import dask
|
|
7
7
|
import dask.utils
|
|
8
|
-
import dask.dataframe.shuffle
|
|
9
|
-
from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
|
|
10
8
|
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
|
|
11
9
|
from distributed.protocol.serialize import dask_deserialize, dask_serialize
|
|
12
10
|
|
|
@@ -14,30 +12,43 @@ from ._version import __git_commit__, __version__
|
|
|
14
12
|
from .cuda_worker import CUDAWorker
|
|
15
13
|
|
|
16
14
|
from .local_cuda_cluster import LocalCUDACluster
|
|
17
|
-
from .proxify_device_objects import proxify_decorator, unproxify_decorator
|
|
18
15
|
|
|
19
16
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
#
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
if
|
|
43
|
-
|
|
17
|
+
try:
|
|
18
|
+
import dask.dataframe as dask_dataframe
|
|
19
|
+
except ImportError:
|
|
20
|
+
# Dask DataFrame (optional) isn't installed
|
|
21
|
+
dask_dataframe = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
if dask_dataframe is not None:
|
|
25
|
+
from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
|
|
26
|
+
from .proxify_device_objects import proxify_decorator, unproxify_decorator
|
|
27
|
+
|
|
28
|
+
# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
|
|
29
|
+
patch_shuffle_expression()
|
|
30
|
+
# Monkey patching Dask to make use of proxify and unproxify in compatibility mode
|
|
31
|
+
dask_dataframe.shuffle.shuffle_group = proxify_decorator(
|
|
32
|
+
dask.dataframe.shuffle.shuffle_group
|
|
33
|
+
)
|
|
34
|
+
dask_dataframe.core._concat = unproxify_decorator(dask.dataframe.core._concat)
|
|
35
|
+
|
|
36
|
+
def _register_cudf_spill_aware():
|
|
37
|
+
import cudf
|
|
38
|
+
|
|
39
|
+
# Only enable Dask/cuDF spilling if cuDF spilling is disabled, see
|
|
40
|
+
# https://github.com/rapidsai/dask-cuda/issues/1363
|
|
41
|
+
if not cudf.get_option("spill"):
|
|
42
|
+
# This reproduces the implementation of `_register_cudf`, see
|
|
43
|
+
# https://github.com/dask/distributed/blob/40fcd65e991382a956c3b879e438be1b100dff97/distributed/protocol/__init__.py#L106-L115
|
|
44
|
+
from cudf.comm import serialize
|
|
45
|
+
|
|
46
|
+
for registry in [
|
|
47
|
+
cuda_serialize,
|
|
48
|
+
cuda_deserialize,
|
|
49
|
+
dask_serialize,
|
|
50
|
+
dask_deserialize,
|
|
51
|
+
]:
|
|
52
|
+
for lib in ["cudf", "dask_cudf"]:
|
|
53
|
+
if lib in registry._lazy:
|
|
54
|
+
registry._lazy[lib] = _register_cudf_spill_aware
|
dask_cuda/_compat.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2025 NVIDIA CORPORATION.
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import importlib.metadata
|
|
5
|
+
|
|
6
|
+
import packaging.version
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@functools.lru_cache(maxsize=None)
|
|
10
|
+
def get_dask_version() -> packaging.version.Version:
|
|
11
|
+
return packaging.version.parse(importlib.metadata.version("dask"))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@functools.lru_cache(maxsize=None)
|
|
15
|
+
def DASK_2025_4_0():
|
|
16
|
+
# dask 2025.4.0 isn't currently released, so we're relying
|
|
17
|
+
# on strictly greater than here.
|
|
18
|
+
return get_dask_version() > packaging.version.parse("2025.3.0")
|
|
@@ -1,15 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2021-2025 NVIDIA CORPORATION.
|
|
1
2
|
import asyncio
|
|
2
3
|
import concurrent.futures
|
|
3
4
|
import contextlib
|
|
4
5
|
import time
|
|
5
6
|
import uuid
|
|
7
|
+
import weakref
|
|
6
8
|
from typing import Any, Dict, Hashable, Iterable, List, Optional
|
|
7
9
|
|
|
8
10
|
import distributed.comm
|
|
11
|
+
from dask.tokenize import tokenize
|
|
9
12
|
from distributed import Client, Worker, default_client, get_worker
|
|
10
13
|
from distributed.comm.addressing import parse_address, parse_host_port, unparse_address
|
|
11
14
|
|
|
12
|
-
|
|
15
|
+
# Mapping tokenize(client ID, [worker addresses]) to CommsContext
|
|
16
|
+
_comms_cache: weakref.WeakValueDictionary[
|
|
17
|
+
str, "CommsContext"
|
|
18
|
+
] = weakref.WeakValueDictionary()
|
|
13
19
|
|
|
14
20
|
|
|
15
21
|
def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
|
|
@@ -38,9 +44,10 @@ def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
|
|
|
38
44
|
|
|
39
45
|
|
|
40
46
|
def default_comms(client: Optional[Client] = None) -> "CommsContext":
|
|
41
|
-
"""Return the default comms object
|
|
47
|
+
"""Return the default comms object for ``client``.
|
|
42
48
|
|
|
43
|
-
Creates a new default comms object if
|
|
49
|
+
Creates a new default comms object if one does not already exist
|
|
50
|
+
for ``client``.
|
|
44
51
|
|
|
45
52
|
Parameters
|
|
46
53
|
----------
|
|
@@ -52,11 +59,31 @@ def default_comms(client: Optional[Client] = None) -> "CommsContext":
|
|
|
52
59
|
-------
|
|
53
60
|
comms: CommsContext
|
|
54
61
|
The default comms object
|
|
62
|
+
|
|
63
|
+
Notes
|
|
64
|
+
-----
|
|
65
|
+
There are some subtle points around explicit-comms and the lifecycle
|
|
66
|
+
of a Dask Cluster.
|
|
67
|
+
|
|
68
|
+
A :class:`CommsContext` establishes explicit communication channels
|
|
69
|
+
between the workers *at the time it's created*. If workers are added
|
|
70
|
+
or removed, they will not be included in the communication channels
|
|
71
|
+
with the other workers.
|
|
72
|
+
|
|
73
|
+
If you need to refresh the explicit communications channels, then
|
|
74
|
+
create a new :class:`CommsContext` object or call ``default_comms``
|
|
75
|
+
again after workers have been added to or removed from the cluster.
|
|
55
76
|
"""
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
77
|
+
# Comms are unique to a {client, [workers]} pair, so we key our
|
|
78
|
+
# cache by the token of that.
|
|
79
|
+
client = client or default_client()
|
|
80
|
+
token = tokenize(client.id, list(client.scheduler_info()["workers"].keys()))
|
|
81
|
+
maybe_comms = _comms_cache.get(token)
|
|
82
|
+
if maybe_comms is None:
|
|
83
|
+
maybe_comms = CommsContext(client=client)
|
|
84
|
+
_comms_cache[token] = maybe_comms
|
|
85
|
+
|
|
86
|
+
return maybe_comms
|
|
60
87
|
|
|
61
88
|
|
|
62
89
|
def worker_state(sessionId: Optional[int] = None) -> dict:
|
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
# Copyright (c) 2021-2025 NVIDIA CORPORATION.
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
import asyncio
|
|
6
|
+
import functools
|
|
4
7
|
from collections import defaultdict
|
|
5
8
|
from math import ceil
|
|
6
9
|
from operator import getitem
|
|
@@ -23,6 +26,7 @@ from distributed import wait
|
|
|
23
26
|
from distributed.protocol import nested_deserialize, to_serialize
|
|
24
27
|
from distributed.worker import Worker
|
|
25
28
|
|
|
29
|
+
from ..._compat import DASK_2025_4_0
|
|
26
30
|
from .. import comms
|
|
27
31
|
|
|
28
32
|
T = TypeVar("T")
|
|
@@ -582,6 +586,128 @@ def _use_explicit_comms() -> bool:
|
|
|
582
586
|
return False
|
|
583
587
|
|
|
584
588
|
|
|
589
|
+
_base_lower = dask_expr._shuffle.Shuffle._lower
|
|
590
|
+
_base_compute = dask.base.compute
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def _contains_shuffle_expr(*args) -> bool:
|
|
594
|
+
"""
|
|
595
|
+
Check whether any of the arguments is a Shuffle expression.
|
|
596
|
+
|
|
597
|
+
This is called by `compute`, which is given a sequence of Dask Collections
|
|
598
|
+
to process. For each of those, we'll check whether the expresion contains a
|
|
599
|
+
Shuffle operation.
|
|
600
|
+
"""
|
|
601
|
+
for collection in args:
|
|
602
|
+
if isinstance(collection, dask.dataframe.DataFrame):
|
|
603
|
+
shuffle_ops = list(
|
|
604
|
+
collection.expr.find_operations(
|
|
605
|
+
(
|
|
606
|
+
dask_expr._shuffle.RearrangeByColumn,
|
|
607
|
+
dask_expr.SetIndex,
|
|
608
|
+
dask_expr._shuffle.Shuffle,
|
|
609
|
+
)
|
|
610
|
+
)
|
|
611
|
+
)
|
|
612
|
+
if len(shuffle_ops) > 0:
|
|
613
|
+
return True
|
|
614
|
+
return False
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
@functools.wraps(_base_compute)
|
|
618
|
+
def _patched_compute(
|
|
619
|
+
*args,
|
|
620
|
+
traverse=True,
|
|
621
|
+
optimize_graph=True,
|
|
622
|
+
scheduler=None,
|
|
623
|
+
get=None,
|
|
624
|
+
**kwargs,
|
|
625
|
+
):
|
|
626
|
+
# A patched version of dask.compute that explicitly materializes the task
|
|
627
|
+
# graph when we're using explicit-comms and the expression contains a
|
|
628
|
+
# Shuffle operation.
|
|
629
|
+
# https://github.com/rapidsai/dask-upstream-testing/issues/37#issuecomment-2779798670
|
|
630
|
+
# contains more details on the issue.
|
|
631
|
+
if DASK_2025_4_0() and _use_explicit_comms() and _contains_shuffle_expr(*args):
|
|
632
|
+
from dask.base import (
|
|
633
|
+
collections_to_expr,
|
|
634
|
+
flatten,
|
|
635
|
+
get_scheduler,
|
|
636
|
+
shorten_traceback,
|
|
637
|
+
unpack_collections,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
collections, repack = unpack_collections(*args, traverse=traverse)
|
|
641
|
+
if not collections:
|
|
642
|
+
return args
|
|
643
|
+
|
|
644
|
+
schedule = get_scheduler(
|
|
645
|
+
scheduler=scheduler,
|
|
646
|
+
collections=collections,
|
|
647
|
+
get=get,
|
|
648
|
+
)
|
|
649
|
+
from dask._expr import FinalizeCompute
|
|
650
|
+
|
|
651
|
+
expr = collections_to_expr(collections, optimize_graph)
|
|
652
|
+
expr = FinalizeCompute(expr)
|
|
653
|
+
|
|
654
|
+
with shorten_traceback():
|
|
655
|
+
expr = expr.optimize()
|
|
656
|
+
keys = list(flatten(expr.__dask_keys__()))
|
|
657
|
+
|
|
658
|
+
# materialize the HLG here
|
|
659
|
+
expr = dict(expr.__dask_graph__())
|
|
660
|
+
|
|
661
|
+
results = schedule(expr, keys, **kwargs)
|
|
662
|
+
return repack(results)
|
|
663
|
+
|
|
664
|
+
else:
|
|
665
|
+
return _base_compute(
|
|
666
|
+
*args,
|
|
667
|
+
traverse=traverse,
|
|
668
|
+
optimize_graph=optimize_graph,
|
|
669
|
+
scheduler=scheduler,
|
|
670
|
+
get=get,
|
|
671
|
+
**kwargs,
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
class ECShuffle(dask_expr._shuffle.TaskShuffle):
|
|
676
|
+
"""Explicit-Comms Shuffle Expression."""
|
|
677
|
+
|
|
678
|
+
def _layer(self):
|
|
679
|
+
# Execute an explicit-comms shuffle
|
|
680
|
+
if not hasattr(self, "_ec_shuffled"):
|
|
681
|
+
on = self.partitioning_index
|
|
682
|
+
df = dask_expr.new_collection(self.frame)
|
|
683
|
+
ec_shuffled = shuffle(
|
|
684
|
+
df,
|
|
685
|
+
[on] if isinstance(on, str) else on,
|
|
686
|
+
self.npartitions_out,
|
|
687
|
+
self.ignore_index,
|
|
688
|
+
)
|
|
689
|
+
object.__setattr__(self, "_ec_shuffled", ec_shuffled)
|
|
690
|
+
graph = self._ec_shuffled.dask.copy()
|
|
691
|
+
shuffled_name = self._ec_shuffled._name
|
|
692
|
+
for i in range(self.npartitions_out):
|
|
693
|
+
graph[(self._name, i)] = graph[(shuffled_name, i)]
|
|
694
|
+
return graph
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _patched_lower(self):
|
|
698
|
+
if self.method in (None, "tasks") and _use_explicit_comms():
|
|
699
|
+
return ECShuffle(
|
|
700
|
+
self.frame,
|
|
701
|
+
self.partitioning_index,
|
|
702
|
+
self.npartitions_out,
|
|
703
|
+
self.ignore_index,
|
|
704
|
+
self.options,
|
|
705
|
+
self.original_partitioning_index,
|
|
706
|
+
)
|
|
707
|
+
else:
|
|
708
|
+
return _base_lower(self)
|
|
709
|
+
|
|
710
|
+
|
|
585
711
|
def patch_shuffle_expression() -> None:
|
|
586
712
|
"""Patch Dasks Shuffle expression.
|
|
587
713
|
|
|
@@ -590,40 +716,6 @@ def patch_shuffle_expression() -> None:
|
|
|
590
716
|
an `ECShuffle` expression when the 'explicit-comms'
|
|
591
717
|
config is set to `True`.
|
|
592
718
|
"""
|
|
593
|
-
|
|
594
|
-
class ECShuffle(dask_expr._shuffle.TaskShuffle):
|
|
595
|
-
"""Explicit-Comms Shuffle Expression."""
|
|
596
|
-
|
|
597
|
-
def _layer(self):
|
|
598
|
-
# Execute an explicit-comms shuffle
|
|
599
|
-
if not hasattr(self, "_ec_shuffled"):
|
|
600
|
-
on = self.partitioning_index
|
|
601
|
-
df = dask_expr.new_collection(self.frame)
|
|
602
|
-
self._ec_shuffled = shuffle(
|
|
603
|
-
df,
|
|
604
|
-
[on] if isinstance(on, str) else on,
|
|
605
|
-
self.npartitions_out,
|
|
606
|
-
self.ignore_index,
|
|
607
|
-
)
|
|
608
|
-
graph = self._ec_shuffled.dask.copy()
|
|
609
|
-
shuffled_name = self._ec_shuffled._name
|
|
610
|
-
for i in range(self.npartitions_out):
|
|
611
|
-
graph[(self._name, i)] = graph[(shuffled_name, i)]
|
|
612
|
-
return graph
|
|
613
|
-
|
|
614
|
-
_base_lower = dask_expr._shuffle.Shuffle._lower
|
|
615
|
-
|
|
616
|
-
def _patched_lower(self):
|
|
617
|
-
if self.method in (None, "tasks") and _use_explicit_comms():
|
|
618
|
-
return ECShuffle(
|
|
619
|
-
self.frame,
|
|
620
|
-
self.partitioning_index,
|
|
621
|
-
self.npartitions_out,
|
|
622
|
-
self.ignore_index,
|
|
623
|
-
self.options,
|
|
624
|
-
self.original_partitioning_index,
|
|
625
|
-
)
|
|
626
|
-
else:
|
|
627
|
-
return _base_lower(self)
|
|
719
|
+
dask.base.compute = _patched_compute
|
|
628
720
|
|
|
629
721
|
dask_expr._shuffle.Shuffle._lower = _patched_lower
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# Copyright (c) 2025, NVIDIA CORPORATION.
|
|
2
|
+
|
|
1
3
|
from typing import Set
|
|
2
4
|
|
|
3
5
|
from dask.sizeof import sizeof
|
|
@@ -140,3 +142,16 @@ def register_cupy(): # NB: this overwrites dask.sizeof.register_cupy()
|
|
|
140
142
|
@sizeof.register(cupy.ndarray)
|
|
141
143
|
def sizeof_cupy_ndarray(x):
|
|
142
144
|
return int(x.nbytes)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@sizeof.register_lazy("pylibcudf")
|
|
148
|
+
def register_pylibcudf():
|
|
149
|
+
import pylibcudf
|
|
150
|
+
|
|
151
|
+
@sizeof.register(pylibcudf.column.OwnerWithCAI)
|
|
152
|
+
def sizeof_owner_with_cai(x):
|
|
153
|
+
# OwnerWithCAI implements __cuda_array_interface__ so this should always
|
|
154
|
+
# be zero-copy
|
|
155
|
+
col = pylibcudf.column.Column.from_cuda_array_interface(x)
|
|
156
|
+
# col.data() returns a gpumemoryview, which knows the size in bytes
|
|
157
|
+
return col.data().nbytes
|
dask_cuda/is_device_object.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# Copyright (c) 2025 NVIDIA CORPORATION.
|
|
1
2
|
from __future__ import absolute_import, division, print_function
|
|
2
3
|
|
|
3
4
|
from dask.utils import Dispatch
|
|
@@ -35,6 +36,8 @@ def register_cudf():
|
|
|
35
36
|
def is_device_object_cudf_series(s):
|
|
36
37
|
return True
|
|
37
38
|
|
|
38
|
-
@is_device_object.register(cudf.
|
|
39
|
+
@is_device_object.register(cudf.Index)
|
|
40
|
+
@is_device_object.register(cudf.RangeIndex)
|
|
41
|
+
@is_device_object.register(cudf.MultiIndex)
|
|
39
42
|
def is_device_object_cudf_index(s):
|
|
40
43
|
return True
|
dask_cuda/is_spillable_object.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# Copyright (c) 2025 NVIDIA CORPORATION.
|
|
1
2
|
from __future__ import absolute_import, division, print_function
|
|
2
3
|
|
|
3
4
|
from typing import Optional
|
|
@@ -34,7 +35,9 @@ def register_cudf():
|
|
|
34
35
|
def is_device_object_cudf_dataframe(df):
|
|
35
36
|
return cudf_spilling_status()
|
|
36
37
|
|
|
37
|
-
@is_spillable_object.register(cudf.
|
|
38
|
+
@is_spillable_object.register(cudf.Index)
|
|
39
|
+
@is_spillable_object.register(cudf.RangeIndex)
|
|
40
|
+
@is_spillable_object.register(cudf.MultiIndex)
|
|
38
41
|
def is_device_object_cudf_index(s):
|
|
39
42
|
return cudf_spilling_status()
|
|
40
43
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# Copyright (c) 2025 NVIDIA CORPORATION.
|
|
1
2
|
import functools
|
|
2
3
|
import pydoc
|
|
3
4
|
from collections import defaultdict
|
|
@@ -242,7 +243,9 @@ def _register_cudf():
|
|
|
242
243
|
|
|
243
244
|
@dispatch.register(cudf.DataFrame)
|
|
244
245
|
@dispatch.register(cudf.Series)
|
|
245
|
-
@dispatch.register(cudf.
|
|
246
|
+
@dispatch.register(cudf.Index)
|
|
247
|
+
@dispatch.register(cudf.MultiIndex)
|
|
248
|
+
@dispatch.register(cudf.RangeIndex)
|
|
246
249
|
def proxify_device_object_cudf_dataframe(
|
|
247
250
|
obj, proxied_id_to_proxy, found_proxies, excl_proxies
|
|
248
251
|
):
|
dask_cuda/proxy_object.py
CHANGED
|
@@ -11,10 +11,6 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple, Type, Un
|
|
|
11
11
|
import pandas
|
|
12
12
|
|
|
13
13
|
import dask
|
|
14
|
-
import dask.array.core
|
|
15
|
-
import dask.dataframe.backends
|
|
16
|
-
import dask.dataframe.dispatch
|
|
17
|
-
import dask.dataframe.utils
|
|
18
14
|
import dask.utils
|
|
19
15
|
import distributed.protocol
|
|
20
16
|
import distributed.utils
|
|
@@ -30,6 +26,22 @@ if TYPE_CHECKING:
|
|
|
30
26
|
from .proxify_host_file import ProxyManager
|
|
31
27
|
|
|
32
28
|
|
|
29
|
+
try:
|
|
30
|
+
import dask.dataframe as dask_dataframe
|
|
31
|
+
import dask.dataframe.backends
|
|
32
|
+
import dask.dataframe.dispatch
|
|
33
|
+
import dask.dataframe.utils
|
|
34
|
+
except ImportError:
|
|
35
|
+
dask_dataframe = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
import dask.array as dask_array
|
|
40
|
+
import dask.array.core
|
|
41
|
+
except ImportError:
|
|
42
|
+
dask_array = None
|
|
43
|
+
|
|
44
|
+
|
|
33
45
|
# List of attributes that should be copied to the proxy at creation, which makes
|
|
34
46
|
# them accessible without deserialization of the proxied object
|
|
35
47
|
_FIXED_ATTRS = ["name", "__len__"]
|
|
@@ -884,14 +896,6 @@ def obj_pxy_dask_deserialize(header, frames):
|
|
|
884
896
|
return subclass(pxy)
|
|
885
897
|
|
|
886
898
|
|
|
887
|
-
@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
|
|
888
|
-
def get_parallel_type_proxy_object(obj: ProxyObject):
|
|
889
|
-
# Notice, `get_parallel_type()` needs a instance not a type object
|
|
890
|
-
return dask.dataframe.dispatch.get_parallel_type(
|
|
891
|
-
obj.__class__.__new__(obj.__class__)
|
|
892
|
-
)
|
|
893
|
-
|
|
894
|
-
|
|
895
899
|
def unproxify_input_wrapper(func):
|
|
896
900
|
"""Unproxify the input of `func`"""
|
|
897
901
|
|
|
@@ -904,26 +908,42 @@ def unproxify_input_wrapper(func):
|
|
|
904
908
|
return wrapper
|
|
905
909
|
|
|
906
910
|
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
)
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
911
|
+
if dask_array is not None:
|
|
912
|
+
|
|
913
|
+
# Register dispatch of ProxyObject on all known dispatch objects
|
|
914
|
+
for dispatch in (
|
|
915
|
+
dask.array.core.tensordot_lookup,
|
|
916
|
+
dask.array.core.einsum_lookup,
|
|
917
|
+
dask.array.core.concatenate_lookup,
|
|
918
|
+
):
|
|
919
|
+
dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
if dask_dataframe is not None:
|
|
923
|
+
|
|
924
|
+
@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
|
|
925
|
+
def get_parallel_type_proxy_object(obj: ProxyObject):
|
|
926
|
+
# Notice, `get_parallel_type()` needs a instance not a type object
|
|
927
|
+
return dask.dataframe.dispatch.get_parallel_type(
|
|
928
|
+
obj.__class__.__new__(obj.__class__)
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
# Register dispatch of ProxyObject on all known dispatch objects
|
|
932
|
+
for dispatch in (
|
|
933
|
+
dask.dataframe.dispatch.hash_object_dispatch,
|
|
934
|
+
dask.dataframe.dispatch.make_meta_dispatch,
|
|
935
|
+
dask.dataframe.utils.make_scalar,
|
|
936
|
+
dask.dataframe.dispatch.group_split_dispatch,
|
|
937
|
+
):
|
|
938
|
+
dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))
|
|
939
|
+
|
|
940
|
+
dask.dataframe.dispatch.concat_dispatch.register(
|
|
941
|
+
ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat)
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
# We overwrite the Dask dispatch of Pandas objects in order to
|
|
945
|
+
# deserialize all ProxyObjects before concatenating
|
|
946
|
+
dask.dataframe.dispatch.concat_dispatch.register(
|
|
947
|
+
(pandas.DataFrame, pandas.Series, pandas.Index),
|
|
948
|
+
unproxify_input_wrapper(dask.dataframe.backends.concat_pandas),
|
|
949
|
+
)
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
1
4
|
from __future__ import absolute_import, division, print_function
|
|
2
5
|
|
|
3
6
|
import os
|
|
@@ -16,7 +19,7 @@ from dask_cuda.utils import (
|
|
|
16
19
|
get_cluster_configuration,
|
|
17
20
|
get_device_total_memory,
|
|
18
21
|
get_gpu_count_mig,
|
|
19
|
-
|
|
22
|
+
get_gpu_uuid,
|
|
20
23
|
get_n_gpus,
|
|
21
24
|
wait_workers,
|
|
22
25
|
)
|
|
@@ -409,7 +412,7 @@ def test_cuda_mig_visible_devices_and_memory_limit_and_nthreads(loop): # noqa:
|
|
|
409
412
|
|
|
410
413
|
|
|
411
414
|
def test_cuda_visible_devices_uuid(loop): # noqa: F811
|
|
412
|
-
gpu_uuid =
|
|
415
|
+
gpu_uuid = get_gpu_uuid(0)
|
|
413
416
|
|
|
414
417
|
with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": gpu_uuid}):
|
|
415
418
|
with popen(["dask", "scheduler", "--port", "9359", "--no-dashboard"]):
|