cudf-polars-cu13 25.12.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +28 -7
  4. cudf_polars/containers/column.py +51 -26
  5. cudf_polars/dsl/expressions/binaryop.py +1 -1
  6. cudf_polars/dsl/expressions/boolean.py +1 -1
  7. cudf_polars/dsl/expressions/selection.py +1 -1
  8. cudf_polars/dsl/expressions/string.py +29 -20
  9. cudf_polars/dsl/expressions/ternary.py +25 -1
  10. cudf_polars/dsl/expressions/unary.py +11 -8
  11. cudf_polars/dsl/ir.py +351 -281
  12. cudf_polars/dsl/translate.py +18 -15
  13. cudf_polars/dsl/utils/aggregations.py +10 -5
  14. cudf_polars/experimental/base.py +10 -0
  15. cudf_polars/experimental/benchmarks/pdsh.py +1 -1
  16. cudf_polars/experimental/benchmarks/utils.py +83 -2
  17. cudf_polars/experimental/distinct.py +2 -0
  18. cudf_polars/experimental/explain.py +1 -1
  19. cudf_polars/experimental/expressions.py +8 -5
  20. cudf_polars/experimental/groupby.py +2 -0
  21. cudf_polars/experimental/io.py +64 -42
  22. cudf_polars/experimental/join.py +15 -2
  23. cudf_polars/experimental/parallel.py +10 -7
  24. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  25. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  26. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  27. cudf_polars/experimental/rapidsmpf/{shuffle.py → collectives/shuffle.py} +90 -114
  28. cudf_polars/experimental/rapidsmpf/core.py +194 -67
  29. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  30. cudf_polars/experimental/rapidsmpf/dispatch.py +6 -3
  31. cudf_polars/experimental/rapidsmpf/io.py +162 -70
  32. cudf_polars/experimental/rapidsmpf/join.py +162 -77
  33. cudf_polars/experimental/rapidsmpf/nodes.py +421 -180
  34. cudf_polars/experimental/rapidsmpf/repartition.py +130 -65
  35. cudf_polars/experimental/rapidsmpf/union.py +24 -5
  36. cudf_polars/experimental/rapidsmpf/utils.py +228 -16
  37. cudf_polars/experimental/shuffle.py +18 -4
  38. cudf_polars/experimental/sort.py +13 -6
  39. cudf_polars/experimental/spilling.py +1 -1
  40. cudf_polars/testing/plugin.py +6 -3
  41. cudf_polars/utils/config.py +67 -0
  42. cudf_polars/utils/versions.py +3 -3
  43. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/METADATA +9 -10
  44. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/RECORD +47 -43
  45. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  46. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  47. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Re-chunking logic for the RapidsMPF streaming runtime."""
4
4
 
@@ -7,14 +7,21 @@ from __future__ import annotations
7
7
  import math
8
8
  from typing import TYPE_CHECKING, Any
9
9
 
10
+ from rapidsmpf.memory.buffer import MemoryType
10
11
  from rapidsmpf.streaming.core.message import Message
11
12
  from rapidsmpf.streaming.core.node import define_py_node
12
13
  from rapidsmpf.streaming.cudf.table_chunk import TableChunk
13
14
 
14
15
  from cudf_polars.containers import DataFrame
16
+ from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager
15
17
  from cudf_polars.experimental.rapidsmpf.dispatch import generate_ir_sub_network
16
18
  from cudf_polars.experimental.rapidsmpf.nodes import shutdown_on_error
17
- from cudf_polars.experimental.rapidsmpf.utils import ChannelManager
19
+ from cudf_polars.experimental.rapidsmpf.utils import (
20
+ ChannelManager,
21
+ Metadata,
22
+ empty_table_chunk,
23
+ opaque_reservation,
24
+ )
18
25
  from cudf_polars.experimental.repartition import Repartition
19
26
  from cudf_polars.experimental.utils import _concat
20
27
 
@@ -34,7 +41,8 @@ async def concatenate_node(
34
41
  ch_out: ChannelPair,
35
42
  ch_in: ChannelPair,
36
43
  *,
37
- max_chunks: int | None,
44
+ output_count: int,
45
+ collective_id: int,
38
46
  ) -> None:
39
47
  """
40
48
  Concatenate node for rapidsmpf.
