dask-cuda 25.4.0__py3-none-any.whl → 25.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_cuda/GIT_COMMIT +1 -1
- dask_cuda/VERSION +1 -1
- dask_cuda/_compat.py +18 -0
- dask_cuda/benchmarks/common.py +4 -1
- dask_cuda/benchmarks/local_cudf_groupby.py +4 -1
- dask_cuda/benchmarks/local_cudf_merge.py +5 -2
- dask_cuda/benchmarks/local_cudf_shuffle.py +5 -2
- dask_cuda/benchmarks/local_cupy.py +4 -1
- dask_cuda/benchmarks/local_cupy_map_overlap.py +4 -1
- dask_cuda/benchmarks/utils.py +7 -4
- dask_cuda/cli.py +21 -15
- dask_cuda/cuda_worker.py +27 -57
- dask_cuda/device_host_file.py +31 -15
- dask_cuda/disk_io.py +7 -4
- dask_cuda/explicit_comms/comms.py +11 -7
- dask_cuda/explicit_comms/dataframe/shuffle.py +147 -55
- dask_cuda/get_device_memory_objects.py +18 -3
- dask_cuda/initialize.py +80 -44
- dask_cuda/is_device_object.py +4 -1
- dask_cuda/is_spillable_object.py +4 -1
- dask_cuda/local_cuda_cluster.py +63 -66
- dask_cuda/plugins.py +17 -16
- dask_cuda/proxify_device_objects.py +15 -10
- dask_cuda/proxify_host_file.py +30 -27
- dask_cuda/proxy_object.py +20 -17
- dask_cuda/tests/conftest.py +41 -0
- dask_cuda/tests/test_dask_cuda_worker.py +114 -27
- dask_cuda/tests/test_dgx.py +10 -18
- dask_cuda/tests/test_explicit_comms.py +51 -18
- dask_cuda/tests/test_from_array.py +7 -5
- dask_cuda/tests/test_initialize.py +16 -37
- dask_cuda/tests/test_local_cuda_cluster.py +164 -54
- dask_cuda/tests/test_proxify_host_file.py +33 -4
- dask_cuda/tests/test_proxy.py +18 -16
- dask_cuda/tests/test_rdd_ucx.py +160 -0
- dask_cuda/tests/test_spill.py +107 -27
- dask_cuda/tests/test_utils.py +106 -20
- dask_cuda/tests/test_worker_spec.py +5 -2
- dask_cuda/utils.py +319 -68
- dask_cuda/utils_test.py +23 -7
- dask_cuda/worker_common.py +196 -0
- dask_cuda/worker_spec.py +12 -5
- {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/METADATA +5 -4
- dask_cuda-25.8.0.dist-info/RECORD +63 -0
- {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/WHEEL +1 -1
- dask_cuda-25.8.0.dist-info/top_level.txt +6 -0
- shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
- shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
- shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
- dask_cuda-25.4.0.dist-info/RECORD +0 -56
- dask_cuda-25.4.0.dist-info/top_level.txt +0 -5
- {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/entry_points.txt +0 -0
- {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
# Copyright (c) 2021-2025 NVIDIA CORPORATION.
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
import asyncio
|
|
6
|
+
import functools
|
|
4
7
|
from collections import defaultdict
|
|
5
8
|
from math import ceil
|
|
6
9
|
from operator import getitem
|
|
@@ -23,6 +26,7 @@ from distributed import wait
|
|
|
23
26
|
from distributed.protocol import nested_deserialize, to_serialize
|
|
24
27
|
from distributed.worker import Worker
|
|
25
28
|
|
|
29
|
+
from ..._compat import DASK_2025_4_0
|
|
26
30
|
from .. import comms
|
|
27
31
|
|
|
28
32
|
T = TypeVar("T")
|
|
@@ -61,13 +65,13 @@ def get_no_comm_postprocess(
|
|
|
61
65
|
) -> Callable[[DataFrame], DataFrame]:
|
|
62
66
|
"""Get function for post-processing partitions not communicated
|
|
63
67
|
|
|
64
|
-
In cuDF, the
|
|
68
|
+
In cuDF, the ``group_split_dispatch`` uses ``scatter_by_map`` to create
|
|
65
69
|
the partitions, which is implemented by splitting a single base dataframe
|
|
66
70
|
into multiple partitions. This means that memory are not freed until
|
|
67
71
|
ALL partitions are deleted.
|
|
68
72
|
|
|
69
73
|
In order to free memory ASAP, we can deep copy partitions NOT being
|
|
70
|
-
communicated. We do this when
|
|
74
|
+
communicated. We do this when ``num_rounds != batchsize``.
|
|
71
75
|
|
|
72
76
|
Parameters
|
|
73
77
|
----------
|
|
@@ -112,7 +116,7 @@ async def send(
|
|
|
112
116
|
rank_to_out_part_ids: Dict[int, Set[int]],
|
|
113
117
|
out_part_id_to_dataframe: Dict[int, DataFrame],
|
|
114
118
|
) -> None:
|
|
115
|
-
"""Notice, items sent are removed from
|
|
119
|
+
"""Notice, items sent are removed from ``out_part_id_to_dataframe``"""
|
|
116
120
|
futures = []
|
|
117
121
|
for rank, out_part_ids in rank_to_out_part_ids.items():
|
|
118
122
|
if rank != myrank:
|
|
@@ -131,7 +135,7 @@ async def recv(
|
|
|
131
135
|
out_part_id_to_dataframe_list: Dict[int, List[DataFrame]],
|
|
132
136
|
proxify: Proxify,
|
|
133
137
|
) -> None:
|
|
134
|
-
"""Notice, received items are appended to
|
|
138
|
+
"""Notice, received items are appended to ``out_parts_list``"""
|
|
135
139
|
|
|
136
140
|
async def read_msg(rank: int) -> None:
|
|
137
141
|
msg: Dict[int, DataFrame] = nested_deserialize(await eps[rank].read())
|
|
@@ -146,11 +150,11 @@ async def recv(
|
|
|
146
150
|
def compute_map_index(
|
|
147
151
|
df: DataFrame, column_names: List[str], npartitions: int
|
|
148
152
|
) -> Series:
|
|
149
|
-
"""Return a Series that maps each row
|
|
153
|
+
"""Return a Series that maps each row ``df`` to a partition ID
|
|
150
154
|
|
|
151
155
|
The partitions are determined by hashing the columns given by column_names
|
|
152
|
-
unless if
|
|
153
|
-
|
|
156
|
+
unless if ``column_names[0] == "_partitions"``, in which case the values of
|
|
157
|
+
``column_names[0]`` are used as index.
|
|
154
158
|
|
|
155
159
|
Parameters
|
|
156
160
|
----------
|
|
@@ -164,7 +168,7 @@ def compute_map_index(
|
|
|
164
168
|
Returns
|
|
165
169
|
-------
|
|
166
170
|
Series
|
|
167
|
-
Series that maps each row
|
|
171
|
+
Series that maps each row ``df`` to a partition ID
|
|
168
172
|
"""
|
|
169
173
|
|
|
170
174
|
if column_names[0] == "_partitions":
|
|
@@ -189,8 +193,8 @@ def partition_dataframe(
|
|
|
189
193
|
"""Partition dataframe to a dict of dataframes
|
|
190
194
|
|
|
191
195
|
The partitions are determined by hashing the columns given by column_names
|
|
192
|
-
unless
|
|
193
|
-
|
|
196
|
+
unless ``column_names[0] == "_partitions"``, in which case the values of
|
|
197
|
+
``column_names[0]`` are used as index.
|
|
194
198
|
|
|
195
199
|
Parameters
|
|
196
200
|
----------
|
|
@@ -297,13 +301,13 @@ async def send_recv_partitions(
|
|
|
297
301
|
rank_to_out_part_ids
|
|
298
302
|
dict that for each worker rank specifies a set of output partition IDs.
|
|
299
303
|
If the worker shouldn't return any partitions, it is excluded from the
|
|
300
|
-
dict. Partition IDs are global integers
|
|
301
|
-
to the dict keys returned by
|
|
304
|
+
dict. Partition IDs are global integers ``0..npartitions`` and corresponds
|
|
305
|
+
to the dict keys returned by ``group_split_dispatch``.
|
|
302
306
|
out_part_id_to_dataframe
|
|
303
307
|
Mapping from partition ID to dataframe. This dict is cleared on return.
|
|
304
308
|
no_comm_postprocess
|
|
305
309
|
Function to post-process partitions not communicated.
|
|
306
|
-
See
|
|
310
|
+
See ``get_no_comm_postprocess``
|
|
307
311
|
proxify
|
|
308
312
|
Function to proxify object.
|
|
309
313
|
out_part_id_to_dataframe_list
|
|
@@ -361,8 +365,8 @@ async def shuffle_task(
|
|
|
361
365
|
rank_to_out_part_ids: dict
|
|
362
366
|
dict that for each worker rank specifies a set of output partition IDs.
|
|
363
367
|
If the worker shouldn't return any partitions, it is excluded from the
|
|
364
|
-
dict. Partition IDs are global integers
|
|
365
|
-
to the dict keys returned by
|
|
368
|
+
dict. Partition IDs are global integers ``0..npartitions`` and corresponds
|
|
369
|
+
to the dict keys returned by ``group_split_dispatch``.
|
|
366
370
|
column_names: list of strings
|
|
367
371
|
List of column names on which we want to split.
|
|
368
372
|
npartitions: int
|
|
@@ -445,7 +449,7 @@ def shuffle(
|
|
|
445
449
|
List of column names on which we want to split.
|
|
446
450
|
npartitions: int or None
|
|
447
451
|
The desired number of output partitions. If None, the number of output
|
|
448
|
-
partitions equals
|
|
452
|
+
partitions equals ``df.npartitions``
|
|
449
453
|
ignore_index: bool
|
|
450
454
|
Ignore index during shuffle. If True, performance may improve,
|
|
451
455
|
but index values will not be preserved.
|
|
@@ -456,7 +460,7 @@ def shuffle(
|
|
|
456
460
|
If -1, each worker will handle all its partitions in a single round and
|
|
457
461
|
all techniques to reduce memory usage are disabled, which might be faster
|
|
458
462
|
when memory pressure isn't an issue.
|
|
459
|
-
If None, the value of
|
|
463
|
+
If None, the value of ``DASK_EXPLICIT_COMMS_BATCHSIZE`` is used or 1 if not
|
|
460
464
|
set thus by default, we prioritize robustness over performance.
|
|
461
465
|
|
|
462
466
|
Returns
|
|
@@ -467,12 +471,12 @@ def shuffle(
|
|
|
467
471
|
Developer Notes
|
|
468
472
|
---------------
|
|
469
473
|
The implementation consist of three steps:
|
|
470
|
-
(a) Stage the partitions of
|
|
474
|
+
(a) Stage the partitions of ``df`` on all workers and then cancel them
|
|
471
475
|
thus at this point the Dask Scheduler doesn't know about any of the
|
|
472
476
|
the partitions.
|
|
473
477
|
(b) Submit a task on each worker that shuffle (all-to-all communicate)
|
|
474
478
|
the staged partitions and return a list of dataframe-partitions.
|
|
475
|
-
(c) Submit a dask graph that extract (using
|
|
479
|
+
(c) Submit a dask graph that extract (using ``getitem()``) individual
|
|
476
480
|
dataframe-partitions from (b).
|
|
477
481
|
"""
|
|
478
482
|
c = comms.default_comms()
|
|
@@ -582,48 +586,136 @@ def _use_explicit_comms() -> bool:
|
|
|
582
586
|
return False
|
|
583
587
|
|
|
584
588
|
|
|
585
|
-
|
|
586
|
-
|
|
589
|
+
_base_lower = dask_expr._shuffle.Shuffle._lower
|
|
590
|
+
_base_compute = dask.base.compute
|
|
587
591
|
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
an `ECShuffle` expression when the 'explicit-comms'
|
|
591
|
-
config is set to `True`.
|
|
592
|
+
|
|
593
|
+
def _contains_shuffle_expr(*args) -> bool:
|
|
592
594
|
"""
|
|
595
|
+
Check whether any of the arguments is a Shuffle expression.
|
|
593
596
|
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
597
|
+
This is called by ``compute``, which is given a sequence of Dask Collections
|
|
598
|
+
to process. For each of those, we'll check whether the expresion contains a
|
|
599
|
+
Shuffle operation.
|
|
600
|
+
"""
|
|
601
|
+
for collection in args:
|
|
602
|
+
if isinstance(collection, dask.dataframe.DataFrame):
|
|
603
|
+
shuffle_ops = list(
|
|
604
|
+
collection.expr.find_operations(
|
|
605
|
+
(
|
|
606
|
+
dask_expr._shuffle.RearrangeByColumn,
|
|
607
|
+
dask_expr.SetIndex,
|
|
608
|
+
dask_expr._shuffle.Shuffle,
|
|
609
|
+
)
|
|
607
610
|
)
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
611
|
+
)
|
|
612
|
+
if len(shuffle_ops) > 0:
|
|
613
|
+
return True
|
|
614
|
+
return False
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
@functools.wraps(_base_compute)
|
|
618
|
+
def _patched_compute(
|
|
619
|
+
*args,
|
|
620
|
+
traverse=True,
|
|
621
|
+
optimize_graph=True,
|
|
622
|
+
scheduler=None,
|
|
623
|
+
get=None,
|
|
624
|
+
**kwargs,
|
|
625
|
+
):
|
|
626
|
+
# A patched version of dask.compute that explicitly materializes the task
|
|
627
|
+
# graph when we're using explicit-comms and the expression contains a
|
|
628
|
+
# Shuffle operation.
|
|
629
|
+
# https://github.com/rapidsai/dask-upstream-testing/issues/37#issuecomment-2779798670
|
|
630
|
+
# contains more details on the issue.
|
|
631
|
+
if DASK_2025_4_0() and _use_explicit_comms() and _contains_shuffle_expr(*args):
|
|
632
|
+
from dask.base import (
|
|
633
|
+
collections_to_expr,
|
|
634
|
+
flatten,
|
|
635
|
+
get_scheduler,
|
|
636
|
+
shorten_traceback,
|
|
637
|
+
unpack_collections,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
collections, repack = unpack_collections(*args, traverse=traverse)
|
|
641
|
+
if not collections:
|
|
642
|
+
return args
|
|
643
|
+
|
|
644
|
+
schedule = get_scheduler(
|
|
645
|
+
scheduler=scheduler,
|
|
646
|
+
collections=collections,
|
|
647
|
+
get=get,
|
|
648
|
+
)
|
|
649
|
+
from dask._expr import FinalizeCompute
|
|
650
|
+
|
|
651
|
+
expr = collections_to_expr(collections, optimize_graph)
|
|
652
|
+
expr = FinalizeCompute(expr)
|
|
653
|
+
|
|
654
|
+
with shorten_traceback():
|
|
655
|
+
expr = expr.optimize()
|
|
656
|
+
keys = list(flatten(expr.__dask_keys__()))
|
|
657
|
+
|
|
658
|
+
# materialize the HLG here
|
|
659
|
+
expr = dict(expr.__dask_graph__())
|
|
660
|
+
|
|
661
|
+
results = schedule(expr, keys, **kwargs)
|
|
662
|
+
return repack(results)
|
|
663
|
+
|
|
664
|
+
else:
|
|
665
|
+
return _base_compute(
|
|
666
|
+
*args,
|
|
667
|
+
traverse=traverse,
|
|
668
|
+
optimize_graph=optimize_graph,
|
|
669
|
+
scheduler=scheduler,
|
|
670
|
+
get=get,
|
|
671
|
+
**kwargs,
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
class ECShuffle(dask_expr._shuffle.TaskShuffle):
|
|
676
|
+
"""Explicit-Comms Shuffle Expression."""
|
|
677
|
+
|
|
678
|
+
def _layer(self):
|
|
679
|
+
# Execute an explicit-comms shuffle
|
|
680
|
+
if not hasattr(self, "_ec_shuffled"):
|
|
681
|
+
on = self.partitioning_index
|
|
682
|
+
df = dask_expr.new_collection(self.frame)
|
|
683
|
+
ec_shuffled = shuffle(
|
|
684
|
+
df,
|
|
685
|
+
[on] if isinstance(on, str) else on,
|
|
621
686
|
self.npartitions_out,
|
|
622
687
|
self.ignore_index,
|
|
623
|
-
self.options,
|
|
624
|
-
self.original_partitioning_index,
|
|
625
688
|
)
|
|
626
|
-
|
|
627
|
-
|
|
689
|
+
object.__setattr__(self, "_ec_shuffled", ec_shuffled)
|
|
690
|
+
graph = self._ec_shuffled.dask.copy()
|
|
691
|
+
shuffled_name = self._ec_shuffled._name
|
|
692
|
+
for i in range(self.npartitions_out):
|
|
693
|
+
graph[(self._name, i)] = graph[(shuffled_name, i)]
|
|
694
|
+
return graph
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _patched_lower(self):
|
|
698
|
+
if self.method in (None, "tasks") and _use_explicit_comms():
|
|
699
|
+
return ECShuffle(
|
|
700
|
+
self.frame,
|
|
701
|
+
self.partitioning_index,
|
|
702
|
+
self.npartitions_out,
|
|
703
|
+
self.ignore_index,
|
|
704
|
+
self.options,
|
|
705
|
+
self.original_partitioning_index,
|
|
706
|
+
)
|
|
707
|
+
else:
|
|
708
|
+
return _base_lower(self)
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def patch_shuffle_expression() -> None:
|
|
712
|
+
"""Patch Dasks Shuffle expression.
|
|
713
|
+
|
|
714
|
+
Notice, this is monkey patched into Dask at dask_cuda
|
|
715
|
+
import, and it changes ``Shuffle._layer`` to lower into
|
|
716
|
+
an ``ECShuffle`` expression when the 'explicit-comms'
|
|
717
|
+
config is set to ``True``.
|
|
718
|
+
"""
|
|
719
|
+
dask.base.compute = _patched_compute
|
|
628
720
|
|
|
629
721
|
dask_expr._shuffle.Shuffle._lower = _patched_lower
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# Copyright (c) 2025, NVIDIA CORPORATION.
|
|
2
|
+
|
|
1
3
|
from typing import Set
|
|
2
4
|
|
|
3
5
|
from dask.sizeof import sizeof
|
|
@@ -25,10 +27,10 @@ class DeviceMemoryId:
|
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
def get_device_memory_ids(obj) -> Set[DeviceMemoryId]:
|
|
28
|
-
"""Find all CUDA device objects in
|
|
30
|
+
"""Find all CUDA device objects in ``obj``
|
|
29
31
|
|
|
30
|
-
Search through
|
|
31
|
-
that either are known to
|
|
32
|
+
Search through ``obj`` and find all CUDA device objects, which are objects
|
|
33
|
+
that either are known to ``dispatch`` or implement ``__cuda_array_interface__``.
|
|
32
34
|
|
|
33
35
|
Parameters
|
|
34
36
|
----------
|
|
@@ -140,3 +142,16 @@ def register_cupy(): # NB: this overwrites dask.sizeof.register_cupy()
|
|
|
140
142
|
@sizeof.register(cupy.ndarray)
|
|
141
143
|
def sizeof_cupy_ndarray(x):
|
|
142
144
|
return int(x.nbytes)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@sizeof.register_lazy("pylibcudf")
|
|
148
|
+
def register_pylibcudf():
|
|
149
|
+
import pylibcudf
|
|
150
|
+
|
|
151
|
+
@sizeof.register(pylibcudf.column.OwnerWithCAI)
|
|
152
|
+
def sizeof_owner_with_cai(x):
|
|
153
|
+
# OwnerWithCAI implements __cuda_array_interface__ so this should always
|
|
154
|
+
# be zero-copy
|
|
155
|
+
col = pylibcudf.column.Column.from_cuda_array_interface(x)
|
|
156
|
+
# col.data() returns a gpumemoryview, which knows the size in bytes
|
|
157
|
+
return col.data().nbytes
|
dask_cuda/initialize.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
1
4
|
import logging
|
|
2
5
|
import os
|
|
3
6
|
|
|
@@ -7,7 +10,7 @@ import numba.cuda
|
|
|
7
10
|
import dask
|
|
8
11
|
from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context
|
|
9
12
|
|
|
10
|
-
from .utils import get_ucx_config
|
|
13
|
+
from .utils import _get_active_ucx_implementation_name, get_ucx_config
|
|
11
14
|
|
|
12
15
|
logger = logging.getLogger(__name__)
|
|
13
16
|
|
|
@@ -22,65 +25,97 @@ def _create_cuda_context_handler():
|
|
|
22
25
|
numba.cuda.current_context()
|
|
23
26
|
|
|
24
27
|
|
|
25
|
-
def
|
|
26
|
-
if protocol not in ["ucx", "ucxx"]:
|
|
27
|
-
return
|
|
28
|
+
def _warn_generic():
|
|
28
29
|
try:
|
|
30
|
+
# TODO: update when UCX-Py is removed, see
|
|
31
|
+
# https://github.com/rapidsai/dask-cuda/issues/1517
|
|
32
|
+
import distributed.comm.ucx
|
|
33
|
+
|
|
29
34
|
# Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
|
|
30
35
|
# context directly from the UCX module, thus avoiding a similar warning there.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
36
|
+
cuda_visible_device = get_device_index_and_uuid(
|
|
37
|
+
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
|
|
38
|
+
)
|
|
39
|
+
ctx = has_cuda_context()
|
|
40
|
+
if (
|
|
41
|
+
ctx.has_context
|
|
42
|
+
and not distributed.comm.ucx.cuda_context_created.has_context
|
|
43
|
+
):
|
|
44
|
+
distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
|
|
45
|
+
|
|
46
|
+
_create_cuda_context_handler()
|
|
47
|
+
|
|
48
|
+
if not distributed.comm.ucx.cuda_context_created.has_context:
|
|
49
|
+
ctx = has_cuda_context()
|
|
50
|
+
if ctx.has_context and ctx.device_info != cuda_visible_device:
|
|
51
|
+
distributed.comm.ucx._warn_cuda_context_wrong_device(
|
|
52
|
+
cuda_visible_device, ctx.device_info, os.getpid()
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
except Exception:
|
|
56
|
+
logger.error("Unable to start CUDA Context", exc_info=True)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _initialize_ucx():
|
|
60
|
+
try:
|
|
61
|
+
import distributed.comm.ucx
|
|
62
|
+
|
|
63
|
+
distributed.comm.ucx.init_once()
|
|
64
|
+
except ModuleNotFoundError:
|
|
65
|
+
# UCX initialization has to be delegated to Distributed, it will take care
|
|
66
|
+
# of setting correct environment variables and importing `ucp` after that.
|
|
67
|
+
# Therefore if ``import ucp`` fails we can just continue here.
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _initialize_ucxx():
|
|
72
|
+
try:
|
|
73
|
+
# Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
|
|
74
|
+
# context directly from the UCX module, thus avoiding a similar warning there.
|
|
75
|
+
import distributed_ucxx.ucxx
|
|
76
|
+
|
|
77
|
+
distributed_ucxx.ucxx.init_once()
|
|
45
78
|
|
|
46
79
|
cuda_visible_device = get_device_index_and_uuid(
|
|
47
80
|
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
|
|
48
81
|
)
|
|
49
82
|
ctx = has_cuda_context()
|
|
50
|
-
if
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
)
|
|
55
|
-
distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
|
|
56
|
-
elif protocol == "ucxx":
|
|
57
|
-
if (
|
|
58
|
-
ctx.has_context
|
|
59
|
-
and not distributed_ucxx.ucxx.cuda_context_created.has_context
|
|
60
|
-
):
|
|
61
|
-
distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
|
|
83
|
+
if (
|
|
84
|
+
ctx.has_context
|
|
85
|
+
and not distributed_ucxx.ucxx.cuda_context_created.has_context
|
|
86
|
+
):
|
|
87
|
+
distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
|
|
62
88
|
|
|
63
89
|
_create_cuda_context_handler()
|
|
64
90
|
|
|
65
|
-
if
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
)
|
|
72
|
-
elif protocol == "ucxx":
|
|
73
|
-
if not distributed_ucxx.ucxx.cuda_context_created.has_context:
|
|
74
|
-
ctx = has_cuda_context()
|
|
75
|
-
if ctx.has_context and ctx.device_info != cuda_visible_device:
|
|
76
|
-
distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
|
|
77
|
-
cuda_visible_device, ctx.device_info, os.getpid()
|
|
78
|
-
)
|
|
91
|
+
if not distributed_ucxx.ucxx.cuda_context_created.has_context:
|
|
92
|
+
ctx = has_cuda_context()
|
|
93
|
+
if ctx.has_context and ctx.device_info != cuda_visible_device:
|
|
94
|
+
distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
|
|
95
|
+
cuda_visible_device, ctx.device_info, os.getpid()
|
|
96
|
+
)
|
|
79
97
|
|
|
80
98
|
except Exception:
|
|
81
99
|
logger.error("Unable to start CUDA Context", exc_info=True)
|
|
82
100
|
|
|
83
101
|
|
|
102
|
+
def _create_cuda_context(protocol="ucx"):
|
|
103
|
+
if protocol not in ["ucx", "ucxx", "ucx-old"]:
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
ucx_implementation = _get_active_ucx_implementation_name(protocol)
|
|
108
|
+
except ValueError:
|
|
109
|
+
# Not a UCX protocol, just raise CUDA context warnings if needed.
|
|
110
|
+
_warn_generic()
|
|
111
|
+
else:
|
|
112
|
+
if ucx_implementation == "ucxx":
|
|
113
|
+
_initialize_ucxx()
|
|
114
|
+
else:
|
|
115
|
+
_initialize_ucx()
|
|
116
|
+
_warn_generic()
|
|
117
|
+
|
|
118
|
+
|
|
84
119
|
def initialize(
|
|
85
120
|
create_cuda_context=True,
|
|
86
121
|
enable_tcp_over_ucx=None,
|
|
@@ -138,6 +173,7 @@ def initialize(
|
|
|
138
173
|
enable_infiniband=enable_infiniband,
|
|
139
174
|
enable_nvlink=enable_nvlink,
|
|
140
175
|
enable_rdmacm=enable_rdmacm,
|
|
176
|
+
protocol=protocol,
|
|
141
177
|
)
|
|
142
178
|
dask.config.set({"distributed.comm.ucx": ucx_config})
|
|
143
179
|
|
dask_cuda/is_device_object.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# Copyright (c) 2025 NVIDIA CORPORATION.
|
|
1
2
|
from __future__ import absolute_import, division, print_function
|
|
2
3
|
|
|
3
4
|
from dask.utils import Dispatch
|
|
@@ -35,6 +36,8 @@ def register_cudf():
|
|
|
35
36
|
def is_device_object_cudf_series(s):
|
|
36
37
|
return True
|
|
37
38
|
|
|
38
|
-
@is_device_object.register(cudf.
|
|
39
|
+
@is_device_object.register(cudf.Index)
|
|
40
|
+
@is_device_object.register(cudf.RangeIndex)
|
|
41
|
+
@is_device_object.register(cudf.MultiIndex)
|
|
39
42
|
def is_device_object_cudf_index(s):
|
|
40
43
|
return True
|
dask_cuda/is_spillable_object.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# Copyright (c) 2025 NVIDIA CORPORATION.
|
|
1
2
|
from __future__ import absolute_import, division, print_function
|
|
2
3
|
|
|
3
4
|
from typing import Optional
|
|
@@ -34,7 +35,9 @@ def register_cudf():
|
|
|
34
35
|
def is_device_object_cudf_dataframe(df):
|
|
35
36
|
return cudf_spilling_status()
|
|
36
37
|
|
|
37
|
-
@is_spillable_object.register(cudf.
|
|
38
|
+
@is_spillable_object.register(cudf.Index)
|
|
39
|
+
@is_spillable_object.register(cudf.RangeIndex)
|
|
40
|
+
@is_spillable_object.register(cudf.MultiIndex)
|
|
38
41
|
def is_device_object_cudf_index(s):
|
|
39
42
|
return cudf_spilling_status()
|
|
40
43
|
|