dask-cuda 25.4.0__py3-none-any.whl → 25.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. dask_cuda/GIT_COMMIT +1 -1
  2. dask_cuda/VERSION +1 -1
  3. dask_cuda/_compat.py +18 -0
  4. dask_cuda/benchmarks/common.py +4 -1
  5. dask_cuda/benchmarks/local_cudf_groupby.py +4 -1
  6. dask_cuda/benchmarks/local_cudf_merge.py +5 -2
  7. dask_cuda/benchmarks/local_cudf_shuffle.py +5 -2
  8. dask_cuda/benchmarks/local_cupy.py +4 -1
  9. dask_cuda/benchmarks/local_cupy_map_overlap.py +4 -1
  10. dask_cuda/benchmarks/utils.py +7 -4
  11. dask_cuda/cli.py +21 -15
  12. dask_cuda/cuda_worker.py +27 -57
  13. dask_cuda/device_host_file.py +31 -15
  14. dask_cuda/disk_io.py +7 -4
  15. dask_cuda/explicit_comms/comms.py +11 -7
  16. dask_cuda/explicit_comms/dataframe/shuffle.py +147 -55
  17. dask_cuda/get_device_memory_objects.py +18 -3
  18. dask_cuda/initialize.py +80 -44
  19. dask_cuda/is_device_object.py +4 -1
  20. dask_cuda/is_spillable_object.py +4 -1
  21. dask_cuda/local_cuda_cluster.py +63 -66
  22. dask_cuda/plugins.py +17 -16
  23. dask_cuda/proxify_device_objects.py +15 -10
  24. dask_cuda/proxify_host_file.py +30 -27
  25. dask_cuda/proxy_object.py +20 -17
  26. dask_cuda/tests/conftest.py +41 -0
  27. dask_cuda/tests/test_dask_cuda_worker.py +114 -27
  28. dask_cuda/tests/test_dgx.py +10 -18
  29. dask_cuda/tests/test_explicit_comms.py +51 -18
  30. dask_cuda/tests/test_from_array.py +7 -5
  31. dask_cuda/tests/test_initialize.py +16 -37
  32. dask_cuda/tests/test_local_cuda_cluster.py +164 -54
  33. dask_cuda/tests/test_proxify_host_file.py +33 -4
  34. dask_cuda/tests/test_proxy.py +18 -16
  35. dask_cuda/tests/test_rdd_ucx.py +160 -0
  36. dask_cuda/tests/test_spill.py +107 -27
  37. dask_cuda/tests/test_utils.py +106 -20
  38. dask_cuda/tests/test_worker_spec.py +5 -2
  39. dask_cuda/utils.py +319 -68
  40. dask_cuda/utils_test.py +23 -7
  41. dask_cuda/worker_common.py +196 -0
  42. dask_cuda/worker_spec.py +12 -5
  43. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/METADATA +5 -4
  44. dask_cuda-25.8.0.dist-info/RECORD +63 -0
  45. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/WHEEL +1 -1
  46. dask_cuda-25.8.0.dist-info/top_level.txt +6 -0
  47. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  48. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  49. shared-actions/telemetry-impls/summarize/send_trace.py +409 -0
  50. dask_cuda-25.4.0.dist-info/RECORD +0 -56
  51. dask_cuda-25.4.0.dist-info/top_level.txt +0 -5
  52. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/entry_points.txt +0 -0
  53. {dask_cuda-25.4.0.dist-info → dask_cuda-25.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,9 @@
1
+ # Copyright (c) 2021-2025 NVIDIA CORPORATION.
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import asyncio
6
+ import functools
4
7
  from collections import defaultdict
5
8
  from math import ceil
6
9
  from operator import getitem
@@ -23,6 +26,7 @@ from distributed import wait
23
26
  from distributed.protocol import nested_deserialize, to_serialize
24
27
  from distributed.worker import Worker
25
28
 
29
+ from ..._compat import DASK_2025_4_0
26
30
  from .. import comms
27
31
 
28
32
  T = TypeVar("T")
@@ -61,13 +65,13 @@ def get_no_comm_postprocess(
61
65
  ) -> Callable[[DataFrame], DataFrame]:
62
66
  """Get function for post-processing partitions not communicated
63
67
 
64
- In cuDF, the `group_split_dispatch` uses `scatter_by_map` to create
68
+ In cuDF, the ``group_split_dispatch`` uses ``scatter_by_map`` to create
65
69
  the partitions, which is implemented by splitting a single base dataframe
66
70
  into multiple partitions. This means that memory are not freed until
67
71
  ALL partitions are deleted.
68
72
 
69
73
  In order to free memory ASAP, we can deep copy partitions NOT being
70
- communicated. We do this when `num_rounds != batchsize`.
74
+ communicated. We do this when ``num_rounds != batchsize``.
71
75
 
72
76
  Parameters
73
77
  ----------
@@ -112,7 +116,7 @@ async def send(
112
116
  rank_to_out_part_ids: Dict[int, Set[int]],
113
117
  out_part_id_to_dataframe: Dict[int, DataFrame],
114
118
  ) -> None:
115
- """Notice, items sent are removed from `out_part_id_to_dataframe`"""
119
+ """Notice, items sent are removed from ``out_part_id_to_dataframe``"""
116
120
  futures = []
117
121
  for rank, out_part_ids in rank_to_out_part_ids.items():
118
122
  if rank != myrank:
@@ -131,7 +135,7 @@ async def recv(
131
135
  out_part_id_to_dataframe_list: Dict[int, List[DataFrame]],
132
136
  proxify: Proxify,
133
137
  ) -> None:
134
- """Notice, received items are appended to `out_parts_list`"""
138
+ """Notice, received items are appended to ``out_parts_list``"""
135
139
 
136
140
  async def read_msg(rank: int) -> None:
137
141
  msg: Dict[int, DataFrame] = nested_deserialize(await eps[rank].read())
@@ -146,11 +150,11 @@ async def recv(
146
150
  def compute_map_index(
147
151
  df: DataFrame, column_names: List[str], npartitions: int
148
152
  ) -> Series:
149
- """Return a Series that maps each row `df` to a partition ID
153
+ """Return a Series that maps each row ``df`` to a partition ID
150
154
 
151
155
  The partitions are determined by hashing the columns given by column_names
152
- unless if `column_names[0] == "_partitions"`, in which case the values of
153
- `column_names[0]` are used as index.
156
+ unless if ``column_names[0] == "_partitions"``, in which case the values of
157
+ ``column_names[0]`` are used as index.
154
158
 
155
159
  Parameters
156
160
  ----------
@@ -164,7 +168,7 @@ def compute_map_index(
164
168
  Returns
165
169
  -------
166
170
  Series
167
- Series that maps each row `df` to a partition ID
171
+ Series that maps each row ``df`` to a partition ID
168
172
  """
169
173
 
170
174
  if column_names[0] == "_partitions":
@@ -189,8 +193,8 @@ def partition_dataframe(
189
193
  """Partition dataframe to a dict of dataframes
190
194
 
191
195
  The partitions are determined by hashing the columns given by column_names
192
- unless `column_names[0] == "_partitions"`, in which case the values of
193
- `column_names[0]` are used as index.
196
+ unless ``column_names[0] == "_partitions"``, in which case the values of
197
+ ``column_names[0]`` are used as index.
194
198
 
195
199
  Parameters
196
200
  ----------
@@ -297,13 +301,13 @@ async def send_recv_partitions(
297
301
  rank_to_out_part_ids
298
302
  dict that for each worker rank specifies a set of output partition IDs.
299
303
  If the worker shouldn't return any partitions, it is excluded from the
300
- dict. Partition IDs are global integers `0..npartitions` and corresponds
301
- to the dict keys returned by `group_split_dispatch`.
304
+ dict. Partition IDs are global integers ``0..npartitions`` and corresponds
305
+ to the dict keys returned by ``group_split_dispatch``.
302
306
  out_part_id_to_dataframe
303
307
  Mapping from partition ID to dataframe. This dict is cleared on return.
304
308
  no_comm_postprocess
305
309
  Function to post-process partitions not communicated.
306
- See `get_no_comm_postprocess`
310
+ See ``get_no_comm_postprocess``
307
311
  proxify
308
312
  Function to proxify object.
309
313
  out_part_id_to_dataframe_list
@@ -361,8 +365,8 @@ async def shuffle_task(
361
365
  rank_to_out_part_ids: dict
362
366
  dict that for each worker rank specifies a set of output partition IDs.
363
367
  If the worker shouldn't return any partitions, it is excluded from the
364
- dict. Partition IDs are global integers `0..npartitions` and corresponds
365
- to the dict keys returned by `group_split_dispatch`.
368
+ dict. Partition IDs are global integers ``0..npartitions`` and corresponds
369
+ to the dict keys returned by ``group_split_dispatch``.
366
370
  column_names: list of strings
367
371
  List of column names on which we want to split.
368
372
  npartitions: int
@@ -445,7 +449,7 @@ def shuffle(
445
449
  List of column names on which we want to split.
446
450
  npartitions: int or None
447
451
  The desired number of output partitions. If None, the number of output
448
- partitions equals `df.npartitions`
452
+ partitions equals ``df.npartitions``
449
453
  ignore_index: bool
450
454
  Ignore index during shuffle. If True, performance may improve,
451
455
  but index values will not be preserved.
@@ -456,7 +460,7 @@ def shuffle(
456
460
  If -1, each worker will handle all its partitions in a single round and
457
461
  all techniques to reduce memory usage are disabled, which might be faster
458
462
  when memory pressure isn't an issue.
459
- If None, the value of `DASK_EXPLICIT_COMMS_BATCHSIZE` is used or 1 if not
463
+ If None, the value of ``DASK_EXPLICIT_COMMS_BATCHSIZE`` is used or 1 if not
460
464
  set thus by default, we prioritize robustness over performance.
461
465
 
462
466
  Returns
@@ -467,12 +471,12 @@ def shuffle(
467
471
  Developer Notes
468
472
  ---------------
469
473
  The implementation consist of three steps:
470
- (a) Stage the partitions of `df` on all workers and then cancel them
474
+ (a) Stage the partitions of ``df`` on all workers and then cancel them
471
475
  thus at this point the Dask Scheduler doesn't know about any of the
472
476
  the partitions.
473
477
  (b) Submit a task on each worker that shuffle (all-to-all communicate)
474
478
  the staged partitions and return a list of dataframe-partitions.
475
- (c) Submit a dask graph that extract (using `getitem()`) individual
479
+ (c) Submit a dask graph that extract (using ``getitem()``) individual
476
480
  dataframe-partitions from (b).
477
481
  """
478
482
  c = comms.default_comms()
@@ -582,48 +586,136 @@ def _use_explicit_comms() -> bool:
582
586
  return False
583
587
 
584
588
 
585
- def patch_shuffle_expression() -> None:
586
- """Patch Dasks Shuffle expression.
589
+ _base_lower = dask_expr._shuffle.Shuffle._lower
590
+ _base_compute = dask.base.compute
587
591
 
588
- Notice, this is monkey patched into Dask at dask_cuda
589
- import, and it changes `Shuffle._layer` to lower into
590
- an `ECShuffle` expression when the 'explicit-comms'
591
- config is set to `True`.
592
+
593
+ def _contains_shuffle_expr(*args) -> bool:
592
594
  """
595
+ Check whether any of the arguments is a Shuffle expression.
593
596
 
594
- class ECShuffle(dask_expr._shuffle.TaskShuffle):
595
- """Explicit-Comms Shuffle Expression."""
596
-
597
- def _layer(self):
598
- # Execute an explicit-comms shuffle
599
- if not hasattr(self, "_ec_shuffled"):
600
- on = self.partitioning_index
601
- df = dask_expr.new_collection(self.frame)
602
- self._ec_shuffled = shuffle(
603
- df,
604
- [on] if isinstance(on, str) else on,
605
- self.npartitions_out,
606
- self.ignore_index,
597
+ This is called by ``compute``, which is given a sequence of Dask Collections
598
+ to process. For each of those, we'll check whether the expresion contains a
599
+ Shuffle operation.
600
+ """
601
+ for collection in args:
602
+ if isinstance(collection, dask.dataframe.DataFrame):
603
+ shuffle_ops = list(
604
+ collection.expr.find_operations(
605
+ (
606
+ dask_expr._shuffle.RearrangeByColumn,
607
+ dask_expr.SetIndex,
608
+ dask_expr._shuffle.Shuffle,
609
+ )
607
610
  )
608
- graph = self._ec_shuffled.dask.copy()
609
- shuffled_name = self._ec_shuffled._name
610
- for i in range(self.npartitions_out):
611
- graph[(self._name, i)] = graph[(shuffled_name, i)]
612
- return graph
613
-
614
- _base_lower = dask_expr._shuffle.Shuffle._lower
615
-
616
- def _patched_lower(self):
617
- if self.method in (None, "tasks") and _use_explicit_comms():
618
- return ECShuffle(
619
- self.frame,
620
- self.partitioning_index,
611
+ )
612
+ if len(shuffle_ops) > 0:
613
+ return True
614
+ return False
615
+
616
+
617
+ @functools.wraps(_base_compute)
618
+ def _patched_compute(
619
+ *args,
620
+ traverse=True,
621
+ optimize_graph=True,
622
+ scheduler=None,
623
+ get=None,
624
+ **kwargs,
625
+ ):
626
+ # A patched version of dask.compute that explicitly materializes the task
627
+ # graph when we're using explicit-comms and the expression contains a
628
+ # Shuffle operation.
629
+ # https://github.com/rapidsai/dask-upstream-testing/issues/37#issuecomment-2779798670
630
+ # contains more details on the issue.
631
+ if DASK_2025_4_0() and _use_explicit_comms() and _contains_shuffle_expr(*args):
632
+ from dask.base import (
633
+ collections_to_expr,
634
+ flatten,
635
+ get_scheduler,
636
+ shorten_traceback,
637
+ unpack_collections,
638
+ )
639
+
640
+ collections, repack = unpack_collections(*args, traverse=traverse)
641
+ if not collections:
642
+ return args
643
+
644
+ schedule = get_scheduler(
645
+ scheduler=scheduler,
646
+ collections=collections,
647
+ get=get,
648
+ )
649
+ from dask._expr import FinalizeCompute
650
+
651
+ expr = collections_to_expr(collections, optimize_graph)
652
+ expr = FinalizeCompute(expr)
653
+
654
+ with shorten_traceback():
655
+ expr = expr.optimize()
656
+ keys = list(flatten(expr.__dask_keys__()))
657
+
658
+ # materialize the HLG here
659
+ expr = dict(expr.__dask_graph__())
660
+
661
+ results = schedule(expr, keys, **kwargs)
662
+ return repack(results)
663
+
664
+ else:
665
+ return _base_compute(
666
+ *args,
667
+ traverse=traverse,
668
+ optimize_graph=optimize_graph,
669
+ scheduler=scheduler,
670
+ get=get,
671
+ **kwargs,
672
+ )
673
+
674
+
675
+ class ECShuffle(dask_expr._shuffle.TaskShuffle):
676
+ """Explicit-Comms Shuffle Expression."""
677
+
678
+ def _layer(self):
679
+ # Execute an explicit-comms shuffle
680
+ if not hasattr(self, "_ec_shuffled"):
681
+ on = self.partitioning_index
682
+ df = dask_expr.new_collection(self.frame)
683
+ ec_shuffled = shuffle(
684
+ df,
685
+ [on] if isinstance(on, str) else on,
621
686
  self.npartitions_out,
622
687
  self.ignore_index,
623
- self.options,
624
- self.original_partitioning_index,
625
688
  )
626
- else:
627
- return _base_lower(self)
689
+ object.__setattr__(self, "_ec_shuffled", ec_shuffled)
690
+ graph = self._ec_shuffled.dask.copy()
691
+ shuffled_name = self._ec_shuffled._name
692
+ for i in range(self.npartitions_out):
693
+ graph[(self._name, i)] = graph[(shuffled_name, i)]
694
+ return graph
695
+
696
+
697
+ def _patched_lower(self):
698
+ if self.method in (None, "tasks") and _use_explicit_comms():
699
+ return ECShuffle(
700
+ self.frame,
701
+ self.partitioning_index,
702
+ self.npartitions_out,
703
+ self.ignore_index,
704
+ self.options,
705
+ self.original_partitioning_index,
706
+ )
707
+ else:
708
+ return _base_lower(self)
709
+
710
+
711
+ def patch_shuffle_expression() -> None:
712
+ """Patch Dasks Shuffle expression.
713
+
714
+ Notice, this is monkey patched into Dask at dask_cuda
715
+ import, and it changes ``Shuffle._layer`` to lower into
716
+ an ``ECShuffle`` expression when the 'explicit-comms'
717
+ config is set to ``True``.
718
+ """
719
+ dask.base.compute = _patched_compute
628
720
 
629
721
  dask_expr._shuffle.Shuffle._lower = _patched_lower
@@ -1,3 +1,5 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
+
1
3
  from typing import Set
2
4
 
3
5
  from dask.sizeof import sizeof
@@ -25,10 +27,10 @@ class DeviceMemoryId:
25
27
 
26
28
 
27
29
  def get_device_memory_ids(obj) -> Set[DeviceMemoryId]:
28
- """Find all CUDA device objects in `obj`
30
+ """Find all CUDA device objects in ``obj``
29
31
 
30
- Search through `obj` and find all CUDA device objects, which are objects
31
- that either are known to `dispatch` or implement `__cuda_array_interface__`.
32
+ Search through ``obj`` and find all CUDA device objects, which are objects
33
+ that either are known to ``dispatch`` or implement ``__cuda_array_interface__``.
32
34
 
33
35
  Parameters
34
36
  ----------
@@ -140,3 +142,16 @@ def register_cupy(): # NB: this overwrites dask.sizeof.register_cupy()
140
142
  @sizeof.register(cupy.ndarray)
141
143
  def sizeof_cupy_ndarray(x):
142
144
  return int(x.nbytes)
145
+
146
+
147
+ @sizeof.register_lazy("pylibcudf")
148
+ def register_pylibcudf():
149
+ import pylibcudf
150
+
151
+ @sizeof.register(pylibcudf.column.OwnerWithCAI)
152
+ def sizeof_owner_with_cai(x):
153
+ # OwnerWithCAI implements __cuda_array_interface__ so this should always
154
+ # be zero-copy
155
+ col = pylibcudf.column.Column.from_cuda_array_interface(x)
156
+ # col.data() returns a gpumemoryview, which knows the size in bytes
157
+ return col.data().nbytes
dask_cuda/initialize.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import logging
2
5
  import os
3
6
 
@@ -7,7 +10,7 @@ import numba.cuda
7
10
  import dask
8
11
  from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context
9
12
 
10
- from .utils import get_ucx_config
13
+ from .utils import _get_active_ucx_implementation_name, get_ucx_config
11
14
 
12
15
  logger = logging.getLogger(__name__)
13
16
 
@@ -22,65 +25,97 @@ def _create_cuda_context_handler():
22
25
  numba.cuda.current_context()
23
26
 
24
27
 
25
- def _create_cuda_context(protocol="ucx"):
26
- if protocol not in ["ucx", "ucxx"]:
27
- return
28
+ def _warn_generic():
28
29
  try:
30
+ # TODO: update when UCX-Py is removed, see
31
+ # https://github.com/rapidsai/dask-cuda/issues/1517
32
+ import distributed.comm.ucx
33
+
29
34
  # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
30
35
  # context directly from the UCX module, thus avoiding a similar warning there.
31
- try:
32
- if protocol == "ucx":
33
- import distributed.comm.ucx
34
-
35
- distributed.comm.ucx.init_once()
36
- elif protocol == "ucxx":
37
- import distributed_ucxx.ucxx
38
-
39
- distributed_ucxx.ucxx.init_once()
40
- except ModuleNotFoundError:
41
- # UCX initialization has to be delegated to Distributed, it will take care
42
- # of setting correct environment variables and importing `ucp` after that.
43
- # Therefore if ``import ucp`` fails we can just continue here.
44
- pass
36
+ cuda_visible_device = get_device_index_and_uuid(
37
+ os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
38
+ )
39
+ ctx = has_cuda_context()
40
+ if (
41
+ ctx.has_context
42
+ and not distributed.comm.ucx.cuda_context_created.has_context
43
+ ):
44
+ distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
45
+
46
+ _create_cuda_context_handler()
47
+
48
+ if not distributed.comm.ucx.cuda_context_created.has_context:
49
+ ctx = has_cuda_context()
50
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
51
+ distributed.comm.ucx._warn_cuda_context_wrong_device(
52
+ cuda_visible_device, ctx.device_info, os.getpid()
53
+ )
54
+
55
+ except Exception:
56
+ logger.error("Unable to start CUDA Context", exc_info=True)
57
+
58
+
59
+ def _initialize_ucx():
60
+ try:
61
+ import distributed.comm.ucx
62
+
63
+ distributed.comm.ucx.init_once()
64
+ except ModuleNotFoundError:
65
+ # UCX initialization has to be delegated to Distributed, it will take care
66
+ # of setting correct environment variables and importing `ucp` after that.
67
+ # Therefore if ``import ucp`` fails we can just continue here.
68
+ pass
69
+
70
+
71
+ def _initialize_ucxx():
72
+ try:
73
+ # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
74
+ # context directly from the UCX module, thus avoiding a similar warning there.
75
+ import distributed_ucxx.ucxx
76
+
77
+ distributed_ucxx.ucxx.init_once()
45
78
 
46
79
  cuda_visible_device = get_device_index_and_uuid(
47
80
  os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
48
81
  )
49
82
  ctx = has_cuda_context()
50
- if protocol == "ucx":
51
- if (
52
- ctx.has_context
53
- and not distributed.comm.ucx.cuda_context_created.has_context
54
- ):
55
- distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
56
- elif protocol == "ucxx":
57
- if (
58
- ctx.has_context
59
- and not distributed_ucxx.ucxx.cuda_context_created.has_context
60
- ):
61
- distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
83
+ if (
84
+ ctx.has_context
85
+ and not distributed_ucxx.ucxx.cuda_context_created.has_context
86
+ ):
87
+ distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())
62
88
 
63
89
  _create_cuda_context_handler()
64
90
 
65
- if protocol == "ucx":
66
- if not distributed.comm.ucx.cuda_context_created.has_context:
67
- ctx = has_cuda_context()
68
- if ctx.has_context and ctx.device_info != cuda_visible_device:
69
- distributed.comm.ucx._warn_cuda_context_wrong_device(
70
- cuda_visible_device, ctx.device_info, os.getpid()
71
- )
72
- elif protocol == "ucxx":
73
- if not distributed_ucxx.ucxx.cuda_context_created.has_context:
74
- ctx = has_cuda_context()
75
- if ctx.has_context and ctx.device_info != cuda_visible_device:
76
- distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
77
- cuda_visible_device, ctx.device_info, os.getpid()
78
- )
91
+ if not distributed_ucxx.ucxx.cuda_context_created.has_context:
92
+ ctx = has_cuda_context()
93
+ if ctx.has_context and ctx.device_info != cuda_visible_device:
94
+ distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
95
+ cuda_visible_device, ctx.device_info, os.getpid()
96
+ )
79
97
 
80
98
  except Exception:
81
99
  logger.error("Unable to start CUDA Context", exc_info=True)
82
100
 
83
101
 
102
+ def _create_cuda_context(protocol="ucx"):
103
+ if protocol not in ["ucx", "ucxx", "ucx-old"]:
104
+ return
105
+
106
+ try:
107
+ ucx_implementation = _get_active_ucx_implementation_name(protocol)
108
+ except ValueError:
109
+ # Not a UCX protocol, just raise CUDA context warnings if needed.
110
+ _warn_generic()
111
+ else:
112
+ if ucx_implementation == "ucxx":
113
+ _initialize_ucxx()
114
+ else:
115
+ _initialize_ucx()
116
+ _warn_generic()
117
+
118
+
84
119
  def initialize(
85
120
  create_cuda_context=True,
86
121
  enable_tcp_over_ucx=None,
@@ -138,6 +173,7 @@ def initialize(
138
173
  enable_infiniband=enable_infiniband,
139
174
  enable_nvlink=enable_nvlink,
140
175
  enable_rdmacm=enable_rdmacm,
176
+ protocol=protocol,
141
177
  )
142
178
  dask.config.set({"distributed.comm.ucx": ucx_config})
143
179
 
@@ -1,3 +1,4 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
1
2
  from __future__ import absolute_import, division, print_function
2
3
 
3
4
  from dask.utils import Dispatch
@@ -35,6 +36,8 @@ def register_cudf():
35
36
  def is_device_object_cudf_series(s):
36
37
  return True
37
38
 
38
- @is_device_object.register(cudf.BaseIndex)
39
+ @is_device_object.register(cudf.Index)
40
+ @is_device_object.register(cudf.RangeIndex)
41
+ @is_device_object.register(cudf.MultiIndex)
39
42
  def is_device_object_cudf_index(s):
40
43
  return True
@@ -1,3 +1,4 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION.
1
2
  from __future__ import absolute_import, division, print_function
2
3
 
3
4
  from typing import Optional
@@ -34,7 +35,9 @@ def register_cudf():
34
35
  def is_device_object_cudf_dataframe(df):
35
36
  return cudf_spilling_status()
36
37
 
37
- @is_spillable_object.register(cudf.BaseIndex)
38
+ @is_spillable_object.register(cudf.Index)
39
+ @is_spillable_object.register(cudf.RangeIndex)
40
+ @is_spillable_object.register(cudf.MultiIndex)
38
41
  def is_device_object_cudf_index(s):
39
42
  return cudf_spilling_status()
40
43