cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,216 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Re-chunking logic for the RapidsMPF streaming runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ from rapidsmpf.memory.buffer import MemoryType
11
+ from rapidsmpf.streaming.core.message import Message
12
+ from rapidsmpf.streaming.core.node import define_py_node
13
+ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
14
+
15
+ from cudf_polars.containers import DataFrame
16
+ from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager
17
+ from cudf_polars.experimental.rapidsmpf.dispatch import generate_ir_sub_network
18
+ from cudf_polars.experimental.rapidsmpf.nodes import shutdown_on_error
19
+ from cudf_polars.experimental.rapidsmpf.utils import (
20
+ ChannelManager,
21
+ Metadata,
22
+ empty_table_chunk,
23
+ opaque_reservation,
24
+ )
25
+ from cudf_polars.experimental.repartition import Repartition
26
+ from cudf_polars.experimental.utils import _concat
27
+
28
+ if TYPE_CHECKING:
29
+ from rapidsmpf.streaming.core.context import Context
30
+
31
+ from cudf_polars.dsl.ir import IR, IRExecutionContext
32
+ from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator
33
+ from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
34
+
35
+
36
+ @define_py_node()
37
+ async def concatenate_node(
38
+ context: Context,
39
+ ir: Repartition,
40
+ ir_context: IRExecutionContext,
41
+ ch_out: ChannelPair,
42
+ ch_in: ChannelPair,
43
+ *,
44
+ output_count: int,
45
+ collective_id: int,
46
+ ) -> None:
47
+ """
48
+ Concatenate node for rapidsmpf.
49
+
50
+ Parameters
51
+ ----------
52
+ context
53
+ The rapidsmpf context.
54
+ ir
55
+ The Repartition IR node.
56
+ ir_context
57
+ The execution context for the IR node.
58
+ ch_out
59
+ The output ChannelPair.
60
+ ch_in
61
+ The input ChannelPair.
62
+ output_count
63
+ The expected number of output chunks.
64
+ collective_id
65
+ Pre-allocated collective ID for this operation.
66
+ """
67
+ async with shutdown_on_error(
68
+ context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
69
+ ):
70
+ # Receive metadata.
71
+ input_metadata = await ch_in.recv_metadata(context)
72
+ metadata = Metadata(output_count)
73
+
74
+ # max_chunks corresponds to the number of chunks we can
75
+ # concatenate together. If None, we must concatenate everything.
76
+ # Since a single-partition operation gets "special treatment",
77
+ # we must make sure `output_count == 1` is always satisfied.
78
+ max_chunks: int | None = None
79
+ if output_count > 1:
80
+ # Make sure max_chunks is at least 2.
81
+ max_chunks = max(2, math.ceil(input_metadata.count / output_count))
82
+
83
+ # Check if we need global communication.
84
+ need_global_repartition = (
85
+ # Avoid allgather of already-duplicated data
86
+ context.comm().nranks > 1
87
+ and not input_metadata.duplicated
88
+ and output_count == 1
89
+ )
90
+
91
+ chunks: list[TableChunk]
92
+ msg: TableChunk | None
93
+ if need_global_repartition:
94
+ # Assume this means "global repartitioning" for now
95
+
96
+ # Send metadata.
97
+ metadata.duplicated = True
98
+ await ch_out.send_metadata(context, metadata)
99
+
100
+ allgather = AllGatherManager(context, collective_id)
101
+ stream = context.get_stream_from_pool()
102
+ seq_num = 0
103
+ while (msg := await ch_in.data.recv(context)) is not None:
104
+ allgather.insert(seq_num, TableChunk.from_message(msg))
105
+ seq_num += 1
106
+ del msg
107
+ allgather.insert_finished()
108
+
109
+ # Extract concatenated result
110
+ result_table = await allgather.extract_concatenated(stream)
111
+
112
+ # If no chunks were gathered, result_table has 0 columns.
113
+ # We need to create an empty table with the correct schema.
114
+ if result_table.num_columns() == 0 and len(ir.schema) > 0:
115
+ output_chunk = empty_table_chunk(ir, context, stream)
116
+ else:
117
+ output_chunk = TableChunk.from_pylibcudf_table(
118
+ result_table, stream, exclusive_view=True
119
+ )
120
+
121
+ await ch_out.data.send(context, Message(0, output_chunk))
122
+ else:
123
+ # Send metadata.
124
+ metadata.duplicated = input_metadata.duplicated
125
+ await ch_out.send_metadata(context, metadata)
126
+
127
+ # Local repartitioning
128
+ seq_num = 0
129
+ while True:
130
+ chunks = []
131
+ done_receiving = False
132
+
133
+ # Collect chunks up to max_chunks or until end of stream
134
+ while len(chunks) < (max_chunks or float("inf")):
135
+ msg = await ch_in.data.recv(context)
136
+ if msg is None:
137
+ done_receiving = True
138
+ break
139
+ chunks.append(
140
+ TableChunk.from_message(msg).make_available_and_spill(
141
+ context.br(), allow_overbooking=True
142
+ )
143
+ )
144
+ del msg
145
+
146
+ if chunks:
147
+ input_bytes = sum(
148
+ chunk.data_alloc_size(MemoryType.DEVICE) for chunk in chunks
149
+ )
150
+ with opaque_reservation(context, input_bytes):
151
+ df = _concat(
152
+ *(
153
+ DataFrame.from_table(
154
+ chunk.table_view(),
155
+ list(ir.schema.keys()),
156
+ list(ir.schema.values()),
157
+ chunk.stream,
158
+ )
159
+ for chunk in chunks
160
+ ),
161
+ context=ir_context,
162
+ )
163
+ await ch_out.data.send(
164
+ context,
165
+ Message(
166
+ seq_num,
167
+ TableChunk.from_pylibcudf_table(
168
+ df.table, df.stream, exclusive_view=True
169
+ ),
170
+ ),
171
+ )
172
+ seq_num += 1
173
+ del df, chunks
174
+
175
+ # Break if we reached end of stream
176
+ if done_receiving:
177
+ break
178
+
179
+ await ch_out.data.drain(context)
180
+
181
+
182
+ @generate_ir_sub_network.register(Repartition)
183
+ def _(
184
+ ir: Repartition, rec: SubNetGenerator
185
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
186
+ # Repartition node.
187
+
188
+ partition_info = rec.state["partition_info"]
189
+ if partition_info[ir].count > 1:
190
+ count_output = partition_info[ir].count
191
+ count_input = partition_info[ir.children[0]].count
192
+ if count_input < count_output:
193
+ raise ValueError("Repartitioning to more chunks is not supported.")
194
+
195
+ # Process children
196
+ nodes, channels = rec(ir.children[0])
197
+
198
+ # Create output ChannelManager
199
+ channels[ir] = ChannelManager(rec.state["context"])
200
+
201
+ # Look up the reserved shuffle ID for this operation
202
+ collective_id = rec.state["collective_id_map"][ir]
203
+
204
+ # Add python node
205
+ nodes[ir] = [
206
+ concatenate_node(
207
+ rec.state["context"],
208
+ ir,
209
+ rec.state["ir_context"],
210
+ channels[ir].reserve_input_slot(),
211
+ channels[ir.children[0]].reserve_output_slot(),
212
+ output_count=partition_info[ir].count,
213
+ collective_id=collective_id,
214
+ )
215
+ ]
216
+ return nodes, channels
@@ -0,0 +1,115 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Union logic for the RapidsMPF streaming runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ from rapidsmpf.streaming.core.message import Message
10
+ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
11
+
12
+ from cudf_polars.dsl.ir import Union
13
+ from cudf_polars.experimental.rapidsmpf.dispatch import (
14
+ generate_ir_sub_network,
15
+ )
16
+ from cudf_polars.experimental.rapidsmpf.nodes import define_py_node, shutdown_on_error
17
+ from cudf_polars.experimental.rapidsmpf.utils import (
18
+ ChannelManager,
19
+ Metadata,
20
+ process_children,
21
+ )
22
+
23
+ if TYPE_CHECKING:
24
+ from rapidsmpf.streaming.core.context import Context
25
+
26
+ from cudf_polars.dsl.ir import IR, IRExecutionContext
27
+ from cudf_polars.experimental.rapidsmpf.core import SubNetGenerator
28
+ from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
29
+
30
+
31
+ @define_py_node()
32
+ async def union_node(
33
+ context: Context,
34
+ ir: Union,
35
+ ir_context: IRExecutionContext,
36
+ ch_out: ChannelPair,
37
+ *chs_in: ChannelPair,
38
+ ) -> None:
39
+ """
40
+ Union node for rapidsmpf.
41
+
42
+ Parameters
43
+ ----------
44
+ context
45
+ The rapidsmpf context.
46
+ ir
47
+ The Union IR node.
48
+ ir_context
49
+ The execution context for the IR node.
50
+ ch_out
51
+ The output ChannelPair.
52
+ chs_in
53
+ The input ChannelPairs.
54
+ """
55
+ async with shutdown_on_error(
56
+ context,
57
+ *[ch.metadata for ch in chs_in],
58
+ *[ch.data for ch in chs_in],
59
+ ch_out.metadata,
60
+ ch_out.data,
61
+ ):
62
+ # Merge and forward metadata.
63
+ total_count = 0
64
+ duplicated = True
65
+ for ch_in in chs_in:
66
+ metadata = await ch_in.recv_metadata(context)
67
+ total_count += metadata.count
68
+ duplicated = duplicated and metadata.duplicated
69
+ await ch_out.send_metadata(
70
+ context, Metadata(total_count, duplicated=duplicated)
71
+ )
72
+
73
+ seq_num_offset = 0
74
+ for ch_in in chs_in:
75
+ num_ch_chunks = 0
76
+ while (msg := await ch_in.data.recv(context)) is not None:
77
+ num_ch_chunks += 1
78
+ await ch_out.data.send(
79
+ context,
80
+ Message(
81
+ msg.sequence_number + seq_num_offset,
82
+ TableChunk.from_message(msg).make_available_and_spill(
83
+ context.br(), allow_overbooking=True
84
+ ),
85
+ ),
86
+ )
87
+ seq_num_offset += num_ch_chunks
88
+
89
+ await ch_out.data.drain(context)
90
+
91
+
92
+ @generate_ir_sub_network.register(Union)
93
+ def _(
94
+ ir: Union, rec: SubNetGenerator
95
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
96
+ # Union operation.
97
+ # Pass-through all child chunks in channel order.
98
+
99
+ # Process children
100
+ nodes, channels = process_children(ir, rec)
101
+
102
+ # Create output ChannelManager
103
+ channels[ir] = ChannelManager(rec.state["context"])
104
+
105
+ # Add simple python node
106
+ nodes[ir] = [
107
+ union_node(
108
+ rec.state["context"],
109
+ ir,
110
+ rec.state["ir_context"],
111
+ channels[ir].reserve_input_slot(),
112
+ *[channels[c].reserve_output_slot() for c in ir.children],
113
+ )
114
+ ]
115
+ return nodes, channels
@@ -0,0 +1,374 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Utility functions and classes for the RapidsMPF streaming runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import asyncio
8
+ import operator
9
+ from contextlib import asynccontextmanager, contextmanager
10
+ from dataclasses import dataclass
11
+ from functools import reduce
12
+ from typing import TYPE_CHECKING, Any
13
+
14
+ from rapidsmpf.streaming.chunks.arbitrary import ArbitraryChunk
15
+ from rapidsmpf.streaming.core.message import Message
16
+ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
17
+
18
+ import pylibcudf as plc
19
+
20
+ from cudf_polars.containers import DataFrame
21
+
22
+ if TYPE_CHECKING:
23
+ from collections.abc import AsyncIterator, Callable, Iterator
24
+
25
+ from rapidsmpf.memory.memory_reservation import MemoryReservation
26
+ from rapidsmpf.streaming.core.channel import Channel
27
+ from rapidsmpf.streaming.core.context import Context
28
+ from rapidsmpf.streaming.core.spillable_messages import SpillableMessages
29
+
30
+ from rmm.pylibrmm.stream import Stream
31
+
32
+ from cudf_polars.dsl.ir import IR
33
+ from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator
34
+
35
+
36
+ @asynccontextmanager
37
+ async def shutdown_on_error(
38
+ context: Context, *channels: Channel[Any]
39
+ ) -> AsyncIterator[None]:
40
+ """
41
+ Shutdown on error for rapidsmpf.
42
+
43
+ Parameters
44
+ ----------
45
+ context
46
+ The rapidsmpf context.
47
+ channels
48
+ The channels to shutdown.
49
+ """
50
+ # TODO: This probably belongs in rapidsmpf.
51
+ try:
52
+ yield
53
+ except BaseException:
54
+ await asyncio.gather(*(ch.shutdown(context) for ch in channels))
55
+ raise
56
+
57
+
58
+ class Metadata:
59
+ """Metadata payload for an individual ChannelPair."""
60
+
61
+ __slots__ = ("count", "duplicated", "partitioned_on")
62
+ count: int
63
+ """Chunk-count estimate."""
64
+ partitioned_on: tuple[str, ...]
65
+ """Partitioned-on columns."""
66
+ duplicated: bool
67
+ """Whether the data is duplicated on all workers."""
68
+
69
+ def __init__(
70
+ self,
71
+ count: int,
72
+ *,
73
+ partitioned_on: tuple[str, ...] = (),
74
+ duplicated: bool = False,
75
+ ):
76
+ self.count = count
77
+ self.partitioned_on = partitioned_on
78
+ self.duplicated = duplicated
79
+
80
+
81
+ @dataclass
82
+ class ChannelPair:
83
+ """
84
+ A pair of channels for metadata and table data.
85
+
86
+ This abstraction ensures that metadata and data are kept separate,
87
+ avoiding ordering issues and making the code more type-safe.
88
+
89
+ Attributes
90
+ ----------
91
+ metadata :
92
+ Channel for metadata.
93
+ data :
94
+ Channel for table data chunks.
95
+
96
+ Notes
97
+ -----
98
+ This is a placeholder implementation. The metadata channel exists
99
+ but is not used yet. Metadata handling will be fully implemented
100
+ in follow-up work.
101
+ """
102
+
103
+ metadata: Channel[ArbitraryChunk]
104
+ data: Channel[TableChunk]
105
+
106
+ @classmethod
107
+ def create(cls, context: Context) -> ChannelPair:
108
+ """Create a new ChannelPair with fresh channels."""
109
+ return cls(
110
+ metadata=context.create_channel(),
111
+ data=context.create_channel(),
112
+ )
113
+
114
+ async def send_metadata(self, ctx: Context, metadata: Metadata) -> None:
115
+ """
116
+ Send metadata and drain the metadata channel.
117
+
118
+ Parameters
119
+ ----------
120
+ ctx :
121
+ The streaming context.
122
+ metadata :
123
+ The metadata to send.
124
+ """
125
+ msg = Message(0, ArbitraryChunk(metadata))
126
+ await self.metadata.send(ctx, msg)
127
+ await self.metadata.drain(ctx)
128
+
129
+ async def recv_metadata(self, ctx: Context) -> Metadata:
130
+ """
131
+ Receive metadata from the metadata channel.
132
+
133
+ Parameters
134
+ ----------
135
+ ctx :
136
+ The streaming context.
137
+
138
+ Returns
139
+ -------
140
+ ChunkMetadata
141
+ The metadata, or None if channel is drained.
142
+ """
143
+ msg = await self.metadata.recv(ctx)
144
+ assert msg is not None, f"Expected Metadata message, got {msg}."
145
+ return ArbitraryChunk.from_message(msg).release()
146
+
147
+
148
+ class ChannelManager:
149
+ """A utility class for managing ChannelPair objects."""
150
+
151
+ def __init__(self, context: Context, *, count: int = 1):
152
+ """
153
+ Initialize the ChannelManager with a given number of ChannelPair slots.
154
+
155
+ Parameters
156
+ ----------
157
+ context
158
+ The rapidsmpf context.
159
+ count: int
160
+ The number of ChannelPair slots to allocate.
161
+ """
162
+ self._channel_slots = [ChannelPair.create(context) for _ in range(count)]
163
+ self._reserved_output_slots: int = 0
164
+ self._reserved_input_slots: int = 0
165
+
166
+ def reserve_input_slot(self) -> ChannelPair:
167
+ """
168
+ Reserve an input channel-pair slot.
169
+
170
+ Returns
171
+ -------
172
+ The reserved ChannelPair.
173
+ """
174
+ if self._reserved_input_slots >= len(self._channel_slots):
175
+ raise ValueError("No more input channel-pair slots available")
176
+ pair = self._channel_slots[self._reserved_input_slots]
177
+ self._reserved_input_slots += 1
178
+ return pair
179
+
180
+ def reserve_output_slot(self) -> ChannelPair:
181
+ """
182
+ Reserve an output channel-pair slot.
183
+
184
+ Returns
185
+ -------
186
+ The reserved ChannelPair.
187
+ """
188
+ if self._reserved_output_slots >= len(self._channel_slots):
189
+ raise ValueError("No more output channel-pair slots available")
190
+ pair = self._channel_slots[self._reserved_output_slots]
191
+ self._reserved_output_slots += 1
192
+ return pair
193
+
194
+
195
+ def process_children(
196
+ ir: IR, rec: SubNetGenerator
197
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
198
+ """
199
+ Process children IR nodes and aggregate their nodes and channels.
200
+
201
+ This helper function recursively processes all children of an IR node,
202
+ collects their streaming network nodes into a dictionary mapping IR nodes
203
+ to their associated nodes, and merges their channel dictionaries.
204
+
205
+ Parameters
206
+ ----------
207
+ ir
208
+ The IR node whose children should be processed.
209
+ rec
210
+ Recursive SubNetGenerator callable.
211
+
212
+ Returns
213
+ -------
214
+ nodes
215
+ Dictionary mapping each IR node to its list of streaming network nodes.
216
+ channels
217
+ Dictionary mapping each child IR node to its ChannelManager.
218
+ """
219
+ if not ir.children:
220
+ return {}, {}
221
+
222
+ _nodes_list, _channels_list = zip(*(rec(c) for c in ir.children), strict=True)
223
+ nodes: dict[IR, list[Any]] = reduce(operator.or_, _nodes_list)
224
+ channels: dict[IR, ChannelManager] = reduce(operator.or_, _channels_list)
225
+ return nodes, channels
226
+
227
+
228
+ def empty_table_chunk(ir: IR, context: Context, stream: Stream) -> TableChunk:
229
+ """
230
+ Make an empty table chunk.
231
+
232
+ Parameters
233
+ ----------
234
+ ir
235
+ The IR node to use for the schema.
236
+ context
237
+ The rapidsmpf context.
238
+ stream
239
+ The stream to use for the table chunk.
240
+
241
+ Returns
242
+ -------
243
+ The empty table chunk.
244
+ """
245
+ # Create an empty table with the correct schema
246
+ # Use dtype.plc_type to get the full DataType (preserves precision/scale for Decimals)
247
+ empty_columns = [
248
+ plc.column_factories.make_empty_column(dtype.plc_type, stream=stream)
249
+ for dtype in ir.schema.values()
250
+ ]
251
+ empty_table = plc.Table(empty_columns)
252
+
253
+ return TableChunk.from_pylibcudf_table(
254
+ empty_table,
255
+ stream,
256
+ exclusive_view=True,
257
+ )
258
+
259
+
260
+ def chunk_to_frame(chunk: TableChunk, ir: IR) -> DataFrame:
261
+ """
262
+ Convert a TableChunk to a DataFrame.
263
+
264
+ Parameters
265
+ ----------
266
+ chunk
267
+ The TableChunk to convert.
268
+ ir
269
+ The IR node to use for the schema.
270
+
271
+ Returns
272
+ -------
273
+ A DataFrame.
274
+ """
275
+ return DataFrame.from_table(
276
+ chunk.table_view(),
277
+ list(ir.schema.keys()),
278
+ list(ir.schema.values()),
279
+ chunk.stream,
280
+ )
281
+
282
+
283
+ def make_spill_function(
284
+ spillable_messages_list: list[SpillableMessages],
285
+ context: Context,
286
+ ) -> Callable[[int], int]:
287
+ """
288
+ Create a spill function for a list of SpillableMessages containers.
289
+
290
+ This utility creates a spill function that can be registered with a
291
+ SpillManager. The spill function uses a smart spilling strategy that
292
+ prioritizes:
293
+ 1. Longest queues first (slow consumers that won't need data soon)
294
+ 2. Newest messages first (just arrived, won't be consumed soon)
295
+
296
+ This strategy keeps "hot" data (about to be consumed) in fast memory
297
+ while spilling "cold" data (won't be needed for a while) to slower tiers.
298
+
299
+ Parameters
300
+ ----------
301
+ spillable_messages_list
302
+ List of SpillableMessages containers to create a spill function for.
303
+ context
304
+ The RapidsMPF context to use for accessing the BufferResource.
305
+
306
+ Returns
307
+ -------
308
+ A spill function that takes an amount (in bytes) and returns the
309
+ actual amount spilled (in bytes).
310
+
311
+ Notes
312
+ -----
313
+ The spilling strategy is particularly effective for fanout scenarios
314
+ where different consumers may process messages at different rates. By
315
+ prioritizing longest queues and newest messages, we maximize the time
316
+ data can remain in slower memory before it's needed.
317
+ """
318
+
319
+ def spill_func(amount: int) -> int:
320
+ """Spill messages from the buffers to free device/host memory."""
321
+ spilled = 0
322
+
323
+ # Collect all messages with metadata for smart spilling
324
+ # Format: (message_id, container_idx, queue_length, sm)
325
+ all_messages: list[tuple[int, int, int, SpillableMessages]] = []
326
+ for container_idx, sm in enumerate(spillable_messages_list):
327
+ content_descriptions = sm.get_content_descriptions()
328
+ queue_length = len(content_descriptions)
329
+ all_messages.extend(
330
+ (message_id, container_idx, queue_length, sm)
331
+ for message_id in content_descriptions
332
+ )
333
+
334
+ # Spill newest messages first from the longest queues
335
+ # Sort by: (1) queue length descending, (2) message_id descending
336
+ # This prioritizes:
337
+ # - Longest queues (slow consumers that won't need data soon)
338
+ # - Newest messages (just arrived, won't be consumed soon)
339
+ all_messages.sort(key=lambda x: (-x[2], -x[0]))
340
+
341
+ # Spill messages until we've freed enough memory
342
+ for message_id, _, _, sm in all_messages:
343
+ if spilled >= amount:
344
+ break
345
+ # Try to spill this message
346
+ spilled += sm.spill(mid=message_id, br=context.br())
347
+
348
+ return spilled
349
+
350
+ return spill_func
351
+
352
+
353
+ @contextmanager
354
+ def opaque_reservation(
355
+ context: Context,
356
+ estimated_bytes: int,
357
+ ) -> Iterator[MemoryReservation]:
358
+ """
359
+ Reserve memory for opaque allocations.
360
+
361
+ Parameters
362
+ ----------
363
+ context
364
+ The RapidsMPF context.
365
+ estimated_bytes
366
+ The estimated number of bytes to reserve.
367
+
368
+ Yields
369
+ ------
370
+ The memory reservation.
371
+ """
372
+ yield context.br().reserve_device_memory_and_spill(
373
+ estimated_bytes, allow_overbooking=True
374
+ )