cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +60 -15
- cudf_polars/containers/column.py +137 -77
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +256 -114
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +33 -3
- cudf_polars/dsl/expressions/unary.py +126 -64
- cudf_polars/dsl/ir.py +1053 -350
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +307 -107
- cudf_polars/dsl/utils/aggregations.py +43 -30
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +55 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +792 -2
- cudf_polars/experimental/benchmarks/utils.py +596 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +30 -15
- cudf_polars/experimental/groupby.py +25 -4
- cudf_polars/experimental/io.py +156 -124
- cudf_polars/experimental/join.py +53 -23
- cudf_polars/experimental/parallel.py +68 -19
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
- cudf_polars/experimental/rapidsmpf/core.py +488 -0
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
- cudf_polars/experimental/rapidsmpf/io.py +696 -0
- cudf_polars/experimental/rapidsmpf/join.py +322 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
- cudf_polars/experimental/rapidsmpf/union.py +115 -0
- cudf_polars/experimental/rapidsmpf/utils.py +374 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +46 -12
- cudf_polars/experimental/sort.py +100 -26
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +93 -17
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +473 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +5 -4
- cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
- cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Join logic for the RapidsMPF streaming runtime."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
9
|
+
|
|
10
|
+
from rapidsmpf.memory.buffer import MemoryType
|
|
11
|
+
from rapidsmpf.streaming.core.message import Message
|
|
12
|
+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
|
|
13
|
+
|
|
14
|
+
from cudf_polars.containers import DataFrame
|
|
15
|
+
from cudf_polars.dsl.ir import IR, Join
|
|
16
|
+
from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager
|
|
17
|
+
from cudf_polars.experimental.rapidsmpf.dispatch import (
|
|
18
|
+
generate_ir_sub_network,
|
|
19
|
+
)
|
|
20
|
+
from cudf_polars.experimental.rapidsmpf.nodes import (
|
|
21
|
+
default_node_multi,
|
|
22
|
+
define_py_node,
|
|
23
|
+
shutdown_on_error,
|
|
24
|
+
)
|
|
25
|
+
from cudf_polars.experimental.rapidsmpf.utils import (
|
|
26
|
+
ChannelManager,
|
|
27
|
+
Metadata,
|
|
28
|
+
chunk_to_frame,
|
|
29
|
+
empty_table_chunk,
|
|
30
|
+
opaque_reservation,
|
|
31
|
+
process_children,
|
|
32
|
+
)
|
|
33
|
+
from cudf_polars.experimental.utils import _concat
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from rapidsmpf.streaming.core.context import Context
|
|
37
|
+
|
|
38
|
+
from cudf_polars.dsl.ir import IR, IRExecutionContext
|
|
39
|
+
from cudf_polars.experimental.rapidsmpf.core import SubNetGenerator
|
|
40
|
+
from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@define_py_node()
|
|
44
|
+
async def broadcast_join_node(
|
|
45
|
+
context: Context,
|
|
46
|
+
ir: Join,
|
|
47
|
+
ir_context: IRExecutionContext,
|
|
48
|
+
ch_out: ChannelPair,
|
|
49
|
+
ch_left: ChannelPair,
|
|
50
|
+
ch_right: ChannelPair,
|
|
51
|
+
broadcast_side: Literal["left", "right"],
|
|
52
|
+
collective_id: int,
|
|
53
|
+
target_partition_size: int,
|
|
54
|
+
) -> None:
|
|
55
|
+
"""
|
|
56
|
+
Join node for rapidsmpf.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
context
|
|
61
|
+
The rapidsmpf context.
|
|
62
|
+
ir
|
|
63
|
+
The Join IR node.
|
|
64
|
+
ir_context
|
|
65
|
+
The execution context for the IR node.
|
|
66
|
+
ch_out
|
|
67
|
+
The output ChannelPair.
|
|
68
|
+
ch_left
|
|
69
|
+
The left input ChannelPair.
|
|
70
|
+
ch_right
|
|
71
|
+
The right input ChannelPair.
|
|
72
|
+
broadcast_side
|
|
73
|
+
The side to broadcast.
|
|
74
|
+
collective_id
|
|
75
|
+
Pre-allocated collective ID for this operation.
|
|
76
|
+
target_partition_size
|
|
77
|
+
The target partition size in bytes.
|
|
78
|
+
"""
|
|
79
|
+
async with shutdown_on_error(
|
|
80
|
+
context,
|
|
81
|
+
ch_left.metadata,
|
|
82
|
+
ch_left.data,
|
|
83
|
+
ch_right.metadata,
|
|
84
|
+
ch_right.data,
|
|
85
|
+
ch_out.metadata,
|
|
86
|
+
ch_out.data,
|
|
87
|
+
):
|
|
88
|
+
# Receive metadata.
|
|
89
|
+
left_metadata, right_metadata = await asyncio.gather(
|
|
90
|
+
ch_left.recv_metadata(context),
|
|
91
|
+
ch_right.recv_metadata(context),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
partitioned_on: tuple[str, ...] = ()
|
|
95
|
+
if broadcast_side == "right":
|
|
96
|
+
# Broadcast right, stream left
|
|
97
|
+
small_ch = ch_right
|
|
98
|
+
large_ch = ch_left
|
|
99
|
+
small_child = ir.children[1]
|
|
100
|
+
large_child = ir.children[0]
|
|
101
|
+
chunk_count = left_metadata.count
|
|
102
|
+
partitioned_on = left_metadata.partitioned_on
|
|
103
|
+
small_duplicated = right_metadata.duplicated
|
|
104
|
+
else:
|
|
105
|
+
# Broadcast left, stream right
|
|
106
|
+
small_ch = ch_left
|
|
107
|
+
large_ch = ch_right
|
|
108
|
+
small_child = ir.children[0]
|
|
109
|
+
large_child = ir.children[1]
|
|
110
|
+
chunk_count = right_metadata.count
|
|
111
|
+
small_duplicated = left_metadata.duplicated
|
|
112
|
+
if ir.options[0] == "Right":
|
|
113
|
+
partitioned_on = right_metadata.partitioned_on
|
|
114
|
+
|
|
115
|
+
# Send metadata.
|
|
116
|
+
output_metadata = Metadata(
|
|
117
|
+
chunk_count,
|
|
118
|
+
partitioned_on=partitioned_on,
|
|
119
|
+
duplicated=left_metadata.duplicated and right_metadata.duplicated,
|
|
120
|
+
)
|
|
121
|
+
await ch_out.send_metadata(context, output_metadata)
|
|
122
|
+
|
|
123
|
+
# Collect small-side (may be empty if no data received)
|
|
124
|
+
small_chunks: list[TableChunk] = []
|
|
125
|
+
small_size = 0
|
|
126
|
+
while (msg := await small_ch.data.recv(context)) is not None:
|
|
127
|
+
small_chunks.append(
|
|
128
|
+
TableChunk.from_message(msg).make_available_and_spill(
|
|
129
|
+
context.br(), allow_overbooking=True
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
del msg
|
|
133
|
+
small_size += small_chunks[-1].data_alloc_size(MemoryType.DEVICE)
|
|
134
|
+
|
|
135
|
+
# Allgather is a collective - all ranks must participate even with no local data
|
|
136
|
+
need_allgather = context.comm().nranks > 1 and not small_duplicated
|
|
137
|
+
if need_allgather:
|
|
138
|
+
allgather = AllGatherManager(context, collective_id)
|
|
139
|
+
for s_id in range(len(small_chunks)):
|
|
140
|
+
allgather.insert(s_id, small_chunks.pop(0))
|
|
141
|
+
allgather.insert_finished()
|
|
142
|
+
stream = ir_context.get_cuda_stream()
|
|
143
|
+
# extract_concatenated returns a plc.Table, not a TableChunk
|
|
144
|
+
small_dfs = [
|
|
145
|
+
DataFrame.from_table(
|
|
146
|
+
await allgather.extract_concatenated(stream),
|
|
147
|
+
list(small_child.schema.keys()),
|
|
148
|
+
list(small_child.schema.values()),
|
|
149
|
+
stream,
|
|
150
|
+
)
|
|
151
|
+
]
|
|
152
|
+
elif len(small_chunks) > 1 and (
|
|
153
|
+
ir.options[0] != "Inner" or small_size < target_partition_size
|
|
154
|
+
):
|
|
155
|
+
# Pre-concat for non-inner joins, otherwise
|
|
156
|
+
# we need a local shuffle, and face additional
|
|
157
|
+
# memory pressure anyway.
|
|
158
|
+
small_dfs = [
|
|
159
|
+
_concat(
|
|
160
|
+
*[chunk_to_frame(chunk, small_child) for chunk in small_chunks],
|
|
161
|
+
context=ir_context,
|
|
162
|
+
)
|
|
163
|
+
]
|
|
164
|
+
small_chunks.clear() # small_dfs is not a view of small_chunks anymore
|
|
165
|
+
else:
|
|
166
|
+
small_dfs = [
|
|
167
|
+
chunk_to_frame(small_chunk, small_child) for small_chunk in small_chunks
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
# Stream through large side, joining with the small-side
|
|
171
|
+
seq_num = 0
|
|
172
|
+
large_chunk_processed = False
|
|
173
|
+
receiving_large_chunks = True
|
|
174
|
+
while receiving_large_chunks:
|
|
175
|
+
msg = await large_ch.data.recv(context)
|
|
176
|
+
if msg is None:
|
|
177
|
+
receiving_large_chunks = False
|
|
178
|
+
if large_chunk_processed:
|
|
179
|
+
# Normal exit - We've processed all large-table data
|
|
180
|
+
break
|
|
181
|
+
elif small_dfs:
|
|
182
|
+
# We received small-table data, but no large-table data.
|
|
183
|
+
# This may never happen, but we can handle it by generating
|
|
184
|
+
# an empty large-table chunk
|
|
185
|
+
stream = ir_context.get_cuda_stream()
|
|
186
|
+
large_chunk = empty_table_chunk(large_child, context, stream)
|
|
187
|
+
else:
|
|
188
|
+
# We received no data for either the small or large table.
|
|
189
|
+
# Drain the output channel and return
|
|
190
|
+
await ch_out.data.drain(context)
|
|
191
|
+
return
|
|
192
|
+
else:
|
|
193
|
+
large_chunk_processed = True
|
|
194
|
+
large_chunk = TableChunk.from_message(msg).make_available_and_spill(
|
|
195
|
+
context.br(), allow_overbooking=True
|
|
196
|
+
)
|
|
197
|
+
seq_num = msg.sequence_number
|
|
198
|
+
del msg
|
|
199
|
+
|
|
200
|
+
large_df = DataFrame.from_table(
|
|
201
|
+
large_chunk.table_view(),
|
|
202
|
+
list(large_child.schema.keys()),
|
|
203
|
+
list(large_child.schema.values()),
|
|
204
|
+
large_chunk.stream,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Lazily create empty small table if small_dfs is empty
|
|
208
|
+
if not small_dfs:
|
|
209
|
+
stream = ir_context.get_cuda_stream()
|
|
210
|
+
empty_small_chunk = empty_table_chunk(small_child, context, stream)
|
|
211
|
+
small_dfs = [chunk_to_frame(empty_small_chunk, small_child)]
|
|
212
|
+
|
|
213
|
+
large_chunk_size = large_chunk.data_alloc_size(MemoryType.DEVICE)
|
|
214
|
+
input_bytes = large_chunk_size + small_size
|
|
215
|
+
with opaque_reservation(context, input_bytes):
|
|
216
|
+
df = _concat(
|
|
217
|
+
*[
|
|
218
|
+
await asyncio.to_thread(
|
|
219
|
+
ir.do_evaluate,
|
|
220
|
+
*ir._non_child_args,
|
|
221
|
+
*(
|
|
222
|
+
[large_df, small_df]
|
|
223
|
+
if broadcast_side == "right"
|
|
224
|
+
else [small_df, large_df]
|
|
225
|
+
),
|
|
226
|
+
context=ir_context,
|
|
227
|
+
)
|
|
228
|
+
for small_df in small_dfs
|
|
229
|
+
],
|
|
230
|
+
context=ir_context,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Send output chunk
|
|
234
|
+
await ch_out.data.send(
|
|
235
|
+
context,
|
|
236
|
+
Message(
|
|
237
|
+
seq_num,
|
|
238
|
+
TableChunk.from_pylibcudf_table(
|
|
239
|
+
df.table, df.stream, exclusive_view=True
|
|
240
|
+
),
|
|
241
|
+
),
|
|
242
|
+
)
|
|
243
|
+
del df, large_df, large_chunk
|
|
244
|
+
|
|
245
|
+
del small_dfs, small_chunks
|
|
246
|
+
await ch_out.data.drain(context)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@generate_ir_sub_network.register(Join)
|
|
250
|
+
def _(
|
|
251
|
+
ir: Join, rec: SubNetGenerator
|
|
252
|
+
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
|
|
253
|
+
# Join operation.
|
|
254
|
+
left, right = ir.children
|
|
255
|
+
partition_info = rec.state["partition_info"]
|
|
256
|
+
output_count = partition_info[ir].count
|
|
257
|
+
|
|
258
|
+
left_count = partition_info[left].count
|
|
259
|
+
right_count = partition_info[right].count
|
|
260
|
+
left_partitioned = (
|
|
261
|
+
partition_info[left].partitioned_on == ir.left_on and left_count == output_count
|
|
262
|
+
)
|
|
263
|
+
right_partitioned = (
|
|
264
|
+
partition_info[right].partitioned_on == ir.right_on
|
|
265
|
+
and right_count == output_count
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
pwise_join = output_count == 1 or (left_partitioned and right_partitioned)
|
|
269
|
+
|
|
270
|
+
# Process children
|
|
271
|
+
nodes, channels = process_children(ir, rec)
|
|
272
|
+
|
|
273
|
+
# Create output ChannelManager
|
|
274
|
+
channels[ir] = ChannelManager(rec.state["context"])
|
|
275
|
+
|
|
276
|
+
if pwise_join:
|
|
277
|
+
# Partition-wise join (use default_node_multi)
|
|
278
|
+
partitioning_index = 1 if ir.options[0] == "Right" else 0
|
|
279
|
+
nodes[ir] = [
|
|
280
|
+
default_node_multi(
|
|
281
|
+
rec.state["context"],
|
|
282
|
+
ir,
|
|
283
|
+
rec.state["ir_context"],
|
|
284
|
+
channels[ir].reserve_input_slot(),
|
|
285
|
+
(
|
|
286
|
+
channels[left].reserve_output_slot(),
|
|
287
|
+
channels[right].reserve_output_slot(),
|
|
288
|
+
),
|
|
289
|
+
partitioning_index=partitioning_index,
|
|
290
|
+
)
|
|
291
|
+
]
|
|
292
|
+
return nodes, channels
|
|
293
|
+
|
|
294
|
+
else:
|
|
295
|
+
# Broadcast join (use broadcast_join_node)
|
|
296
|
+
broadcast_side: Literal["left", "right"]
|
|
297
|
+
if left_count >= right_count:
|
|
298
|
+
# Broadcast right, stream left
|
|
299
|
+
broadcast_side = "right"
|
|
300
|
+
else:
|
|
301
|
+
broadcast_side = "left"
|
|
302
|
+
|
|
303
|
+
# Get target partition size
|
|
304
|
+
config_options = rec.state["config_options"]
|
|
305
|
+
executor = config_options.executor
|
|
306
|
+
assert executor.name == "streaming", "Join node requires streaming executor"
|
|
307
|
+
target_partition_size = executor.target_partition_size
|
|
308
|
+
|
|
309
|
+
nodes[ir] = [
|
|
310
|
+
broadcast_join_node(
|
|
311
|
+
rec.state["context"],
|
|
312
|
+
ir,
|
|
313
|
+
rec.state["ir_context"],
|
|
314
|
+
channels[ir].reserve_input_slot(),
|
|
315
|
+
channels[left].reserve_output_slot(),
|
|
316
|
+
channels[right].reserve_output_slot(),
|
|
317
|
+
broadcast_side=broadcast_side,
|
|
318
|
+
collective_id=rec.state["collective_id_map"][ir],
|
|
319
|
+
target_partition_size=target_partition_size,
|
|
320
|
+
)
|
|
321
|
+
]
|
|
322
|
+
return nodes, channels
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Core lowering logic for the RapidsMPF streaming runtime."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
import cudf_polars.experimental.rapidsmpf.io # noqa: F401
|
|
10
|
+
from cudf_polars.dsl.ir import IR, Sort
|
|
11
|
+
from cudf_polars.experimental.base import PartitionInfo
|
|
12
|
+
from cudf_polars.experimental.io import StreamingSink
|
|
13
|
+
from cudf_polars.experimental.parallel import _lower_ir_pwise
|
|
14
|
+
from cudf_polars.experimental.rapidsmpf.dispatch import (
|
|
15
|
+
lower_ir_node,
|
|
16
|
+
)
|
|
17
|
+
from cudf_polars.experimental.repartition import Repartition
|
|
18
|
+
from cudf_polars.experimental.sort import ShuffleSorted
|
|
19
|
+
from cudf_polars.experimental.utils import _lower_ir_fallback
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from collections.abc import MutableMapping
|
|
23
|
+
|
|
24
|
+
from cudf_polars.experimental.rapidsmpf.dispatch import LowerIRTransformer
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@lower_ir_node.register(IR)
|
|
28
|
+
def _lower_ir_node_task_engine(
|
|
29
|
+
ir: IR, rec: LowerIRTransformer
|
|
30
|
+
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
|
|
31
|
+
# Use task-engine lowering logic
|
|
32
|
+
from cudf_polars.experimental.dispatch import lower_ir_node as base_lower_ir_node
|
|
33
|
+
|
|
34
|
+
return base_lower_ir_node(ir, rec)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@lower_ir_node.register(ShuffleSorted)
|
|
38
|
+
@lower_ir_node.register(StreamingSink)
|
|
39
|
+
def _unsupported(
|
|
40
|
+
ir: IR, rec: LowerIRTransformer
|
|
41
|
+
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
|
|
42
|
+
# Unsupported operations - Fall back to a single partition/chunk.
|
|
43
|
+
return _lower_ir_fallback(
|
|
44
|
+
ir, rec, msg=f"Class {type(ir)} does not support multiple partitions."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@lower_ir_node.register(Sort)
|
|
49
|
+
def _(
|
|
50
|
+
ir: Sort, rec: LowerIRTransformer
|
|
51
|
+
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
|
|
52
|
+
if ir.zlice is not None:
|
|
53
|
+
# Top- or bottom-k support
|
|
54
|
+
has_offset = ir.zlice[0] > 0 or (
|
|
55
|
+
ir.zlice[0] < 0
|
|
56
|
+
and ir.zlice[1] is not None
|
|
57
|
+
and ir.zlice[0] + ir.zlice[1] < 0
|
|
58
|
+
)
|
|
59
|
+
if not has_offset:
|
|
60
|
+
# Sort input partitions
|
|
61
|
+
new_node, partition_info = _lower_ir_pwise(ir, rec)
|
|
62
|
+
if partition_info[new_node].count > 1:
|
|
63
|
+
# Collapse down to single partition
|
|
64
|
+
inter = Repartition(new_node.schema, new_node)
|
|
65
|
+
partition_info[inter] = PartitionInfo(count=1)
|
|
66
|
+
# Sort reduced partition
|
|
67
|
+
new_node = ir.reconstruct([inter])
|
|
68
|
+
partition_info[new_node] = PartitionInfo(count=1)
|
|
69
|
+
return new_node, partition_info
|
|
70
|
+
|
|
71
|
+
# TODO: Add general multi-partition Sort support
|
|
72
|
+
return _lower_ir_fallback(
|
|
73
|
+
ir, rec, msg=f"Class {type(ir)} does not support multiple partitions."
|
|
74
|
+
)
|