cudf-polars-cu13 25.12.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +28 -7
  4. cudf_polars/containers/column.py +51 -26
  5. cudf_polars/dsl/expressions/binaryop.py +1 -1
  6. cudf_polars/dsl/expressions/boolean.py +1 -1
  7. cudf_polars/dsl/expressions/selection.py +1 -1
  8. cudf_polars/dsl/expressions/string.py +29 -20
  9. cudf_polars/dsl/expressions/ternary.py +25 -1
  10. cudf_polars/dsl/expressions/unary.py +11 -8
  11. cudf_polars/dsl/ir.py +351 -281
  12. cudf_polars/dsl/translate.py +18 -15
  13. cudf_polars/dsl/utils/aggregations.py +10 -5
  14. cudf_polars/experimental/base.py +10 -0
  15. cudf_polars/experimental/benchmarks/pdsh.py +1 -1
  16. cudf_polars/experimental/benchmarks/utils.py +83 -2
  17. cudf_polars/experimental/distinct.py +2 -0
  18. cudf_polars/experimental/explain.py +1 -1
  19. cudf_polars/experimental/expressions.py +8 -5
  20. cudf_polars/experimental/groupby.py +2 -0
  21. cudf_polars/experimental/io.py +64 -42
  22. cudf_polars/experimental/join.py +15 -2
  23. cudf_polars/experimental/parallel.py +10 -7
  24. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  25. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  26. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  27. cudf_polars/experimental/rapidsmpf/{shuffle.py → collectives/shuffle.py} +90 -114
  28. cudf_polars/experimental/rapidsmpf/core.py +194 -67
  29. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  30. cudf_polars/experimental/rapidsmpf/dispatch.py +6 -3
  31. cudf_polars/experimental/rapidsmpf/io.py +162 -70
  32. cudf_polars/experimental/rapidsmpf/join.py +162 -77
  33. cudf_polars/experimental/rapidsmpf/nodes.py +421 -180
  34. cudf_polars/experimental/rapidsmpf/repartition.py +130 -65
  35. cudf_polars/experimental/rapidsmpf/union.py +24 -5
  36. cudf_polars/experimental/rapidsmpf/utils.py +228 -16
  37. cudf_polars/experimental/shuffle.py +18 -4
  38. cudf_polars/experimental/sort.py +13 -6
  39. cudf_polars/experimental/spilling.py +1 -1
  40. cudf_polars/testing/plugin.py +6 -3
  41. cudf_polars/utils/config.py +67 -0
  42. cudf_polars/utils/versions.py +3 -3
  43. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/METADATA +9 -10
  44. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/RECORD +47 -43
  45. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  46. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  47. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Core node definitions for the RapidsMPF streaming runtime."""
4
4
 
@@ -7,17 +7,23 @@ from __future__ import annotations
7
7
  import asyncio
8
8
  from typing import TYPE_CHECKING, Any, cast
9
9
 
10
+ from rapidsmpf.memory.buffer import MemoryType
10
11
  from rapidsmpf.streaming.core.message import Message
11
12
  from rapidsmpf.streaming.core.node import define_py_node
13
+ from rapidsmpf.streaming.core.spillable_messages import SpillableMessages
12
14
  from rapidsmpf.streaming.cudf.table_chunk import TableChunk
13
15
 
14
16
  from cudf_polars.containers import DataFrame
15
- from cudf_polars.dsl.ir import IR, Empty
17
+ from cudf_polars.dsl.ir import IR, Cache, Empty, Filter, Projection
16
18
  from cudf_polars.experimental.rapidsmpf.dispatch import (
17
19
  generate_ir_sub_network,
18
20
  )
19
21
  from cudf_polars.experimental.rapidsmpf.utils import (
20
22
  ChannelManager,
23
+ Metadata,
24
+ empty_table_chunk,
25
+ make_spill_function,
26
+ opaque_reservation,
21
27
  process_children,
22
28
  shutdown_on_error,
23
29
  )
@@ -37,6 +43,8 @@ async def default_node_single(
37
43
  ir_context: IRExecutionContext,
38
44
  ch_out: ChannelPair,
39
45
  ch_in: ChannelPair,
46
+ *,
47
+ preserve_partitioning: bool = False,
40
48
  ) -> None:
41
49
  """
42
50
  Single-channel default node for rapidsmpf.
