cudf-polars-cu13 25.12.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +28 -7
- cudf_polars/containers/column.py +51 -26
- cudf_polars/dsl/expressions/binaryop.py +1 -1
- cudf_polars/dsl/expressions/boolean.py +1 -1
- cudf_polars/dsl/expressions/selection.py +1 -1
- cudf_polars/dsl/expressions/string.py +29 -20
- cudf_polars/dsl/expressions/ternary.py +25 -1
- cudf_polars/dsl/expressions/unary.py +11 -8
- cudf_polars/dsl/ir.py +351 -281
- cudf_polars/dsl/translate.py +18 -15
- cudf_polars/dsl/utils/aggregations.py +10 -5
- cudf_polars/experimental/base.py +10 -0
- cudf_polars/experimental/benchmarks/pdsh.py +1 -1
- cudf_polars/experimental/benchmarks/utils.py +83 -2
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +1 -1
- cudf_polars/experimental/expressions.py +8 -5
- cudf_polars/experimental/groupby.py +2 -0
- cudf_polars/experimental/io.py +64 -42
- cudf_polars/experimental/join.py +15 -2
- cudf_polars/experimental/parallel.py +10 -7
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/{shuffle.py → collectives/shuffle.py} +90 -114
- cudf_polars/experimental/rapidsmpf/core.py +194 -67
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +6 -3
- cudf_polars/experimental/rapidsmpf/io.py +162 -70
- cudf_polars/experimental/rapidsmpf/join.py +162 -77
- cudf_polars/experimental/rapidsmpf/nodes.py +421 -180
- cudf_polars/experimental/rapidsmpf/repartition.py +130 -65
- cudf_polars/experimental/rapidsmpf/union.py +24 -5
- cudf_polars/experimental/rapidsmpf/utils.py +228 -16
- cudf_polars/experimental/shuffle.py +18 -4
- cudf_polars/experimental/sort.py +13 -6
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/testing/plugin.py +6 -3
- cudf_polars/utils/config.py +67 -0
- cudf_polars/utils/versions.py +3 -3
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/METADATA +9 -10
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/RECORD +47 -43
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Core node definitions for the RapidsMPF streaming runtime."""
|
|
4
4
|
|
|
@@ -7,17 +7,23 @@ from __future__ import annotations
|
|
|
7
7
|
import asyncio
|
|
8
8
|
from typing import TYPE_CHECKING, Any, cast
|
|
9
9
|
|
|
10
|
+
from rapidsmpf.memory.buffer import MemoryType
|
|
10
11
|
from rapidsmpf.streaming.core.message import Message
|
|
11
12
|
from rapidsmpf.streaming.core.node import define_py_node
|
|
13
|
+
from rapidsmpf.streaming.core.spillable_messages import SpillableMessages
|
|
12
14
|
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
|
|
13
15
|
|
|
14
16
|
from cudf_polars.containers import DataFrame
|
|
15
|
-
from cudf_polars.dsl.ir import IR, Empty
|
|
17
|
+
from cudf_polars.dsl.ir import IR, Cache, Empty, Filter, Projection
|
|
16
18
|
from cudf_polars.experimental.rapidsmpf.dispatch import (
|
|
17
19
|
generate_ir_sub_network,
|
|
18
20
|
)
|
|
19
21
|
from cudf_polars.experimental.rapidsmpf.utils import (
|
|
20
22
|
ChannelManager,
|
|
23
|
+
Metadata,
|
|
24
|
+
empty_table_chunk,
|
|
25
|
+
make_spill_function,
|
|
26
|
+
opaque_reservation,
|
|
21
27
|
process_children,
|
|
22
28
|
shutdown_on_error,
|
|
23
29
|
)
|
|
@@ -37,6 +43,8 @@ async def default_node_single(
|
|
|
37
43
|
ir_context: IRExecutionContext,
|
|
38
44
|
ch_out: ChannelPair,
|
|
39
45
|
ch_in: ChannelPair,
|
|
46
|
+
*,
|
|
47
|
+
preserve_partitioning: bool = False,
|
|
40
48
|
) -> None:
|
|
41
49
|
"""
|
|
42
50
|
Single-channel default node for rapidsmpf.
|
|
@@ -53,32 +61,71 @@ async def default_node_single(
|
|
|
53
61
|
The output ChannelPair.
|
|
54
62
|
ch_in
|
|
55
63
|
The input ChannelPair.
|
|
64
|
+
preserve_partitioning
|
|
65
|
+
Whether to preserve the partitioning metadata of the input chunks.
|
|
56
66
|
|
|
57
67
|
Notes
|
|
58
68
|
-----
|
|
59
69
|
Chunks are processed in the order they are received.
|
|
60
70
|
"""
|
|
61
|
-
async with shutdown_on_error(
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
71
|
+
async with shutdown_on_error(
|
|
72
|
+
context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
|
|
73
|
+
):
|
|
74
|
+
# Recv/send metadata.
|
|
75
|
+
metadata_in = await ch_in.recv_metadata(context)
|
|
76
|
+
metadata_out = Metadata(
|
|
77
|
+
metadata_in.count,
|
|
78
|
+
partitioned_on=metadata_in.partitioned_on if preserve_partitioning else (),
|
|
79
|
+
duplicated=metadata_in.duplicated,
|
|
80
|
+
)
|
|
81
|
+
await ch_out.send_metadata(context, metadata_out)
|
|
82
|
+
|
|
83
|
+
# Recv/send data.
|
|
84
|
+
seq_num = 0
|
|
85
|
+
receiving = True
|
|
86
|
+
received_any = False
|
|
87
|
+
while receiving:
|
|
88
|
+
msg = await ch_in.data.recv(context)
|
|
89
|
+
if msg is None:
|
|
90
|
+
receiving = False
|
|
91
|
+
if received_any:
|
|
92
|
+
break
|
|
93
|
+
else:
|
|
94
|
+
# Make sure we have an empty chunk in case do_evaluate
|
|
95
|
+
# always produces rows (e.g. aggregation)
|
|
96
|
+
stream = ir_context.get_cuda_stream()
|
|
97
|
+
chunk = empty_table_chunk(ir.children[0], context, stream)
|
|
98
|
+
else:
|
|
99
|
+
received_any = True
|
|
100
|
+
chunk = TableChunk.from_message(msg).make_available_and_spill(
|
|
101
|
+
context.br(), allow_overbooking=True
|
|
102
|
+
)
|
|
103
|
+
seq_num = msg.sequence_number
|
|
104
|
+
del msg
|
|
105
|
+
|
|
106
|
+
input_bytes = chunk.data_alloc_size(MemoryType.DEVICE)
|
|
107
|
+
with opaque_reservation(context, input_bytes):
|
|
108
|
+
df = await asyncio.to_thread(
|
|
109
|
+
ir.do_evaluate,
|
|
110
|
+
*ir._non_child_args,
|
|
111
|
+
DataFrame.from_table(
|
|
112
|
+
chunk.table_view(),
|
|
113
|
+
list(ir.children[0].schema.keys()),
|
|
114
|
+
list(ir.children[0].schema.values()),
|
|
115
|
+
chunk.stream,
|
|
116
|
+
),
|
|
117
|
+
context=ir_context,
|
|
118
|
+
)
|
|
119
|
+
await ch_out.data.send(
|
|
120
|
+
context,
|
|
121
|
+
Message(
|
|
122
|
+
seq_num,
|
|
123
|
+
TableChunk.from_pylibcudf_table(
|
|
124
|
+
df.table, chunk.stream, exclusive_view=True
|
|
125
|
+
),
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
del df, chunk
|
|
82
129
|
|
|
83
130
|
await ch_out.data.drain(context)
|
|
84
131
|
|
|
@@ -90,6 +137,8 @@ async def default_node_multi(
|
|
|
90
137
|
ir_context: IRExecutionContext,
|
|
91
138
|
ch_out: ChannelPair,
|
|
92
139
|
chs_in: tuple[ChannelPair, ...],
|
|
140
|
+
*,
|
|
141
|
+
partitioning_index: int | None = None,
|
|
93
142
|
) -> None:
|
|
94
143
|
"""
|
|
95
144
|
Pointwise node for rapidsmpf.
|
|
@@ -103,17 +152,30 @@ async def default_node_multi(
|
|
|
103
152
|
ir_context
|
|
104
153
|
The execution context for the IR node.
|
|
105
154
|
ch_out
|
|
106
|
-
The output ChannelPair
|
|
155
|
+
The output ChannelPair.
|
|
107
156
|
chs_in
|
|
108
|
-
Tuple of input ChannelPairs
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
Input chunks must be aligned for evaluation. Messages from each input
|
|
113
|
-
channel are assumed to arrive in sequence number order, so we only need
|
|
114
|
-
to hold one chunk per channel at a time.
|
|
157
|
+
Tuple of input ChannelPairs.
|
|
158
|
+
partitioning_index
|
|
159
|
+
Index of the input channel to preserve partitioning information for.
|
|
160
|
+
If None, no partitioning information is preserved.
|
|
115
161
|
"""
|
|
116
|
-
async with shutdown_on_error(
|
|
162
|
+
async with shutdown_on_error(
|
|
163
|
+
context,
|
|
164
|
+
*[ch.metadata for ch in chs_in],
|
|
165
|
+
ch_out.metadata,
|
|
166
|
+
*[ch.data for ch in chs_in],
|
|
167
|
+
ch_out.data,
|
|
168
|
+
):
|
|
169
|
+
# Merge and forward basic metadata.
|
|
170
|
+
metadata = Metadata(1)
|
|
171
|
+
for idx, ch_in in enumerate(chs_in):
|
|
172
|
+
md_child = await ch_in.recv_metadata(context)
|
|
173
|
+
metadata.count = max(md_child.count, metadata.count)
|
|
174
|
+
metadata.duplicated = metadata.duplicated and md_child.duplicated
|
|
175
|
+
if idx == partitioning_index:
|
|
176
|
+
metadata.partitioned_on = md_child.partitioned_on
|
|
177
|
+
await ch_out.send_metadata(context, metadata)
|
|
178
|
+
|
|
117
179
|
seq_num = 0
|
|
118
180
|
n_children = len(chs_in)
|
|
119
181
|
finished_channels: set[int] = set()
|
|
@@ -122,11 +184,10 @@ async def default_node_multi(
|
|
|
122
184
|
ready_chunks: list[TableChunk | None] = [None] * n_children
|
|
123
185
|
chunk_count: list[int] = [0] * n_children
|
|
124
186
|
|
|
187
|
+
# Recv/send data.
|
|
125
188
|
while True:
|
|
126
189
|
# Receive from all non-finished channels
|
|
127
|
-
for ch_idx,
|
|
128
|
-
zip(chs_in, ir.children, strict=True)
|
|
129
|
-
):
|
|
190
|
+
for ch_idx, ch_in in enumerate(chs_in):
|
|
130
191
|
if ch_idx in finished_channels:
|
|
131
192
|
continue # This channel already finished, reuse its data
|
|
132
193
|
|
|
@@ -138,19 +199,20 @@ async def default_node_multi(
|
|
|
138
199
|
# Store the new chunk (replacing previous if any)
|
|
139
200
|
ready_chunks[ch_idx] = TableChunk.from_message(msg)
|
|
140
201
|
chunk_count[ch_idx] += 1
|
|
141
|
-
|
|
142
|
-
f"Channel {ch_idx} has no data after receive loop."
|
|
143
|
-
)
|
|
202
|
+
del msg
|
|
144
203
|
|
|
145
204
|
# If all channels finished, we're done
|
|
146
205
|
if len(finished_channels) == n_children:
|
|
147
206
|
break
|
|
148
207
|
|
|
149
|
-
#
|
|
150
|
-
#
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
208
|
+
# Check if any channel drained without providing data.
|
|
209
|
+
# If so, create an empty chunk for that channel.
|
|
210
|
+
for ch_idx, child in enumerate(ir.children):
|
|
211
|
+
if ready_chunks[ch_idx] is None:
|
|
212
|
+
# Channel drained without data - create empty chunk
|
|
213
|
+
stream = ir_context.get_cuda_stream()
|
|
214
|
+
ready_chunks[ch_idx] = empty_table_chunk(child, context, stream)
|
|
215
|
+
|
|
154
216
|
# Ensure all table chunks are unspilled and available.
|
|
155
217
|
ready_chunks = [
|
|
156
218
|
chunk.make_available_and_spill(context.br(), allow_overbooking=True)
|
|
@@ -166,27 +228,33 @@ async def default_node_multi(
|
|
|
166
228
|
for chunk, child in zip(ready_chunks, ir.children, strict=True)
|
|
167
229
|
]
|
|
168
230
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
*ir._non_child_args,
|
|
173
|
-
*dfs,
|
|
174
|
-
context=ir_context,
|
|
231
|
+
input_bytes = sum(
|
|
232
|
+
chunk.data_alloc_size(MemoryType.DEVICE)
|
|
233
|
+
for chunk in cast(list[TableChunk], ready_chunks)
|
|
175
234
|
)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
235
|
+
with opaque_reservation(context, input_bytes):
|
|
236
|
+
df = await asyncio.to_thread(
|
|
237
|
+
ir.do_evaluate,
|
|
238
|
+
*ir._non_child_args,
|
|
239
|
+
*dfs,
|
|
240
|
+
context=ir_context,
|
|
241
|
+
)
|
|
242
|
+
await ch_out.data.send(
|
|
243
|
+
context,
|
|
244
|
+
Message(
|
|
245
|
+
seq_num,
|
|
246
|
+
TableChunk.from_pylibcudf_table(
|
|
247
|
+
df.table,
|
|
248
|
+
df.stream,
|
|
249
|
+
exclusive_view=True,
|
|
250
|
+
),
|
|
184
251
|
),
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
|
|
252
|
+
)
|
|
253
|
+
seq_num += 1
|
|
254
|
+
del df, dfs
|
|
188
255
|
|
|
189
256
|
# Drain the output channel
|
|
257
|
+
del ready_chunks
|
|
190
258
|
await ch_out.data.drain(context)
|
|
191
259
|
|
|
192
260
|
|
|
@@ -213,12 +281,23 @@ async def fanout_node_bounded(
|
|
|
213
281
|
"""
|
|
214
282
|
# TODO: Use rapidsmpf fanout node once available.
|
|
215
283
|
# See: https://github.com/rapidsai/rapidsmpf/issues/560
|
|
216
|
-
async with shutdown_on_error(
|
|
284
|
+
async with shutdown_on_error(
|
|
285
|
+
context,
|
|
286
|
+
ch_in.metadata,
|
|
287
|
+
ch_in.data,
|
|
288
|
+
*[ch.metadata for ch in chs_out],
|
|
289
|
+
*[ch.data for ch in chs_out],
|
|
290
|
+
):
|
|
291
|
+
# Forward metadata to all outputs.
|
|
292
|
+
metadata = await ch_in.recv_metadata(context)
|
|
293
|
+
await asyncio.gather(*(ch.send_metadata(context, metadata) for ch in chs_out))
|
|
294
|
+
|
|
217
295
|
while (msg := await ch_in.data.recv(context)) is not None:
|
|
218
296
|
table_chunk = TableChunk.from_message(msg).make_available_and_spill(
|
|
219
297
|
context.br(), allow_overbooking=True
|
|
220
298
|
)
|
|
221
299
|
seq_num = msg.sequence_number
|
|
300
|
+
del msg
|
|
222
301
|
for ch_out in chs_out:
|
|
223
302
|
await ch_out.data.send(
|
|
224
303
|
context,
|
|
@@ -231,6 +310,7 @@ async def fanout_node_bounded(
|
|
|
231
310
|
),
|
|
232
311
|
),
|
|
233
312
|
)
|
|
313
|
+
del table_chunk
|
|
234
314
|
|
|
235
315
|
await asyncio.gather(*(ch.data.drain(context) for ch in chs_out))
|
|
236
316
|
|
|
@@ -242,14 +322,15 @@ async def fanout_node_unbounded(
|
|
|
242
322
|
*chs_out: ChannelPair,
|
|
243
323
|
) -> None:
|
|
244
324
|
"""
|
|
245
|
-
Unbounded fanout node for rapidsmpf.
|
|
325
|
+
Unbounded fanout node for rapidsmpf with spilling support.
|
|
246
326
|
|
|
247
327
|
Broadcasts chunks from input to all output channels. This is called
|
|
248
328
|
"unbounded" because it handles the case where one channel may consume
|
|
249
329
|
all data before another channel consumes any data.
|
|
250
330
|
|
|
251
|
-
The implementation uses adaptive sending:
|
|
252
|
-
- Maintains a FIFO buffer for each output channel
|
|
331
|
+
The implementation uses adaptive sending with spillable buffers:
|
|
332
|
+
- Maintains a spillable FIFO buffer for each output channel
|
|
333
|
+
- Messages are buffered in host memory (spillable to disk)
|
|
253
334
|
- Sends to all channels concurrently
|
|
254
335
|
- Receives next chunk as soon as any channel makes progress
|
|
255
336
|
- Efficient for both balanced and imbalanced consumption patterns
|
|
@@ -265,107 +346,182 @@ async def fanout_node_unbounded(
|
|
|
265
346
|
"""
|
|
266
347
|
# TODO: Use rapidsmpf fanout node once available.
|
|
267
348
|
# See: https://github.com/rapidsai/rapidsmpf/issues/560
|
|
268
|
-
async with shutdown_on_error(
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
#
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
# Collect all currently active tasks
|
|
299
|
-
tasks_to_wait = list(active_tasks.values())
|
|
300
|
-
# Only include recv_task if we're allowed to receive
|
|
301
|
-
if recv_task is not None and can_receive:
|
|
302
|
-
tasks_to_wait.append(recv_task)
|
|
303
|
-
|
|
304
|
-
# Start new tasks for outputs with work to do
|
|
305
|
-
for idx in range(len(chs_out)):
|
|
306
|
-
if idx not in active_tasks:
|
|
307
|
-
if output_buffers[idx]:
|
|
308
|
-
# Send next buffered message
|
|
309
|
-
task = asyncio.create_task(send_one_from_buffer(idx))
|
|
310
|
-
active_tasks[idx] = task
|
|
311
|
-
tasks_to_wait.append(task)
|
|
312
|
-
elif idx in needs_drain:
|
|
313
|
-
# Buffer empty and no more input - drain this output
|
|
314
|
-
task = asyncio.create_task(drain_output(idx))
|
|
315
|
-
active_tasks[idx] = task
|
|
316
|
-
tasks_to_wait.append(task)
|
|
317
|
-
needs_drain.discard(idx)
|
|
318
|
-
|
|
319
|
-
# If nothing to wait for, we're done
|
|
320
|
-
if not tasks_to_wait:
|
|
321
|
-
break
|
|
349
|
+
async with shutdown_on_error(
|
|
350
|
+
context,
|
|
351
|
+
ch_in.metadata,
|
|
352
|
+
ch_in.data,
|
|
353
|
+
*[ch.metadata for ch in chs_out],
|
|
354
|
+
*[ch.data for ch in chs_out],
|
|
355
|
+
):
|
|
356
|
+
# Forward metadata to all outputs.
|
|
357
|
+
metadata = await ch_in.recv_metadata(context)
|
|
358
|
+
await asyncio.gather(*(ch.send_metadata(context, metadata) for ch in chs_out))
|
|
359
|
+
|
|
360
|
+
# Spillable FIFO buffer for each output channel
|
|
361
|
+
output_buffers: list[SpillableMessages] = [SpillableMessages() for _ in chs_out]
|
|
362
|
+
num_outputs = len(chs_out)
|
|
363
|
+
|
|
364
|
+
# Track message IDs in FIFO order for each output buffer
|
|
365
|
+
buffer_ids: list[list[int]] = [[] for _ in chs_out]
|
|
366
|
+
|
|
367
|
+
# Register a single spill function for all buffers
|
|
368
|
+
# This ensures global FIFO ordering when spilling across all outputs
|
|
369
|
+
spill_func_id = context.br().spill_manager.add_spill_function(
|
|
370
|
+
make_spill_function(output_buffers, context), priority=0
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
try:
|
|
374
|
+
# Track active send/drain tasks for each output
|
|
375
|
+
active_tasks: dict[int, asyncio.Task] = {}
|
|
376
|
+
|
|
377
|
+
# Track which outputs need to be drained (set when no more input)
|
|
378
|
+
needs_drain: set[int] = set()
|
|
322
379
|
|
|
323
|
-
#
|
|
324
|
-
|
|
325
|
-
|
|
380
|
+
# Receive task
|
|
381
|
+
recv_task: asyncio.Task | None = asyncio.create_task(
|
|
382
|
+
ch_in.data.recv(context)
|
|
326
383
|
)
|
|
327
384
|
|
|
328
|
-
#
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
385
|
+
# Flag to indicate we should start a new receive (for backpressure)
|
|
386
|
+
can_receive: bool = True
|
|
387
|
+
|
|
388
|
+
async def send_one_from_buffer(idx: int) -> None:
|
|
389
|
+
"""
|
|
390
|
+
Send one buffered message for output idx.
|
|
391
|
+
|
|
392
|
+
The message remains in host memory (spillable) through the channel.
|
|
393
|
+
The downstream consumer will call make_available() when needed.
|
|
394
|
+
"""
|
|
395
|
+
if buffer_ids[idx]:
|
|
396
|
+
mid = buffer_ids[idx].pop(0)
|
|
397
|
+
msg = output_buffers[idx].extract(mid=mid)
|
|
398
|
+
await chs_out[idx].data.send(context, msg)
|
|
399
|
+
|
|
400
|
+
async def drain_output(idx: int) -> None:
|
|
401
|
+
"""Drain output channel idx."""
|
|
402
|
+
await chs_out[idx].data.drain(context)
|
|
403
|
+
|
|
404
|
+
# Main loop: coordinate receiving, sending, and draining
|
|
405
|
+
while (
|
|
406
|
+
recv_task is not None or active_tasks or any(buffer_ids) or needs_drain
|
|
407
|
+
):
|
|
408
|
+
# Collect all currently active tasks
|
|
409
|
+
tasks_to_wait = list(active_tasks.values())
|
|
410
|
+
# Only include recv_task if we're allowed to receive
|
|
411
|
+
if recv_task is not None and can_receive:
|
|
412
|
+
tasks_to_wait.append(recv_task)
|
|
413
|
+
|
|
414
|
+
# Start new tasks for outputs with work to do
|
|
415
|
+
for idx in range(len(chs_out)):
|
|
416
|
+
if idx not in active_tasks:
|
|
417
|
+
if buffer_ids[idx]:
|
|
418
|
+
# Send next buffered message
|
|
419
|
+
task = asyncio.create_task(send_one_from_buffer(idx))
|
|
420
|
+
active_tasks[idx] = task
|
|
421
|
+
tasks_to_wait.append(task)
|
|
422
|
+
elif idx in needs_drain:
|
|
423
|
+
# Buffer empty and no more input - drain this output
|
|
424
|
+
task = asyncio.create_task(drain_output(idx))
|
|
425
|
+
active_tasks[idx] = task
|
|
426
|
+
tasks_to_wait.append(task)
|
|
427
|
+
needs_drain.discard(idx)
|
|
428
|
+
|
|
429
|
+
# If nothing to wait for, we're done
|
|
430
|
+
if not tasks_to_wait:
|
|
431
|
+
break
|
|
432
|
+
|
|
433
|
+
# Wait for ANY task to complete
|
|
434
|
+
done, _ = await asyncio.wait(
|
|
435
|
+
tasks_to_wait, return_when=asyncio.FIRST_COMPLETED
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Process completed tasks
|
|
439
|
+
for task in done:
|
|
440
|
+
if task is recv_task:
|
|
441
|
+
# Receive completed
|
|
442
|
+
msg = task.result()
|
|
443
|
+
if msg is None:
|
|
444
|
+
# End of input - mark all outputs as needing drain
|
|
445
|
+
recv_task = None
|
|
446
|
+
needs_drain.update(range(len(chs_out)))
|
|
447
|
+
else:
|
|
448
|
+
# Determine where to copy based on:
|
|
449
|
+
# 1. Current message location (avoid unnecessary transfers)
|
|
450
|
+
# 2. Available memory (avoid OOM)
|
|
451
|
+
content_desc = msg.get_content_description()
|
|
452
|
+
device_size = content_desc.content_sizes.get(
|
|
453
|
+
MemoryType.DEVICE, 0
|
|
454
|
+
)
|
|
455
|
+
copy_cost = msg.copy_cost()
|
|
456
|
+
|
|
457
|
+
# Check if we have enough device memory for all copies
|
|
458
|
+
# We need (num_outputs - 1) copies since last one reuses original
|
|
459
|
+
num_copies = num_outputs - 1
|
|
460
|
+
total_copy_cost = copy_cost * num_copies
|
|
461
|
+
available_device_mem = context.br().memory_available(
|
|
462
|
+
MemoryType.DEVICE
|
|
351
463
|
)
|
|
352
|
-
buffer.append(message)
|
|
353
464
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
465
|
+
# Decide target memory:
|
|
466
|
+
# Use device ONLY if message is in device AND we have sufficient headroom.
|
|
467
|
+
# TODO: Use further information about the downstream operations to make
|
|
468
|
+
# a more informed decision.
|
|
469
|
+
required_headroom = total_copy_cost * 2
|
|
470
|
+
if (
|
|
471
|
+
device_size > 0
|
|
472
|
+
and available_device_mem >= required_headroom
|
|
473
|
+
):
|
|
474
|
+
# Use reserve_device_memory_and_spill to automatically trigger spilling
|
|
475
|
+
# if needed to make room for the copy
|
|
476
|
+
memory_reservation = (
|
|
477
|
+
context.br().reserve_device_memory_and_spill(
|
|
478
|
+
total_copy_cost,
|
|
479
|
+
allow_overbooking=True,
|
|
480
|
+
)
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
# Use host memory for buffering - much safer
|
|
484
|
+
# Downstream consumers will make_available() when they need device memory
|
|
485
|
+
memory_reservation, _ = context.br().reserve(
|
|
486
|
+
MemoryType.HOST,
|
|
487
|
+
total_copy_cost,
|
|
488
|
+
allow_overbooking=True,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# Copy message for each output buffer
|
|
492
|
+
# Copies are spillable and allow downstream consumers
|
|
493
|
+
# to control device memory allocation
|
|
494
|
+
for idx, sm in enumerate(output_buffers):
|
|
495
|
+
if idx < num_outputs - 1:
|
|
496
|
+
# Copy to target memory and insert into spillable buffer
|
|
497
|
+
mid = sm.insert(msg.copy(memory_reservation))
|
|
498
|
+
else:
|
|
499
|
+
# Optimization: reuse the original message for last output
|
|
500
|
+
# (no copy needed)
|
|
501
|
+
mid = sm.insert(msg)
|
|
502
|
+
buffer_ids[idx].append(mid)
|
|
503
|
+
|
|
504
|
+
# Don't receive next chunk until at least one send completes
|
|
505
|
+
can_receive = False
|
|
506
|
+
recv_task = asyncio.create_task(ch_in.data.recv(context))
|
|
507
|
+
else:
|
|
508
|
+
# Must be a send or drain task - find which output and remove it
|
|
509
|
+
for idx, at in list(active_tasks.items()):
|
|
510
|
+
if at is task:
|
|
511
|
+
del active_tasks[idx]
|
|
512
|
+
# A send completed - allow receiving again
|
|
513
|
+
can_receive = True
|
|
514
|
+
break
|
|
515
|
+
|
|
516
|
+
finally:
|
|
517
|
+
# Clean up spill function registration
|
|
518
|
+
context.br().spill_manager.remove_spill_function(spill_func_id)
|
|
365
519
|
|
|
366
520
|
|
|
367
521
|
@generate_ir_sub_network.register(IR)
|
|
368
|
-
def _(
|
|
522
|
+
def _(
|
|
523
|
+
ir: IR, rec: SubNetGenerator
|
|
524
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
369
525
|
# Default generate_ir_sub_network logic.
|
|
370
526
|
# Use simple pointwise node.
|
|
371
527
|
|
|
@@ -377,18 +533,27 @@ def _(ir: IR, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManager]
|
|
|
377
533
|
|
|
378
534
|
if len(ir.children) == 1:
|
|
379
535
|
# Single-channel default node
|
|
380
|
-
|
|
536
|
+
preserve_partitioning = isinstance(
|
|
537
|
+
# TODO: We don't need to worry about
|
|
538
|
+
# non-pointwise Filter operations here,
|
|
539
|
+
# because the lowering stage would have
|
|
540
|
+
# collapsed to one partition anyway.
|
|
541
|
+
ir,
|
|
542
|
+
(Cache, Projection, Filter),
|
|
543
|
+
)
|
|
544
|
+
nodes[ir] = [
|
|
381
545
|
default_node_single(
|
|
382
546
|
rec.state["context"],
|
|
383
547
|
ir,
|
|
384
548
|
rec.state["ir_context"],
|
|
385
549
|
channels[ir].reserve_input_slot(),
|
|
386
550
|
channels[ir.children[0]].reserve_output_slot(),
|
|
551
|
+
preserve_partitioning=preserve_partitioning,
|
|
387
552
|
)
|
|
388
|
-
|
|
553
|
+
]
|
|
389
554
|
else:
|
|
390
555
|
# Multi-channel default node
|
|
391
|
-
nodes
|
|
556
|
+
nodes[ir] = [
|
|
392
557
|
default_node_multi(
|
|
393
558
|
rec.state["context"],
|
|
394
559
|
ir,
|
|
@@ -396,7 +561,7 @@ def _(ir: IR, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManager]
|
|
|
396
561
|
channels[ir].reserve_input_slot(),
|
|
397
562
|
tuple(channels[c].reserve_output_slot() for c in ir.children),
|
|
398
563
|
)
|
|
399
|
-
|
|
564
|
+
]
|
|
400
565
|
|
|
401
566
|
return nodes, channels
|
|
402
567
|
|
|
@@ -422,7 +587,10 @@ async def empty_node(
|
|
|
422
587
|
ch_out
|
|
423
588
|
The output ChannelPair.
|
|
424
589
|
"""
|
|
425
|
-
async with shutdown_on_error(context, ch_out.data):
|
|
590
|
+
async with shutdown_on_error(context, ch_out.metadata, ch_out.data):
|
|
591
|
+
# Send metadata indicating a single empty chunk
|
|
592
|
+
await ch_out.send_metadata(context, Metadata(1, duplicated=True))
|
|
593
|
+
|
|
426
594
|
# Evaluate the IR node to create an empty DataFrame
|
|
427
595
|
df: DataFrame = ir.do_evaluate(*ir._non_child_args, context=ir_context)
|
|
428
596
|
|
|
@@ -436,20 +604,22 @@ async def empty_node(
|
|
|
436
604
|
|
|
437
605
|
|
|
438
606
|
@generate_ir_sub_network.register(Empty)
|
|
439
|
-
def _(
|
|
607
|
+
def _(
|
|
608
|
+
ir: Empty, rec: SubNetGenerator
|
|
609
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
440
610
|
"""Generate network for Empty node - produces one empty chunk."""
|
|
441
611
|
context = rec.state["context"]
|
|
442
612
|
ir_context = rec.state["ir_context"]
|
|
443
613
|
channels: dict[IR, ChannelManager] = {ir: ChannelManager(rec.state["context"])}
|
|
444
|
-
nodes: list[Any] =
|
|
445
|
-
empty_node(context, ir, ir_context, channels[ir].reserve_input_slot())
|
|
446
|
-
|
|
614
|
+
nodes: dict[IR, list[Any]] = {
|
|
615
|
+
ir: [empty_node(context, ir, ir_context, channels[ir].reserve_input_slot())]
|
|
616
|
+
}
|
|
447
617
|
return nodes, channels
|
|
448
618
|
|
|
449
619
|
|
|
450
620
|
def generate_ir_sub_network_wrapper(
|
|
451
621
|
ir: IR, rec: SubNetGenerator
|
|
452
|
-
) -> tuple[list[Any], dict[IR, ChannelManager]]:
|
|
622
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
453
623
|
"""
|
|
454
624
|
Generate a sub-network for the RapidsMPF streaming runtime.
|
|
455
625
|
|
|
@@ -463,7 +633,7 @@ def generate_ir_sub_network_wrapper(
|
|
|
463
633
|
Returns
|
|
464
634
|
-------
|
|
465
635
|
nodes
|
|
466
|
-
|
|
636
|
+
Dictionary mapping each IR node to its list of streaming-network node(s).
|
|
467
637
|
channels
|
|
468
638
|
Dictionary mapping between each IR node and its
|
|
469
639
|
corresponding streaming-network output ChannelManager.
|
|
@@ -474,21 +644,92 @@ def generate_ir_sub_network_wrapper(
|
|
|
474
644
|
if (fanout_info := rec.state["fanout_nodes"].get(ir)) is not None:
|
|
475
645
|
count = fanout_info.num_consumers
|
|
476
646
|
manager = ChannelManager(rec.state["context"], count=count)
|
|
647
|
+
fanout_node: Any
|
|
477
648
|
if fanout_info.unbounded:
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
*[manager.reserve_input_slot() for _ in range(count)],
|
|
483
|
-
)
|
|
649
|
+
fanout_node = fanout_node_unbounded(
|
|
650
|
+
rec.state["context"],
|
|
651
|
+
channels[ir].reserve_output_slot(),
|
|
652
|
+
*[manager.reserve_input_slot() for _ in range(count)],
|
|
484
653
|
)
|
|
485
654
|
else: # "bounded"
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
*[manager.reserve_input_slot() for _ in range(count)],
|
|
491
|
-
)
|
|
655
|
+
fanout_node = fanout_node_bounded(
|
|
656
|
+
rec.state["context"],
|
|
657
|
+
channels[ir].reserve_output_slot(),
|
|
658
|
+
*[manager.reserve_input_slot() for _ in range(count)],
|
|
492
659
|
)
|
|
660
|
+
nodes[ir].append(fanout_node)
|
|
493
661
|
channels[ir] = manager
|
|
494
662
|
return nodes, channels
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
@define_py_node()
|
|
666
|
+
async def metadata_feeder_node(
|
|
667
|
+
context: Context,
|
|
668
|
+
channel: ChannelPair,
|
|
669
|
+
metadata: Metadata,
|
|
670
|
+
) -> None:
|
|
671
|
+
"""
|
|
672
|
+
Feed metadata to a channel pair.
|
|
673
|
+
|
|
674
|
+
Parameters
|
|
675
|
+
----------
|
|
676
|
+
context
|
|
677
|
+
The rapidsmpf context.
|
|
678
|
+
channel
|
|
679
|
+
The channel pair.
|
|
680
|
+
metadata
|
|
681
|
+
The metadata to feed.
|
|
682
|
+
"""
|
|
683
|
+
async with shutdown_on_error(context, channel.metadata, channel.data):
|
|
684
|
+
await channel.send_metadata(context, metadata)
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
@define_py_node()
|
|
688
|
+
async def metadata_drain_node(
|
|
689
|
+
context: Context,
|
|
690
|
+
ir: IR,
|
|
691
|
+
ir_context: IRExecutionContext,
|
|
692
|
+
ch_in: ChannelPair,
|
|
693
|
+
ch_out: Any,
|
|
694
|
+
metadata_collector: list[Metadata] | None,
|
|
695
|
+
) -> None:
|
|
696
|
+
"""
|
|
697
|
+
Drain metadata and forward data to a single channel.
|
|
698
|
+
|
|
699
|
+
Parameters
|
|
700
|
+
----------
|
|
701
|
+
context
|
|
702
|
+
The rapidsmpf context.
|
|
703
|
+
ir
|
|
704
|
+
The IR node.
|
|
705
|
+
ir_context
|
|
706
|
+
The execution context for the IR node.
|
|
707
|
+
ch_in
|
|
708
|
+
The input ChannelPair (with metadata and data channels).
|
|
709
|
+
ch_out
|
|
710
|
+
The output data channel.
|
|
711
|
+
metadata_collector
|
|
712
|
+
The list to collect the final metadata.
|
|
713
|
+
This list will be mutated when the network is executed.
|
|
714
|
+
If None, metadata will not be collected.
|
|
715
|
+
"""
|
|
716
|
+
async with shutdown_on_error(context, ch_in.metadata, ch_in.data, ch_out):
|
|
717
|
+
# Drain metadata channel (we don't need it after this point)
|
|
718
|
+
metadata = await ch_in.recv_metadata(context)
|
|
719
|
+
send_empty = metadata.duplicated and context.comm().rank != 0
|
|
720
|
+
if metadata_collector is not None:
|
|
721
|
+
metadata_collector.append(metadata)
|
|
722
|
+
|
|
723
|
+
# Forward non-duplicated data messages
|
|
724
|
+
while (msg := await ch_in.data.recv(context)) is not None:
|
|
725
|
+
if not send_empty:
|
|
726
|
+
await ch_out.send(context, msg)
|
|
727
|
+
|
|
728
|
+
# Send empty data if needed
|
|
729
|
+
if send_empty:
|
|
730
|
+
stream = ir_context.get_cuda_stream()
|
|
731
|
+
await ch_out.send(
|
|
732
|
+
context, Message(0, empty_table_chunk(ir, context, stream))
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
await ch_out.drain(context)
|