cudf-polars-cu13 25.12.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +28 -7
  4. cudf_polars/containers/column.py +51 -26
  5. cudf_polars/dsl/expressions/binaryop.py +1 -1
  6. cudf_polars/dsl/expressions/boolean.py +1 -1
  7. cudf_polars/dsl/expressions/selection.py +1 -1
  8. cudf_polars/dsl/expressions/string.py +29 -20
  9. cudf_polars/dsl/expressions/ternary.py +25 -1
  10. cudf_polars/dsl/expressions/unary.py +11 -8
  11. cudf_polars/dsl/ir.py +351 -281
  12. cudf_polars/dsl/translate.py +18 -15
  13. cudf_polars/dsl/utils/aggregations.py +10 -5
  14. cudf_polars/experimental/base.py +10 -0
  15. cudf_polars/experimental/benchmarks/pdsh.py +1 -1
  16. cudf_polars/experimental/benchmarks/utils.py +83 -2
  17. cudf_polars/experimental/distinct.py +2 -0
  18. cudf_polars/experimental/explain.py +1 -1
  19. cudf_polars/experimental/expressions.py +8 -5
  20. cudf_polars/experimental/groupby.py +2 -0
  21. cudf_polars/experimental/io.py +64 -42
  22. cudf_polars/experimental/join.py +15 -2
  23. cudf_polars/experimental/parallel.py +10 -7
  24. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  25. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  26. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  27. cudf_polars/experimental/rapidsmpf/{shuffle.py → collectives/shuffle.py} +90 -114
  28. cudf_polars/experimental/rapidsmpf/core.py +194 -67
  29. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  30. cudf_polars/experimental/rapidsmpf/dispatch.py +6 -3
  31. cudf_polars/experimental/rapidsmpf/io.py +162 -70
  32. cudf_polars/experimental/rapidsmpf/join.py +162 -77
  33. cudf_polars/experimental/rapidsmpf/nodes.py +421 -180
  34. cudf_polars/experimental/rapidsmpf/repartition.py +130 -65
  35. cudf_polars/experimental/rapidsmpf/union.py +24 -5
  36. cudf_polars/experimental/rapidsmpf/utils.py +228 -16
  37. cudf_polars/experimental/shuffle.py +18 -4
  38. cudf_polars/experimental/sort.py +13 -6
  39. cudf_polars/experimental/spilling.py +1 -1
  40. cudf_polars/testing/plugin.py +6 -3
  41. cudf_polars/utils/config.py +67 -0
  42. cudf_polars/utils/versions.py +3 -3
  43. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/METADATA +9 -10
  44. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/RECORD +47 -43
  45. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  46. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  47. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
21
21
  from cudf_polars.dsl.expr import NamedExpr
22
22
  from cudf_polars.dsl.ir import IR, IRExecutionContext
23
23
  from cudf_polars.experimental.parallel import LowerIRTransformer
24
- from cudf_polars.utils.config import ShuffleMethod
24
+ from cudf_polars.utils.config import ShuffleMethod, ShufflerInsertionMethod
25
25
 
26
26
 