@@ -53,32 +61,71 @@ async def default_node_single(
53
61
  The output ChannelPair.
54
62
  ch_in
55
63
  The input ChannelPair.
64
+ preserve_partitioning
65
+ Whether to preserve the partitioning metadata of the input chunks.
56
66
 
57
67
  Notes
58
68
  -----
59
69
  Chunks are processed in the order they are received.
60
70
  """
61
- async with shutdown_on_error(context, ch_in.data, ch_out.data):
62
- while (msg := await ch_in.data.recv(context)) is not None:
63
- chunk = TableChunk.from_message(msg).make_available_and_spill(
64
- context.br(), allow_overbooking=True
65
- )
66
- seq_num = msg.sequence_number
67
- df = await asyncio.to_thread(
68
- ir.do_evaluate,
69
- *ir._non_child_args,
70
- DataFrame.from_table(
71
- chunk.table_view(),
72
- list(ir.children[0].schema.keys()),
73
- list(ir.children[0].schema.values()),
74
- chunk.stream,
75
- ),
76
- context=ir_context,
77
- )
78
- chunk = TableChunk.from_pylibcudf_table(
79
- df.table, chunk.stream, exclusive_view=True
80
- )
81
- await ch_out.data.send(context, Message(seq_num, chunk))
71
+ async with shutdown_on_error(
72
+ context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
73
+ ):
74
+ # Recv/send metadata.
75
+ metadata_in = await ch_in.recv_metadata(context)
76
+ metadata_out = Metadata(
77
+ metadata_in.count,
78
+ partitioned_on=metadata_in.partitioned_on if preserve_partitioning else (),
79
+ duplicated=metadata_in.duplicated,
80
+ )
81
+ await ch_out.send_metadata(context, metadata_out)
82
+
83
+ # Recv/send data.
84
+ seq_num = 0
85
+ receiving = True
86
+ received_any = False
87
+ while receiving:
88
+ msg = await ch_in.data.recv(context)
89
+ if msg is None:
90
+ receiving = False
91
+ if received_any:
92
+ break
93
+ else:
94
+ # Make sure we have an empty chunk in case do_evaluate
95
+ # always produces rows (e.g. aggregation)
96
+ stream = ir_context.get_cuda_stream()
97
+ chunk = empty_table_chunk(ir.children[0], context, stream)
98
+ else:
99
+ received_any = True
100
+ chunk = TableChunk.from_message(msg).make_available_and_spill(
101
+ context.br(), allow_overbooking=True
102
+ )
103
+ seq_num = msg.sequence_number
104
+ del msg
105
+
106
+ input_bytes = chunk.data_alloc_size(MemoryType.DEVICE)
107
+ with opaque_reservation(context, input_bytes):
108
+ df = await asyncio.to_thread(
109
+ ir.do_evaluate,
110
+ *ir._non_child_args,
111
+ DataFrame.from_table(
112
+ chunk.table_view(),
113
+ list(ir.children[0].schema.keys()),
114
+ list(ir.children[0].schema.values()),
115
+ chunk.stream,
116
+ ),
117
+ context=ir_context,
118
+ )
119
+ await ch_out.data.send(
120
+ context,
121
+ Message(
122
+ seq_num,
123
+ TableChunk.from_pylibcudf_table(
124
+ df.table, chunk.stream, exclusive_view=True
125
+ ),
126
+ ),
127
+ )
128
+ del df, chunk
82
129
 
83
130
  await ch_out.data.drain(context)
84
131
 
@@ -90,6 +137,8 @@ async def default_node_multi(
90
137
  ir_context: IRExecutionContext,
91
138
  ch_out: ChannelPair,
92
139
  chs_in: tuple[ChannelPair, ...],
140
+ *,
141
+ partitioning_index: int | None = None,
93
142
  ) -> None:
94
143
  """
95
144
  Pointwise node for rapidsmpf.
@@ -103,17 +152,30 @@ async def default_node_multi(
103
152
  ir_context
104
153
  The execution context for the IR node.
105
154
  ch_out
106
- The output ChannelPair (metadata already sent).
155
+ The output ChannelPair.
107
156
  chs_in
108
- Tuple of input ChannelPairs (metadata already received).
109
-
110
- Notes
111
- -----
112
- Input chunks must be aligned for evaluation. Messages from each input
113
- channel are assumed to arrive in sequence number order, so we only need
114
- to hold one chunk per channel at a time.
157
+ Tuple of input ChannelPairs.
158
+ partitioning_index
159
+ Index of the input channel to preserve partitioning information for.
160
+ If None, no partitioning information is preserved.
115
161
  """
116
- async with shutdown_on_error(context, *[ch.data for ch in chs_in], ch_out.data):
162
+ async with shutdown_on_error(
163
+ context,
164
+ *[ch.metadata for ch in chs_in],
165
+ ch_out.metadata,
166
+ *[ch.data for ch in chs_in],
167
+ ch_out.data,
168
+ ):
169
+ # Merge and forward basic metadata.
170
+ metadata = Metadata(1)
171
+ for idx, ch_in in enumerate(chs_in):
172
+ md_child = await ch_in.recv_metadata(context)
173
+ metadata.count = max(md_child.count, metadata.count)
174
+ metadata.duplicated = metadata.duplicated and md_child.duplicated
175
+ if idx == partitioning_index:
176
+ metadata.partitioned_on = md_child.partitioned_on
177
+ await ch_out.send_metadata(context, metadata)
178
+
117
179
  seq_num = 0
118
180
  n_children = len(chs_in)
119
181
  finished_channels: set[int] = set()
@@ -122,11 +184,10 @@ async def default_node_multi(
122
184
  ready_chunks: list[TableChunk | None] = [None] * n_children
123
185
  chunk_count: list[int] = [0] * n_children
124
186
 
187
+ # Recv/send data.
125
188
  while True:
126
189
  # Receive from all non-finished channels
127
- for ch_idx, (ch_in, _child) in enumerate(
128
- zip(chs_in, ir.children, strict=True)
129
- ):
190
+ for ch_idx, ch_in in enumerate(chs_in):
130
191
  if ch_idx in finished_channels:
131
192
  continue # This channel already finished, reuse its data
132
193
 
@@ -138,19 +199,20 @@ async def default_node_multi(
138
199
  # Store the new chunk (replacing previous if any)
139
200
  ready_chunks[ch_idx] = TableChunk.from_message(msg)
140
201
  chunk_count[ch_idx] += 1
141
- assert ready_chunks[ch_idx] is not None, (
142
- f"Channel {ch_idx} has no data after receive loop."
143
- )
202
+ del msg
144
203
 
145
204
  # If all channels finished, we're done
146
205
  if len(finished_channels) == n_children:
147
206
  break
148
207
 
149
- # Convert chunks to DataFrames right before evaluation
150
- # All chunks are guaranteed to be non-None by the assertion above
151
- assert all(chunk is not None for chunk in ready_chunks), (
152
- "All chunks must be non-None"
153
- )
208
+ # Check if any channel drained without providing data.
209
+ # If so, create an empty chunk for that channel.
210
+ for ch_idx, child in enumerate(ir.children):
211
+ if ready_chunks[ch_idx] is None:
212
+ # Channel drained without data - create empty chunk
213
+ stream = ir_context.get_cuda_stream()
214
+ ready_chunks[ch_idx] = empty_table_chunk(child, context, stream)
215
+
154
216
  # Ensure all table chunks are unspilled and available.
155
217
  ready_chunks = [
156
218
  chunk.make_available_and_spill(context.br(), allow_overbooking=True)
@@ -166,27 +228,33 @@ async def default_node_multi(
166
228
  for chunk, child in zip(ready_chunks, ir.children, strict=True)
167
229
  ]
168
230
 
169
- # Evaluate the IR node with current chunks
170
- df = await asyncio.to_thread(
171
- ir.do_evaluate,
172
- *ir._non_child_args,
173
- *dfs,
174
- context=ir_context,
231
+ input_bytes = sum(
232
+ chunk.data_alloc_size(MemoryType.DEVICE)
233
+ for chunk in cast(list[TableChunk], ready_chunks)
175
234
  )
176
- await ch_out.data.send(
177
- context,
178
- Message(
179
- seq_num,
180
- TableChunk.from_pylibcudf_table(
181
- df.table,
182
- df.stream,
183
- exclusive_view=True,
235
+ with opaque_reservation(context, input_bytes):
236
+ df = await asyncio.to_thread(
237
+ ir.do_evaluate,
238
+ *ir._non_child_args,
239
+ *dfs,
240
+ context=ir_context,
241
+ )
242
+ await ch_out.data.send(
243
+ context,
244
+ Message(
245
+ seq_num,
246
+ TableChunk.from_pylibcudf_table(
247
+ df.table,
248
+ df.stream,
249
+ exclusive_view=True,
250
+ ),
184
251
  ),
185
- ),
186
- )
187
- seq_num += 1
252
+ )
253
+ seq_num += 1
254
+ del df, dfs
188
255
 
189
256
  # Drain the output channel
257
+ del ready_chunks
190
258
  await ch_out.data.drain(context)
191
259
 
192
260
 
@@ -213,12 +281,23 @@ async def fanout_node_bounded(
213
281
  """
214
282
  # TODO: Use rapidsmpf fanout node once available.
215
283
  # See: https://github.com/rapidsai/rapidsmpf/issues/560
216
- async with shutdown_on_error(context, ch_in.data, *[ch.data for ch in chs_out]):
284
+ async with shutdown_on_error(
285
+ context,
286
+ ch_in.metadata,
287
+ ch_in.data,
288
+ *[ch.metadata for ch in chs_out],
289
+ *[ch.data for ch in chs_out],
290
+ ):
291
+ # Forward metadata to all outputs.
292
+ metadata = await ch_in.recv_metadata(context)
293
+ await asyncio.gather(*(ch.send_metadata(context, metadata) for ch in chs_out))
294
+
217
295
  while (msg := await ch_in.data.recv(context)) is not None:
218
296
  table_chunk = TableChunk.from_message(msg).make_available_and_spill(
219
297
  context.br(), allow_overbooking=True
220
298
  )
221
299
  seq_num = msg.sequence_number
300
+ del msg
222
301
  for ch_out in chs_out:
223
302
  await ch_out.data.send(
224
303
  context,
@@ -231,6 +310,7 @@ async def fanout_node_bounded(
231
310
  ),
232
311
  ),
233
312
  )
313
+ del table_chunk
234
314
 
235
315
  await asyncio.gather(*(ch.data.drain(context) for ch in chs_out))
236
316
 
@@ -242,14 +322,15 @@ async def fanout_node_unbounded(
242
322
  *chs_out: ChannelPair,
243
323
  ) -> None:
244
324
  """
245
- Unbounded fanout node for rapidsmpf.
325
+ Unbounded fanout node for rapidsmpf with spilling support.
246
326
 
247
327
  Broadcasts chunks from input to all output channels. This is called
248
328
  "unbounded" because it handles the case where one channel may consume
249
329
  all data before another channel consumes any data.
250
330
 
251
- The implementation uses adaptive sending:
252
- - Maintains a FIFO buffer for each output channel
331
+ The implementation uses adaptive sending with spillable buffers:
332
+ - Maintains a spillable FIFO buffer for each output channel
333
+ - Messages are buffered in host memory (spillable to disk)
253
334
  - Sends to all channels concurrently
254
335
  - Receives next chunk as soon as any channel makes progress
255
336
  - Efficient for both balanced and imbalanced consumption patterns
@@ -265,107 +346,182 @@ async def fanout_node_unbounded(
265
346
  """
266
347
  # TODO: Use rapidsmpf fanout node once available.
267
348
  # See: https://github.com/rapidsai/rapidsmpf/issues/560
268
- async with shutdown_on_error(context, ch_in.data, *[ch.data for ch in chs_out]):
269
- # FIFO buffer for each output channel
270
- output_buffers: list[list[Message]] = [[] for _ in chs_out]
271
-
272
- # Track active send/drain tasks for each output
273
- active_tasks: dict[int, asyncio.Task] = {}
274
-
275
- # Track which outputs need to be drained (set when no more input)
276
- needs_drain: set[int] = set()
277
-
278
- # Receive task
279
- recv_task: asyncio.Task | None = asyncio.create_task(ch_in.data.recv(context))
280
-
281
- # Flag to indicate we should start a new receive (for backpressure)
282
- can_receive: bool = True
283
-
284
- async def send_one_from_buffer(idx: int) -> None:
285
- """Send one buffered message for output idx."""
286
- if output_buffers[idx]:
287
- msg = output_buffers[idx].pop(0)
288
- await chs_out[idx].data.send(context, msg)
289
-
290
- async def drain_output(idx: int) -> None:
291
- """Drain output channel idx."""
292
- await chs_out[idx].data.drain(context)
293
-
294
- # Main loop: coordinate receiving, sending, and draining
295
- while (
296
- recv_task is not None or active_tasks or any(output_buffers) or needs_drain
297
- ):
298
- # Collect all currently active tasks
299
- tasks_to_wait = list(active_tasks.values())
300
- # Only include recv_task if we're allowed to receive
301
- if recv_task is not None and can_receive:
302
- tasks_to_wait.append(recv_task)
303
-
304
- # Start new tasks for outputs with work to do
305
- for idx in range(len(chs_out)):
306
- if idx not in active_tasks:
307
- if output_buffers[idx]:
308
- # Send next buffered message
309
- task = asyncio.create_task(send_one_from_buffer(idx))
310
- active_tasks[idx] = task
311
- tasks_to_wait.append(task)
312
- elif idx in needs_drain:
313
- # Buffer empty and no more input - drain this output
314
- task = asyncio.create_task(drain_output(idx))
315
- active_tasks[idx] = task
316
- tasks_to_wait.append(task)
317
- needs_drain.discard(idx)
318
-
319
- # If nothing to wait for, we're done
320
- if not tasks_to_wait:
321
- break
349
+ async with shutdown_on_error(
350
+ context,
351
+ ch_in.metadata,
352
+ ch_in.data,
353
+ *[ch.metadata for ch in chs_out],
354
+ *[ch.data for ch in chs_out],
355
+ ):
356
+ # Forward metadata to all outputs.
357
+ metadata = await ch_in.recv_metadata(context)
358
+ await asyncio.gather(*(ch.send_metadata(context, metadata) for ch in chs_out))
359
+
360
+ # Spillable FIFO buffer for each output channel
361
+ output_buffers: list[SpillableMessages] = [SpillableMessages() for _ in chs_out]
362
+ num_outputs = len(chs_out)
363
+
364
+ # Track message IDs in FIFO order for each output buffer
365
+ buffer_ids: list[list[int]] = [[] for _ in chs_out]
366
+
367
+ # Register a single spill function for all buffers
368
+ # This ensures global FIFO ordering when spilling across all outputs
369
+ spill_func_id = context.br().spill_manager.add_spill_function(
370
+ make_spill_function(output_buffers, context), priority=0
371
+ )
372
+
373
+ try:
374
+ # Track active send/drain tasks for each output
375
+ active_tasks: dict[int, asyncio.Task] = {}
376
+
377
+ # Track which outputs need to be drained (set when no more input)
378
+ needs_drain: set[int] = set()
322
379
 
323
- # Wait for ANY task to complete
324
- done, _ = await asyncio.wait(
325
- tasks_to_wait, return_when=asyncio.FIRST_COMPLETED
380
+ # Receive task
381
+ recv_task: asyncio.Task | None = asyncio.create_task(
382
+ ch_in.data.recv(context)
326
383
  )
327
384
 
328
- # Process completed tasks
329
- for task in done:
330
- if task is recv_task:
331
- # Receive completed
332
- msg = task.result()
333
- if msg is None:
334
- # End of input - mark all outputs as needing drain
335
- recv_task = None
336
- needs_drain.update(range(len(chs_out)))
337
- else:
338
- # Add message to all output buffers
339
- chunk = TableChunk.from_message(msg).make_available_and_spill(
340
- context.br(), allow_overbooking=True
341
- )
342
- seq_num = msg.sequence_number
343
- for buffer in output_buffers:
344
- message = Message(
345
- seq_num,
346
- TableChunk.from_pylibcudf_table(
347
- chunk.table_view(),
348
- chunk.stream,
349
- exclusive_view=False,
350
- ),
385
+ # Flag to indicate we should start a new receive (for backpressure)
386
+ can_receive: bool = True
387
+
388
+ async def send_one_from_buffer(idx: int) -> None:
389
+ """
390
+ Send one buffered message for output idx.
391
+
392
+ The message remains in host memory (spillable) through the channel.
393
+ The downstream consumer will call make_available() when needed.
394
+ """
395
+ if buffer_ids[idx]:
396
+ mid = buffer_ids[idx].pop(0)
397
+ msg = output_buffers[idx].extract(mid=mid)
398
+ await chs_out[idx].data.send(context, msg)
399
+
400
+ async def drain_output(idx: int) -> None:
401
+ """Drain output channel idx."""
402
+ await chs_out[idx].data.drain(context)
403
+
404
+ # Main loop: coordinate receiving, sending, and draining
405
+ while (
406
+ recv_task is not None or active_tasks or any(buffer_ids) or needs_drain
407
+ ):
408
+ # Collect all currently active tasks
409
+ tasks_to_wait = list(active_tasks.values())
410
+ # Only include recv_task if we're allowed to receive
411
+ if recv_task is not None and can_receive:
412
+ tasks_to_wait.append(recv_task)
413
+
414
+ # Start new tasks for outputs with work to do
415
+ for idx in range(len(chs_out)):
416
+ if idx not in active_tasks:
417
+ if buffer_ids[idx]:
418
+ # Send next buffered message
419
+ task = asyncio.create_task(send_one_from_buffer(idx))
420
+ active_tasks[idx] = task
421
+ tasks_to_wait.append(task)
422
+ elif idx in needs_drain:
423
+ # Buffer empty and no more input - drain this output
424
+ task = asyncio.create_task(drain_output(idx))
425
+ active_tasks[idx] = task
426
+ tasks_to_wait.append(task)
427
+ needs_drain.discard(idx)
428
+
429
+ # If nothing to wait for, we're done
430
+ if not tasks_to_wait:
431
+ break
432
+
433
+ # Wait for ANY task to complete
434
+ done, _ = await asyncio.wait(
435
+ tasks_to_wait, return_when=asyncio.FIRST_COMPLETED
436
+ )
437
+
438
+ # Process completed tasks
439
+ for task in done:
440
+ if task is recv_task:
441
+ # Receive completed
442
+ msg = task.result()
443
+ if msg is None:
444
+ # End of input - mark all outputs as needing drain
445
+ recv_task = None
446
+ needs_drain.update(range(len(chs_out)))
447
+ else:
448
+ # Determine where to copy based on:
449
+ # 1. Current message location (avoid unnecessary transfers)
450
+ # 2. Available memory (avoid OOM)
451
+ content_desc = msg.get_content_description()
452
+ device_size = content_desc.content_sizes.get(
453
+ MemoryType.DEVICE, 0
454
+ )
455
+ copy_cost = msg.copy_cost()
456
+
457
+ # Check if we have enough device memory for all copies
458
+ # We need (num_outputs - 1) copies since last one reuses original
459
+ num_copies = num_outputs - 1
460
+ total_copy_cost = copy_cost * num_copies
461
+ available_device_mem = context.br().memory_available(
462
+ MemoryType.DEVICE
351
463
  )
352
- buffer.append(message)
353
464
 
354
- # Don't receive next chunk until at least one send completes
355
- can_receive = False
356
- recv_task = asyncio.create_task(ch_in.data.recv(context))
357
- else:
358
- # Must be a send or drain task - find which output and remove it
359
- for idx, at in list(active_tasks.items()):
360
- if at is task:
361
- del active_tasks[idx]
362
- # A send completed - allow receiving again
363
- can_receive = True
364
- break
465
+ # Decide target memory:
466
+ # Use device ONLY if message is in device AND we have sufficient headroom.
467
+ # TODO: Use further information about the downstream operations to make
468
+ # a more informed decision.
469
+ required_headroom = total_copy_cost * 2
470
+ if (
471
+ device_size > 0
472
+ and available_device_mem >= required_headroom
473
+ ):
474
+ # Use reserve_device_memory_and_spill to automatically trigger spilling
475
+ # if needed to make room for the copy
476
+ memory_reservation = (
477
+ context.br().reserve_device_memory_and_spill(
478
+ total_copy_cost,
479
+ allow_overbooking=True,
480
+ )
481
+ )
482
+ else:
483
+ # Use host memory for buffering - much safer
484
+ # Downstream consumers will make_available() when they need device memory
485
+ memory_reservation, _ = context.br().reserve(
486
+ MemoryType.HOST,
487
+ total_copy_cost,
488
+ allow_overbooking=True,
489
+ )
490
+
491
+ # Copy message for each output buffer
492
+ # Copies are spillable and allow downstream consumers
493
+ # to control device memory allocation
494
+ for idx, sm in enumerate(output_buffers):
495
+ if idx < num_outputs - 1:
496
+ # Copy to target memory and insert into spillable buffer
497
+ mid = sm.insert(msg.copy(memory_reservation))
498
+ else:
499
+ # Optimization: reuse the original message for last output
500
+ # (no copy needed)
501
+ mid = sm.insert(msg)
502
+ buffer_ids[idx].append(mid)
503
+
504
+ # Don't receive next chunk until at least one send completes
505
+ can_receive = False
506
+ recv_task = asyncio.create_task(ch_in.data.recv(context))
507
+ else:
508
+ # Must be a send or drain task - find which output and remove it
509
+ for idx, at in list(active_tasks.items()):
510
+ if at is task:
511
+ del active_tasks[idx]
512
+ # A send completed - allow receiving again
513
+ can_receive = True
514
+ break
515
+
516
+ finally:
517
+ # Clean up spill function registration
518
+ context.br().spill_manager.remove_spill_function(spill_func_id)
365
519
 
366
520
 
367
521
  @generate_ir_sub_network.register(IR)
368
- def _(ir: IR, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManager]]:
522
+ def _(
523
+ ir: IR, rec: SubNetGenerator
524
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
369
525
  # Default generate_ir_sub_network logic.
370
526
  # Use simple pointwise node.
371
527
 
@@ -377,18 +533,27 @@ def _(ir: IR, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManager]
377
533
 
378
534
  if len(ir.children) == 1:
379
535
  # Single-channel default node
380
- nodes.append(
536
+ preserve_partitioning = isinstance(
537
+ # TODO: We don't need to worry about
538
+ # non-pointwise Filter operations here,
539
+ # because the lowering stage would have
540
+ # collapsed to one partition anyway.
541
+ ir,
542
+ (Cache, Projection, Filter),
543
+ )
544
+ nodes[ir] = [
381
545
  default_node_single(
382
546
  rec.state["context"],
383
547
  ir,
384
548
  rec.state["ir_context"],
385
549
  channels[ir].reserve_input_slot(),
386
550
  channels[ir.children[0]].reserve_output_slot(),
551
+ preserve_partitioning=preserve_partitioning,
387
552
  )
388
- )
553
+ ]
389
554
  else:
390
555
  # Multi-channel default node
391
- nodes.append(
556
+ nodes[ir] = [
392
557
  default_node_multi(
393
558
  rec.state["context"],
394
559
  ir,
@@ -396,7 +561,7 @@ def _(ir: IR, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManager]
396
561
  channels[ir].reserve_input_slot(),
397
562
  tuple(channels[c].reserve_output_slot() for c in ir.children),
398
563
  )
399
- )
564
+ ]
400
565
 
401
566
  return nodes, channels
402
567
 
@@ -422,7 +587,10 @@ async def empty_node(
422
587
  ch_out
423
588
  The output ChannelPair.
424
589
  """
425
- async with shutdown_on_error(context, ch_out.data):
590
+ async with shutdown_on_error(context, ch_out.metadata, ch_out.data):
591
+ # Send metadata indicating a single empty chunk
592
+ await ch_out.send_metadata(context, Metadata(1, duplicated=True))
593
+
426
594
  # Evaluate the IR node to create an empty DataFrame
427
595
  df: DataFrame = ir.do_evaluate(*ir._non_child_args, context=ir_context)
428
596
 
@@ -436,20 +604,22 @@ async def empty_node(
436
604
 
437
605
 
438
606
  @generate_ir_sub_network.register(Empty)
439
- def _(ir: Empty, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManager]]:
607
+ def _(
608
+ ir: Empty, rec: SubNetGenerator
609
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
440
610
  """Generate network for Empty node - produces one empty chunk."""
441
611
  context = rec.state["context"]
442
612
  ir_context = rec.state["ir_context"]
443
613
  channels: dict[IR, ChannelManager] = {ir: ChannelManager(rec.state["context"])}
444
- nodes: list[Any] = [
445
- empty_node(context, ir, ir_context, channels[ir].reserve_input_slot())
446
- ]
614
+ nodes: dict[IR, list[Any]] = {
615
+ ir: [empty_node(context, ir, ir_context, channels[ir].reserve_input_slot())]
616
+ }
447
617
  return nodes, channels
448
618
 
449
619
 
450
620
  def generate_ir_sub_network_wrapper(
451
621
  ir: IR, rec: SubNetGenerator
452
- ) -> tuple[list[Any], dict[IR, ChannelManager]]:
622
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
453
623
  """
454
624
  Generate a sub-network for the RapidsMPF streaming runtime.
455
625
 
@@ -463,7 +633,7 @@ def generate_ir_sub_network_wrapper(
463
633
  Returns
464
634
  -------
465
635
  nodes
466
- List of streaming-network node(s) for the subgraph.
636
+ Dictionary mapping each IR node to its list of streaming-network node(s).
467
637
  channels
468
638
  Dictionary mapping between each IR node and its
469
639
  corresponding streaming-network output ChannelManager.
@@ -474,21 +644,92 @@ def generate_ir_sub_network_wrapper(
474
644
  if (fanout_info := rec.state["fanout_nodes"].get(ir)) is not None:
475
645
  count = fanout_info.num_consumers
476
646
  manager = ChannelManager(rec.state["context"], count=count)
647
+ fanout_node: Any
477
648
  if fanout_info.unbounded:
478
- nodes.append(
479
- fanout_node_unbounded(
480
- rec.state["context"],
481
- channels[ir].reserve_output_slot(),
482
- *[manager.reserve_input_slot() for _ in range(count)],
483
- )
649
+ fanout_node = fanout_node_unbounded(
650
+ rec.state["context"],
651
+ channels[ir].reserve_output_slot(),
652
+ *[manager.reserve_input_slot() for _ in range(count)],
484
653
  )
485
654
  else: # "bounded"
486
- nodes.append(
487
- fanout_node_bounded(
488
- rec.state["context"],
489
- channels[ir].reserve_output_slot(),
490
- *[manager.reserve_input_slot() for _ in range(count)],
491
- )
655
+ fanout_node = fanout_node_bounded(
656
+ rec.state["context"],
657
+ channels[ir].reserve_output_slot(),
658
+ *[manager.reserve_input_slot() for _ in range(count)],
492
659
  )
660
+ nodes[ir].append(fanout_node)
493
661
  channels[ir] = manager
494
662
  return nodes, channels
663
+
664
+
665
+ @define_py_node()
666
+ async def metadata_feeder_node(
667
+ context: Context,
668
+ channel: ChannelPair,
669
+ metadata: Metadata,
670
+ ) -> None:
671
+ """
672
+ Feed metadata to a channel pair.
673
+
674
+ Parameters
675
+ ----------
676
+ context
677
+ The rapidsmpf context.
678
+ channel
679
+ The channel pair.
680
+ metadata
681
+ The metadata to feed.
682
+ """
683
+ async with shutdown_on_error(context, channel.metadata, channel.data):
684
+ await channel.send_metadata(context, metadata)
685
+
686
+
687
+ @define_py_node()
688
+ async def metadata_drain_node(
689
+ context: Context,
690
+ ir: IR,
691
+ ir_context: IRExecutionContext,
692
+ ch_in: ChannelPair,
693
+ ch_out: Any,
694
+ metadata_collector: list[Metadata] | None,
695
+ ) -> None:
696
+ """
697
+ Drain metadata and forward data to a single channel.
698
+
699
+ Parameters
700
+ ----------
701
+ context
702
+ The rapidsmpf context.
703
+ ir
704
+ The IR node.
705
+ ir_context
706
+ The execution context for the IR node.
707
+ ch_in
708
+ The input ChannelPair (with metadata and data channels).
709
+ ch_out
710
+ The output data channel.
711
+ metadata_collector
712
+ The list to collect the final metadata.
713
+ This list will be mutated when the network is executed.
714
+ If None, metadata will not be collected.
715
+ """
716
+ async with shutdown_on_error(context, ch_in.metadata, ch_in.data, ch_out):
717
+ # Drain metadata channel (we don't need it after this point)
718
+ metadata = await ch_in.recv_metadata(context)
719
+ send_empty = metadata.duplicated and context.comm().rank != 0
720
+ if metadata_collector is not None:
721
+ metadata_collector.append(metadata)
722
+
723
+ # Forward non-duplicated data messages
724
+ while (msg := await ch_in.data.recv(context)) is not None:
725
+ if not send_empty:
726
+ await ch_out.send(context, msg)
727
+
728
+ # Send empty data if needed
729
+ if send_empty:
730
+ stream = ir_context.get_cuda_stream()
731
+ await ch_out.send(
732
+ context, Message(0, empty_table_chunk(ir, context, stream))
733
+ )
734
+
735
+ await ch_out.drain(context)