cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from cudf_polars.dsl.ir import (
22
22
  Filter,
23
23
  HConcat,
24
24
  HStack,
25
+ IRExecutionContext,
25
26
  MapFunction,
26
27
  Projection,
27
28
  Slice,
@@ -42,7 +43,9 @@ if TYPE_CHECKING:
42
43
  from collections.abc import MutableMapping
43
44
  from typing import Any
44
45
 
45
- from cudf_polars.containers import DataFrame
46
+ import polars as pl
47
+
48
+ from cudf_polars.experimental.base import StatsCollector
46
49
  from cudf_polars.experimental.dispatch import LowerIRTransformer, State
47
50
  from cudf_polars.utils.config import ConfigOptions
48
51
 
@@ -59,7 +62,7 @@ def _(
59
62
 
60
63
  def lower_ir_graph(
61
64
  ir: IR, config_options: ConfigOptions
62
- ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
65
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo], StatsCollector]:
63
66
  """
64
67
  Rewrite an IR graph and extract partitioning information.
65
68
 
@@ -72,9 +75,10 @@ def lower_ir_graph(
72
75
 
73
76
  Returns
74
77
  -------
75
- new_ir, partition_info
76
- The rewritten graph, and a mapping from unique nodes
77
- in the new graph to associated partitioning information.
78
+ new_ir, partition_info, stats
79
+ The rewritten graph, a mapping from unique nodes
80
+ in the new graph to associated partitioning information,
81
+ and the statistics collector.
78
82
 
79
83
  Notes
80
84
  -----
@@ -90,7 +94,7 @@ def lower_ir_graph(
90
94
  "stats": collect_statistics(ir, config_options),
91
95
  }
92
96
  mapper: LowerIRTransformer = CachingVisitor(lower_ir_node, state=state)
93
- return mapper(ir)
97
+ return *mapper(ir), state["stats"]
94
98
 
95
99
 
96
100
  def task_graph(
@@ -110,6 +114,8 @@ def task_graph(
110
114
  associated partitioning information.
111
115
  config_options
112
116
  GPUEngine configuration options.
117
+ context
118
+ Runtime context for IR node execution.
113
119
 
114
120
  Returns
115
121
  -------
@@ -130,9 +136,13 @@ def task_graph(
130
136
  --------
131
137
  generate_ir_tasks
132
138
  """
139
+ context = IRExecutionContext.from_config_options(config_options)
133
140
  graph = reduce(
134
141
  operator.or_,
135
- (generate_ir_tasks(node, partition_info) for node in traversal([ir])),
142
+ (
143
+ generate_ir_tasks(node, partition_info, context=context)
144
+ for node in traversal([ir])
145
+ ),
136
146
  )
137
147
 
138
148
  key_name = get_key_name(ir)
@@ -140,7 +150,10 @@ def task_graph(
140
150
 
141
151
  key: str | tuple[str, int]
142
152
  if partition_count > 1:
143
- graph[key_name] = (_concat, *partition_info[ir].keys(ir))
153
+ graph[key_name] = (
154
+ partial(_concat, context=context),
155
+ *partition_info[ir].keys(ir),
156
+ )
144
157
  key = key_name
145
158
  else:
146
159
  key = (key_name, 0)
@@ -158,10 +171,10 @@ def get_scheduler(config_options: ConfigOptions) -> Any:
158
171
  "'in-memory' executor not supported in 'generate_ir_tasks'"
159
172
  )
160
173
 
161
- scheduler = config_options.executor.scheduler
174
+ cluster = config_options.executor.cluster
162
175
 
163
176
  if (
164
- scheduler == "distributed"
177
+ cluster == "distributed"
165
178
  ): # pragma: no cover; block depends on executor type and Distributed cluster
166
179
  from distributed import get_client
167
180
 
@@ -171,12 +184,12 @@ def get_scheduler(config_options: ConfigOptions) -> Any:
171
184
  DaskRegisterManager.register_once()
172
185
  DaskRegisterManager.run_on_cluster(client)
173
186
  return client.get
174
- elif scheduler == "synchronous":
187
+ elif cluster == "single":
175
188
  from cudf_polars.experimental.scheduler import synchronous_scheduler
176
189
 
177
190
  return synchronous_scheduler
178
191
  else: # pragma: no cover
179
- raise ValueError(f"{scheduler} not a supported scheduler option.")
192
+ raise ValueError(f"{cluster} not a supported cluster option.")
180
193
 
181
194
 
182
195
  def post_process_task_graph(
@@ -214,10 +227,34 @@ def post_process_task_graph(
214
227
  return graph
215
228
 
216
229
 
230
+ def evaluate_rapidsmpf(
231
+ ir: IR,
232
+ config_options: ConfigOptions,
233
+ ) -> pl.DataFrame: # pragma: no cover; rapidsmpf runtime not tested in CI yet
234
+ """
235
+ Evaluate with the RapidsMPF streaming runtime.
236
+
237
+ Parameters
238
+ ----------
239
+ ir
240
+ Logical plan to evaluate.
241
+ config_options
242
+ GPUEngine configuration options.
243
+
244
+ Returns
245
+ -------
246
+ A cudf-polars DataFrame object.
247
+ """
248
+ from cudf_polars.experimental.rapidsmpf.core import evaluate_logical_plan
249
+
250
+ result, _ = evaluate_logical_plan(ir, config_options, collect_metadata=False)
251
+ return result
252
+
253
+
217
254
  def evaluate_streaming(
218
255
  ir: IR,
219
256
  config_options: ConfigOptions,
220
- ) -> DataFrame:
257
+ ) -> pl.DataFrame:
221
258
  """
222
259
  Evaluate an IR graph with partitioning.
223
260
 
@@ -235,16 +272,26 @@ def evaluate_streaming(
235
272
  # Clear source info cache in case data was overwritten
236
273
  _clear_source_info_cache()
237
274
 
238
- ir, partition_info = lower_ir_graph(ir, config_options)
275
+ assert config_options.executor.name == "streaming", "Executor must be streaming"
276
+ if (
277
+ config_options.executor.runtime == "rapidsmpf"
278
+ ): # pragma: no cover; rapidsmpf runtime not tested in CI yet
279
+ # Using the RapidsMPF streaming runtime.
280
+ return evaluate_rapidsmpf(ir, config_options)
281
+ else:
282
+ # Using the default task engine.
283
+ ir, partition_info, _ = lower_ir_graph(ir, config_options)
239
284
 
240
- graph, key = task_graph(ir, partition_info, config_options)
285
+ graph, key = task_graph(ir, partition_info, config_options)
241
286
 
242
- return get_scheduler(config_options)(graph, key)
287
+ return get_scheduler(config_options)(graph, key).to_polars()
243
288
 
244
289
 
245
290
  @generate_ir_tasks.register(IR)
246
291
  def _(
247
- ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
292
+ ir: IR,
293
+ partition_info: MutableMapping[IR, PartitionInfo],
294
+ context: IRExecutionContext,
248
295
  ) -> MutableMapping[Any, Any]:
249
296
  # Generate pointwise (embarrassingly-parallel) tasks by default
250
297
  child_names = [get_key_name(c) for c in ir.children]
@@ -252,7 +299,7 @@ def _(
252
299
 
253
300
  return {
254
301
  key: (
255
- ir.do_evaluate,
302
+ partial(ir.do_evaluate, context=context),
256
303
  *ir._non_child_args,
257
304
  *[
258
305
  (child_name, 0 if bcast_child[j] else i)
@@ -292,7 +339,9 @@ def _(
292
339
 
293
340
  @generate_ir_tasks.register(Union)
294
341
  def _(
295
- ir: Union, partition_info: MutableMapping[IR, PartitionInfo]
342
+ ir: Union,
343
+ partition_info: MutableMapping[IR, PartitionInfo],
344
+ context: IRExecutionContext,
296
345
  ) -> MutableMapping[Any, Any]:
297
346
  key_name = get_key_name(ir)
298
347
  partition = itertools.count()
@@ -0,0 +1,8 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """RapidsMPF streaming-engine support."""
5
+
6
+ from __future__ import annotations
7
+
8
+ __all__: list[str] = []
@@ -0,0 +1,9 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Collective operations for the RapidsMPF streaming runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from cudf_polars.experimental.rapidsmpf.collectives.common import ReserveOpIDs
8
+
9
+ __all__ = ["ReserveOpIDs"]
@@ -0,0 +1,90 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """AllGather logic for the RapidsMPF streaming runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import asyncio
8
+ from typing import TYPE_CHECKING
9
+
10
+ from rapidsmpf.integrations.cudf.partition import unpack_and_concat
11
+ from rapidsmpf.memory.packed_data import PackedData
12
+ from rapidsmpf.streaming.coll.allgather import AllGather
13
+
14
+ from pylibcudf.contiguous_split import pack
15
+
16
+ if TYPE_CHECKING:
17
+ from rapidsmpf.streaming.core.context import Context
18
+ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
19
+
20
+ import pylibcudf as plc
21
+ from rmm.pylibrmm.stream import Stream
22
+
23
+
24
+ class AllGatherManager:
25
+ """
26
+ AllGather manager.
27
+
28
+ Parameters
29
+ ----------
30
+ context: Context
31
+ The streaming context.
32
+ op_id: int
33
+ Pre-allocated operation ID for this operation.
34
+ """
35
+
36
+ def __init__(self, context: Context, op_id: int):
37
+ self.context = context
38
+ self.allgather = AllGather(self.context, op_id)
39
+
40
+ def insert(self, sequence_number: int, chunk: TableChunk) -> None:
41
+ """
42
+ Insert a chunk into the AllGatherContext.
43
+
44
+ Parameters
45
+ ----------
46
+ sequence_number: int
47
+ The sequence number of the chunk to insert.
48
+ chunk: TableChunk
49
+ The table chunk to insert.
50
+ """
51
+ self.allgather.insert(
52
+ sequence_number,
53
+ PackedData.from_cudf_packed_columns(
54
+ pack(
55
+ chunk.table_view(),
56
+ chunk.stream,
57
+ ),
58
+ chunk.stream,
59
+ self.context.br(),
60
+ ),
61
+ )
62
+ del chunk
63
+
64
+ def insert_finished(self) -> None:
65
+ """Insert finished into the AllGatherManager."""
66
+ self.allgather.insert_finished()
67
+
68
+ async def extract_concatenated(
69
+ self, stream: Stream, *, ordered: bool = True
70
+ ) -> plc.Table:
71
+ """
72
+ Extract the concatenated result.
73
+
74
+ Parameters
75
+ ----------
76
+ stream: Stream
77
+ The stream to use for chunk extraction.
78
+ ordered: bool
79
+ Whether to extract the data in ordered or unordered fashion.
80
+
81
+ Returns
82
+ -------
83
+ The concatenated AllGather result.
84
+ """
85
+ return await asyncio.to_thread(
86
+ unpack_and_concat,
87
+ partitions=await self.allgather.extract_all(self.context, ordered=ordered),
88
+ stream=stream,
89
+ br=self.context.br(),
90
+ )
@@ -0,0 +1,96 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Common utilities for collective operations."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import threading
8
+ from typing import TYPE_CHECKING, Literal
9
+
10
+ from rapidsmpf.shuffler import Shuffler
11
+
12
+ from cudf_polars.dsl.traversal import traversal
13
+ from cudf_polars.experimental.join import Join
14
+ from cudf_polars.experimental.repartition import Repartition
15
+ from cudf_polars.experimental.shuffle import Shuffle
16
+
17
+ if TYPE_CHECKING:
18
+ from types import TracebackType
19
+
20
+ from cudf_polars.dsl.ir import IR
21
+
22
+
23
+ # Set of available collective IDs
24
+ _collective_id_vacancy: set[int] = set(range(Shuffler.max_concurrent_shuffles))
25
+ _collective_id_vacancy_lock: threading.Lock = threading.Lock()
26
+
27
+
28
+ def _get_new_collective_id() -> int:
29
+ with _collective_id_vacancy_lock:
30
+ if not _collective_id_vacancy:
31
+ raise ValueError(
32
+ f"Cannot shuffle more than {Shuffler.max_concurrent_shuffles} "
33
+ "times in a single query."
34
+ )
35
+
36
+ return _collective_id_vacancy.pop()
37
+
38
+
39
+ def _release_collective_id(collective_id: int) -> None:
40
+ """Release a collective ID back to the vacancy set."""
41
+ with _collective_id_vacancy_lock:
42
+ _collective_id_vacancy.add(collective_id)
43
+
44
+
45
+ class ReserveOpIDs:
46
+ """
47
+ Context manager to reserve collective IDs for pipeline execution.
48
+
49
+ Parameters
50
+ ----------
51
+ ir : IR
52
+ The root IR node of the pipeline.
53
+
54
+ Notes
55
+ -----
56
+ This context manager:
57
+ 1. Identifies all Shuffle nodes in the IR
58
+ 2. Reserves collective IDs from the vacancy pool
59
+ 3. Creates a mapping from IR nodes to their reserved IDs
60
+ 4. Releases all IDs back to the pool on __exit__
61
+ """
62
+
63
+ def __init__(self, ir: IR):
64
+ # Find all collective IR nodes.
65
+ self.collective_nodes: list[IR] = [
66
+ node
67
+ for node in traversal([ir])
68
+ if isinstance(node, (Shuffle, Join, Repartition))
69
+ ]
70
+ self.collective_id_map: dict[IR, int] = {}
71
+
72
+ def __enter__(self) -> dict[IR, int]:
73
+ """
74
+ Reserve collective IDs and return the mapping.
75
+
76
+ Returns
77
+ -------
78
+ collective_id_map : dict[IR, int]
79
+ Mapping from IR nodes to their reserved collective IDs.
80
+ """
81
+ # Reserve IDs and map nodes directly to their IDs
82
+ for node in self.collective_nodes:
83
+ self.collective_id_map[node] = _get_new_collective_id()
84
+
85
+ return self.collective_id_map
86
+
87
+ def __exit__(
88
+ self,
89
+ exc_type: type | None,
90
+ exc_val: Exception | None,
91
+ exc_tb: TracebackType | None,
92
+ ) -> Literal[False]:
93
+ """Release all reserved collective IDs back to the vacancy pool."""
94
+ for collective_id in self.collective_id_map.values():
95
+ _release_collective_id(collective_id)
96
+ return False
@@ -0,0 +1,253 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Shuffle logic for the RapidsMPF streaming runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ from rapidsmpf.integrations.cudf.partition import (
10
+ partition_and_pack as py_partition_and_pack,
11
+ unpack_and_concat as py_unpack_and_concat,
12
+ )
13
+ from rapidsmpf.streaming.coll.shuffler import ShufflerAsync
14
+ from rapidsmpf.streaming.core.message import Message
15
+ from rapidsmpf.streaming.core.node import define_py_node
16
+ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
17
+
18
+ from cudf_polars.dsl.expr import Col
19
+ from cudf_polars.experimental.rapidsmpf.dispatch import (
20
+ generate_ir_sub_network,
21
+ )
22
+ from cudf_polars.experimental.rapidsmpf.nodes import shutdown_on_error
23
+ from cudf_polars.experimental.rapidsmpf.utils import (
24
+ ChannelManager,
25
+ Metadata,
26
+ )
27
+ from cudf_polars.experimental.shuffle import Shuffle
28
+
29
+ if TYPE_CHECKING:
30
+ from rapidsmpf.streaming.core.context import Context
31
+
32
+ import pylibcudf as plc
33
+ from rmm.pylibrmm.stream import Stream
34
+
35
+ from cudf_polars.dsl.ir import IR, IRExecutionContext
36
+ from cudf_polars.experimental.rapidsmpf.core import SubNetGenerator
37
+ from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
38
+
39
+
40
+ class ShuffleManager:
41
+ """
42
+ ShufflerAsync manager.
43
+
44
+ Parameters
45
+ ----------
46
+ context: Context
47
+ The streaming context.
48
+ num_partitions: int
49
+ The number of partitions to shuffle into.
50
+ columns_to_hash: tuple[int, ...]
51
+ The columns to hash.
52
+ collective_id: int
53
+ The collective ID.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ context: Context,
59
+ num_partitions: int,
60
+ columns_to_hash: tuple[int, ...],
61
+ collective_id: int,
62
+ ):
63
+ self.context = context
64
+ self.num_partitions = num_partitions
65
+ self.columns_to_hash = columns_to_hash
66
+ self.shuffler = ShufflerAsync(
67
+ context,
68
+ collective_id,
69
+ num_partitions,
70
+ )
71
+
72
+ def insert_chunk(self, chunk: TableChunk) -> None:
73
+ """
74
+ Insert a chunk into the ShuffleContext.
75
+
76
+ Parameters
77
+ ----------
78
+ chunk: TableChunk
79
+ The table chunk to insert.
80
+ """
81
+ # Partition and pack using the Python function
82
+ partitioned_chunks = py_partition_and_pack(
83
+ table=chunk.table_view(),
84
+ columns_to_hash=self.columns_to_hash,
85
+ num_partitions=self.num_partitions,
86
+ stream=chunk.stream,
87
+ br=self.context.br(),
88
+ )
89
+
90
+ # Insert into shuffler
91
+ self.shuffler.insert(partitioned_chunks)
92
+
93
+ async def insert_finished(self) -> None:
94
+ """Insert finished into the ShuffleManager."""
95
+ await self.shuffler.insert_finished(self.context)
96
+
97
+ async def extract_chunk(self, sequence_number: int, stream: Stream) -> plc.Table:
98
+ """
99
+ Extract a chunk from the ShuffleManager.
100
+
101
+ Parameters
102
+ ----------
103
+ sequence_number: int
104
+ The sequence number of the chunk to extract.
105
+ stream: Stream
106
+ The stream to use for chunk extraction.
107
+
108
+ Returns
109
+ -------
110
+ The extracted table.
111
+ """
112
+ partition_chunks = await self.shuffler.extract_async(
113
+ self.context, sequence_number
114
+ )
115
+ return py_unpack_and_concat(
116
+ partitions=partition_chunks,
117
+ stream=stream,
118
+ br=self.context.br(),
119
+ )
120
+
121
+
122
+ @define_py_node()
123
+ async def shuffle_node(
124
+ context: Context,
125
+ ir: Shuffle,
126
+ ir_context: IRExecutionContext,
127
+ ch_in: ChannelPair,
128
+ ch_out: ChannelPair,
129
+ columns_to_hash: tuple[int, ...],
130
+ num_partitions: int,
131
+ collective_id: int,
132
+ ) -> None:
133
+ """
134
+ Execute a local shuffle pipeline in a single node.
135
+
136
+ This node combines partition_and_pack, shuffler, and unpack_and_concat
137
+ into a single Python node using rapidsmpf.shuffler.Shuffler and utilities
138
+ from rapidsmpf.integrations.cudf.partition.
139
+
140
+ Parameters
141
+ ----------
142
+ context
143
+ The rapidsmpf context.
144
+ ir
145
+ The Shuffle IR node.
146
+ ir_context
147
+ The execution context for the IR node.
148
+ ch_in
149
+ Input ChannelPair with metadata and data channels.
150
+ ch_out
151
+ Output ChannelPair with metadata and data channels.
152
+ columns_to_hash
153
+ Tuple of column indices to use for hashing.
154
+ num_partitions
155
+ Number of partitions to shuffle into.
156
+ collective_id
157
+ The collective ID.
158
+ """
159
+ async with shutdown_on_error(
160
+ context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
161
+ ):
162
+ # Receive and send updated metadata.
163
+ _ = await ch_in.recv_metadata(context)
164
+ column_names = list(ir.schema.keys())
165
+ partitioned_on = tuple(column_names[i] for i in columns_to_hash)
166
+ output_metadata = Metadata(
167
+ max(1, num_partitions // context.comm().nranks),
168
+ partitioned_on=partitioned_on,
169
+ )
170
+ await ch_out.send_metadata(context, output_metadata)
171
+
172
+ # Create ShuffleManager instance
173
+ shuffle = ShuffleManager(
174
+ context, num_partitions, columns_to_hash, collective_id
175
+ )
176
+
177
+ # Process input chunks
178
+ while (msg := await ch_in.data.recv(context)) is not None:
179
+ # Extract TableChunk from message and insert into shuffler
180
+ shuffle.insert_chunk(
181
+ TableChunk.from_message(msg).make_available_and_spill(
182
+ context.br(), allow_overbooking=True
183
+ )
184
+ )
185
+ del msg
186
+
187
+ # Insert finished
188
+ await shuffle.insert_finished()
189
+
190
+ # Extract shuffled partitions and send them out
191
+ stream = ir_context.get_cuda_stream()
192
+ for partition_id in range(
193
+ # Round-robin partition assignment
194
+ context.comm().rank,
195
+ num_partitions,
196
+ context.comm().nranks,
197
+ ):
198
+ # Extract and send the output chunk
199
+ await ch_out.data.send(
200
+ context,
201
+ Message(
202
+ partition_id,
203
+ TableChunk.from_pylibcudf_table(
204
+ table=await shuffle.extract_chunk(partition_id, stream),
205
+ stream=stream,
206
+ exclusive_view=True,
207
+ ),
208
+ ),
209
+ )
210
+
211
+ await ch_out.data.drain(context)
212
+
213
+
214
+ @generate_ir_sub_network.register(Shuffle)
215
+ def _(
216
+ ir: Shuffle, rec: SubNetGenerator
217
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
218
+ # Local shuffle operation.
219
+
220
+ # Process children
221
+ (child,) = ir.children
222
+ nodes, channels = rec(child)
223
+
224
+ keys: list[Col] = [ne.value for ne in ir.keys if isinstance(ne.value, Col)]
225
+ if len(keys) != len(ir.keys): # pragma: no cover
226
+ raise NotImplementedError("Shuffle requires simple keys.")
227
+ column_names = list(ir.schema.keys())
228
+
229
+ context = rec.state["context"]
230
+ columns_to_hash = tuple(column_names.index(k.name) for k in keys)
231
+ num_partitions = rec.state["partition_info"][ir].count
232
+
233
+ # Look up the reserved collective ID for this operation
234
+ collective_id = rec.state["collective_id_map"][ir]
235
+
236
+ # Create output ChannelManager
237
+ channels[ir] = ChannelManager(rec.state["context"])
238
+
239
+ # Complete shuffle node
240
+ nodes[ir] = [
241
+ shuffle_node(
242
+ context,
243
+ ir,
244
+ rec.state["ir_context"],
245
+ ch_in=channels[child].reserve_output_slot(),
246
+ ch_out=channels[ir].reserve_input_slot(),
247
+ columns_to_hash=columns_to_hash,
248
+ num_partitions=num_partitions,
249
+ collective_id=collective_id,
250
+ )
251
+ ]
252
+
253
+ return nodes, channels