@@ -51,66 +59,122 @@ async def concatenate_node(
51
59
  The output ChannelPair.
52
60
  ch_in
53
61
  The input ChannelPair.
54
- max_chunks
55
- The maximum number of chunks to concatenate at once.
56
- If `None`, concatenate all input chunks.
62
+ output_count
63
+ The expected number of output chunks.
64
+ collective_id
65
+ Pre-allocated collective ID for this operation.
57
66
  """
58
- # TODO: Use multiple streams
59
- max_chunks = max(2, max_chunks) if max_chunks else None
60
- async with shutdown_on_error(context, ch_in.data, ch_out.data):
61
- seq_num = 0
62
- while True:
63
- chunks: list[TableChunk] = []
64
- msg: TableChunk | None = None
65
-
66
- # Collect chunks up to max_chunks or until end of stream
67
- while len(chunks) < (max_chunks or float("inf")):
68
- msg = await ch_in.data.recv(context)
69
- if msg is None:
70
- break
71
- chunks.append(
72
- TableChunk.from_message(msg).make_available_and_spill(
73
- context.br(), allow_overbooking=True
74
- )
67
+ async with shutdown_on_error(
68
+ context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
69
+ ):
70
+ # Receive metadata.
71
+ input_metadata = await ch_in.recv_metadata(context)
72
+ metadata = Metadata(output_count)
73
+
74
+ # max_chunks corresponds to the number of chunks we can
75
+ # concatenate together. If None, we must concatenate everything.
76
+ # Since a single-partition operation gets "special treatment",
77
+ # we must make sure `output_count == 1` is always satisfied.
78
+ max_chunks: int | None = None
79
+ if output_count > 1:
80
+ # Make sure max_chunks is at least 2.
81
+ max_chunks = max(2, math.ceil(input_metadata.count / output_count))
82
+
83
+ # Check if we need global communication.
84
+ need_global_repartition = (
85
+ # Avoid allgather of already-duplicated data
86
+ context.comm().nranks > 1
87
+ and not input_metadata.duplicated
88
+ and output_count == 1
89
+ )
90
+
91
+ chunks: list[TableChunk]
92
+ msg: TableChunk | None
93
+ if need_global_repartition:
94
+ # Assume this means "global repartitioning" for now
95
+
96
+ # Send metadata.
97
+ metadata.duplicated = True
98
+ await ch_out.send_metadata(context, metadata)
99
+
100
+ allgather = AllGatherManager(context, collective_id)
101
+ stream = context.get_stream_from_pool()
102
+ seq_num = 0
103
+ while (msg := await ch_in.data.recv(context)) is not None:
104
+ allgather.insert(seq_num, TableChunk.from_message(msg))
105
+ seq_num += 1
106
+ del msg
107
+ allgather.insert_finished()
108
+
109
+ # Extract concatenated result
110
+ result_table = await allgather.extract_concatenated(stream)
111
+
112
+ # If no chunks were gathered, result_table has 0 columns.
113
+ # We need to create an empty table with the correct schema.
114
+ if result_table.num_columns() == 0 and len(ir.schema) > 0:
115
+ output_chunk = empty_table_chunk(ir, context, stream)
116
+ else:
117
+ output_chunk = TableChunk.from_pylibcudf_table(
118
+ result_table, stream, exclusive_view=True
75
119
  )
76
120
 
77
- # Process collected chunks
78
- if chunks:
79
- df = (
80
- DataFrame.from_table(
81
- chunks[0].table_view(),
82
- list(ir.schema.keys()),
83
- list(ir.schema.values()),
84
- chunks[0].stream,
121
+ await ch_out.data.send(context, Message(0, output_chunk))
122
+ else:
123
+ # Send metadata.
124
+ metadata.duplicated = input_metadata.duplicated
125
+ await ch_out.send_metadata(context, metadata)
126
+
127
+ # Local repartitioning
128
+ seq_num = 0
129
+ while True:
130
+ chunks = []
131
+ done_receiving = False
132
+
133
+ # Collect chunks up to max_chunks or until end of stream
134
+ while len(chunks) < (max_chunks or float("inf")):
135
+ msg = await ch_in.data.recv(context)
136
+ if msg is None:
137
+ done_receiving = True
138
+ break
139
+ chunks.append(
140
+ TableChunk.from_message(msg).make_available_and_spill(
141
+ context.br(), allow_overbooking=True
142
+ )
85
143
  )
86
- if len(chunks) == 1
87
- else _concat(
88
- *(
89
- DataFrame.from_table(
90
- chunk.table_view(),
91
- list(ir.schema.keys()),
92
- list(ir.schema.values()),
93
- chunk.stream,
94
- )
95
- for chunk in chunks
96
- ),
97
- context=ir_context,
98
- )
99
- )
100
- await ch_out.data.send(
101
- context,
102
- Message(
103
- seq_num,
104
- TableChunk.from_pylibcudf_table(
105
- df.table, df.stream, exclusive_view=True
106
- ),
107
- ),
108
- )
109
- seq_num += 1
144
+ del msg
110
145
 
111
- # Break if we reached end of stream
112
- if msg is None:
113
- break
146
+ if chunks:
147
+ input_bytes = sum(
148
+ chunk.data_alloc_size(MemoryType.DEVICE) for chunk in chunks
149
+ )
150
+ with opaque_reservation(context, input_bytes):
151
+ df = _concat(
152
+ *(
153
+ DataFrame.from_table(
154
+ chunk.table_view(),
155
+ list(ir.schema.keys()),
156
+ list(ir.schema.values()),
157
+ chunk.stream,
158
+ )
159
+ for chunk in chunks
160
+ ),
161
+ context=ir_context,
162
+ )
163
+ await ch_out.data.send(
164
+ context,
165
+ Message(
166
+ seq_num,
167
+ TableChunk.from_pylibcudf_table(
168
+ df.table, df.stream, exclusive_view=True
169
+ ),
170
+ ),
171
+ )
172
+ seq_num += 1
173
+ del df, chunks
174
+
175
+ # Break if we reached end of stream
176
+ if done_receiving:
177
+ break
114
178
 
115
179
  await ch_out.data.drain(context)
116
180
 
@@ -118,18 +182,15 @@ async def concatenate_node(
118
182
  @generate_ir_sub_network.register(Repartition)
119
183
  def _(
120
184
  ir: Repartition, rec: SubNetGenerator
121
- ) -> tuple[list[Any], dict[IR, ChannelManager]]:
185
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
122
186
  # Repartition node.
123
187
 
124
188
  partition_info = rec.state["partition_info"]
125
- max_chunks: int | None = None
126
189
  if partition_info[ir].count > 1:
127
190
  count_output = partition_info[ir].count
128
191
  count_input = partition_info[ir.children[0]].count
129
192
  if count_input < count_output:
130
193
  raise ValueError("Repartitioning to more chunks is not supported.")
131
- # Make sure max_chunks is at least 2
132
- max_chunks = max(2, math.ceil(count_input / count_output))
133
194
 
134
195
  # Process children
135
196
  nodes, channels = rec(ir.children[0])
@@ -137,15 +198,19 @@ def _(
137
198
  # Create output ChannelManager
138
199
  channels[ir] = ChannelManager(rec.state["context"])
139
200
 
201
+ # Look up the reserved shuffle ID for this operation
202
+ collective_id = rec.state["collective_id_map"][ir]
203
+
140
204
  # Add python node
141
- nodes.append(
205
+ nodes[ir] = [
142
206
  concatenate_node(
143
207
  rec.state["context"],
144
208
  ir,
145
209
  rec.state["ir_context"],
146
210
  channels[ir].reserve_input_slot(),
147
211
  channels[ir.children[0]].reserve_output_slot(),
148
- max_chunks=max_chunks,
212
+ output_count=partition_info[ir].count,
213
+ collective_id=collective_id,
149
214
  )
150
- )
215
+ ]
151
216
  return nodes, channels
@@ -16,6 +16,7 @@ from cudf_polars.experimental.rapidsmpf.dispatch import (
16
16
  from cudf_polars.experimental.rapidsmpf.nodes import define_py_node, shutdown_on_error
17
17
  from cudf_polars.experimental.rapidsmpf.utils import (
18
18
  ChannelManager,
19
+ Metadata,
19
20
  process_children,
20
21
  )
21
22
 
@@ -51,8 +52,24 @@ async def union_node(
51
52
  chs_in
52
53
  The input ChannelPairs.
53
54
  """
54
- # TODO: Use multiple streams
55
- async with shutdown_on_error(context, *[ch.data for ch in chs_in], ch_out.data):
55
+ async with shutdown_on_error(
56
+ context,
57
+ *[ch.metadata for ch in chs_in],
58
+ *[ch.data for ch in chs_in],
59
+ ch_out.metadata,
60
+ ch_out.data,
61
+ ):
62
+ # Merge and forward metadata.
63
+ total_count = 0
64
+ duplicated = True
65
+ for ch_in in chs_in:
66
+ metadata = await ch_in.recv_metadata(context)
67
+ total_count += metadata.count
68
+ duplicated = duplicated and metadata.duplicated
69
+ await ch_out.send_metadata(
70
+ context, Metadata(total_count, duplicated=duplicated)
71
+ )
72
+
56
73
  seq_num_offset = 0
57
74
  for ch_in in chs_in:
58
75
  num_ch_chunks = 0
@@ -73,7 +90,9 @@ async def union_node(
73
90
 
74
91
 
75
92
  @generate_ir_sub_network.register(Union)
76
- def _(ir: Union, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManager]]:
93
+ def _(
94
+ ir: Union, rec: SubNetGenerator
95
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
77
96
  # Union operation.
78
97
  # Pass-through all child chunks in channel order.
79
98
 
@@ -84,7 +103,7 @@ def _(ir: Union, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManag
84
103
  channels[ir] = ChannelManager(rec.state["context"])
85
104
 
86
105
  # Add simple python node
87
- nodes.append(
106
+ nodes[ir] = [
88
107
  union_node(
89
108
  rec.state["context"],
90
109
  ir,
@@ -92,5 +111,5 @@ def _(ir: Union, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManag
92
111
  channels[ir].reserve_input_slot(),
93
112
  *[channels[c].reserve_output_slot() for c in ir.children],
94
113
  )
95
- )
114
+ ]
96
115
  return nodes, channels
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Utility functions and classes for the RapidsMPF streaming runtime."""
4
4
 
@@ -6,26 +6,33 @@ from __future__ import annotations
6
6
 
7
7
  import asyncio
8
8
  import operator
9
- from contextlib import asynccontextmanager
9
+ from contextlib import asynccontextmanager, contextmanager
10
10
  from dataclasses import dataclass
11
11
  from functools import reduce
12
- from typing import TYPE_CHECKING, Any, TypeAlias
12
+ from typing import TYPE_CHECKING, Any
13
+
14
+ from rapidsmpf.streaming.chunks.arbitrary import ArbitraryChunk
15
+ from rapidsmpf.streaming.core.message import Message
16
+ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
17
+
18
+ import pylibcudf as plc
19
+
20
+ from cudf_polars.containers import DataFrame
13
21
 
14
22
  if TYPE_CHECKING:
15
- from collections.abc import AsyncIterator
23
+ from collections.abc import AsyncIterator, Callable, Iterator
16
24
 
25
+ from rapidsmpf.memory.memory_reservation import MemoryReservation
17
26
  from rapidsmpf.streaming.core.channel import Channel
18
27
  from rapidsmpf.streaming.core.context import Context
19
- from rapidsmpf.streaming.cudf.table_chunk import TableChunk
28
+ from rapidsmpf.streaming.core.spillable_messages import SpillableMessages
29
+
30
+ from rmm.pylibrmm.stream import Stream
20
31
 
21
32
  from cudf_polars.dsl.ir import IR
22
33
  from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator
23
34
 
24
35
 
25
- # Type alias for metadata payloads (placeholder - not used yet)
26
- MetadataPayload: TypeAlias = Any
27
-
28
-
29
36
  @asynccontextmanager
30
37
  async def shutdown_on_error(
31
38
  context: Context, *channels: Channel[Any]
@@ -48,6 +55,29 @@ async def shutdown_on_error(
48
55
  raise
49
56
 
50
57
 
58
+ class Metadata:
59
+ """Metadata payload for an individual ChannelPair."""
60
+
61
+ __slots__ = ("count", "duplicated", "partitioned_on")
62
+ count: int
63
+ """Chunk-count estimate."""
64
+ partitioned_on: tuple[str, ...]
65
+ """Partitioned-on columns."""
66
+ duplicated: bool
67
+ """Whether the data is duplicated on all workers."""
68
+
69
+ def __init__(
70
+ self,
71
+ count: int,
72
+ *,
73
+ partitioned_on: tuple[str, ...] = (),
74
+ duplicated: bool = False,
75
+ ):
76
+ self.count = count
77
+ self.partitioned_on = partitioned_on
78
+ self.duplicated = duplicated
79
+
80
+
51
81
  @dataclass
52
82
  class ChannelPair:
53
83
  """
@@ -70,7 +100,7 @@ class ChannelPair:
70
100
  in follow-up work.
71
101
  """
72
102
 
73
- metadata: Channel[MetadataPayload]
103
+ metadata: Channel[ArbitraryChunk]
74
104
  data: Channel[TableChunk]
75
105
 
76
106
  @classmethod
@@ -81,6 +111,39 @@ class ChannelPair:
81
111
  data=context.create_channel(),
82
112
  )
83
113
 
114
+ async def send_metadata(self, ctx: Context, metadata: Metadata) -> None:
115
+ """
116
+ Send metadata and drain the metadata channel.
117
+
118
+ Parameters
119
+ ----------
120
+ ctx :
121
+ The streaming context.
122
+ metadata :
123
+ The metadata to send.
124
+ """
125
+ msg = Message(0, ArbitraryChunk(metadata))
126
+ await self.metadata.send(ctx, msg)
127
+ await self.metadata.drain(ctx)
128
+
129
+ async def recv_metadata(self, ctx: Context) -> Metadata:
130
+ """
131
+ Receive metadata from the metadata channel.
132
+
133
+ Parameters
134
+ ----------
135
+ ctx :
136
+ The streaming context.
137
+
138
+ Returns
139
+ -------
140
+ ChunkMetadata
141
+ The metadata, or None if channel is drained.
142
+ """
143
+ msg = await self.metadata.recv(ctx)
144
+ assert msg is not None, f"Expected Metadata message, got {msg}."
145
+ return ArbitraryChunk.from_message(msg).release()
146
+
84
147
 
85
148
  class ChannelManager:
86
149
  """A utility class for managing ChannelPair objects."""
@@ -131,13 +194,13 @@ class ChannelManager:
131
194
 
132
195
  def process_children(
133
196
  ir: IR, rec: SubNetGenerator
134
- ) -> tuple[list[Any], dict[IR, ChannelManager]]:
197
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
135
198
  """
136
199
  Process children IR nodes and aggregate their nodes and channels.
137
200
 
138
201
  This helper function recursively processes all children of an IR node,
139
- collects their streaming network nodes into a flat list, and merges
140
- their channel dictionaries.
202
+ collects their streaming network nodes into a dictionary mapping IR nodes
203
+ to their associated nodes, and merges their channel dictionaries.
141
204
 
142
205
  Parameters
143
206
  ----------
@@ -149,14 +212,163 @@ def process_children(
149
212
  Returns
150
213
  -------
151
214
  nodes
152
- Flat list of all streaming network nodes from all children.
215
+ Dictionary mapping each IR node to its list of streaming network nodes.
153
216
  channels
154
217
  Dictionary mapping each child IR node to its ChannelManager.
155
218
  """
156
219
  if not ir.children:
157
- return [], {}
220
+ return {}, {}
158
221
 
159
222
  _nodes_list, _channels_list = zip(*(rec(c) for c in ir.children), strict=True)
160
- nodes: list[Any] = list(reduce(operator.add, _nodes_list, []))
223
+ nodes: dict[IR, list[Any]] = reduce(operator.or_, _nodes_list)
161
224
  channels: dict[IR, ChannelManager] = reduce(operator.or_, _channels_list)
162
225
  return nodes, channels
226
+
227
+
228
+ def empty_table_chunk(ir: IR, context: Context, stream: Stream) -> TableChunk:
229
+ """
230
+ Make an empty table chunk.
231
+
232
+ Parameters
233
+ ----------
234
+ ir
235
+ The IR node to use for the schema.
236
+ context
237
+ The rapidsmpf context.
238
+ stream
239
+ The stream to use for the table chunk.
240
+
241
+ Returns
242
+ -------
243
+ The empty table chunk.
244
+ """
245
+ # Create an empty table with the correct schema
246
+ # Use dtype.plc_type to get the full DataType (preserves precision/scale for Decimals)
247
+ empty_columns = [
248
+ plc.column_factories.make_empty_column(dtype.plc_type, stream=stream)
249
+ for dtype in ir.schema.values()
250
+ ]
251
+ empty_table = plc.Table(empty_columns)
252
+
253
+ return TableChunk.from_pylibcudf_table(
254
+ empty_table,
255
+ stream,
256
+ exclusive_view=True,
257
+ )
258
+
259
+
260
+ def chunk_to_frame(chunk: TableChunk, ir: IR) -> DataFrame:
261
+ """
262
+ Convert a TableChunk to a DataFrame.
263
+
264
+ Parameters
265
+ ----------
266
+ chunk
267
+ The TableChunk to convert.
268
+ ir
269
+ The IR node to use for the schema.
270
+
271
+ Returns
272
+ -------
273
+ A DataFrame.
274
+ """
275
+ return DataFrame.from_table(
276
+ chunk.table_view(),
277
+ list(ir.schema.keys()),
278
+ list(ir.schema.values()),
279
+ chunk.stream,
280
+ )
281
+
282
+
283
+ def make_spill_function(
284
+ spillable_messages_list: list[SpillableMessages],
285
+ context: Context,
286
+ ) -> Callable[[int], int]:
287
+ """
288
+ Create a spill function for a list of SpillableMessages containers.
289
+
290
+ This utility creates a spill function that can be registered with a
291
+ SpillManager. The spill function uses a smart spilling strategy that
292
+ prioritizes:
293
+ 1. Longest queues first (slow consumers that won't need data soon)
294
+ 2. Newest messages first (just arrived, won't be consumed soon)
295
+
296
+ This strategy keeps "hot" data (about to be consumed) in fast memory
297
+ while spilling "cold" data (won't be needed for a while) to slower tiers.
298
+
299
+ Parameters
300
+ ----------
301
+ spillable_messages_list
302
+ List of SpillableMessages containers to create a spill function for.
303
+ context
304
+ The RapidsMPF context to use for accessing the BufferResource.
305
+
306
+ Returns
307
+ -------
308
+ A spill function that takes an amount (in bytes) and returns the
309
+ actual amount spilled (in bytes).
310
+
311
+ Notes
312
+ -----
313
+ The spilling strategy is particularly effective for fanout scenarios
314
+ where different consumers may process messages at different rates. By
315
+ prioritizing longest queues and newest messages, we maximize the time
316
+ data can remain in slower memory before it's needed.
317
+ """
318
+
319
+ def spill_func(amount: int) -> int:
320
+ """Spill messages from the buffers to free device/host memory."""
321
+ spilled = 0
322
+
323
+ # Collect all messages with metadata for smart spilling
324
+ # Format: (message_id, container_idx, queue_length, sm)
325
+ all_messages: list[tuple[int, int, int, SpillableMessages]] = []
326
+ for container_idx, sm in enumerate(spillable_messages_list):
327
+ content_descriptions = sm.get_content_descriptions()
328
+ queue_length = len(content_descriptions)
329
+ all_messages.extend(
330
+ (message_id, container_idx, queue_length, sm)
331
+ for message_id in content_descriptions
332
+ )
333
+
334
+ # Spill newest messages first from the longest queues
335
+ # Sort by: (1) queue length descending, (2) message_id descending
336
+ # This prioritizes:
337
+ # - Longest queues (slow consumers that won't need data soon)
338
+ # - Newest messages (just arrived, won't be consumed soon)
339
+ all_messages.sort(key=lambda x: (-x[2], -x[0]))
340
+
341
+ # Spill messages until we've freed enough memory
342
+ for message_id, _, _, sm in all_messages:
343
+ if spilled >= amount:
344
+ break
345
+ # Try to spill this message
346
+ spilled += sm.spill(mid=message_id, br=context.br())
347
+
348
+ return spilled
349
+
350
+ return spill_func
351
+
352
+
353
+ @contextmanager
354
+ def opaque_reservation(
355
+ context: Context,
356
+ estimated_bytes: int,
357
+ ) -> Iterator[MemoryReservation]:
358
+ """
359
+ Reserve memory for opaque allocations.
360
+
361
+ Parameters
362
+ ----------
363
+ context
364
+ The RapidsMPF context.
365
+ estimated_bytes
366
+ The estimated number of bytes to reserve.
367
+
368
+ Yields
369
+ ------
370
+ The memory reservation.
371
+ """
372
+ yield context.br().reserve_device_memory_and_spill(
373
+ estimated_bytes, allow_overbooking=True
374
+ )