27
27
  def _maybe_shuffle_frame(
@@ -30,6 +30,8 @@ def _maybe_shuffle_frame(
30
30
  partition_info: MutableMapping[IR, PartitionInfo],
31
31
  shuffle_method: ShuffleMethod,
32
32
  output_count: int,
33
+ *,
34
+ shuffler_insertion_method: ShufflerInsertionMethod,
33
35
  ) -> IR:
34
36
  # Shuffle `frame` if it isn't already shuffled.
35
37
  if (
@@ -44,6 +46,7 @@ def _maybe_shuffle_frame(
44
46
  frame.schema,
45
47
  on,
46
48
  shuffle_method,
49
+ shuffler_insertion_method,
47
50
  frame,
48
51
  )
49
52
  partition_info[frame] = PartitionInfo(
@@ -60,6 +63,8 @@ def _make_hash_join(
60
63
  left: IR,
61
64
  right: IR,
62
65
  shuffle_method: ShuffleMethod,
66
+ *,
67
+ shuffler_insertion_method: ShufflerInsertionMethod,
63
68
  ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
64
69
  # Shuffle left and right dataframes (if necessary)
65
70
  new_left = _maybe_shuffle_frame(
@@ -68,6 +73,7 @@ def _make_hash_join(
68
73
  partition_info,
69
74
  shuffle_method,
70
75
  output_count,
76
+ shuffler_insertion_method=shuffler_insertion_method,
71
77
  )
72
78
  new_right = _maybe_shuffle_frame(
73
79
  right,
@@ -75,6 +81,7 @@ def _make_hash_join(
75
81
  partition_info,
76
82
  shuffle_method,
77
83
  output_count,
84
+ shuffler_insertion_method=shuffler_insertion_method,
78
85
  )
79
86
  if left != new_left or right != new_right:
80
87
  ir = ir.reconstruct([new_left, new_right])
@@ -144,7 +151,9 @@ def _make_bcast_join(
144
151
  left: IR,
145
152
  right: IR,
146
153
  shuffle_method: ShuffleMethod,
154
+ *,
147
155
  streaming_runtime: str,
156
+ shuffler_insertion_method: ShufflerInsertionMethod,
148
157
  ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
149
158
  if ir.options[0] != "Inner":
150
159
  left_count = partition_info[left].count
@@ -171,6 +180,7 @@ def _make_bcast_join(
171
180
  partition_info,
172
181
  shuffle_method,
173
182
  right_count,
183
+ shuffler_insertion_method=shuffler_insertion_method,
174
184
  )
175
185
  else:
176
186
  left = _maybe_shuffle_frame(
@@ -179,6 +189,7 @@ def _make_bcast_join(
179
189
  partition_info,
180
190
  shuffle_method,
181
191
  left_count,
192
+ shuffler_insertion_method=shuffler_insertion_method,
182
193
  )
183
194
 
184
195
  new_node = ir.reconstruct([left, right])
@@ -290,7 +301,8 @@ def _(
290
301
  left,
291
302
  right,
292
303
  config_options.executor.shuffle_method,
293
- config_options.executor.runtime,
304
+ streaming_runtime=config_options.executor.runtime,
305
+ shuffler_insertion_method=config_options.executor.shuffler_insertion_method,
294
306
  )
295
307
  else:
296
308
  # Create a hash join
@@ -301,6 +313,7 @@ def _(
301
313
  left,
302
314
  right,
303
315
  config_options.executor.shuffle_method,
316
+ shuffler_insertion_method=config_options.executor.shuffler_insertion_method,
304
317
  )
305
318
 
306
319
 
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
45
45
 
46
46
  import polars as pl
47
47
 
48
+ from cudf_polars.experimental.base import StatsCollector
48
49
  from cudf_polars.experimental.dispatch import LowerIRTransformer, State
49
50
  from cudf_polars.utils.config import ConfigOptions
50
51
 
@@ -61,7 +62,7 @@ def _(
61
62
 
62
63
  def lower_ir_graph(
63
64
  ir: IR, config_options: ConfigOptions
64
- ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
65
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo], StatsCollector]:
65
66
  """
66
67
  Rewrite an IR graph and extract partitioning information.
67
68
 
@@ -74,9 +75,10 @@ def lower_ir_graph(
74
75
 
75
76
  Returns
76
77
  -------
77
- new_ir, partition_info
78
- The rewritten graph, and a mapping from unique nodes
79
- in the new graph to associated partitioning information.
78
+ new_ir, partition_info, stats
79
+ The rewritten graph, a mapping from unique nodes
80
+ in the new graph to associated partitioning information,
81
+ and the statistics collector.
80
82
 
81
83
  Notes
82
84
  -----
@@ -92,7 +94,7 @@ def lower_ir_graph(
92
94
  "stats": collect_statistics(ir, config_options),
93
95
  }
94
96
  mapper: LowerIRTransformer = CachingVisitor(lower_ir_node, state=state)
95
- return mapper(ir)
97
+ return *mapper(ir), state["stats"]
96
98
 
97
99
 
98
100
  def task_graph(
@@ -245,7 +247,8 @@ def evaluate_rapidsmpf(
245
247
  """
246
248
  from cudf_polars.experimental.rapidsmpf.core import evaluate_logical_plan
247
249
 
248
- return evaluate_logical_plan(ir, config_options)
250
+ result, _ = evaluate_logical_plan(ir, config_options, collect_metadata=False)
251
+ return result
249
252
 
250
253
 
251
254
  def evaluate_streaming(
@@ -277,7 +280,7 @@ def evaluate_streaming(
277
280
  return evaluate_rapidsmpf(ir, config_options)
278
281
  else:
279
282
  # Using the default task engine.
280
- ir, partition_info = lower_ir_graph(ir, config_options)
283
+ ir, partition_info, _ = lower_ir_graph(ir, config_options)
281
284
 
282
285
  graph, key = task_graph(ir, partition_info, config_options)
283
286
 
@@ -0,0 +1,9 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Collective operations for the RapidsMPF streaming runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from cudf_polars.experimental.rapidsmpf.collectives.common import ReserveOpIDs
8
+
9
+ __all__ = ["ReserveOpIDs"]
@@ -0,0 +1,90 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """AllGather logic for the RapidsMPF streaming runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import asyncio
8
+ from typing import TYPE_CHECKING
9
+
10
+ from rapidsmpf.integrations.cudf.partition import unpack_and_concat
11
+ from rapidsmpf.memory.packed_data import PackedData
12
+ from rapidsmpf.streaming.coll.allgather import AllGather
13
+
14
+ from pylibcudf.contiguous_split import pack
15
+
16
+ if TYPE_CHECKING:
17
+ from rapidsmpf.streaming.core.context import Context
18
+ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
19
+
20
+ import pylibcudf as plc
21
+ from rmm.pylibrmm.stream import Stream
22
+
23
+
24
+ class AllGatherManager:
25
+ """
26
+ AllGather manager.
27
+
28
+ Parameters
29
+ ----------
30
+ context: Context
31
+ The streaming context.
32
+ op_id: int
33
+ Pre-allocated operation ID for this operation.
34
+ """
35
+
36
+ def __init__(self, context: Context, op_id: int):
37
+ self.context = context
38
+ self.allgather = AllGather(self.context, op_id)
39
+
40
+ def insert(self, sequence_number: int, chunk: TableChunk) -> None:
41
+ """
42
+ Insert a chunk into the AllGatherContext.
43
+
44
+ Parameters
45
+ ----------
46
+ sequence_number: int
47
+ The sequence number of the chunk to insert.
48
+ chunk: TableChunk
49
+ The table chunk to insert.
50
+ """
51
+ self.allgather.insert(
52
+ sequence_number,
53
+ PackedData.from_cudf_packed_columns(
54
+ pack(
55
+ chunk.table_view(),
56
+ chunk.stream,
57
+ ),
58
+ chunk.stream,
59
+ self.context.br(),
60
+ ),
61
+ )
62
+ del chunk
63
+
64
+ def insert_finished(self) -> None:
65
+ """Insert finished into the AllGatherManager."""
66
+ self.allgather.insert_finished()
67
+
68
+ async def extract_concatenated(
69
+ self, stream: Stream, *, ordered: bool = True
70
+ ) -> plc.Table:
71
+ """
72
+ Extract the concatenated result.
73
+
74
+ Parameters
75
+ ----------
76
+ stream: Stream
77
+ The stream to use for chunk extraction.
78
+ ordered: bool
79
+ Whether to extract the data in ordered or unordered fashion.
80
+
81
+ Returns
82
+ -------
83
+ The concatenated AllGather result.
84
+ """
85
+ return await asyncio.to_thread(
86
+ unpack_and_concat,
87
+ partitions=await self.allgather.extract_all(self.context, ordered=ordered),
88
+ stream=stream,
89
+ br=self.context.br(),
90
+ )
@@ -0,0 +1,96 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Common utilities for collective operations."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import threading
8
+ from typing import TYPE_CHECKING, Literal
9
+
10
+ from rapidsmpf.shuffler import Shuffler
11
+
12
+ from cudf_polars.dsl.traversal import traversal
13
+ from cudf_polars.experimental.join import Join
14
+ from cudf_polars.experimental.repartition import Repartition
15
+ from cudf_polars.experimental.shuffle import Shuffle
16
+
17
+ if TYPE_CHECKING:
18
+ from types import TracebackType
19
+
20
+ from cudf_polars.dsl.ir import IR
21
+
22
+
23
+ # Set of available collective IDs
24
+ _collective_id_vacancy: set[int] = set(range(Shuffler.max_concurrent_shuffles))
25
+ _collective_id_vacancy_lock: threading.Lock = threading.Lock()
26
+
27
+
28
+ def _get_new_collective_id() -> int:
29
+ with _collective_id_vacancy_lock:
30
+ if not _collective_id_vacancy:
31
+ raise ValueError(
32
+ f"Cannot shuffle more than {Shuffler.max_concurrent_shuffles} "
33
+ "times in a single query."
34
+ )
35
+
36
+ return _collective_id_vacancy.pop()
37
+
38
+
39
+ def _release_collective_id(collective_id: int) -> None:
40
+ """Release a collective ID back to the vacancy set."""
41
+ with _collective_id_vacancy_lock:
42
+ _collective_id_vacancy.add(collective_id)
43
+
44
+
45
+ class ReserveOpIDs:
46
+ """
47
+ Context manager to reserve collective IDs for pipeline execution.
48
+
49
+ Parameters
50
+ ----------
51
+ ir : IR
52
+ The root IR node of the pipeline.
53
+
54
+ Notes
55
+ -----
56
+ This context manager:
57
+ 1. Identifies all Shuffle nodes in the IR
58
+ 2. Reserves collective IDs from the vacancy pool
59
+ 3. Creates a mapping from IR nodes to their reserved IDs
60
+ 4. Releases all IDs back to the pool on __exit__
61
+ """
62
+
63
+ def __init__(self, ir: IR):
64
+ # Find all collective IR nodes.
65
+ self.collective_nodes: list[IR] = [
66
+ node
67
+ for node in traversal([ir])
68
+ if isinstance(node, (Shuffle, Join, Repartition))
69
+ ]
70
+ self.collective_id_map: dict[IR, int] = {}
71
+
72
+ def __enter__(self) -> dict[IR, int]:
73
+ """
74
+ Reserve collective IDs and return the mapping.
75
+
76
+ Returns
77
+ -------
78
+ collective_id_map : dict[IR, int]
79
+ Mapping from IR nodes to their reserved collective IDs.
80
+ """
81
+ # Reserve IDs and map nodes directly to their IDs
82
+ for node in self.collective_nodes:
83
+ self.collective_id_map[node] = _get_new_collective_id()
84
+
85
+ return self.collective_id_map
86
+
87
+ def __exit__(
88
+ self,
89
+ exc_type: type | None,
90
+ exc_val: Exception | None,
91
+ exc_tb: TracebackType | None,
92
+ ) -> Literal[False]:
93
+ """Release all reserved collective IDs back to the vacancy pool."""
94
+ for collective_id in self.collective_id_map.values():
95
+ _release_collective_id(collective_id)
96
+ return False
@@ -1,20 +1,16 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Shuffle logic for the RapidsMPF streaming runtime."""
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- import threading
8
- from typing import TYPE_CHECKING, Any, Literal
7
+ from typing import TYPE_CHECKING, Any
9
8
 
10
- from rapidsmpf.communicator.single import new_communicator
11
- from rapidsmpf.config import Options, get_environment_variables
12
9
  from rapidsmpf.integrations.cudf.partition import (
13
10
  partition_and_pack as py_partition_and_pack,
14
11
  unpack_and_concat as py_unpack_and_concat,
15
12
  )
16
- from rapidsmpf.progress_thread import ProgressThread
17
- from rapidsmpf.shuffler import Shuffler
13
+ from rapidsmpf.streaming.coll.shuffler import ShufflerAsync
18
14
  from rapidsmpf.streaming.core.message import Message
19
15
  from rapidsmpf.streaming.core.node import define_py_node
20
16
  from rapidsmpf.streaming.cudf.table_chunk import TableChunk
@@ -24,12 +20,13 @@ from cudf_polars.experimental.rapidsmpf.dispatch import (
24
20
  generate_ir_sub_network,
25
21
  )
26
22
  from cudf_polars.experimental.rapidsmpf.nodes import shutdown_on_error
27
- from cudf_polars.experimental.rapidsmpf.utils import ChannelManager
23
+ from cudf_polars.experimental.rapidsmpf.utils import (
24
+ ChannelManager,
25
+ Metadata,
26
+ )
28
27
  from cudf_polars.experimental.shuffle import Shuffle
29
28
 
30
29
  if TYPE_CHECKING:
31
- from types import TracebackType
32
-
33
30
  from rapidsmpf.streaming.core.context import Context
34
31
 
35
32
  import pylibcudf as plc
@@ -40,36 +37,9 @@ if TYPE_CHECKING:
40
37
  from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
41
38
 
42
39
 
43
- # TODO: This implementation only supports a single GPU for now.
44
- # Multi-GPU support will require a distinct GlobalShuffle
45
- # context manager, and updated _shuffle_id_vacancy logic.
46
-
47
-
48
- # Set of available shuffle IDs
49
- _shuffle_id_vacancy: set[int] = set(range(Shuffler.max_concurrent_shuffles))
50
- _shuffle_id_vacancy_lock: threading.Lock = threading.Lock()
51
-
52
-
53
- def _get_new_shuffle_id() -> int:
54
- with _shuffle_id_vacancy_lock:
55
- if not _shuffle_id_vacancy:
56
- raise ValueError(
57
- f"Cannot shuffle more than {Shuffler.max_concurrent_shuffles} "
58
- "times in a single query."
59
- )
60
-
61
- return _shuffle_id_vacancy.pop()
62
-
63
-
64
- def _release_shuffle_id(op_id: int) -> None:
65
- """Release a shuffle ID back to the vacancy set."""
66
- with _shuffle_id_vacancy_lock:
67
- _shuffle_id_vacancy.add(op_id)
68
-
69
-
70
- class LocalShuffle:
40
+ class ShuffleManager:
71
41
  """
72
- Local shuffle instance context manager.
42
+ ShufflerAsync manager.
73
43
 
74
44
  Parameters
75
45
  ----------
@@ -79,6 +49,8 @@ class LocalShuffle:
79
49
  The number of partitions to shuffle into.
80
50
  columns_to_hash: tuple[int, ...]
81
51
  The columns to hash.
52
+ collective_id: int
53
+ The collective ID.
82
54
  """
83
55
 
84
56
  def __init__(
@@ -86,43 +58,20 @@ class LocalShuffle:
86
58
  context: Context,
87
59
  num_partitions: int,
88
60
  columns_to_hash: tuple[int, ...],
61
+ collective_id: int,
89
62
  ):
90
63
  self.context = context
91
- self.br = context.br()
92
64
  self.num_partitions = num_partitions
93
65
  self.columns_to_hash = columns_to_hash
94
- self._insertion_finished = False
95
-
96
- def __enter__(self) -> LocalShuffle:
97
- """Enter the local shuffle instance context manager."""
98
- self.op_id = _get_new_shuffle_id()
99
- statistics = self.context.statistics()
100
- comm = new_communicator(Options(get_environment_variables()))
101
- progress_thread = ProgressThread(comm, statistics)
102
- self.shuffler = Shuffler(
103
- comm=comm,
104
- progress_thread=progress_thread,
105
- op_id=self.op_id,
106
- total_num_partitions=self.num_partitions,
107
- br=self.br,
108
- statistics=statistics,
66
+ self.shuffler = ShufflerAsync(
67
+ context,
68
+ collective_id,
69
+ num_partitions,
109
70
  )
110
- return self
111
-
112
- def __exit__(
113
- self,
114
- exc_type: type | None,
115
- exc_val: Exception | None,
116
- exc_tb: TracebackType | None,
117
- ) -> Literal[False]:
118
- """Exit the local shuffle instance context manager."""
119
- self.shuffler.shutdown()
120
- _release_shuffle_id(self.op_id)
121
- return False
122
71
 
123
72
  def insert_chunk(self, chunk: TableChunk) -> None:
124
73
  """
125
- Insert a chunk into the local shuffle instance.
74
+ Insert a chunk into the ShuffleContext.
126
75
 
127
76
  Parameters
128
77
  ----------
@@ -135,15 +84,19 @@ class LocalShuffle:
135
84
  columns_to_hash=self.columns_to_hash,
136
85
  num_partitions=self.num_partitions,
137
86
  stream=chunk.stream,
138
- br=self.br,
87
+ br=self.context.br(),
139
88
  )
140
89
 
141
90
  # Insert into shuffler
142
- self.shuffler.insert_chunks(partitioned_chunks)
91
+ self.shuffler.insert(partitioned_chunks)
92
+
93
+ async def insert_finished(self) -> None:
94
+ """Insert finished into the ShuffleManager."""
95
+ await self.shuffler.insert_finished(self.context)
143
96
 
144
- def extract_chunk(self, sequence_number: int, stream: Stream) -> plc.Table:
97
+ async def extract_chunk(self, sequence_number: int, stream: Stream) -> plc.Table:
145
98
  """
146
- Extract a chunk from the local shuffle instance.
99
+ Extract a chunk from the ShuffleManager.
147
100
 
148
101
  Parameters
149
102
  ----------
@@ -156,21 +109,18 @@ class LocalShuffle:
156
109
  -------
157
110
  The extracted table.
158
111
  """
159
- if not self._insertion_finished:
160
- self.shuffler.insert_finished(list(range(self.num_partitions)))
161
- self._insertion_finished = True
162
-
163
- self.shuffler.wait_on(sequence_number)
164
- partition_chunks = self.shuffler.extract(sequence_number)
112
+ partition_chunks = await self.shuffler.extract_async(
113
+ self.context, sequence_number
114
+ )
165
115
  return py_unpack_and_concat(
166
116
  partitions=partition_chunks,
167
117
  stream=stream,
168
- br=self.br,
118
+ br=self.context.br(),
169
119
  )
170
120
 
171
121
 
172
122
  @define_py_node()
173
- async def local_shuffle_node(
123
+ async def shuffle_node(
174
124
  context: Context,
175
125
  ir: Shuffle,
176
126
  ir_context: IRExecutionContext,
@@ -178,6 +128,7 @@ async def local_shuffle_node(
178
128
  ch_out: ChannelPair,
179
129
  columns_to_hash: tuple[int, ...],
180
130
  num_partitions: int,
131
+ collective_id: int,
181
132
  ) -> None:
182
133
  """
183
134
  Execute a local shuffle pipeline in a single node.
@@ -202,46 +153,68 @@ async def local_shuffle_node(
202
153
  Tuple of column indices to use for hashing.
203
154
  num_partitions
204
155
  Number of partitions to shuffle into.
156
+ collective_id
157
+ The collective ID.
205
158
  """
206
159
  async with shutdown_on_error(
207
160
  context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
208
161
  ):
209
- # Create LocalShuffle context manager to handle shuffler lifecycle
210
- # TODO: Use ir_context to get the stream (not available yet)
211
- with LocalShuffle(context, num_partitions, columns_to_hash) as local_shuffle:
212
- # Process input chunks
213
- while True:
214
- msg = await ch_in.data.recv(context)
215
- if msg is None:
216
- break
217
-
218
- # Extract TableChunk from message
219
- chunk = TableChunk.from_message(msg).make_available_and_spill(
220
- context.br(), allow_overbooking=True
221
- )
162
+ # Receive and send updated metadata.
163
+ _ = await ch_in.recv_metadata(context)
164
+ column_names = list(ir.schema.keys())
165
+ partitioned_on = tuple(column_names[i] for i in columns_to_hash)
166
+ output_metadata = Metadata(
167
+ max(1, num_partitions // context.comm().nranks),
168
+ partitioned_on=partitioned_on,
169
+ )
170
+ await ch_out.send_metadata(context, output_metadata)
222
171
 
223
- # Get the table view and insert into shuffler
224
- local_shuffle.insert_chunk(chunk)
225
-
226
- # Extract shuffled partitions and send them out
227
- # LocalShuffle.extract_chunk handles insert_finished, wait, extract, and unpack
228
- stream = ir_context.get_cuda_stream()
229
- for partition_id in range(num_partitions):
230
- # Create a new TableChunk with the result
231
- output_chunk = TableChunk.from_pylibcudf_table(
232
- table=local_shuffle.extract_chunk(partition_id, stream),
233
- stream=stream,
234
- exclusive_view=True,
235
- )
172
+ # Create ShuffleManager instance
173
+ shuffle = ShuffleManager(
174
+ context, num_partitions, columns_to_hash, collective_id
175
+ )
236
176
 
237
- # Send the output chunk
238
- await ch_out.data.send(context, Message(partition_id, output_chunk))
177
+ # Process input chunks
178
+ while (msg := await ch_in.data.recv(context)) is not None:
179
+ # Extract TableChunk from message and insert into shuffler
180
+ shuffle.insert_chunk(
181
+ TableChunk.from_message(msg).make_available_and_spill(
182
+ context.br(), allow_overbooking=True
183
+ )
184
+ )
185
+ del msg
186
+
187
+ # Insert finished
188
+ await shuffle.insert_finished()
189
+
190
+ # Extract shuffled partitions and send them out
191
+ stream = ir_context.get_cuda_stream()
192
+ for partition_id in range(
193
+ # Round-robin partition assignment
194
+ context.comm().rank,
195
+ num_partitions,
196
+ context.comm().nranks,
197
+ ):
198
+ # Extract and send the output chunk
199
+ await ch_out.data.send(
200
+ context,
201
+ Message(
202
+ partition_id,
203
+ TableChunk.from_pylibcudf_table(
204
+ table=await shuffle.extract_chunk(partition_id, stream),
205
+ stream=stream,
206
+ exclusive_view=True,
207
+ ),
208
+ ),
209
+ )
239
210
 
240
211
  await ch_out.data.drain(context)
241
212
 
242
213
 
243
214
  @generate_ir_sub_network.register(Shuffle)
244
- def _(ir: Shuffle, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManager]]:
215
+ def _(
216
+ ir: Shuffle, rec: SubNetGenerator
217
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
245
218
  # Local shuffle operation.
246
219
 
247
220
  # Process children
@@ -257,13 +230,15 @@ def _(ir: Shuffle, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelMan
257
230
  columns_to_hash = tuple(column_names.index(k.name) for k in keys)
258
231
  num_partitions = rec.state["partition_info"][ir].count
259
232
 
233
+ # Look up the reserved collective ID for this operation
234
+ collective_id = rec.state["collective_id_map"][ir]
235
+
260
236
  # Create output ChannelManager
261
237
  channels[ir] = ChannelManager(rec.state["context"])
262
238
 
263
- # Complete shuffle pipeline in a single node
264
- # LocalShuffle context manager handles shuffle ID lifecycle internally
265
- nodes.append(
266
- local_shuffle_node(
239
+ # Complete shuffle node
240
+ nodes[ir] = [
241
+ shuffle_node(
267
242
  context,
268
243
  ir,
269
244
  rec.state["ir_context"],
@@ -271,7 +246,8 @@ def _(ir: Shuffle, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelMan
271
246
  ch_out=channels[ir].reserve_input_slot(),
272
247
  columns_to_hash=columns_to_hash,
273
248
  num_partitions=num_partitions,
249
+ collective_id=collective_id,
274
250
  )
275
- )
251
+ ]
276
252
 
277
253
  return nodes, channels