cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,735 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Core node definitions for the RapidsMPF streaming runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import asyncio
8
+ from typing import TYPE_CHECKING, Any, cast
9
+
10
+ from rapidsmpf.memory.buffer import MemoryType
11
+ from rapidsmpf.streaming.core.message import Message
12
+ from rapidsmpf.streaming.core.node import define_py_node
13
+ from rapidsmpf.streaming.core.spillable_messages import SpillableMessages
14
+ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
15
+
16
+ from cudf_polars.containers import DataFrame
17
+ from cudf_polars.dsl.ir import IR, Cache, Empty, Filter, Projection
18
+ from cudf_polars.experimental.rapidsmpf.dispatch import (
19
+ generate_ir_sub_network,
20
+ )
21
+ from cudf_polars.experimental.rapidsmpf.utils import (
22
+ ChannelManager,
23
+ Metadata,
24
+ empty_table_chunk,
25
+ make_spill_function,
26
+ opaque_reservation,
27
+ process_children,
28
+ shutdown_on_error,
29
+ )
30
+
31
+ if TYPE_CHECKING:
32
+ from rapidsmpf.streaming.core.context import Context
33
+
34
+ from cudf_polars.dsl.ir import IRExecutionContext
35
+ from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator
36
+ from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
37
+
38
+
39
+ @define_py_node()
40
+ async def default_node_single(
41
+ context: Context,
42
+ ir: IR,
43
+ ir_context: IRExecutionContext,
44
+ ch_out: ChannelPair,
45
+ ch_in: ChannelPair,
46
+ *,
47
+ preserve_partitioning: bool = False,
48
+ ) -> None:
49
+ """
50
+ Single-channel default node for rapidsmpf.
51
+
52
+ Parameters
53
+ ----------
54
+ context
55
+ The rapidsmpf context.
56
+ ir
57
+ The IR node.
58
+ ir_context
59
+ The execution context for the IR node.
60
+ ch_out
61
+ The output ChannelPair.
62
+ ch_in
63
+ The input ChannelPair.
64
+ preserve_partitioning
65
+ Whether to preserve the partitioning metadata of the input chunks.
66
+
67
+ Notes
68
+ -----
69
+ Chunks are processed in the order they are received.
70
+ """
71
+ async with shutdown_on_error(
72
+ context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
73
+ ):
74
+ # Recv/send metadata.
75
+ metadata_in = await ch_in.recv_metadata(context)
76
+ metadata_out = Metadata(
77
+ metadata_in.count,
78
+ partitioned_on=metadata_in.partitioned_on if preserve_partitioning else (),
79
+ duplicated=metadata_in.duplicated,
80
+ )
81
+ await ch_out.send_metadata(context, metadata_out)
82
+
83
+ # Recv/send data.
84
+ seq_num = 0
85
+ receiving = True
86
+ received_any = False
87
+ while receiving:
88
+ msg = await ch_in.data.recv(context)
89
+ if msg is None:
90
+ receiving = False
91
+ if received_any:
92
+ break
93
+ else:
94
+ # Make sure we have an empty chunk in case do_evaluate
95
+ # always produces rows (e.g. aggregation)
96
+ stream = ir_context.get_cuda_stream()
97
+ chunk = empty_table_chunk(ir.children[0], context, stream)
98
+ else:
99
+ received_any = True
100
+ chunk = TableChunk.from_message(msg).make_available_and_spill(
101
+ context.br(), allow_overbooking=True
102
+ )
103
+ seq_num = msg.sequence_number
104
+ del msg
105
+
106
+ input_bytes = chunk.data_alloc_size(MemoryType.DEVICE)
107
+ with opaque_reservation(context, input_bytes):
108
+ df = await asyncio.to_thread(
109
+ ir.do_evaluate,
110
+ *ir._non_child_args,
111
+ DataFrame.from_table(
112
+ chunk.table_view(),
113
+ list(ir.children[0].schema.keys()),
114
+ list(ir.children[0].schema.values()),
115
+ chunk.stream,
116
+ ),
117
+ context=ir_context,
118
+ )
119
+ await ch_out.data.send(
120
+ context,
121
+ Message(
122
+ seq_num,
123
+ TableChunk.from_pylibcudf_table(
124
+ df.table, chunk.stream, exclusive_view=True
125
+ ),
126
+ ),
127
+ )
128
+ del df, chunk
129
+
130
+ await ch_out.data.drain(context)
131
+
132
+
133
+ @define_py_node()
134
+ async def default_node_multi(
135
+ context: Context,
136
+ ir: IR,
137
+ ir_context: IRExecutionContext,
138
+ ch_out: ChannelPair,
139
+ chs_in: tuple[ChannelPair, ...],
140
+ *,
141
+ partitioning_index: int | None = None,
142
+ ) -> None:
143
+ """
144
+ Pointwise node for rapidsmpf.
145
+
146
+ Parameters
147
+ ----------
148
+ context
149
+ The rapidsmpf context.
150
+ ir
151
+ The IR node.
152
+ ir_context
153
+ The execution context for the IR node.
154
+ ch_out
155
+ The output ChannelPair.
156
+ chs_in
157
+ Tuple of input ChannelPairs.
158
+ partitioning_index
159
+ Index of the input channel to preserve partitioning information for.
160
+ If None, no partitioning information is preserved.
161
+ """
162
+ async with shutdown_on_error(
163
+ context,
164
+ *[ch.metadata for ch in chs_in],
165
+ ch_out.metadata,
166
+ *[ch.data for ch in chs_in],
167
+ ch_out.data,
168
+ ):
169
+ # Merge and forward basic metadata.
170
+ metadata = Metadata(1)
171
+ for idx, ch_in in enumerate(chs_in):
172
+ md_child = await ch_in.recv_metadata(context)
173
+ metadata.count = max(md_child.count, metadata.count)
174
+ metadata.duplicated = metadata.duplicated and md_child.duplicated
175
+ if idx == partitioning_index:
176
+ metadata.partitioned_on = md_child.partitioned_on
177
+ await ch_out.send_metadata(context, metadata)
178
+
179
+ seq_num = 0
180
+ n_children = len(chs_in)
181
+ finished_channels: set[int] = set()
182
+ # Store TableChunk objects to keep data alive and prevent use-after-free
183
+ # with stream-ordered allocations
184
+ ready_chunks: list[TableChunk | None] = [None] * n_children
185
+ chunk_count: list[int] = [0] * n_children
186
+
187
+ # Recv/send data.
188
+ while True:
189
+ # Receive from all non-finished channels
190
+ for ch_idx, ch_in in enumerate(chs_in):
191
+ if ch_idx in finished_channels:
192
+ continue # This channel already finished, reuse its data
193
+
194
+ msg = await ch_in.data.recv(context)
195
+ if msg is None:
196
+ # Channel finished - keep its last chunk for reuse
197
+ finished_channels.add(ch_idx)
198
+ else:
199
+ # Store the new chunk (replacing previous if any)
200
+ ready_chunks[ch_idx] = TableChunk.from_message(msg)
201
+ chunk_count[ch_idx] += 1
202
+ del msg
203
+
204
+ # If all channels finished, we're done
205
+ if len(finished_channels) == n_children:
206
+ break
207
+
208
+ # Check if any channel drained without providing data.
209
+ # If so, create an empty chunk for that channel.
210
+ for ch_idx, child in enumerate(ir.children):
211
+ if ready_chunks[ch_idx] is None:
212
+ # Channel drained without data - create empty chunk
213
+ stream = ir_context.get_cuda_stream()
214
+ ready_chunks[ch_idx] = empty_table_chunk(child, context, stream)
215
+
216
+ # Ensure all table chunks are unspilled and available.
217
+ ready_chunks = [
218
+ chunk.make_available_and_spill(context.br(), allow_overbooking=True)
219
+ for chunk in cast(list[TableChunk], ready_chunks)
220
+ ]
221
+ dfs = [
222
+ DataFrame.from_table(
223
+ chunk.table_view(), # type: ignore[union-attr]
224
+ list(child.schema.keys()),
225
+ list(child.schema.values()),
226
+ chunk.stream, # type: ignore[union-attr]
227
+ )
228
+ for chunk, child in zip(ready_chunks, ir.children, strict=True)
229
+ ]
230
+
231
+ input_bytes = sum(
232
+ chunk.data_alloc_size(MemoryType.DEVICE)
233
+ for chunk in cast(list[TableChunk], ready_chunks)
234
+ )
235
+ with opaque_reservation(context, input_bytes):
236
+ df = await asyncio.to_thread(
237
+ ir.do_evaluate,
238
+ *ir._non_child_args,
239
+ *dfs,
240
+ context=ir_context,
241
+ )
242
+ await ch_out.data.send(
243
+ context,
244
+ Message(
245
+ seq_num,
246
+ TableChunk.from_pylibcudf_table(
247
+ df.table,
248
+ df.stream,
249
+ exclusive_view=True,
250
+ ),
251
+ ),
252
+ )
253
+ seq_num += 1
254
+ del df, dfs
255
+
256
+ # Drain the output channel
257
+ del ready_chunks
258
+ await ch_out.data.drain(context)
259
+
260
+
261
+ @define_py_node()
262
+ async def fanout_node_bounded(
263
+ context: Context,
264
+ ch_in: ChannelPair,
265
+ *chs_out: ChannelPair,
266
+ ) -> None:
267
+ """
268
+ Bounded fanout node for rapidsmpf.
269
+
270
+ Each chunk is broadcasted to all output channels
271
+ as it arrives.
272
+
273
+ Parameters
274
+ ----------
275
+ context
276
+ The rapidsmpf context.
277
+ ch_in
278
+ The input ChannelPair.
279
+ chs_out
280
+ The output ChannelPairs.
281
+ """
282
+ # TODO: Use rapidsmpf fanout node once available.
283
+ # See: https://github.com/rapidsai/rapidsmpf/issues/560
284
+ async with shutdown_on_error(
285
+ context,
286
+ ch_in.metadata,
287
+ ch_in.data,
288
+ *[ch.metadata for ch in chs_out],
289
+ *[ch.data for ch in chs_out],
290
+ ):
291
+ # Forward metadata to all outputs.
292
+ metadata = await ch_in.recv_metadata(context)
293
+ await asyncio.gather(*(ch.send_metadata(context, metadata) for ch in chs_out))
294
+
295
+ while (msg := await ch_in.data.recv(context)) is not None:
296
+ table_chunk = TableChunk.from_message(msg).make_available_and_spill(
297
+ context.br(), allow_overbooking=True
298
+ )
299
+ seq_num = msg.sequence_number
300
+ del msg
301
+ for ch_out in chs_out:
302
+ await ch_out.data.send(
303
+ context,
304
+ Message(
305
+ seq_num,
306
+ TableChunk.from_pylibcudf_table(
307
+ table_chunk.table_view(),
308
+ table_chunk.stream,
309
+ exclusive_view=False,
310
+ ),
311
+ ),
312
+ )
313
+ del table_chunk
314
+
315
+ await asyncio.gather(*(ch.data.drain(context) for ch in chs_out))
316
+
317
+
318
+ @define_py_node()
319
+ async def fanout_node_unbounded(
320
+ context: Context,
321
+ ch_in: ChannelPair,
322
+ *chs_out: ChannelPair,
323
+ ) -> None:
324
+ """
325
+ Unbounded fanout node for rapidsmpf with spilling support.
326
+
327
+ Broadcasts chunks from input to all output channels. This is called
328
+ "unbounded" because it handles the case where one channel may consume
329
+ all data before another channel consumes any data.
330
+
331
+ The implementation uses adaptive sending with spillable buffers:
332
+ - Maintains a spillable FIFO buffer for each output channel
333
+ - Messages are buffered in host memory (spillable to disk)
334
+ - Sends to all channels concurrently
335
+ - Receives next chunk as soon as any channel makes progress
336
+ - Efficient for both balanced and imbalanced consumption patterns
337
+
338
+ Parameters
339
+ ----------
340
+ context
341
+ The rapidsmpf context.
342
+ ch_in
343
+ The input ChannelPair.
344
+ chs_out
345
+ The output ChannelPairs.
346
+ """
347
+ # TODO: Use rapidsmpf fanout node once available.
348
+ # See: https://github.com/rapidsai/rapidsmpf/issues/560
349
+ async with shutdown_on_error(
350
+ context,
351
+ ch_in.metadata,
352
+ ch_in.data,
353
+ *[ch.metadata for ch in chs_out],
354
+ *[ch.data for ch in chs_out],
355
+ ):
356
+ # Forward metadata to all outputs.
357
+ metadata = await ch_in.recv_metadata(context)
358
+ await asyncio.gather(*(ch.send_metadata(context, metadata) for ch in chs_out))
359
+
360
+ # Spillable FIFO buffer for each output channel
361
+ output_buffers: list[SpillableMessages] = [SpillableMessages() for _ in chs_out]
362
+ num_outputs = len(chs_out)
363
+
364
+ # Track message IDs in FIFO order for each output buffer
365
+ buffer_ids: list[list[int]] = [[] for _ in chs_out]
366
+
367
+ # Register a single spill function for all buffers
368
+ # This ensures global FIFO ordering when spilling across all outputs
369
+ spill_func_id = context.br().spill_manager.add_spill_function(
370
+ make_spill_function(output_buffers, context), priority=0
371
+ )
372
+
373
+ try:
374
+ # Track active send/drain tasks for each output
375
+ active_tasks: dict[int, asyncio.Task] = {}
376
+
377
+ # Track which outputs need to be drained (set when no more input)
378
+ needs_drain: set[int] = set()
379
+
380
+ # Receive task
381
+ recv_task: asyncio.Task | None = asyncio.create_task(
382
+ ch_in.data.recv(context)
383
+ )
384
+
385
+ # Flag to indicate we should start a new receive (for backpressure)
386
+ can_receive: bool = True
387
+
388
+ async def send_one_from_buffer(idx: int) -> None:
389
+ """
390
+ Send one buffered message for output idx.
391
+
392
+ The message remains in host memory (spillable) through the channel.
393
+ The downstream consumer will call make_available() when needed.
394
+ """
395
+ if buffer_ids[idx]:
396
+ mid = buffer_ids[idx].pop(0)
397
+ msg = output_buffers[idx].extract(mid=mid)
398
+ await chs_out[idx].data.send(context, msg)
399
+
400
+ async def drain_output(idx: int) -> None:
401
+ """Drain output channel idx."""
402
+ await chs_out[idx].data.drain(context)
403
+
404
+ # Main loop: coordinate receiving, sending, and draining
405
+ while (
406
+ recv_task is not None or active_tasks or any(buffer_ids) or needs_drain
407
+ ):
408
+ # Collect all currently active tasks
409
+ tasks_to_wait = list(active_tasks.values())
410
+ # Only include recv_task if we're allowed to receive
411
+ if recv_task is not None and can_receive:
412
+ tasks_to_wait.append(recv_task)
413
+
414
+ # Start new tasks for outputs with work to do
415
+ for idx in range(len(chs_out)):
416
+ if idx not in active_tasks:
417
+ if buffer_ids[idx]:
418
+ # Send next buffered message
419
+ task = asyncio.create_task(send_one_from_buffer(idx))
420
+ active_tasks[idx] = task
421
+ tasks_to_wait.append(task)
422
+ elif idx in needs_drain:
423
+ # Buffer empty and no more input - drain this output
424
+ task = asyncio.create_task(drain_output(idx))
425
+ active_tasks[idx] = task
426
+ tasks_to_wait.append(task)
427
+ needs_drain.discard(idx)
428
+
429
+ # If nothing to wait for, we're done
430
+ if not tasks_to_wait:
431
+ break
432
+
433
+ # Wait for ANY task to complete
434
+ done, _ = await asyncio.wait(
435
+ tasks_to_wait, return_when=asyncio.FIRST_COMPLETED
436
+ )
437
+
438
+ # Process completed tasks
439
+ for task in done:
440
+ if task is recv_task:
441
+ # Receive completed
442
+ msg = task.result()
443
+ if msg is None:
444
+ # End of input - mark all outputs as needing drain
445
+ recv_task = None
446
+ needs_drain.update(range(len(chs_out)))
447
+ else:
448
+ # Determine where to copy based on:
449
+ # 1. Current message location (avoid unnecessary transfers)
450
+ # 2. Available memory (avoid OOM)
451
+ content_desc = msg.get_content_description()
452
+ device_size = content_desc.content_sizes.get(
453
+ MemoryType.DEVICE, 0
454
+ )
455
+ copy_cost = msg.copy_cost()
456
+
457
+ # Check if we have enough device memory for all copies
458
+ # We need (num_outputs - 1) copies since last one reuses original
459
+ num_copies = num_outputs - 1
460
+ total_copy_cost = copy_cost * num_copies
461
+ available_device_mem = context.br().memory_available(
462
+ MemoryType.DEVICE
463
+ )
464
+
465
+ # Decide target memory:
466
+ # Use device ONLY if message is in device AND we have sufficient headroom.
467
+ # TODO: Use further information about the downstream operations to make
468
+ # a more informed decision.
469
+ required_headroom = total_copy_cost * 2
470
+ if (
471
+ device_size > 0
472
+ and available_device_mem >= required_headroom
473
+ ):
474
+ # Use reserve_device_memory_and_spill to automatically trigger spilling
475
+ # if needed to make room for the copy
476
+ memory_reservation = (
477
+ context.br().reserve_device_memory_and_spill(
478
+ total_copy_cost,
479
+ allow_overbooking=True,
480
+ )
481
+ )
482
+ else:
483
+ # Use host memory for buffering - much safer
484
+ # Downstream consumers will make_available() when they need device memory
485
+ memory_reservation, _ = context.br().reserve(
486
+ MemoryType.HOST,
487
+ total_copy_cost,
488
+ allow_overbooking=True,
489
+ )
490
+
491
+ # Copy message for each output buffer
492
+ # Copies are spillable and allow downstream consumers
493
+ # to control device memory allocation
494
+ for idx, sm in enumerate(output_buffers):
495
+ if idx < num_outputs - 1:
496
+ # Copy to target memory and insert into spillable buffer
497
+ mid = sm.insert(msg.copy(memory_reservation))
498
+ else:
499
+ # Optimization: reuse the original message for last output
500
+ # (no copy needed)
501
+ mid = sm.insert(msg)
502
+ buffer_ids[idx].append(mid)
503
+
504
+ # Don't receive next chunk until at least one send completes
505
+ can_receive = False
506
+ recv_task = asyncio.create_task(ch_in.data.recv(context))
507
+ else:
508
+ # Must be a send or drain task - find which output and remove it
509
+ for idx, at in list(active_tasks.items()):
510
+ if at is task:
511
+ del active_tasks[idx]
512
+ # A send completed - allow receiving again
513
+ can_receive = True
514
+ break
515
+
516
+ finally:
517
+ # Clean up spill function registration
518
+ context.br().spill_manager.remove_spill_function(spill_func_id)
519
+
520
+
521
+ @generate_ir_sub_network.register(IR)
522
+ def _(
523
+ ir: IR, rec: SubNetGenerator
524
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
525
+ # Default generate_ir_sub_network logic.
526
+ # Use simple pointwise node.
527
+
528
+ # Process children
529
+ nodes, channels = process_children(ir, rec)
530
+
531
+ # Create output ChannelManager
532
+ channels[ir] = ChannelManager(rec.state["context"])
533
+
534
+ if len(ir.children) == 1:
535
+ # Single-channel default node
536
+ preserve_partitioning = isinstance(
537
+ # TODO: We don't need to worry about
538
+ # non-pointwise Filter operations here,
539
+ # because the lowering stage would have
540
+ # collapsed to one partition anyway.
541
+ ir,
542
+ (Cache, Projection, Filter),
543
+ )
544
+ nodes[ir] = [
545
+ default_node_single(
546
+ rec.state["context"],
547
+ ir,
548
+ rec.state["ir_context"],
549
+ channels[ir].reserve_input_slot(),
550
+ channels[ir.children[0]].reserve_output_slot(),
551
+ preserve_partitioning=preserve_partitioning,
552
+ )
553
+ ]
554
+ else:
555
+ # Multi-channel default node
556
+ nodes[ir] = [
557
+ default_node_multi(
558
+ rec.state["context"],
559
+ ir,
560
+ rec.state["ir_context"],
561
+ channels[ir].reserve_input_slot(),
562
+ tuple(channels[c].reserve_output_slot() for c in ir.children),
563
+ )
564
+ ]
565
+
566
+ return nodes, channels
567
+
568
+
569
+ @define_py_node()
570
+ async def empty_node(
571
+ context: Context,
572
+ ir: Empty,
573
+ ir_context: IRExecutionContext,
574
+ ch_out: ChannelPair,
575
+ ) -> None:
576
+ """
577
+ Empty node for rapidsmpf - produces a single empty chunk.
578
+
579
+ Parameters
580
+ ----------
581
+ context
582
+ The rapidsmpf context.
583
+ ir
584
+ The Empty node.
585
+ ir_context
586
+ The execution context for the IR node.
587
+ ch_out
588
+ The output ChannelPair.
589
+ """
590
+ async with shutdown_on_error(context, ch_out.metadata, ch_out.data):
591
+ # Send metadata indicating a single empty chunk
592
+ await ch_out.send_metadata(context, Metadata(1, duplicated=True))
593
+
594
+ # Evaluate the IR node to create an empty DataFrame
595
+ df: DataFrame = ir.do_evaluate(*ir._non_child_args, context=ir_context)
596
+
597
+ # Return the output chunk (empty but with correct schema)
598
+ chunk = TableChunk.from_pylibcudf_table(
599
+ df.table, df.stream, exclusive_view=True
600
+ )
601
+ await ch_out.data.send(context, Message(0, chunk))
602
+
603
+ await ch_out.data.drain(context)
604
+
605
+
606
+ @generate_ir_sub_network.register(Empty)
607
+ def _(
608
+ ir: Empty, rec: SubNetGenerator
609
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
610
+ """Generate network for Empty node - produces one empty chunk."""
611
+ context = rec.state["context"]
612
+ ir_context = rec.state["ir_context"]
613
+ channels: dict[IR, ChannelManager] = {ir: ChannelManager(rec.state["context"])}
614
+ nodes: dict[IR, list[Any]] = {
615
+ ir: [empty_node(context, ir, ir_context, channels[ir].reserve_input_slot())]
616
+ }
617
+ return nodes, channels
618
+
619
+
620
+ def generate_ir_sub_network_wrapper(
621
+ ir: IR, rec: SubNetGenerator
622
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
623
+ """
624
+ Generate a sub-network for the RapidsMPF streaming runtime.
625
+
626
+ Parameters
627
+ ----------
628
+ ir
629
+ The IR node.
630
+ rec
631
+ Recursive SubNetGenerator callable.
632
+
633
+ Returns
634
+ -------
635
+ nodes
636
+ Dictionary mapping each IR node to its list of streaming-network node(s).
637
+ channels
638
+ Dictionary mapping between each IR node and its
639
+ corresponding streaming-network output ChannelManager.
640
+ """
641
+ nodes, channels = generate_ir_sub_network(ir, rec)
642
+
643
+ # Check if this node needs fanout
644
+ if (fanout_info := rec.state["fanout_nodes"].get(ir)) is not None:
645
+ count = fanout_info.num_consumers
646
+ manager = ChannelManager(rec.state["context"], count=count)
647
+ fanout_node: Any
648
+ if fanout_info.unbounded:
649
+ fanout_node = fanout_node_unbounded(
650
+ rec.state["context"],
651
+ channels[ir].reserve_output_slot(),
652
+ *[manager.reserve_input_slot() for _ in range(count)],
653
+ )
654
+ else: # "bounded"
655
+ fanout_node = fanout_node_bounded(
656
+ rec.state["context"],
657
+ channels[ir].reserve_output_slot(),
658
+ *[manager.reserve_input_slot() for _ in range(count)],
659
+ )
660
+ nodes[ir].append(fanout_node)
661
+ channels[ir] = manager
662
+ return nodes, channels
663
+
664
+
665
+ @define_py_node()
666
+ async def metadata_feeder_node(
667
+ context: Context,
668
+ channel: ChannelPair,
669
+ metadata: Metadata,
670
+ ) -> None:
671
+ """
672
+ Feed metadata to a channel pair.
673
+
674
+ Parameters
675
+ ----------
676
+ context
677
+ The rapidsmpf context.
678
+ channel
679
+ The channel pair.
680
+ metadata
681
+ The metadata to feed.
682
+ """
683
+ async with shutdown_on_error(context, channel.metadata, channel.data):
684
+ await channel.send_metadata(context, metadata)
685
+
686
+
687
+ @define_py_node()
688
+ async def metadata_drain_node(
689
+ context: Context,
690
+ ir: IR,
691
+ ir_context: IRExecutionContext,
692
+ ch_in: ChannelPair,
693
+ ch_out: Any,
694
+ metadata_collector: list[Metadata] | None,
695
+ ) -> None:
696
+ """
697
+ Drain metadata and forward data to a single channel.
698
+
699
+ Parameters
700
+ ----------
701
+ context
702
+ The rapidsmpf context.
703
+ ir
704
+ The IR node.
705
+ ir_context
706
+ The execution context for the IR node.
707
+ ch_in
708
+ The input ChannelPair (with metadata and data channels).
709
+ ch_out
710
+ The output data channel.
711
+ metadata_collector
712
+ The list to collect the final metadata.
713
+ This list will be mutated when the network is executed.
714
+ If None, metadata will not be collected.
715
+ """
716
+ async with shutdown_on_error(context, ch_in.metadata, ch_in.data, ch_out):
717
+ # Drain metadata channel (we don't need it after this point)
718
+ metadata = await ch_in.recv_metadata(context)
719
+ send_empty = metadata.duplicated and context.comm().rank != 0
720
+ if metadata_collector is not None:
721
+ metadata_collector.append(metadata)
722
+
723
+ # Forward non-duplicated data messages
724
+ while (msg := await ch_in.data.recv(context)) is not None:
725
+ if not send_empty:
726
+ await ch_out.send(context, msg)
727
+
728
+ # Send empty data if needed
729
+ if send_empty:
730
+ stream = ir_context.get_cuda_stream()
731
+ await ch_out.send(
732
+ context, Message(0, empty_table_chunk(ir, context, stream))
733
+ )
734
+
735
+ await ch_out.drain(context)