cudf-polars-cu13 25.12.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +28 -7
- cudf_polars/containers/column.py +51 -26
- cudf_polars/dsl/expressions/binaryop.py +1 -1
- cudf_polars/dsl/expressions/boolean.py +1 -1
- cudf_polars/dsl/expressions/selection.py +1 -1
- cudf_polars/dsl/expressions/string.py +29 -20
- cudf_polars/dsl/expressions/ternary.py +25 -1
- cudf_polars/dsl/expressions/unary.py +11 -8
- cudf_polars/dsl/ir.py +351 -281
- cudf_polars/dsl/translate.py +18 -15
- cudf_polars/dsl/utils/aggregations.py +10 -5
- cudf_polars/experimental/base.py +10 -0
- cudf_polars/experimental/benchmarks/pdsh.py +1 -1
- cudf_polars/experimental/benchmarks/utils.py +83 -2
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +1 -1
- cudf_polars/experimental/expressions.py +8 -5
- cudf_polars/experimental/groupby.py +2 -0
- cudf_polars/experimental/io.py +64 -42
- cudf_polars/experimental/join.py +15 -2
- cudf_polars/experimental/parallel.py +10 -7
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/{shuffle.py → collectives/shuffle.py} +90 -114
- cudf_polars/experimental/rapidsmpf/core.py +194 -67
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +6 -3
- cudf_polars/experimental/rapidsmpf/io.py +162 -70
- cudf_polars/experimental/rapidsmpf/join.py +162 -77
- cudf_polars/experimental/rapidsmpf/nodes.py +421 -180
- cudf_polars/experimental/rapidsmpf/repartition.py +130 -65
- cudf_polars/experimental/rapidsmpf/union.py +24 -5
- cudf_polars/experimental/rapidsmpf/utils.py +228 -16
- cudf_polars/experimental/shuffle.py +18 -4
- cudf_polars/experimental/sort.py +13 -6
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/testing/plugin.py +6 -3
- cudf_polars/utils/config.py +67 -0
- cudf_polars/utils/versions.py +3 -3
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/METADATA +9 -10
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/RECORD +47 -43
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Re-chunking logic for the RapidsMPF streaming runtime."""
|
|
4
4
|
|
|
@@ -7,14 +7,21 @@ from __future__ import annotations
|
|
|
7
7
|
import math
|
|
8
8
|
from typing import TYPE_CHECKING, Any
|
|
9
9
|
|
|
10
|
+
from rapidsmpf.memory.buffer import MemoryType
|
|
10
11
|
from rapidsmpf.streaming.core.message import Message
|
|
11
12
|
from rapidsmpf.streaming.core.node import define_py_node
|
|
12
13
|
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
|
|
13
14
|
|
|
14
15
|
from cudf_polars.containers import DataFrame
|
|
16
|
+
from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager
|
|
15
17
|
from cudf_polars.experimental.rapidsmpf.dispatch import generate_ir_sub_network
|
|
16
18
|
from cudf_polars.experimental.rapidsmpf.nodes import shutdown_on_error
|
|
17
|
-
from cudf_polars.experimental.rapidsmpf.utils import
|
|
19
|
+
from cudf_polars.experimental.rapidsmpf.utils import (
|
|
20
|
+
ChannelManager,
|
|
21
|
+
Metadata,
|
|
22
|
+
empty_table_chunk,
|
|
23
|
+
opaque_reservation,
|
|
24
|
+
)
|
|
18
25
|
from cudf_polars.experimental.repartition import Repartition
|
|
19
26
|
from cudf_polars.experimental.utils import _concat
|
|
20
27
|
|
|
@@ -34,7 +41,8 @@ async def concatenate_node(
|
|
|
34
41
|
ch_out: ChannelPair,
|
|
35
42
|
ch_in: ChannelPair,
|
|
36
43
|
*,
|
|
37
|
-
|
|
44
|
+
output_count: int,
|
|
45
|
+
collective_id: int,
|
|
38
46
|
) -> None:
|
|
39
47
|
"""
|
|
40
48
|
Concatenate node for rapidsmpf.
|
|
@@ -51,66 +59,122 @@ async def concatenate_node(
|
|
|
51
59
|
The output ChannelPair.
|
|
52
60
|
ch_in
|
|
53
61
|
The input ChannelPair.
|
|
54
|
-
|
|
55
|
-
The
|
|
56
|
-
|
|
62
|
+
output_count
|
|
63
|
+
The expected number of output chunks.
|
|
64
|
+
collective_id
|
|
65
|
+
Pre-allocated collective ID for this operation.
|
|
57
66
|
"""
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
67
|
+
async with shutdown_on_error(
|
|
68
|
+
context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
|
|
69
|
+
):
|
|
70
|
+
# Receive metadata.
|
|
71
|
+
input_metadata = await ch_in.recv_metadata(context)
|
|
72
|
+
metadata = Metadata(output_count)
|
|
73
|
+
|
|
74
|
+
# max_chunks corresponds to the number of chunks we can
|
|
75
|
+
# concatenate together. If None, we must concatenate everything.
|
|
76
|
+
# Since a single-partition operation gets "special treatment",
|
|
77
|
+
# we must make sure `output_count == 1` is always satisfied.
|
|
78
|
+
max_chunks: int | None = None
|
|
79
|
+
if output_count > 1:
|
|
80
|
+
# Make sure max_chunks is at least 2.
|
|
81
|
+
max_chunks = max(2, math.ceil(input_metadata.count / output_count))
|
|
82
|
+
|
|
83
|
+
# Check if we need global communication.
|
|
84
|
+
need_global_repartition = (
|
|
85
|
+
# Avoid allgather of already-duplicated data
|
|
86
|
+
context.comm().nranks > 1
|
|
87
|
+
and not input_metadata.duplicated
|
|
88
|
+
and output_count == 1
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
chunks: list[TableChunk]
|
|
92
|
+
msg: TableChunk | None
|
|
93
|
+
if need_global_repartition:
|
|
94
|
+
# Assume this means "global repartitioning" for now
|
|
95
|
+
|
|
96
|
+
# Send metadata.
|
|
97
|
+
metadata.duplicated = True
|
|
98
|
+
await ch_out.send_metadata(context, metadata)
|
|
99
|
+
|
|
100
|
+
allgather = AllGatherManager(context, collective_id)
|
|
101
|
+
stream = context.get_stream_from_pool()
|
|
102
|
+
seq_num = 0
|
|
103
|
+
while (msg := await ch_in.data.recv(context)) is not None:
|
|
104
|
+
allgather.insert(seq_num, TableChunk.from_message(msg))
|
|
105
|
+
seq_num += 1
|
|
106
|
+
del msg
|
|
107
|
+
allgather.insert_finished()
|
|
108
|
+
|
|
109
|
+
# Extract concatenated result
|
|
110
|
+
result_table = await allgather.extract_concatenated(stream)
|
|
111
|
+
|
|
112
|
+
# If no chunks were gathered, result_table has 0 columns.
|
|
113
|
+
# We need to create an empty table with the correct schema.
|
|
114
|
+
if result_table.num_columns() == 0 and len(ir.schema) > 0:
|
|
115
|
+
output_chunk = empty_table_chunk(ir, context, stream)
|
|
116
|
+
else:
|
|
117
|
+
output_chunk = TableChunk.from_pylibcudf_table(
|
|
118
|
+
result_table, stream, exclusive_view=True
|
|
75
119
|
)
|
|
76
120
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
121
|
+
await ch_out.data.send(context, Message(0, output_chunk))
|
|
122
|
+
else:
|
|
123
|
+
# Send metadata.
|
|
124
|
+
metadata.duplicated = input_metadata.duplicated
|
|
125
|
+
await ch_out.send_metadata(context, metadata)
|
|
126
|
+
|
|
127
|
+
# Local repartitioning
|
|
128
|
+
seq_num = 0
|
|
129
|
+
while True:
|
|
130
|
+
chunks = []
|
|
131
|
+
done_receiving = False
|
|
132
|
+
|
|
133
|
+
# Collect chunks up to max_chunks or until end of stream
|
|
134
|
+
while len(chunks) < (max_chunks or float("inf")):
|
|
135
|
+
msg = await ch_in.data.recv(context)
|
|
136
|
+
if msg is None:
|
|
137
|
+
done_receiving = True
|
|
138
|
+
break
|
|
139
|
+
chunks.append(
|
|
140
|
+
TableChunk.from_message(msg).make_available_and_spill(
|
|
141
|
+
context.br(), allow_overbooking=True
|
|
142
|
+
)
|
|
85
143
|
)
|
|
86
|
-
|
|
87
|
-
else _concat(
|
|
88
|
-
*(
|
|
89
|
-
DataFrame.from_table(
|
|
90
|
-
chunk.table_view(),
|
|
91
|
-
list(ir.schema.keys()),
|
|
92
|
-
list(ir.schema.values()),
|
|
93
|
-
chunk.stream,
|
|
94
|
-
)
|
|
95
|
-
for chunk in chunks
|
|
96
|
-
),
|
|
97
|
-
context=ir_context,
|
|
98
|
-
)
|
|
99
|
-
)
|
|
100
|
-
await ch_out.data.send(
|
|
101
|
-
context,
|
|
102
|
-
Message(
|
|
103
|
-
seq_num,
|
|
104
|
-
TableChunk.from_pylibcudf_table(
|
|
105
|
-
df.table, df.stream, exclusive_view=True
|
|
106
|
-
),
|
|
107
|
-
),
|
|
108
|
-
)
|
|
109
|
-
seq_num += 1
|
|
144
|
+
del msg
|
|
110
145
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
146
|
+
if chunks:
|
|
147
|
+
input_bytes = sum(
|
|
148
|
+
chunk.data_alloc_size(MemoryType.DEVICE) for chunk in chunks
|
|
149
|
+
)
|
|
150
|
+
with opaque_reservation(context, input_bytes):
|
|
151
|
+
df = _concat(
|
|
152
|
+
*(
|
|
153
|
+
DataFrame.from_table(
|
|
154
|
+
chunk.table_view(),
|
|
155
|
+
list(ir.schema.keys()),
|
|
156
|
+
list(ir.schema.values()),
|
|
157
|
+
chunk.stream,
|
|
158
|
+
)
|
|
159
|
+
for chunk in chunks
|
|
160
|
+
),
|
|
161
|
+
context=ir_context,
|
|
162
|
+
)
|
|
163
|
+
await ch_out.data.send(
|
|
164
|
+
context,
|
|
165
|
+
Message(
|
|
166
|
+
seq_num,
|
|
167
|
+
TableChunk.from_pylibcudf_table(
|
|
168
|
+
df.table, df.stream, exclusive_view=True
|
|
169
|
+
),
|
|
170
|
+
),
|
|
171
|
+
)
|
|
172
|
+
seq_num += 1
|
|
173
|
+
del df, chunks
|
|
174
|
+
|
|
175
|
+
# Break if we reached end of stream
|
|
176
|
+
if done_receiving:
|
|
177
|
+
break
|
|
114
178
|
|
|
115
179
|
await ch_out.data.drain(context)
|
|
116
180
|
|
|
@@ -118,18 +182,15 @@ async def concatenate_node(
|
|
|
118
182
|
@generate_ir_sub_network.register(Repartition)
|
|
119
183
|
def _(
|
|
120
184
|
ir: Repartition, rec: SubNetGenerator
|
|
121
|
-
) -> tuple[list[Any], dict[IR, ChannelManager]]:
|
|
185
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
122
186
|
# Repartition node.
|
|
123
187
|
|
|
124
188
|
partition_info = rec.state["partition_info"]
|
|
125
|
-
max_chunks: int | None = None
|
|
126
189
|
if partition_info[ir].count > 1:
|
|
127
190
|
count_output = partition_info[ir].count
|
|
128
191
|
count_input = partition_info[ir.children[0]].count
|
|
129
192
|
if count_input < count_output:
|
|
130
193
|
raise ValueError("Repartitioning to more chunks is not supported.")
|
|
131
|
-
# Make sure max_chunks is at least 2
|
|
132
|
-
max_chunks = max(2, math.ceil(count_input / count_output))
|
|
133
194
|
|
|
134
195
|
# Process children
|
|
135
196
|
nodes, channels = rec(ir.children[0])
|
|
@@ -137,15 +198,19 @@ def _(
|
|
|
137
198
|
# Create output ChannelManager
|
|
138
199
|
channels[ir] = ChannelManager(rec.state["context"])
|
|
139
200
|
|
|
201
|
+
# Look up the reserved shuffle ID for this operation
|
|
202
|
+
collective_id = rec.state["collective_id_map"][ir]
|
|
203
|
+
|
|
140
204
|
# Add python node
|
|
141
|
-
nodes
|
|
205
|
+
nodes[ir] = [
|
|
142
206
|
concatenate_node(
|
|
143
207
|
rec.state["context"],
|
|
144
208
|
ir,
|
|
145
209
|
rec.state["ir_context"],
|
|
146
210
|
channels[ir].reserve_input_slot(),
|
|
147
211
|
channels[ir.children[0]].reserve_output_slot(),
|
|
148
|
-
|
|
212
|
+
output_count=partition_info[ir].count,
|
|
213
|
+
collective_id=collective_id,
|
|
149
214
|
)
|
|
150
|
-
|
|
215
|
+
]
|
|
151
216
|
return nodes, channels
|
|
@@ -16,6 +16,7 @@ from cudf_polars.experimental.rapidsmpf.dispatch import (
|
|
|
16
16
|
from cudf_polars.experimental.rapidsmpf.nodes import define_py_node, shutdown_on_error
|
|
17
17
|
from cudf_polars.experimental.rapidsmpf.utils import (
|
|
18
18
|
ChannelManager,
|
|
19
|
+
Metadata,
|
|
19
20
|
process_children,
|
|
20
21
|
)
|
|
21
22
|
|
|
@@ -51,8 +52,24 @@ async def union_node(
|
|
|
51
52
|
chs_in
|
|
52
53
|
The input ChannelPairs.
|
|
53
54
|
"""
|
|
54
|
-
|
|
55
|
-
|
|
55
|
+
async with shutdown_on_error(
|
|
56
|
+
context,
|
|
57
|
+
*[ch.metadata for ch in chs_in],
|
|
58
|
+
*[ch.data for ch in chs_in],
|
|
59
|
+
ch_out.metadata,
|
|
60
|
+
ch_out.data,
|
|
61
|
+
):
|
|
62
|
+
# Merge and forward metadata.
|
|
63
|
+
total_count = 0
|
|
64
|
+
duplicated = True
|
|
65
|
+
for ch_in in chs_in:
|
|
66
|
+
metadata = await ch_in.recv_metadata(context)
|
|
67
|
+
total_count += metadata.count
|
|
68
|
+
duplicated = duplicated and metadata.duplicated
|
|
69
|
+
await ch_out.send_metadata(
|
|
70
|
+
context, Metadata(total_count, duplicated=duplicated)
|
|
71
|
+
)
|
|
72
|
+
|
|
56
73
|
seq_num_offset = 0
|
|
57
74
|
for ch_in in chs_in:
|
|
58
75
|
num_ch_chunks = 0
|
|
@@ -73,7 +90,9 @@ async def union_node(
|
|
|
73
90
|
|
|
74
91
|
|
|
75
92
|
@generate_ir_sub_network.register(Union)
|
|
76
|
-
def _(
|
|
93
|
+
def _(
|
|
94
|
+
ir: Union, rec: SubNetGenerator
|
|
95
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
77
96
|
# Union operation.
|
|
78
97
|
# Pass-through all child chunks in channel order.
|
|
79
98
|
|
|
@@ -84,7 +103,7 @@ def _(ir: Union, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManag
|
|
|
84
103
|
channels[ir] = ChannelManager(rec.state["context"])
|
|
85
104
|
|
|
86
105
|
# Add simple python node
|
|
87
|
-
nodes
|
|
106
|
+
nodes[ir] = [
|
|
88
107
|
union_node(
|
|
89
108
|
rec.state["context"],
|
|
90
109
|
ir,
|
|
@@ -92,5 +111,5 @@ def _(ir: Union, rec: SubNetGenerator) -> tuple[list[Any], dict[IR, ChannelManag
|
|
|
92
111
|
channels[ir].reserve_input_slot(),
|
|
93
112
|
*[channels[c].reserve_output_slot() for c in ir.children],
|
|
94
113
|
)
|
|
95
|
-
|
|
114
|
+
]
|
|
96
115
|
return nodes, channels
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Utility functions and classes for the RapidsMPF streaming runtime."""
|
|
4
4
|
|
|
@@ -6,26 +6,33 @@ from __future__ import annotations
|
|
|
6
6
|
|
|
7
7
|
import asyncio
|
|
8
8
|
import operator
|
|
9
|
-
from contextlib import asynccontextmanager
|
|
9
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
10
10
|
from dataclasses import dataclass
|
|
11
11
|
from functools import reduce
|
|
12
|
-
from typing import TYPE_CHECKING, Any
|
|
12
|
+
from typing import TYPE_CHECKING, Any
|
|
13
|
+
|
|
14
|
+
from rapidsmpf.streaming.chunks.arbitrary import ArbitraryChunk
|
|
15
|
+
from rapidsmpf.streaming.core.message import Message
|
|
16
|
+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
|
|
17
|
+
|
|
18
|
+
import pylibcudf as plc
|
|
19
|
+
|
|
20
|
+
from cudf_polars.containers import DataFrame
|
|
13
21
|
|
|
14
22
|
if TYPE_CHECKING:
|
|
15
|
-
from collections.abc import AsyncIterator
|
|
23
|
+
from collections.abc import AsyncIterator, Callable, Iterator
|
|
16
24
|
|
|
25
|
+
from rapidsmpf.memory.memory_reservation import MemoryReservation
|
|
17
26
|
from rapidsmpf.streaming.core.channel import Channel
|
|
18
27
|
from rapidsmpf.streaming.core.context import Context
|
|
19
|
-
from rapidsmpf.streaming.
|
|
28
|
+
from rapidsmpf.streaming.core.spillable_messages import SpillableMessages
|
|
29
|
+
|
|
30
|
+
from rmm.pylibrmm.stream import Stream
|
|
20
31
|
|
|
21
32
|
from cudf_polars.dsl.ir import IR
|
|
22
33
|
from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator
|
|
23
34
|
|
|
24
35
|
|
|
25
|
-
# Type alias for metadata payloads (placeholder - not used yet)
|
|
26
|
-
MetadataPayload: TypeAlias = Any
|
|
27
|
-
|
|
28
|
-
|
|
29
36
|
@asynccontextmanager
|
|
30
37
|
async def shutdown_on_error(
|
|
31
38
|
context: Context, *channels: Channel[Any]
|
|
@@ -48,6 +55,29 @@ async def shutdown_on_error(
|
|
|
48
55
|
raise
|
|
49
56
|
|
|
50
57
|
|
|
58
|
+
class Metadata:
|
|
59
|
+
"""Metadata payload for an individual ChannelPair."""
|
|
60
|
+
|
|
61
|
+
__slots__ = ("count", "duplicated", "partitioned_on")
|
|
62
|
+
count: int
|
|
63
|
+
"""Chunk-count estimate."""
|
|
64
|
+
partitioned_on: tuple[str, ...]
|
|
65
|
+
"""Partitioned-on columns."""
|
|
66
|
+
duplicated: bool
|
|
67
|
+
"""Whether the data is duplicated on all workers."""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
count: int,
|
|
72
|
+
*,
|
|
73
|
+
partitioned_on: tuple[str, ...] = (),
|
|
74
|
+
duplicated: bool = False,
|
|
75
|
+
):
|
|
76
|
+
self.count = count
|
|
77
|
+
self.partitioned_on = partitioned_on
|
|
78
|
+
self.duplicated = duplicated
|
|
79
|
+
|
|
80
|
+
|
|
51
81
|
@dataclass
|
|
52
82
|
class ChannelPair:
|
|
53
83
|
"""
|
|
@@ -70,7 +100,7 @@ class ChannelPair:
|
|
|
70
100
|
in follow-up work.
|
|
71
101
|
"""
|
|
72
102
|
|
|
73
|
-
metadata: Channel[
|
|
103
|
+
metadata: Channel[ArbitraryChunk]
|
|
74
104
|
data: Channel[TableChunk]
|
|
75
105
|
|
|
76
106
|
@classmethod
|
|
@@ -81,6 +111,39 @@ class ChannelPair:
|
|
|
81
111
|
data=context.create_channel(),
|
|
82
112
|
)
|
|
83
113
|
|
|
114
|
+
async def send_metadata(self, ctx: Context, metadata: Metadata) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Send metadata and drain the metadata channel.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
ctx :
|
|
121
|
+
The streaming context.
|
|
122
|
+
metadata :
|
|
123
|
+
The metadata to send.
|
|
124
|
+
"""
|
|
125
|
+
msg = Message(0, ArbitraryChunk(metadata))
|
|
126
|
+
await self.metadata.send(ctx, msg)
|
|
127
|
+
await self.metadata.drain(ctx)
|
|
128
|
+
|
|
129
|
+
async def recv_metadata(self, ctx: Context) -> Metadata:
|
|
130
|
+
"""
|
|
131
|
+
Receive metadata from the metadata channel.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
ctx :
|
|
136
|
+
The streaming context.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
ChunkMetadata
|
|
141
|
+
The metadata, or None if channel is drained.
|
|
142
|
+
"""
|
|
143
|
+
msg = await self.metadata.recv(ctx)
|
|
144
|
+
assert msg is not None, f"Expected Metadata message, got {msg}."
|
|
145
|
+
return ArbitraryChunk.from_message(msg).release()
|
|
146
|
+
|
|
84
147
|
|
|
85
148
|
class ChannelManager:
|
|
86
149
|
"""A utility class for managing ChannelPair objects."""
|
|
@@ -131,13 +194,13 @@ class ChannelManager:
|
|
|
131
194
|
|
|
132
195
|
def process_children(
|
|
133
196
|
ir: IR, rec: SubNetGenerator
|
|
134
|
-
) -> tuple[list[Any], dict[IR, ChannelManager]]:
|
|
197
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
135
198
|
"""
|
|
136
199
|
Process children IR nodes and aggregate their nodes and channels.
|
|
137
200
|
|
|
138
201
|
This helper function recursively processes all children of an IR node,
|
|
139
|
-
collects their streaming network nodes into a
|
|
140
|
-
their channel dictionaries.
|
|
202
|
+
collects their streaming network nodes into a dictionary mapping IR nodes
|
|
203
|
+
to their associated nodes, and merges their channel dictionaries.
|
|
141
204
|
|
|
142
205
|
Parameters
|
|
143
206
|
----------
|
|
@@ -149,14 +212,163 @@ def process_children(
|
|
|
149
212
|
Returns
|
|
150
213
|
-------
|
|
151
214
|
nodes
|
|
152
|
-
|
|
215
|
+
Dictionary mapping each IR node to its list of streaming network nodes.
|
|
153
216
|
channels
|
|
154
217
|
Dictionary mapping each child IR node to its ChannelManager.
|
|
155
218
|
"""
|
|
156
219
|
if not ir.children:
|
|
157
|
-
return
|
|
220
|
+
return {}, {}
|
|
158
221
|
|
|
159
222
|
_nodes_list, _channels_list = zip(*(rec(c) for c in ir.children), strict=True)
|
|
160
|
-
nodes: list[Any] =
|
|
223
|
+
nodes: dict[IR, list[Any]] = reduce(operator.or_, _nodes_list)
|
|
161
224
|
channels: dict[IR, ChannelManager] = reduce(operator.or_, _channels_list)
|
|
162
225
|
return nodes, channels
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def empty_table_chunk(ir: IR, context: Context, stream: Stream) -> TableChunk:
|
|
229
|
+
"""
|
|
230
|
+
Make an empty table chunk.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
ir
|
|
235
|
+
The IR node to use for the schema.
|
|
236
|
+
context
|
|
237
|
+
The rapidsmpf context.
|
|
238
|
+
stream
|
|
239
|
+
The stream to use for the table chunk.
|
|
240
|
+
|
|
241
|
+
Returns
|
|
242
|
+
-------
|
|
243
|
+
The empty table chunk.
|
|
244
|
+
"""
|
|
245
|
+
# Create an empty table with the correct schema
|
|
246
|
+
# Use dtype.plc_type to get the full DataType (preserves precision/scale for Decimals)
|
|
247
|
+
empty_columns = [
|
|
248
|
+
plc.column_factories.make_empty_column(dtype.plc_type, stream=stream)
|
|
249
|
+
for dtype in ir.schema.values()
|
|
250
|
+
]
|
|
251
|
+
empty_table = plc.Table(empty_columns)
|
|
252
|
+
|
|
253
|
+
return TableChunk.from_pylibcudf_table(
|
|
254
|
+
empty_table,
|
|
255
|
+
stream,
|
|
256
|
+
exclusive_view=True,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def chunk_to_frame(chunk: TableChunk, ir: IR) -> DataFrame:
|
|
261
|
+
"""
|
|
262
|
+
Convert a TableChunk to a DataFrame.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
chunk
|
|
267
|
+
The TableChunk to convert.
|
|
268
|
+
ir
|
|
269
|
+
The IR node to use for the schema.
|
|
270
|
+
|
|
271
|
+
Returns
|
|
272
|
+
-------
|
|
273
|
+
A DataFrame.
|
|
274
|
+
"""
|
|
275
|
+
return DataFrame.from_table(
|
|
276
|
+
chunk.table_view(),
|
|
277
|
+
list(ir.schema.keys()),
|
|
278
|
+
list(ir.schema.values()),
|
|
279
|
+
chunk.stream,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def make_spill_function(
|
|
284
|
+
spillable_messages_list: list[SpillableMessages],
|
|
285
|
+
context: Context,
|
|
286
|
+
) -> Callable[[int], int]:
|
|
287
|
+
"""
|
|
288
|
+
Create a spill function for a list of SpillableMessages containers.
|
|
289
|
+
|
|
290
|
+
This utility creates a spill function that can be registered with a
|
|
291
|
+
SpillManager. The spill function uses a smart spilling strategy that
|
|
292
|
+
prioritizes:
|
|
293
|
+
1. Longest queues first (slow consumers that won't need data soon)
|
|
294
|
+
2. Newest messages first (just arrived, won't be consumed soon)
|
|
295
|
+
|
|
296
|
+
This strategy keeps "hot" data (about to be consumed) in fast memory
|
|
297
|
+
while spilling "cold" data (won't be needed for a while) to slower tiers.
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
spillable_messages_list
|
|
302
|
+
List of SpillableMessages containers to create a spill function for.
|
|
303
|
+
context
|
|
304
|
+
The RapidsMPF context to use for accessing the BufferResource.
|
|
305
|
+
|
|
306
|
+
Returns
|
|
307
|
+
-------
|
|
308
|
+
A spill function that takes an amount (in bytes) and returns the
|
|
309
|
+
actual amount spilled (in bytes).
|
|
310
|
+
|
|
311
|
+
Notes
|
|
312
|
+
-----
|
|
313
|
+
The spilling strategy is particularly effective for fanout scenarios
|
|
314
|
+
where different consumers may process messages at different rates. By
|
|
315
|
+
prioritizing longest queues and newest messages, we maximize the time
|
|
316
|
+
data can remain in slower memory before it's needed.
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def spill_func(amount: int) -> int:
|
|
320
|
+
"""Spill messages from the buffers to free device/host memory."""
|
|
321
|
+
spilled = 0
|
|
322
|
+
|
|
323
|
+
# Collect all messages with metadata for smart spilling
|
|
324
|
+
# Format: (message_id, container_idx, queue_length, sm)
|
|
325
|
+
all_messages: list[tuple[int, int, int, SpillableMessages]] = []
|
|
326
|
+
for container_idx, sm in enumerate(spillable_messages_list):
|
|
327
|
+
content_descriptions = sm.get_content_descriptions()
|
|
328
|
+
queue_length = len(content_descriptions)
|
|
329
|
+
all_messages.extend(
|
|
330
|
+
(message_id, container_idx, queue_length, sm)
|
|
331
|
+
for message_id in content_descriptions
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Spill newest messages first from the longest queues
|
|
335
|
+
# Sort by: (1) queue length descending, (2) message_id descending
|
|
336
|
+
# This prioritizes:
|
|
337
|
+
# - Longest queues (slow consumers that won't need data soon)
|
|
338
|
+
# - Newest messages (just arrived, won't be consumed soon)
|
|
339
|
+
all_messages.sort(key=lambda x: (-x[2], -x[0]))
|
|
340
|
+
|
|
341
|
+
# Spill messages until we've freed enough memory
|
|
342
|
+
for message_id, _, _, sm in all_messages:
|
|
343
|
+
if spilled >= amount:
|
|
344
|
+
break
|
|
345
|
+
# Try to spill this message
|
|
346
|
+
spilled += sm.spill(mid=message_id, br=context.br())
|
|
347
|
+
|
|
348
|
+
return spilled
|
|
349
|
+
|
|
350
|
+
return spill_func
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@contextmanager
|
|
354
|
+
def opaque_reservation(
|
|
355
|
+
context: Context,
|
|
356
|
+
estimated_bytes: int,
|
|
357
|
+
) -> Iterator[MemoryReservation]:
|
|
358
|
+
"""
|
|
359
|
+
Reserve memory for opaque allocations.
|
|
360
|
+
|
|
361
|
+
Parameters
|
|
362
|
+
----------
|
|
363
|
+
context
|
|
364
|
+
The RapidsMPF context.
|
|
365
|
+
estimated_bytes
|
|
366
|
+
The estimated number of bytes to reserve.
|
|
367
|
+
|
|
368
|
+
Yields
|
|
369
|
+
------
|
|
370
|
+
The memory reservation.
|
|
371
|
+
"""
|
|
372
|
+
yield context.br().reserve_device_memory_and_spill(
|
|
373
|
+
estimated_bytes, allow_overbooking=True
|
|
374
|
+
)
|