cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +60 -15
- cudf_polars/containers/column.py +137 -77
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +256 -114
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +33 -3
- cudf_polars/dsl/expressions/unary.py +126 -64
- cudf_polars/dsl/ir.py +1053 -350
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +307 -107
- cudf_polars/dsl/utils/aggregations.py +43 -30
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +55 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +792 -2
- cudf_polars/experimental/benchmarks/utils.py +596 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +30 -15
- cudf_polars/experimental/groupby.py +25 -4
- cudf_polars/experimental/io.py +156 -124
- cudf_polars/experimental/join.py +53 -23
- cudf_polars/experimental/parallel.py +68 -19
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
- cudf_polars/experimental/rapidsmpf/core.py +488 -0
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
- cudf_polars/experimental/rapidsmpf/io.py +696 -0
- cudf_polars/experimental/rapidsmpf/join.py +322 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
- cudf_polars/experimental/rapidsmpf/union.py +115 -0
- cudf_polars/experimental/rapidsmpf/utils.py +374 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +46 -12
- cudf_polars/experimental/sort.py +100 -26
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +93 -17
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +473 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +5 -4
- cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
- cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,735 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Core node definitions for the RapidsMPF streaming runtime."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
9
|
+
|
|
10
|
+
from rapidsmpf.memory.buffer import MemoryType
|
|
11
|
+
from rapidsmpf.streaming.core.message import Message
|
|
12
|
+
from rapidsmpf.streaming.core.node import define_py_node
|
|
13
|
+
from rapidsmpf.streaming.core.spillable_messages import SpillableMessages
|
|
14
|
+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
|
|
15
|
+
|
|
16
|
+
from cudf_polars.containers import DataFrame
|
|
17
|
+
from cudf_polars.dsl.ir import IR, Cache, Empty, Filter, Projection
|
|
18
|
+
from cudf_polars.experimental.rapidsmpf.dispatch import (
|
|
19
|
+
generate_ir_sub_network,
|
|
20
|
+
)
|
|
21
|
+
from cudf_polars.experimental.rapidsmpf.utils import (
|
|
22
|
+
ChannelManager,
|
|
23
|
+
Metadata,
|
|
24
|
+
empty_table_chunk,
|
|
25
|
+
make_spill_function,
|
|
26
|
+
opaque_reservation,
|
|
27
|
+
process_children,
|
|
28
|
+
shutdown_on_error,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from rapidsmpf.streaming.core.context import Context
|
|
33
|
+
|
|
34
|
+
from cudf_polars.dsl.ir import IRExecutionContext
|
|
35
|
+
from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator
|
|
36
|
+
from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@define_py_node()
|
|
40
|
+
async def default_node_single(
|
|
41
|
+
context: Context,
|
|
42
|
+
ir: IR,
|
|
43
|
+
ir_context: IRExecutionContext,
|
|
44
|
+
ch_out: ChannelPair,
|
|
45
|
+
ch_in: ChannelPair,
|
|
46
|
+
*,
|
|
47
|
+
preserve_partitioning: bool = False,
|
|
48
|
+
) -> None:
|
|
49
|
+
"""
|
|
50
|
+
Single-channel default node for rapidsmpf.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
context
|
|
55
|
+
The rapidsmpf context.
|
|
56
|
+
ir
|
|
57
|
+
The IR node.
|
|
58
|
+
ir_context
|
|
59
|
+
The execution context for the IR node.
|
|
60
|
+
ch_out
|
|
61
|
+
The output ChannelPair.
|
|
62
|
+
ch_in
|
|
63
|
+
The input ChannelPair.
|
|
64
|
+
preserve_partitioning
|
|
65
|
+
Whether to preserve the partitioning metadata of the input chunks.
|
|
66
|
+
|
|
67
|
+
Notes
|
|
68
|
+
-----
|
|
69
|
+
Chunks are processed in the order they are received.
|
|
70
|
+
"""
|
|
71
|
+
async with shutdown_on_error(
|
|
72
|
+
context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
|
|
73
|
+
):
|
|
74
|
+
# Recv/send metadata.
|
|
75
|
+
metadata_in = await ch_in.recv_metadata(context)
|
|
76
|
+
metadata_out = Metadata(
|
|
77
|
+
metadata_in.count,
|
|
78
|
+
partitioned_on=metadata_in.partitioned_on if preserve_partitioning else (),
|
|
79
|
+
duplicated=metadata_in.duplicated,
|
|
80
|
+
)
|
|
81
|
+
await ch_out.send_metadata(context, metadata_out)
|
|
82
|
+
|
|
83
|
+
# Recv/send data.
|
|
84
|
+
seq_num = 0
|
|
85
|
+
receiving = True
|
|
86
|
+
received_any = False
|
|
87
|
+
while receiving:
|
|
88
|
+
msg = await ch_in.data.recv(context)
|
|
89
|
+
if msg is None:
|
|
90
|
+
receiving = False
|
|
91
|
+
if received_any:
|
|
92
|
+
break
|
|
93
|
+
else:
|
|
94
|
+
# Make sure we have an empty chunk in case do_evaluate
|
|
95
|
+
# always produces rows (e.g. aggregation)
|
|
96
|
+
stream = ir_context.get_cuda_stream()
|
|
97
|
+
chunk = empty_table_chunk(ir.children[0], context, stream)
|
|
98
|
+
else:
|
|
99
|
+
received_any = True
|
|
100
|
+
chunk = TableChunk.from_message(msg).make_available_and_spill(
|
|
101
|
+
context.br(), allow_overbooking=True
|
|
102
|
+
)
|
|
103
|
+
seq_num = msg.sequence_number
|
|
104
|
+
del msg
|
|
105
|
+
|
|
106
|
+
input_bytes = chunk.data_alloc_size(MemoryType.DEVICE)
|
|
107
|
+
with opaque_reservation(context, input_bytes):
|
|
108
|
+
df = await asyncio.to_thread(
|
|
109
|
+
ir.do_evaluate,
|
|
110
|
+
*ir._non_child_args,
|
|
111
|
+
DataFrame.from_table(
|
|
112
|
+
chunk.table_view(),
|
|
113
|
+
list(ir.children[0].schema.keys()),
|
|
114
|
+
list(ir.children[0].schema.values()),
|
|
115
|
+
chunk.stream,
|
|
116
|
+
),
|
|
117
|
+
context=ir_context,
|
|
118
|
+
)
|
|
119
|
+
await ch_out.data.send(
|
|
120
|
+
context,
|
|
121
|
+
Message(
|
|
122
|
+
seq_num,
|
|
123
|
+
TableChunk.from_pylibcudf_table(
|
|
124
|
+
df.table, chunk.stream, exclusive_view=True
|
|
125
|
+
),
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
del df, chunk
|
|
129
|
+
|
|
130
|
+
await ch_out.data.drain(context)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@define_py_node()
|
|
134
|
+
async def default_node_multi(
|
|
135
|
+
context: Context,
|
|
136
|
+
ir: IR,
|
|
137
|
+
ir_context: IRExecutionContext,
|
|
138
|
+
ch_out: ChannelPair,
|
|
139
|
+
chs_in: tuple[ChannelPair, ...],
|
|
140
|
+
*,
|
|
141
|
+
partitioning_index: int | None = None,
|
|
142
|
+
) -> None:
|
|
143
|
+
"""
|
|
144
|
+
Pointwise node for rapidsmpf.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
context
|
|
149
|
+
The rapidsmpf context.
|
|
150
|
+
ir
|
|
151
|
+
The IR node.
|
|
152
|
+
ir_context
|
|
153
|
+
The execution context for the IR node.
|
|
154
|
+
ch_out
|
|
155
|
+
The output ChannelPair.
|
|
156
|
+
chs_in
|
|
157
|
+
Tuple of input ChannelPairs.
|
|
158
|
+
partitioning_index
|
|
159
|
+
Index of the input channel to preserve partitioning information for.
|
|
160
|
+
If None, no partitioning information is preserved.
|
|
161
|
+
"""
|
|
162
|
+
async with shutdown_on_error(
|
|
163
|
+
context,
|
|
164
|
+
*[ch.metadata for ch in chs_in],
|
|
165
|
+
ch_out.metadata,
|
|
166
|
+
*[ch.data for ch in chs_in],
|
|
167
|
+
ch_out.data,
|
|
168
|
+
):
|
|
169
|
+
# Merge and forward basic metadata.
|
|
170
|
+
metadata = Metadata(1)
|
|
171
|
+
for idx, ch_in in enumerate(chs_in):
|
|
172
|
+
md_child = await ch_in.recv_metadata(context)
|
|
173
|
+
metadata.count = max(md_child.count, metadata.count)
|
|
174
|
+
metadata.duplicated = metadata.duplicated and md_child.duplicated
|
|
175
|
+
if idx == partitioning_index:
|
|
176
|
+
metadata.partitioned_on = md_child.partitioned_on
|
|
177
|
+
await ch_out.send_metadata(context, metadata)
|
|
178
|
+
|
|
179
|
+
seq_num = 0
|
|
180
|
+
n_children = len(chs_in)
|
|
181
|
+
finished_channels: set[int] = set()
|
|
182
|
+
# Store TableChunk objects to keep data alive and prevent use-after-free
|
|
183
|
+
# with stream-ordered allocations
|
|
184
|
+
ready_chunks: list[TableChunk | None] = [None] * n_children
|
|
185
|
+
chunk_count: list[int] = [0] * n_children
|
|
186
|
+
|
|
187
|
+
# Recv/send data.
|
|
188
|
+
while True:
|
|
189
|
+
# Receive from all non-finished channels
|
|
190
|
+
for ch_idx, ch_in in enumerate(chs_in):
|
|
191
|
+
if ch_idx in finished_channels:
|
|
192
|
+
continue # This channel already finished, reuse its data
|
|
193
|
+
|
|
194
|
+
msg = await ch_in.data.recv(context)
|
|
195
|
+
if msg is None:
|
|
196
|
+
# Channel finished - keep its last chunk for reuse
|
|
197
|
+
finished_channels.add(ch_idx)
|
|
198
|
+
else:
|
|
199
|
+
# Store the new chunk (replacing previous if any)
|
|
200
|
+
ready_chunks[ch_idx] = TableChunk.from_message(msg)
|
|
201
|
+
chunk_count[ch_idx] += 1
|
|
202
|
+
del msg
|
|
203
|
+
|
|
204
|
+
# If all channels finished, we're done
|
|
205
|
+
if len(finished_channels) == n_children:
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
# Check if any channel drained without providing data.
|
|
209
|
+
# If so, create an empty chunk for that channel.
|
|
210
|
+
for ch_idx, child in enumerate(ir.children):
|
|
211
|
+
if ready_chunks[ch_idx] is None:
|
|
212
|
+
# Channel drained without data - create empty chunk
|
|
213
|
+
stream = ir_context.get_cuda_stream()
|
|
214
|
+
ready_chunks[ch_idx] = empty_table_chunk(child, context, stream)
|
|
215
|
+
|
|
216
|
+
# Ensure all table chunks are unspilled and available.
|
|
217
|
+
ready_chunks = [
|
|
218
|
+
chunk.make_available_and_spill(context.br(), allow_overbooking=True)
|
|
219
|
+
for chunk in cast(list[TableChunk], ready_chunks)
|
|
220
|
+
]
|
|
221
|
+
dfs = [
|
|
222
|
+
DataFrame.from_table(
|
|
223
|
+
chunk.table_view(), # type: ignore[union-attr]
|
|
224
|
+
list(child.schema.keys()),
|
|
225
|
+
list(child.schema.values()),
|
|
226
|
+
chunk.stream, # type: ignore[union-attr]
|
|
227
|
+
)
|
|
228
|
+
for chunk, child in zip(ready_chunks, ir.children, strict=True)
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
input_bytes = sum(
|
|
232
|
+
chunk.data_alloc_size(MemoryType.DEVICE)
|
|
233
|
+
for chunk in cast(list[TableChunk], ready_chunks)
|
|
234
|
+
)
|
|
235
|
+
with opaque_reservation(context, input_bytes):
|
|
236
|
+
df = await asyncio.to_thread(
|
|
237
|
+
ir.do_evaluate,
|
|
238
|
+
*ir._non_child_args,
|
|
239
|
+
*dfs,
|
|
240
|
+
context=ir_context,
|
|
241
|
+
)
|
|
242
|
+
await ch_out.data.send(
|
|
243
|
+
context,
|
|
244
|
+
Message(
|
|
245
|
+
seq_num,
|
|
246
|
+
TableChunk.from_pylibcudf_table(
|
|
247
|
+
df.table,
|
|
248
|
+
df.stream,
|
|
249
|
+
exclusive_view=True,
|
|
250
|
+
),
|
|
251
|
+
),
|
|
252
|
+
)
|
|
253
|
+
seq_num += 1
|
|
254
|
+
del df, dfs
|
|
255
|
+
|
|
256
|
+
# Drain the output channel
|
|
257
|
+
del ready_chunks
|
|
258
|
+
await ch_out.data.drain(context)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@define_py_node()
|
|
262
|
+
async def fanout_node_bounded(
|
|
263
|
+
context: Context,
|
|
264
|
+
ch_in: ChannelPair,
|
|
265
|
+
*chs_out: ChannelPair,
|
|
266
|
+
) -> None:
|
|
267
|
+
"""
|
|
268
|
+
Bounded fanout node for rapidsmpf.
|
|
269
|
+
|
|
270
|
+
Each chunk is broadcasted to all output channels
|
|
271
|
+
as it arrives.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
context
|
|
276
|
+
The rapidsmpf context.
|
|
277
|
+
ch_in
|
|
278
|
+
The input ChannelPair.
|
|
279
|
+
chs_out
|
|
280
|
+
The output ChannelPairs.
|
|
281
|
+
"""
|
|
282
|
+
# TODO: Use rapidsmpf fanout node once available.
|
|
283
|
+
# See: https://github.com/rapidsai/rapidsmpf/issues/560
|
|
284
|
+
async with shutdown_on_error(
|
|
285
|
+
context,
|
|
286
|
+
ch_in.metadata,
|
|
287
|
+
ch_in.data,
|
|
288
|
+
*[ch.metadata for ch in chs_out],
|
|
289
|
+
*[ch.data for ch in chs_out],
|
|
290
|
+
):
|
|
291
|
+
# Forward metadata to all outputs.
|
|
292
|
+
metadata = await ch_in.recv_metadata(context)
|
|
293
|
+
await asyncio.gather(*(ch.send_metadata(context, metadata) for ch in chs_out))
|
|
294
|
+
|
|
295
|
+
while (msg := await ch_in.data.recv(context)) is not None:
|
|
296
|
+
table_chunk = TableChunk.from_message(msg).make_available_and_spill(
|
|
297
|
+
context.br(), allow_overbooking=True
|
|
298
|
+
)
|
|
299
|
+
seq_num = msg.sequence_number
|
|
300
|
+
del msg
|
|
301
|
+
for ch_out in chs_out:
|
|
302
|
+
await ch_out.data.send(
|
|
303
|
+
context,
|
|
304
|
+
Message(
|
|
305
|
+
seq_num,
|
|
306
|
+
TableChunk.from_pylibcudf_table(
|
|
307
|
+
table_chunk.table_view(),
|
|
308
|
+
table_chunk.stream,
|
|
309
|
+
exclusive_view=False,
|
|
310
|
+
),
|
|
311
|
+
),
|
|
312
|
+
)
|
|
313
|
+
del table_chunk
|
|
314
|
+
|
|
315
|
+
await asyncio.gather(*(ch.data.drain(context) for ch in chs_out))
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@define_py_node()
|
|
319
|
+
async def fanout_node_unbounded(
|
|
320
|
+
context: Context,
|
|
321
|
+
ch_in: ChannelPair,
|
|
322
|
+
*chs_out: ChannelPair,
|
|
323
|
+
) -> None:
|
|
324
|
+
"""
|
|
325
|
+
Unbounded fanout node for rapidsmpf with spilling support.
|
|
326
|
+
|
|
327
|
+
Broadcasts chunks from input to all output channels. This is called
|
|
328
|
+
"unbounded" because it handles the case where one channel may consume
|
|
329
|
+
all data before another channel consumes any data.
|
|
330
|
+
|
|
331
|
+
The implementation uses adaptive sending with spillable buffers:
|
|
332
|
+
- Maintains a spillable FIFO buffer for each output channel
|
|
333
|
+
- Messages are buffered in host memory (spillable to disk)
|
|
334
|
+
- Sends to all channels concurrently
|
|
335
|
+
- Receives next chunk as soon as any channel makes progress
|
|
336
|
+
- Efficient for both balanced and imbalanced consumption patterns
|
|
337
|
+
|
|
338
|
+
Parameters
|
|
339
|
+
----------
|
|
340
|
+
context
|
|
341
|
+
The rapidsmpf context.
|
|
342
|
+
ch_in
|
|
343
|
+
The input ChannelPair.
|
|
344
|
+
chs_out
|
|
345
|
+
The output ChannelPairs.
|
|
346
|
+
"""
|
|
347
|
+
# TODO: Use rapidsmpf fanout node once available.
|
|
348
|
+
# See: https://github.com/rapidsai/rapidsmpf/issues/560
|
|
349
|
+
async with shutdown_on_error(
|
|
350
|
+
context,
|
|
351
|
+
ch_in.metadata,
|
|
352
|
+
ch_in.data,
|
|
353
|
+
*[ch.metadata for ch in chs_out],
|
|
354
|
+
*[ch.data for ch in chs_out],
|
|
355
|
+
):
|
|
356
|
+
# Forward metadata to all outputs.
|
|
357
|
+
metadata = await ch_in.recv_metadata(context)
|
|
358
|
+
await asyncio.gather(*(ch.send_metadata(context, metadata) for ch in chs_out))
|
|
359
|
+
|
|
360
|
+
# Spillable FIFO buffer for each output channel
|
|
361
|
+
output_buffers: list[SpillableMessages] = [SpillableMessages() for _ in chs_out]
|
|
362
|
+
num_outputs = len(chs_out)
|
|
363
|
+
|
|
364
|
+
# Track message IDs in FIFO order for each output buffer
|
|
365
|
+
buffer_ids: list[list[int]] = [[] for _ in chs_out]
|
|
366
|
+
|
|
367
|
+
# Register a single spill function for all buffers
|
|
368
|
+
# This ensures global FIFO ordering when spilling across all outputs
|
|
369
|
+
spill_func_id = context.br().spill_manager.add_spill_function(
|
|
370
|
+
make_spill_function(output_buffers, context), priority=0
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
try:
|
|
374
|
+
# Track active send/drain tasks for each output
|
|
375
|
+
active_tasks: dict[int, asyncio.Task] = {}
|
|
376
|
+
|
|
377
|
+
# Track which outputs need to be drained (set when no more input)
|
|
378
|
+
needs_drain: set[int] = set()
|
|
379
|
+
|
|
380
|
+
# Receive task
|
|
381
|
+
recv_task: asyncio.Task | None = asyncio.create_task(
|
|
382
|
+
ch_in.data.recv(context)
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Flag to indicate we should start a new receive (for backpressure)
|
|
386
|
+
can_receive: bool = True
|
|
387
|
+
|
|
388
|
+
async def send_one_from_buffer(idx: int) -> None:
|
|
389
|
+
"""
|
|
390
|
+
Send one buffered message for output idx.
|
|
391
|
+
|
|
392
|
+
The message remains in host memory (spillable) through the channel.
|
|
393
|
+
The downstream consumer will call make_available() when needed.
|
|
394
|
+
"""
|
|
395
|
+
if buffer_ids[idx]:
|
|
396
|
+
mid = buffer_ids[idx].pop(0)
|
|
397
|
+
msg = output_buffers[idx].extract(mid=mid)
|
|
398
|
+
await chs_out[idx].data.send(context, msg)
|
|
399
|
+
|
|
400
|
+
async def drain_output(idx: int) -> None:
|
|
401
|
+
"""Drain output channel idx."""
|
|
402
|
+
await chs_out[idx].data.drain(context)
|
|
403
|
+
|
|
404
|
+
# Main loop: coordinate receiving, sending, and draining
|
|
405
|
+
while (
|
|
406
|
+
recv_task is not None or active_tasks or any(buffer_ids) or needs_drain
|
|
407
|
+
):
|
|
408
|
+
# Collect all currently active tasks
|
|
409
|
+
tasks_to_wait = list(active_tasks.values())
|
|
410
|
+
# Only include recv_task if we're allowed to receive
|
|
411
|
+
if recv_task is not None and can_receive:
|
|
412
|
+
tasks_to_wait.append(recv_task)
|
|
413
|
+
|
|
414
|
+
# Start new tasks for outputs with work to do
|
|
415
|
+
for idx in range(len(chs_out)):
|
|
416
|
+
if idx not in active_tasks:
|
|
417
|
+
if buffer_ids[idx]:
|
|
418
|
+
# Send next buffered message
|
|
419
|
+
task = asyncio.create_task(send_one_from_buffer(idx))
|
|
420
|
+
active_tasks[idx] = task
|
|
421
|
+
tasks_to_wait.append(task)
|
|
422
|
+
elif idx in needs_drain:
|
|
423
|
+
# Buffer empty and no more input - drain this output
|
|
424
|
+
task = asyncio.create_task(drain_output(idx))
|
|
425
|
+
active_tasks[idx] = task
|
|
426
|
+
tasks_to_wait.append(task)
|
|
427
|
+
needs_drain.discard(idx)
|
|
428
|
+
|
|
429
|
+
# If nothing to wait for, we're done
|
|
430
|
+
if not tasks_to_wait:
|
|
431
|
+
break
|
|
432
|
+
|
|
433
|
+
# Wait for ANY task to complete
|
|
434
|
+
done, _ = await asyncio.wait(
|
|
435
|
+
tasks_to_wait, return_when=asyncio.FIRST_COMPLETED
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Process completed tasks
|
|
439
|
+
for task in done:
|
|
440
|
+
if task is recv_task:
|
|
441
|
+
# Receive completed
|
|
442
|
+
msg = task.result()
|
|
443
|
+
if msg is None:
|
|
444
|
+
# End of input - mark all outputs as needing drain
|
|
445
|
+
recv_task = None
|
|
446
|
+
needs_drain.update(range(len(chs_out)))
|
|
447
|
+
else:
|
|
448
|
+
# Determine where to copy based on:
|
|
449
|
+
# 1. Current message location (avoid unnecessary transfers)
|
|
450
|
+
# 2. Available memory (avoid OOM)
|
|
451
|
+
content_desc = msg.get_content_description()
|
|
452
|
+
device_size = content_desc.content_sizes.get(
|
|
453
|
+
MemoryType.DEVICE, 0
|
|
454
|
+
)
|
|
455
|
+
copy_cost = msg.copy_cost()
|
|
456
|
+
|
|
457
|
+
# Check if we have enough device memory for all copies
|
|
458
|
+
# We need (num_outputs - 1) copies since last one reuses original
|
|
459
|
+
num_copies = num_outputs - 1
|
|
460
|
+
total_copy_cost = copy_cost * num_copies
|
|
461
|
+
available_device_mem = context.br().memory_available(
|
|
462
|
+
MemoryType.DEVICE
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
# Decide target memory:
|
|
466
|
+
# Use device ONLY if message is in device AND we have sufficient headroom.
|
|
467
|
+
# TODO: Use further information about the downstream operations to make
|
|
468
|
+
# a more informed decision.
|
|
469
|
+
required_headroom = total_copy_cost * 2
|
|
470
|
+
if (
|
|
471
|
+
device_size > 0
|
|
472
|
+
and available_device_mem >= required_headroom
|
|
473
|
+
):
|
|
474
|
+
# Use reserve_device_memory_and_spill to automatically trigger spilling
|
|
475
|
+
# if needed to make room for the copy
|
|
476
|
+
memory_reservation = (
|
|
477
|
+
context.br().reserve_device_memory_and_spill(
|
|
478
|
+
total_copy_cost,
|
|
479
|
+
allow_overbooking=True,
|
|
480
|
+
)
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
# Use host memory for buffering - much safer
|
|
484
|
+
# Downstream consumers will make_available() when they need device memory
|
|
485
|
+
memory_reservation, _ = context.br().reserve(
|
|
486
|
+
MemoryType.HOST,
|
|
487
|
+
total_copy_cost,
|
|
488
|
+
allow_overbooking=True,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# Copy message for each output buffer
|
|
492
|
+
# Copies are spillable and allow downstream consumers
|
|
493
|
+
# to control device memory allocation
|
|
494
|
+
for idx, sm in enumerate(output_buffers):
|
|
495
|
+
if idx < num_outputs - 1:
|
|
496
|
+
# Copy to target memory and insert into spillable buffer
|
|
497
|
+
mid = sm.insert(msg.copy(memory_reservation))
|
|
498
|
+
else:
|
|
499
|
+
# Optimization: reuse the original message for last output
|
|
500
|
+
# (no copy needed)
|
|
501
|
+
mid = sm.insert(msg)
|
|
502
|
+
buffer_ids[idx].append(mid)
|
|
503
|
+
|
|
504
|
+
# Don't receive next chunk until at least one send completes
|
|
505
|
+
can_receive = False
|
|
506
|
+
recv_task = asyncio.create_task(ch_in.data.recv(context))
|
|
507
|
+
else:
|
|
508
|
+
# Must be a send or drain task - find which output and remove it
|
|
509
|
+
for idx, at in list(active_tasks.items()):
|
|
510
|
+
if at is task:
|
|
511
|
+
del active_tasks[idx]
|
|
512
|
+
# A send completed - allow receiving again
|
|
513
|
+
can_receive = True
|
|
514
|
+
break
|
|
515
|
+
|
|
516
|
+
finally:
|
|
517
|
+
# Clean up spill function registration
|
|
518
|
+
context.br().spill_manager.remove_spill_function(spill_func_id)
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
@generate_ir_sub_network.register(IR)
|
|
522
|
+
def _(
|
|
523
|
+
ir: IR, rec: SubNetGenerator
|
|
524
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
525
|
+
# Default generate_ir_sub_network logic.
|
|
526
|
+
# Use simple pointwise node.
|
|
527
|
+
|
|
528
|
+
# Process children
|
|
529
|
+
nodes, channels = process_children(ir, rec)
|
|
530
|
+
|
|
531
|
+
# Create output ChannelManager
|
|
532
|
+
channels[ir] = ChannelManager(rec.state["context"])
|
|
533
|
+
|
|
534
|
+
if len(ir.children) == 1:
|
|
535
|
+
# Single-channel default node
|
|
536
|
+
preserve_partitioning = isinstance(
|
|
537
|
+
# TODO: We don't need to worry about
|
|
538
|
+
# non-pointwise Filter operations here,
|
|
539
|
+
# because the lowering stage would have
|
|
540
|
+
# collapsed to one partition anyway.
|
|
541
|
+
ir,
|
|
542
|
+
(Cache, Projection, Filter),
|
|
543
|
+
)
|
|
544
|
+
nodes[ir] = [
|
|
545
|
+
default_node_single(
|
|
546
|
+
rec.state["context"],
|
|
547
|
+
ir,
|
|
548
|
+
rec.state["ir_context"],
|
|
549
|
+
channels[ir].reserve_input_slot(),
|
|
550
|
+
channels[ir.children[0]].reserve_output_slot(),
|
|
551
|
+
preserve_partitioning=preserve_partitioning,
|
|
552
|
+
)
|
|
553
|
+
]
|
|
554
|
+
else:
|
|
555
|
+
# Multi-channel default node
|
|
556
|
+
nodes[ir] = [
|
|
557
|
+
default_node_multi(
|
|
558
|
+
rec.state["context"],
|
|
559
|
+
ir,
|
|
560
|
+
rec.state["ir_context"],
|
|
561
|
+
channels[ir].reserve_input_slot(),
|
|
562
|
+
tuple(channels[c].reserve_output_slot() for c in ir.children),
|
|
563
|
+
)
|
|
564
|
+
]
|
|
565
|
+
|
|
566
|
+
return nodes, channels
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
@define_py_node()
|
|
570
|
+
async def empty_node(
|
|
571
|
+
context: Context,
|
|
572
|
+
ir: Empty,
|
|
573
|
+
ir_context: IRExecutionContext,
|
|
574
|
+
ch_out: ChannelPair,
|
|
575
|
+
) -> None:
|
|
576
|
+
"""
|
|
577
|
+
Empty node for rapidsmpf - produces a single empty chunk.
|
|
578
|
+
|
|
579
|
+
Parameters
|
|
580
|
+
----------
|
|
581
|
+
context
|
|
582
|
+
The rapidsmpf context.
|
|
583
|
+
ir
|
|
584
|
+
The Empty node.
|
|
585
|
+
ir_context
|
|
586
|
+
The execution context for the IR node.
|
|
587
|
+
ch_out
|
|
588
|
+
The output ChannelPair.
|
|
589
|
+
"""
|
|
590
|
+
async with shutdown_on_error(context, ch_out.metadata, ch_out.data):
|
|
591
|
+
# Send metadata indicating a single empty chunk
|
|
592
|
+
await ch_out.send_metadata(context, Metadata(1, duplicated=True))
|
|
593
|
+
|
|
594
|
+
# Evaluate the IR node to create an empty DataFrame
|
|
595
|
+
df: DataFrame = ir.do_evaluate(*ir._non_child_args, context=ir_context)
|
|
596
|
+
|
|
597
|
+
# Return the output chunk (empty but with correct schema)
|
|
598
|
+
chunk = TableChunk.from_pylibcudf_table(
|
|
599
|
+
df.table, df.stream, exclusive_view=True
|
|
600
|
+
)
|
|
601
|
+
await ch_out.data.send(context, Message(0, chunk))
|
|
602
|
+
|
|
603
|
+
await ch_out.data.drain(context)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
@generate_ir_sub_network.register(Empty)
|
|
607
|
+
def _(
|
|
608
|
+
ir: Empty, rec: SubNetGenerator
|
|
609
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
610
|
+
"""Generate network for Empty node - produces one empty chunk."""
|
|
611
|
+
context = rec.state["context"]
|
|
612
|
+
ir_context = rec.state["ir_context"]
|
|
613
|
+
channels: dict[IR, ChannelManager] = {ir: ChannelManager(rec.state["context"])}
|
|
614
|
+
nodes: dict[IR, list[Any]] = {
|
|
615
|
+
ir: [empty_node(context, ir, ir_context, channels[ir].reserve_input_slot())]
|
|
616
|
+
}
|
|
617
|
+
return nodes, channels
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def generate_ir_sub_network_wrapper(
|
|
621
|
+
ir: IR, rec: SubNetGenerator
|
|
622
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
623
|
+
"""
|
|
624
|
+
Generate a sub-network for the RapidsMPF streaming runtime.
|
|
625
|
+
|
|
626
|
+
Parameters
|
|
627
|
+
----------
|
|
628
|
+
ir
|
|
629
|
+
The IR node.
|
|
630
|
+
rec
|
|
631
|
+
Recursive SubNetGenerator callable.
|
|
632
|
+
|
|
633
|
+
Returns
|
|
634
|
+
-------
|
|
635
|
+
nodes
|
|
636
|
+
Dictionary mapping each IR node to its list of streaming-network node(s).
|
|
637
|
+
channels
|
|
638
|
+
Dictionary mapping between each IR node and its
|
|
639
|
+
corresponding streaming-network output ChannelManager.
|
|
640
|
+
"""
|
|
641
|
+
nodes, channels = generate_ir_sub_network(ir, rec)
|
|
642
|
+
|
|
643
|
+
# Check if this node needs fanout
|
|
644
|
+
if (fanout_info := rec.state["fanout_nodes"].get(ir)) is not None:
|
|
645
|
+
count = fanout_info.num_consumers
|
|
646
|
+
manager = ChannelManager(rec.state["context"], count=count)
|
|
647
|
+
fanout_node: Any
|
|
648
|
+
if fanout_info.unbounded:
|
|
649
|
+
fanout_node = fanout_node_unbounded(
|
|
650
|
+
rec.state["context"],
|
|
651
|
+
channels[ir].reserve_output_slot(),
|
|
652
|
+
*[manager.reserve_input_slot() for _ in range(count)],
|
|
653
|
+
)
|
|
654
|
+
else: # "bounded"
|
|
655
|
+
fanout_node = fanout_node_bounded(
|
|
656
|
+
rec.state["context"],
|
|
657
|
+
channels[ir].reserve_output_slot(),
|
|
658
|
+
*[manager.reserve_input_slot() for _ in range(count)],
|
|
659
|
+
)
|
|
660
|
+
nodes[ir].append(fanout_node)
|
|
661
|
+
channels[ir] = manager
|
|
662
|
+
return nodes, channels
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
@define_py_node()
|
|
666
|
+
async def metadata_feeder_node(
|
|
667
|
+
context: Context,
|
|
668
|
+
channel: ChannelPair,
|
|
669
|
+
metadata: Metadata,
|
|
670
|
+
) -> None:
|
|
671
|
+
"""
|
|
672
|
+
Feed metadata to a channel pair.
|
|
673
|
+
|
|
674
|
+
Parameters
|
|
675
|
+
----------
|
|
676
|
+
context
|
|
677
|
+
The rapidsmpf context.
|
|
678
|
+
channel
|
|
679
|
+
The channel pair.
|
|
680
|
+
metadata
|
|
681
|
+
The metadata to feed.
|
|
682
|
+
"""
|
|
683
|
+
async with shutdown_on_error(context, channel.metadata, channel.data):
|
|
684
|
+
await channel.send_metadata(context, metadata)
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
@define_py_node()
|
|
688
|
+
async def metadata_drain_node(
|
|
689
|
+
context: Context,
|
|
690
|
+
ir: IR,
|
|
691
|
+
ir_context: IRExecutionContext,
|
|
692
|
+
ch_in: ChannelPair,
|
|
693
|
+
ch_out: Any,
|
|
694
|
+
metadata_collector: list[Metadata] | None,
|
|
695
|
+
) -> None:
|
|
696
|
+
"""
|
|
697
|
+
Drain metadata and forward data to a single channel.
|
|
698
|
+
|
|
699
|
+
Parameters
|
|
700
|
+
----------
|
|
701
|
+
context
|
|
702
|
+
The rapidsmpf context.
|
|
703
|
+
ir
|
|
704
|
+
The IR node.
|
|
705
|
+
ir_context
|
|
706
|
+
The execution context for the IR node.
|
|
707
|
+
ch_in
|
|
708
|
+
The input ChannelPair (with metadata and data channels).
|
|
709
|
+
ch_out
|
|
710
|
+
The output data channel.
|
|
711
|
+
metadata_collector
|
|
712
|
+
The list to collect the final metadata.
|
|
713
|
+
This list will be mutated when the network is executed.
|
|
714
|
+
If None, metadata will not be collected.
|
|
715
|
+
"""
|
|
716
|
+
async with shutdown_on_error(context, ch_in.metadata, ch_in.data, ch_out):
|
|
717
|
+
# Drain metadata channel (we don't need it after this point)
|
|
718
|
+
metadata = await ch_in.recv_metadata(context)
|
|
719
|
+
send_empty = metadata.duplicated and context.comm().rank != 0
|
|
720
|
+
if metadata_collector is not None:
|
|
721
|
+
metadata_collector.append(metadata)
|
|
722
|
+
|
|
723
|
+
# Forward non-duplicated data messages
|
|
724
|
+
while (msg := await ch_in.data.recv(context)) is not None:
|
|
725
|
+
if not send_empty:
|
|
726
|
+
await ch_out.send(context, msg)
|
|
727
|
+
|
|
728
|
+
# Send empty data if needed
|
|
729
|
+
if send_empty:
|
|
730
|
+
stream = ir_context.get_cuda_stream()
|
|
731
|
+
await ch_out.send(
|
|
732
|
+
context, Message(0, empty_table_chunk(ir, context, stream))
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
await ch_out.drain(context)
|