cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +60 -15
- cudf_polars/containers/column.py +137 -77
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +256 -114
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +33 -3
- cudf_polars/dsl/expressions/unary.py +126 -64
- cudf_polars/dsl/ir.py +1053 -350
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +307 -107
- cudf_polars/dsl/utils/aggregations.py +43 -30
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +55 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +792 -2
- cudf_polars/experimental/benchmarks/utils.py +596 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +30 -15
- cudf_polars/experimental/groupby.py +25 -4
- cudf_polars/experimental/io.py +156 -124
- cudf_polars/experimental/join.py +53 -23
- cudf_polars/experimental/parallel.py +68 -19
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
- cudf_polars/experimental/rapidsmpf/core.py +488 -0
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
- cudf_polars/experimental/rapidsmpf/io.py +696 -0
- cudf_polars/experimental/rapidsmpf/join.py +322 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
- cudf_polars/experimental/rapidsmpf/union.py +115 -0
- cudf_polars/experimental/rapidsmpf/utils.py +374 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +46 -12
- cudf_polars/experimental/sort.py +100 -26
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +93 -17
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +473 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +5 -4
- cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
- cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Re-chunking logic for the RapidsMPF streaming runtime."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from rapidsmpf.memory.buffer import MemoryType
|
|
11
|
+
from rapidsmpf.streaming.core.message import Message
|
|
12
|
+
from rapidsmpf.streaming.core.node import define_py_node
|
|
13
|
+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
|
|
14
|
+
|
|
15
|
+
from cudf_polars.containers import DataFrame
|
|
16
|
+
from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager
|
|
17
|
+
from cudf_polars.experimental.rapidsmpf.dispatch import generate_ir_sub_network
|
|
18
|
+
from cudf_polars.experimental.rapidsmpf.nodes import shutdown_on_error
|
|
19
|
+
from cudf_polars.experimental.rapidsmpf.utils import (
|
|
20
|
+
ChannelManager,
|
|
21
|
+
Metadata,
|
|
22
|
+
empty_table_chunk,
|
|
23
|
+
opaque_reservation,
|
|
24
|
+
)
|
|
25
|
+
from cudf_polars.experimental.repartition import Repartition
|
|
26
|
+
from cudf_polars.experimental.utils import _concat
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from rapidsmpf.streaming.core.context import Context
|
|
30
|
+
|
|
31
|
+
from cudf_polars.dsl.ir import IR, IRExecutionContext
|
|
32
|
+
from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator
|
|
33
|
+
from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@define_py_node()
|
|
37
|
+
async def concatenate_node(
|
|
38
|
+
context: Context,
|
|
39
|
+
ir: Repartition,
|
|
40
|
+
ir_context: IRExecutionContext,
|
|
41
|
+
ch_out: ChannelPair,
|
|
42
|
+
ch_in: ChannelPair,
|
|
43
|
+
*,
|
|
44
|
+
output_count: int,
|
|
45
|
+
collective_id: int,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""
|
|
48
|
+
Concatenate node for rapidsmpf.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
context
|
|
53
|
+
The rapidsmpf context.
|
|
54
|
+
ir
|
|
55
|
+
The Repartition IR node.
|
|
56
|
+
ir_context
|
|
57
|
+
The execution context for the IR node.
|
|
58
|
+
ch_out
|
|
59
|
+
The output ChannelPair.
|
|
60
|
+
ch_in
|
|
61
|
+
The input ChannelPair.
|
|
62
|
+
output_count
|
|
63
|
+
The expected number of output chunks.
|
|
64
|
+
collective_id
|
|
65
|
+
Pre-allocated collective ID for this operation.
|
|
66
|
+
"""
|
|
67
|
+
async with shutdown_on_error(
|
|
68
|
+
context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
|
|
69
|
+
):
|
|
70
|
+
# Receive metadata.
|
|
71
|
+
input_metadata = await ch_in.recv_metadata(context)
|
|
72
|
+
metadata = Metadata(output_count)
|
|
73
|
+
|
|
74
|
+
# max_chunks corresponds to the number of chunks we can
|
|
75
|
+
# concatenate together. If None, we must concatenate everything.
|
|
76
|
+
# Since a single-partition operation gets "special treatment",
|
|
77
|
+
# we must make sure `output_count == 1` is always satisfied.
|
|
78
|
+
max_chunks: int | None = None
|
|
79
|
+
if output_count > 1:
|
|
80
|
+
# Make sure max_chunks is at least 2.
|
|
81
|
+
max_chunks = max(2, math.ceil(input_metadata.count / output_count))
|
|
82
|
+
|
|
83
|
+
# Check if we need global communication.
|
|
84
|
+
need_global_repartition = (
|
|
85
|
+
# Avoid allgather of already-duplicated data
|
|
86
|
+
context.comm().nranks > 1
|
|
87
|
+
and not input_metadata.duplicated
|
|
88
|
+
and output_count == 1
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
chunks: list[TableChunk]
|
|
92
|
+
msg: TableChunk | None
|
|
93
|
+
if need_global_repartition:
|
|
94
|
+
# Assume this means "global repartitioning" for now
|
|
95
|
+
|
|
96
|
+
# Send metadata.
|
|
97
|
+
metadata.duplicated = True
|
|
98
|
+
await ch_out.send_metadata(context, metadata)
|
|
99
|
+
|
|
100
|
+
allgather = AllGatherManager(context, collective_id)
|
|
101
|
+
stream = context.get_stream_from_pool()
|
|
102
|
+
seq_num = 0
|
|
103
|
+
while (msg := await ch_in.data.recv(context)) is not None:
|
|
104
|
+
allgather.insert(seq_num, TableChunk.from_message(msg))
|
|
105
|
+
seq_num += 1
|
|
106
|
+
del msg
|
|
107
|
+
allgather.insert_finished()
|
|
108
|
+
|
|
109
|
+
# Extract concatenated result
|
|
110
|
+
result_table = await allgather.extract_concatenated(stream)
|
|
111
|
+
|
|
112
|
+
# If no chunks were gathered, result_table has 0 columns.
|
|
113
|
+
# We need to create an empty table with the correct schema.
|
|
114
|
+
if result_table.num_columns() == 0 and len(ir.schema) > 0:
|
|
115
|
+
output_chunk = empty_table_chunk(ir, context, stream)
|
|
116
|
+
else:
|
|
117
|
+
output_chunk = TableChunk.from_pylibcudf_table(
|
|
118
|
+
result_table, stream, exclusive_view=True
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
await ch_out.data.send(context, Message(0, output_chunk))
|
|
122
|
+
else:
|
|
123
|
+
# Send metadata.
|
|
124
|
+
metadata.duplicated = input_metadata.duplicated
|
|
125
|
+
await ch_out.send_metadata(context, metadata)
|
|
126
|
+
|
|
127
|
+
# Local repartitioning
|
|
128
|
+
seq_num = 0
|
|
129
|
+
while True:
|
|
130
|
+
chunks = []
|
|
131
|
+
done_receiving = False
|
|
132
|
+
|
|
133
|
+
# Collect chunks up to max_chunks or until end of stream
|
|
134
|
+
while len(chunks) < (max_chunks or float("inf")):
|
|
135
|
+
msg = await ch_in.data.recv(context)
|
|
136
|
+
if msg is None:
|
|
137
|
+
done_receiving = True
|
|
138
|
+
break
|
|
139
|
+
chunks.append(
|
|
140
|
+
TableChunk.from_message(msg).make_available_and_spill(
|
|
141
|
+
context.br(), allow_overbooking=True
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
del msg
|
|
145
|
+
|
|
146
|
+
if chunks:
|
|
147
|
+
input_bytes = sum(
|
|
148
|
+
chunk.data_alloc_size(MemoryType.DEVICE) for chunk in chunks
|
|
149
|
+
)
|
|
150
|
+
with opaque_reservation(context, input_bytes):
|
|
151
|
+
df = _concat(
|
|
152
|
+
*(
|
|
153
|
+
DataFrame.from_table(
|
|
154
|
+
chunk.table_view(),
|
|
155
|
+
list(ir.schema.keys()),
|
|
156
|
+
list(ir.schema.values()),
|
|
157
|
+
chunk.stream,
|
|
158
|
+
)
|
|
159
|
+
for chunk in chunks
|
|
160
|
+
),
|
|
161
|
+
context=ir_context,
|
|
162
|
+
)
|
|
163
|
+
await ch_out.data.send(
|
|
164
|
+
context,
|
|
165
|
+
Message(
|
|
166
|
+
seq_num,
|
|
167
|
+
TableChunk.from_pylibcudf_table(
|
|
168
|
+
df.table, df.stream, exclusive_view=True
|
|
169
|
+
),
|
|
170
|
+
),
|
|
171
|
+
)
|
|
172
|
+
seq_num += 1
|
|
173
|
+
del df, chunks
|
|
174
|
+
|
|
175
|
+
# Break if we reached end of stream
|
|
176
|
+
if done_receiving:
|
|
177
|
+
break
|
|
178
|
+
|
|
179
|
+
await ch_out.data.drain(context)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@generate_ir_sub_network.register(Repartition)
|
|
183
|
+
def _(
|
|
184
|
+
ir: Repartition, rec: SubNetGenerator
|
|
185
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
186
|
+
# Repartition node.
|
|
187
|
+
|
|
188
|
+
partition_info = rec.state["partition_info"]
|
|
189
|
+
if partition_info[ir].count > 1:
|
|
190
|
+
count_output = partition_info[ir].count
|
|
191
|
+
count_input = partition_info[ir.children[0]].count
|
|
192
|
+
if count_input < count_output:
|
|
193
|
+
raise ValueError("Repartitioning to more chunks is not supported.")
|
|
194
|
+
|
|
195
|
+
# Process children
|
|
196
|
+
nodes, channels = rec(ir.children[0])
|
|
197
|
+
|
|
198
|
+
# Create output ChannelManager
|
|
199
|
+
channels[ir] = ChannelManager(rec.state["context"])
|
|
200
|
+
|
|
201
|
+
# Look up the reserved shuffle ID for this operation
|
|
202
|
+
collective_id = rec.state["collective_id_map"][ir]
|
|
203
|
+
|
|
204
|
+
# Add python node
|
|
205
|
+
nodes[ir] = [
|
|
206
|
+
concatenate_node(
|
|
207
|
+
rec.state["context"],
|
|
208
|
+
ir,
|
|
209
|
+
rec.state["ir_context"],
|
|
210
|
+
channels[ir].reserve_input_slot(),
|
|
211
|
+
channels[ir.children[0]].reserve_output_slot(),
|
|
212
|
+
output_count=partition_info[ir].count,
|
|
213
|
+
collective_id=collective_id,
|
|
214
|
+
)
|
|
215
|
+
]
|
|
216
|
+
return nodes, channels
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Union logic for the RapidsMPF streaming runtime."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
from rapidsmpf.streaming.core.message import Message
|
|
10
|
+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
|
|
11
|
+
|
|
12
|
+
from cudf_polars.dsl.ir import Union
|
|
13
|
+
from cudf_polars.experimental.rapidsmpf.dispatch import (
|
|
14
|
+
generate_ir_sub_network,
|
|
15
|
+
)
|
|
16
|
+
from cudf_polars.experimental.rapidsmpf.nodes import define_py_node, shutdown_on_error
|
|
17
|
+
from cudf_polars.experimental.rapidsmpf.utils import (
|
|
18
|
+
ChannelManager,
|
|
19
|
+
Metadata,
|
|
20
|
+
process_children,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from rapidsmpf.streaming.core.context import Context
|
|
25
|
+
|
|
26
|
+
from cudf_polars.dsl.ir import IR, IRExecutionContext
|
|
27
|
+
from cudf_polars.experimental.rapidsmpf.core import SubNetGenerator
|
|
28
|
+
from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@define_py_node()
|
|
32
|
+
async def union_node(
|
|
33
|
+
context: Context,
|
|
34
|
+
ir: Union,
|
|
35
|
+
ir_context: IRExecutionContext,
|
|
36
|
+
ch_out: ChannelPair,
|
|
37
|
+
*chs_in: ChannelPair,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Union node for rapidsmpf.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
context
|
|
45
|
+
The rapidsmpf context.
|
|
46
|
+
ir
|
|
47
|
+
The Union IR node.
|
|
48
|
+
ir_context
|
|
49
|
+
The execution context for the IR node.
|
|
50
|
+
ch_out
|
|
51
|
+
The output ChannelPair.
|
|
52
|
+
chs_in
|
|
53
|
+
The input ChannelPairs.
|
|
54
|
+
"""
|
|
55
|
+
async with shutdown_on_error(
|
|
56
|
+
context,
|
|
57
|
+
*[ch.metadata for ch in chs_in],
|
|
58
|
+
*[ch.data for ch in chs_in],
|
|
59
|
+
ch_out.metadata,
|
|
60
|
+
ch_out.data,
|
|
61
|
+
):
|
|
62
|
+
# Merge and forward metadata.
|
|
63
|
+
total_count = 0
|
|
64
|
+
duplicated = True
|
|
65
|
+
for ch_in in chs_in:
|
|
66
|
+
metadata = await ch_in.recv_metadata(context)
|
|
67
|
+
total_count += metadata.count
|
|
68
|
+
duplicated = duplicated and metadata.duplicated
|
|
69
|
+
await ch_out.send_metadata(
|
|
70
|
+
context, Metadata(total_count, duplicated=duplicated)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
seq_num_offset = 0
|
|
74
|
+
for ch_in in chs_in:
|
|
75
|
+
num_ch_chunks = 0
|
|
76
|
+
while (msg := await ch_in.data.recv(context)) is not None:
|
|
77
|
+
num_ch_chunks += 1
|
|
78
|
+
await ch_out.data.send(
|
|
79
|
+
context,
|
|
80
|
+
Message(
|
|
81
|
+
msg.sequence_number + seq_num_offset,
|
|
82
|
+
TableChunk.from_message(msg).make_available_and_spill(
|
|
83
|
+
context.br(), allow_overbooking=True
|
|
84
|
+
),
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
seq_num_offset += num_ch_chunks
|
|
88
|
+
|
|
89
|
+
await ch_out.data.drain(context)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@generate_ir_sub_network.register(Union)
|
|
93
|
+
def _(
|
|
94
|
+
ir: Union, rec: SubNetGenerator
|
|
95
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
96
|
+
# Union operation.
|
|
97
|
+
# Pass-through all child chunks in channel order.
|
|
98
|
+
|
|
99
|
+
# Process children
|
|
100
|
+
nodes, channels = process_children(ir, rec)
|
|
101
|
+
|
|
102
|
+
# Create output ChannelManager
|
|
103
|
+
channels[ir] = ChannelManager(rec.state["context"])
|
|
104
|
+
|
|
105
|
+
# Add simple python node
|
|
106
|
+
nodes[ir] = [
|
|
107
|
+
union_node(
|
|
108
|
+
rec.state["context"],
|
|
109
|
+
ir,
|
|
110
|
+
rec.state["ir_context"],
|
|
111
|
+
channels[ir].reserve_input_slot(),
|
|
112
|
+
*[channels[c].reserve_output_slot() for c in ir.children],
|
|
113
|
+
)
|
|
114
|
+
]
|
|
115
|
+
return nodes, channels
|
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Utility functions and classes for the RapidsMPF streaming runtime."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import operator
|
|
9
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from functools import reduce
|
|
12
|
+
from typing import TYPE_CHECKING, Any
|
|
13
|
+
|
|
14
|
+
from rapidsmpf.streaming.chunks.arbitrary import ArbitraryChunk
|
|
15
|
+
from rapidsmpf.streaming.core.message import Message
|
|
16
|
+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
|
|
17
|
+
|
|
18
|
+
import pylibcudf as plc
|
|
19
|
+
|
|
20
|
+
from cudf_polars.containers import DataFrame
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from collections.abc import AsyncIterator, Callable, Iterator
|
|
24
|
+
|
|
25
|
+
from rapidsmpf.memory.memory_reservation import MemoryReservation
|
|
26
|
+
from rapidsmpf.streaming.core.channel import Channel
|
|
27
|
+
from rapidsmpf.streaming.core.context import Context
|
|
28
|
+
from rapidsmpf.streaming.core.spillable_messages import SpillableMessages
|
|
29
|
+
|
|
30
|
+
from rmm.pylibrmm.stream import Stream
|
|
31
|
+
|
|
32
|
+
from cudf_polars.dsl.ir import IR
|
|
33
|
+
from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@asynccontextmanager
|
|
37
|
+
async def shutdown_on_error(
|
|
38
|
+
context: Context, *channels: Channel[Any]
|
|
39
|
+
) -> AsyncIterator[None]:
|
|
40
|
+
"""
|
|
41
|
+
Shutdown on error for rapidsmpf.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
context
|
|
46
|
+
The rapidsmpf context.
|
|
47
|
+
channels
|
|
48
|
+
The channels to shutdown.
|
|
49
|
+
"""
|
|
50
|
+
# TODO: This probably belongs in rapidsmpf.
|
|
51
|
+
try:
|
|
52
|
+
yield
|
|
53
|
+
except BaseException:
|
|
54
|
+
await asyncio.gather(*(ch.shutdown(context) for ch in channels))
|
|
55
|
+
raise
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Metadata:
|
|
59
|
+
"""Metadata payload for an individual ChannelPair."""
|
|
60
|
+
|
|
61
|
+
__slots__ = ("count", "duplicated", "partitioned_on")
|
|
62
|
+
count: int
|
|
63
|
+
"""Chunk-count estimate."""
|
|
64
|
+
partitioned_on: tuple[str, ...]
|
|
65
|
+
"""Partitioned-on columns."""
|
|
66
|
+
duplicated: bool
|
|
67
|
+
"""Whether the data is duplicated on all workers."""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
count: int,
|
|
72
|
+
*,
|
|
73
|
+
partitioned_on: tuple[str, ...] = (),
|
|
74
|
+
duplicated: bool = False,
|
|
75
|
+
):
|
|
76
|
+
self.count = count
|
|
77
|
+
self.partitioned_on = partitioned_on
|
|
78
|
+
self.duplicated = duplicated
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class ChannelPair:
|
|
83
|
+
"""
|
|
84
|
+
A pair of channels for metadata and table data.
|
|
85
|
+
|
|
86
|
+
This abstraction ensures that metadata and data are kept separate,
|
|
87
|
+
avoiding ordering issues and making the code more type-safe.
|
|
88
|
+
|
|
89
|
+
Attributes
|
|
90
|
+
----------
|
|
91
|
+
metadata :
|
|
92
|
+
Channel for metadata.
|
|
93
|
+
data :
|
|
94
|
+
Channel for table data chunks.
|
|
95
|
+
|
|
96
|
+
Notes
|
|
97
|
+
-----
|
|
98
|
+
This is a placeholder implementation. The metadata channel exists
|
|
99
|
+
but is not used yet. Metadata handling will be fully implemented
|
|
100
|
+
in follow-up work.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
metadata: Channel[ArbitraryChunk]
|
|
104
|
+
data: Channel[TableChunk]
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def create(cls, context: Context) -> ChannelPair:
|
|
108
|
+
"""Create a new ChannelPair with fresh channels."""
|
|
109
|
+
return cls(
|
|
110
|
+
metadata=context.create_channel(),
|
|
111
|
+
data=context.create_channel(),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
async def send_metadata(self, ctx: Context, metadata: Metadata) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Send metadata and drain the metadata channel.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
ctx :
|
|
121
|
+
The streaming context.
|
|
122
|
+
metadata :
|
|
123
|
+
The metadata to send.
|
|
124
|
+
"""
|
|
125
|
+
msg = Message(0, ArbitraryChunk(metadata))
|
|
126
|
+
await self.metadata.send(ctx, msg)
|
|
127
|
+
await self.metadata.drain(ctx)
|
|
128
|
+
|
|
129
|
+
async def recv_metadata(self, ctx: Context) -> Metadata:
|
|
130
|
+
"""
|
|
131
|
+
Receive metadata from the metadata channel.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
ctx :
|
|
136
|
+
The streaming context.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
ChunkMetadata
|
|
141
|
+
The metadata, or None if channel is drained.
|
|
142
|
+
"""
|
|
143
|
+
msg = await self.metadata.recv(ctx)
|
|
144
|
+
assert msg is not None, f"Expected Metadata message, got {msg}."
|
|
145
|
+
return ArbitraryChunk.from_message(msg).release()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class ChannelManager:
|
|
149
|
+
"""A utility class for managing ChannelPair objects."""
|
|
150
|
+
|
|
151
|
+
def __init__(self, context: Context, *, count: int = 1):
|
|
152
|
+
"""
|
|
153
|
+
Initialize the ChannelManager with a given number of ChannelPair slots.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
context
|
|
158
|
+
The rapidsmpf context.
|
|
159
|
+
count: int
|
|
160
|
+
The number of ChannelPair slots to allocate.
|
|
161
|
+
"""
|
|
162
|
+
self._channel_slots = [ChannelPair.create(context) for _ in range(count)]
|
|
163
|
+
self._reserved_output_slots: int = 0
|
|
164
|
+
self._reserved_input_slots: int = 0
|
|
165
|
+
|
|
166
|
+
def reserve_input_slot(self) -> ChannelPair:
|
|
167
|
+
"""
|
|
168
|
+
Reserve an input channel-pair slot.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
The reserved ChannelPair.
|
|
173
|
+
"""
|
|
174
|
+
if self._reserved_input_slots >= len(self._channel_slots):
|
|
175
|
+
raise ValueError("No more input channel-pair slots available")
|
|
176
|
+
pair = self._channel_slots[self._reserved_input_slots]
|
|
177
|
+
self._reserved_input_slots += 1
|
|
178
|
+
return pair
|
|
179
|
+
|
|
180
|
+
def reserve_output_slot(self) -> ChannelPair:
|
|
181
|
+
"""
|
|
182
|
+
Reserve an output channel-pair slot.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
The reserved ChannelPair.
|
|
187
|
+
"""
|
|
188
|
+
if self._reserved_output_slots >= len(self._channel_slots):
|
|
189
|
+
raise ValueError("No more output channel-pair slots available")
|
|
190
|
+
pair = self._channel_slots[self._reserved_output_slots]
|
|
191
|
+
self._reserved_output_slots += 1
|
|
192
|
+
return pair
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def process_children(
|
|
196
|
+
ir: IR, rec: SubNetGenerator
|
|
197
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
198
|
+
"""
|
|
199
|
+
Process children IR nodes and aggregate their nodes and channels.
|
|
200
|
+
|
|
201
|
+
This helper function recursively processes all children of an IR node,
|
|
202
|
+
collects their streaming network nodes into a dictionary mapping IR nodes
|
|
203
|
+
to their associated nodes, and merges their channel dictionaries.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
ir
|
|
208
|
+
The IR node whose children should be processed.
|
|
209
|
+
rec
|
|
210
|
+
Recursive SubNetGenerator callable.
|
|
211
|
+
|
|
212
|
+
Returns
|
|
213
|
+
-------
|
|
214
|
+
nodes
|
|
215
|
+
Dictionary mapping each IR node to its list of streaming network nodes.
|
|
216
|
+
channels
|
|
217
|
+
Dictionary mapping each child IR node to its ChannelManager.
|
|
218
|
+
"""
|
|
219
|
+
if not ir.children:
|
|
220
|
+
return {}, {}
|
|
221
|
+
|
|
222
|
+
_nodes_list, _channels_list = zip(*(rec(c) for c in ir.children), strict=True)
|
|
223
|
+
nodes: dict[IR, list[Any]] = reduce(operator.or_, _nodes_list)
|
|
224
|
+
channels: dict[IR, ChannelManager] = reduce(operator.or_, _channels_list)
|
|
225
|
+
return nodes, channels
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def empty_table_chunk(ir: IR, context: Context, stream: Stream) -> TableChunk:
|
|
229
|
+
"""
|
|
230
|
+
Make an empty table chunk.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
ir
|
|
235
|
+
The IR node to use for the schema.
|
|
236
|
+
context
|
|
237
|
+
The rapidsmpf context.
|
|
238
|
+
stream
|
|
239
|
+
The stream to use for the table chunk.
|
|
240
|
+
|
|
241
|
+
Returns
|
|
242
|
+
-------
|
|
243
|
+
The empty table chunk.
|
|
244
|
+
"""
|
|
245
|
+
# Create an empty table with the correct schema
|
|
246
|
+
# Use dtype.plc_type to get the full DataType (preserves precision/scale for Decimals)
|
|
247
|
+
empty_columns = [
|
|
248
|
+
plc.column_factories.make_empty_column(dtype.plc_type, stream=stream)
|
|
249
|
+
for dtype in ir.schema.values()
|
|
250
|
+
]
|
|
251
|
+
empty_table = plc.Table(empty_columns)
|
|
252
|
+
|
|
253
|
+
return TableChunk.from_pylibcudf_table(
|
|
254
|
+
empty_table,
|
|
255
|
+
stream,
|
|
256
|
+
exclusive_view=True,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def chunk_to_frame(chunk: TableChunk, ir: IR) -> DataFrame:
|
|
261
|
+
"""
|
|
262
|
+
Convert a TableChunk to a DataFrame.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
chunk
|
|
267
|
+
The TableChunk to convert.
|
|
268
|
+
ir
|
|
269
|
+
The IR node to use for the schema.
|
|
270
|
+
|
|
271
|
+
Returns
|
|
272
|
+
-------
|
|
273
|
+
A DataFrame.
|
|
274
|
+
"""
|
|
275
|
+
return DataFrame.from_table(
|
|
276
|
+
chunk.table_view(),
|
|
277
|
+
list(ir.schema.keys()),
|
|
278
|
+
list(ir.schema.values()),
|
|
279
|
+
chunk.stream,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def make_spill_function(
|
|
284
|
+
spillable_messages_list: list[SpillableMessages],
|
|
285
|
+
context: Context,
|
|
286
|
+
) -> Callable[[int], int]:
|
|
287
|
+
"""
|
|
288
|
+
Create a spill function for a list of SpillableMessages containers.
|
|
289
|
+
|
|
290
|
+
This utility creates a spill function that can be registered with a
|
|
291
|
+
SpillManager. The spill function uses a smart spilling strategy that
|
|
292
|
+
prioritizes:
|
|
293
|
+
1. Longest queues first (slow consumers that won't need data soon)
|
|
294
|
+
2. Newest messages first (just arrived, won't be consumed soon)
|
|
295
|
+
|
|
296
|
+
This strategy keeps "hot" data (about to be consumed) in fast memory
|
|
297
|
+
while spilling "cold" data (won't be needed for a while) to slower tiers.
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
spillable_messages_list
|
|
302
|
+
List of SpillableMessages containers to create a spill function for.
|
|
303
|
+
context
|
|
304
|
+
The RapidsMPF context to use for accessing the BufferResource.
|
|
305
|
+
|
|
306
|
+
Returns
|
|
307
|
+
-------
|
|
308
|
+
A spill function that takes an amount (in bytes) and returns the
|
|
309
|
+
actual amount spilled (in bytes).
|
|
310
|
+
|
|
311
|
+
Notes
|
|
312
|
+
-----
|
|
313
|
+
The spilling strategy is particularly effective for fanout scenarios
|
|
314
|
+
where different consumers may process messages at different rates. By
|
|
315
|
+
prioritizing longest queues and newest messages, we maximize the time
|
|
316
|
+
data can remain in slower memory before it's needed.
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def spill_func(amount: int) -> int:
|
|
320
|
+
"""Spill messages from the buffers to free device/host memory."""
|
|
321
|
+
spilled = 0
|
|
322
|
+
|
|
323
|
+
# Collect all messages with metadata for smart spilling
|
|
324
|
+
# Format: (message_id, container_idx, queue_length, sm)
|
|
325
|
+
all_messages: list[tuple[int, int, int, SpillableMessages]] = []
|
|
326
|
+
for container_idx, sm in enumerate(spillable_messages_list):
|
|
327
|
+
content_descriptions = sm.get_content_descriptions()
|
|
328
|
+
queue_length = len(content_descriptions)
|
|
329
|
+
all_messages.extend(
|
|
330
|
+
(message_id, container_idx, queue_length, sm)
|
|
331
|
+
for message_id in content_descriptions
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Spill newest messages first from the longest queues
|
|
335
|
+
# Sort by: (1) queue length descending, (2) message_id descending
|
|
336
|
+
# This prioritizes:
|
|
337
|
+
# - Longest queues (slow consumers that won't need data soon)
|
|
338
|
+
# - Newest messages (just arrived, won't be consumed soon)
|
|
339
|
+
all_messages.sort(key=lambda x: (-x[2], -x[0]))
|
|
340
|
+
|
|
341
|
+
# Spill messages until we've freed enough memory
|
|
342
|
+
for message_id, _, _, sm in all_messages:
|
|
343
|
+
if spilled >= amount:
|
|
344
|
+
break
|
|
345
|
+
# Try to spill this message
|
|
346
|
+
spilled += sm.spill(mid=message_id, br=context.br())
|
|
347
|
+
|
|
348
|
+
return spilled
|
|
349
|
+
|
|
350
|
+
return spill_func
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@contextmanager
|
|
354
|
+
def opaque_reservation(
|
|
355
|
+
context: Context,
|
|
356
|
+
estimated_bytes: int,
|
|
357
|
+
) -> Iterator[MemoryReservation]:
|
|
358
|
+
"""
|
|
359
|
+
Reserve memory for opaque allocations.
|
|
360
|
+
|
|
361
|
+
Parameters
|
|
362
|
+
----------
|
|
363
|
+
context
|
|
364
|
+
The RapidsMPF context.
|
|
365
|
+
estimated_bytes
|
|
366
|
+
The estimated number of bytes to reserve.
|
|
367
|
+
|
|
368
|
+
Yields
|
|
369
|
+
------
|
|
370
|
+
The memory reservation.
|
|
371
|
+
"""
|
|
372
|
+
yield context.br().reserve_device_memory_and_spill(
|
|
373
|
+
estimated_bytes, allow_overbooking=True
|
|
374
|
+
)
|