cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +82 -65
- cudf_polars/containers/column.py +138 -7
- cudf_polars/containers/dataframe.py +26 -39
- cudf_polars/dsl/expr.py +3 -1
- cudf_polars/dsl/expressions/aggregation.py +27 -63
- cudf_polars/dsl/expressions/base.py +40 -72
- cudf_polars/dsl/expressions/binaryop.py +5 -41
- cudf_polars/dsl/expressions/boolean.py +25 -53
- cudf_polars/dsl/expressions/datetime.py +97 -17
- cudf_polars/dsl/expressions/literal.py +27 -33
- cudf_polars/dsl/expressions/rolling.py +110 -9
- cudf_polars/dsl/expressions/selection.py +8 -26
- cudf_polars/dsl/expressions/slicing.py +47 -0
- cudf_polars/dsl/expressions/sorting.py +5 -18
- cudf_polars/dsl/expressions/string.py +33 -36
- cudf_polars/dsl/expressions/ternary.py +3 -10
- cudf_polars/dsl/expressions/unary.py +35 -75
- cudf_polars/dsl/ir.py +749 -212
- cudf_polars/dsl/nodebase.py +8 -1
- cudf_polars/dsl/to_ast.py +5 -3
- cudf_polars/dsl/translate.py +319 -171
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +292 -0
- cudf_polars/dsl/utils/groupby.py +97 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +46 -0
- cudf_polars/dsl/utils/rolling.py +113 -0
- cudf_polars/dsl/utils/windows.py +186 -0
- cudf_polars/experimental/base.py +17 -19
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsh.py +1279 -0
- cudf_polars/experimental/dask_registers.py +196 -0
- cudf_polars/experimental/distinct.py +174 -0
- cudf_polars/experimental/explain.py +127 -0
- cudf_polars/experimental/expressions.py +521 -0
- cudf_polars/experimental/groupby.py +288 -0
- cudf_polars/experimental/io.py +58 -29
- cudf_polars/experimental/join.py +353 -0
- cudf_polars/experimental/parallel.py +166 -93
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +92 -7
- cudf_polars/experimental/shuffle.py +294 -0
- cudf_polars/experimental/sort.py +45 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/utils.py +100 -0
- cudf_polars/testing/asserts.py +146 -6
- cudf_polars/testing/io.py +72 -0
- cudf_polars/testing/plugin.py +78 -76
- cudf_polars/typing/__init__.py +59 -6
- cudf_polars/utils/config.py +353 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +22 -5
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +5 -4
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/METADATA +10 -7
- cudf_polars_cu12-25.6.0.dist-info/RECORD +73 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/WHEEL +1 -1
- cudf_polars/experimental/dask_serialize.py +0 -59
- cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info/licenses}/LICENSE +0 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""
|
|
4
|
+
Multi-partition Expr classes and utilities.
|
|
5
|
+
|
|
6
|
+
This module includes the necessary functionality to
|
|
7
|
+
decompose a non-pointwise expression graph into stages
|
|
8
|
+
that can each be mapped onto a simple partition-wise
|
|
9
|
+
task graph at execution time.
|
|
10
|
+
|
|
11
|
+
For example, if ``Select.exprs`` contains an ``expr.Agg``
|
|
12
|
+
node, ``decompose_expr_graph`` will decompose the complex
|
|
13
|
+
NamedExpr node into a sequence of three new IR nodes::
|
|
14
|
+
|
|
15
|
+
- Select: Partition-wise aggregation logic.
|
|
16
|
+
- Repartition: Concatenate the results of each partition.
|
|
17
|
+
- Select: Final aggregation on the combined results.
|
|
18
|
+
|
|
19
|
+
In this example, the Select stages are mapped onto a simple
|
|
20
|
+
partition-wise task graph at execution time, and the Repartition
|
|
21
|
+
stage is used to capture the data-movement required for a global
|
|
22
|
+
aggregation. At the moment, data movement is always introduced
|
|
23
|
+
by either repartitioning or shuffling the data.
|
|
24
|
+
|
|
25
|
+
Since we are introducing intermediate IR nodes, we are also
|
|
26
|
+
introducing a temporary column for each intermediate result.
|
|
27
|
+
In order to avoid column-name collisions with the original
|
|
28
|
+
input-IR node, we generate unique names for temporary columns
|
|
29
|
+
and concatenate them to the input-IR node using ``HConcat``.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from __future__ import annotations
|
|
33
|
+
|
|
34
|
+
import operator
|
|
35
|
+
from functools import reduce
|
|
36
|
+
from typing import TYPE_CHECKING
|
|
37
|
+
|
|
38
|
+
import pylibcudf as plc
|
|
39
|
+
|
|
40
|
+
from cudf_polars.dsl.expressions.aggregation import Agg
|
|
41
|
+
from cudf_polars.dsl.expressions.base import Col, Expr, NamedExpr
|
|
42
|
+
from cudf_polars.dsl.expressions.binaryop import BinOp
|
|
43
|
+
from cudf_polars.dsl.expressions.literal import Literal
|
|
44
|
+
from cudf_polars.dsl.expressions.unary import Cast, UnaryFunction
|
|
45
|
+
from cudf_polars.dsl.ir import Distinct, Empty, HConcat, Select
|
|
46
|
+
from cudf_polars.dsl.traversal import (
|
|
47
|
+
CachingVisitor,
|
|
48
|
+
)
|
|
49
|
+
from cudf_polars.dsl.utils.naming import unique_names
|
|
50
|
+
from cudf_polars.experimental.base import PartitionInfo
|
|
51
|
+
from cudf_polars.experimental.repartition import Repartition
|
|
52
|
+
from cudf_polars.experimental.utils import _leaf_column_names
|
|
53
|
+
|
|
54
|
+
if TYPE_CHECKING:
|
|
55
|
+
from collections.abc import Generator, MutableMapping, Sequence
|
|
56
|
+
from typing import TypeAlias
|
|
57
|
+
|
|
58
|
+
from cudf_polars.dsl.expressions.base import Expr
|
|
59
|
+
from cudf_polars.dsl.ir import IR
|
|
60
|
+
from cudf_polars.typing import GenericTransformer, Schema
|
|
61
|
+
from cudf_polars.utils.config import ConfigOptions
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
ExprDecomposer: TypeAlias = (
|
|
65
|
+
"GenericTransformer[Expr, tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]]"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def select(
|
|
70
|
+
exprs: Sequence[Expr],
|
|
71
|
+
input_ir: IR,
|
|
72
|
+
partition_info: MutableMapping[IR, PartitionInfo],
|
|
73
|
+
*,
|
|
74
|
+
names: Generator[str, None, None],
|
|
75
|
+
repartition: bool = False,
|
|
76
|
+
) -> tuple[list[Col], IR, MutableMapping[IR, PartitionInfo]]:
|
|
77
|
+
"""
|
|
78
|
+
Select expressions from an IR node, introducing temporaries.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
exprs
|
|
83
|
+
Expressions to select.
|
|
84
|
+
input_ir
|
|
85
|
+
The input IR node to select from.
|
|
86
|
+
partition_info
|
|
87
|
+
A mapping from all unique IR nodes to the
|
|
88
|
+
associated partitioning information.
|
|
89
|
+
names
|
|
90
|
+
Generator of unique names for temporaries.
|
|
91
|
+
repartition
|
|
92
|
+
Whether to add a Repartition node after the
|
|
93
|
+
new selection.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
columns
|
|
98
|
+
Expressions to select from the new IR output.
|
|
99
|
+
new_ir
|
|
100
|
+
The new IR node that will introduce temporaries.
|
|
101
|
+
partition_info
|
|
102
|
+
A mapping from unique nodes in the new graph to associated
|
|
103
|
+
partitioning information.
|
|
104
|
+
"""
|
|
105
|
+
output_names = [next(names) for _ in range(len(exprs))]
|
|
106
|
+
named_exprs = [
|
|
107
|
+
NamedExpr(name, expr) for name, expr in zip(output_names, exprs, strict=True)
|
|
108
|
+
]
|
|
109
|
+
new_ir: IR = Select(
|
|
110
|
+
{ne.name: ne.value.dtype for ne in named_exprs},
|
|
111
|
+
named_exprs,
|
|
112
|
+
True, # noqa: FBT003
|
|
113
|
+
input_ir,
|
|
114
|
+
)
|
|
115
|
+
partition_info[new_ir] = PartitionInfo(count=partition_info[input_ir].count)
|
|
116
|
+
|
|
117
|
+
# Optionally collapse into one output partition
|
|
118
|
+
if repartition:
|
|
119
|
+
new_ir = Repartition(new_ir.schema, new_ir)
|
|
120
|
+
partition_info[new_ir] = PartitionInfo(count=1)
|
|
121
|
+
|
|
122
|
+
columns = [Col(ne.value.dtype, ne.name) for ne in named_exprs]
|
|
123
|
+
return columns, new_ir, partition_info
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _decompose_unique(
|
|
127
|
+
unique: UnaryFunction,
|
|
128
|
+
input_ir: IR,
|
|
129
|
+
partition_info: MutableMapping[IR, PartitionInfo],
|
|
130
|
+
config_options: ConfigOptions,
|
|
131
|
+
*,
|
|
132
|
+
names: Generator[str, None, None],
|
|
133
|
+
) -> tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]:
|
|
134
|
+
"""
|
|
135
|
+
Decompose a 'unique' UnaryFunction into partition-wise stages.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
unique
|
|
140
|
+
The expression node to decompose.
|
|
141
|
+
input_ir
|
|
142
|
+
The original input-IR node that ``unique`` will evaluate.
|
|
143
|
+
partition_info
|
|
144
|
+
A mapping from all unique IR nodes to the
|
|
145
|
+
associated partitioning information.
|
|
146
|
+
config_options
|
|
147
|
+
GPUEngine configuration options.
|
|
148
|
+
names
|
|
149
|
+
Generator of unique names for temporaries.
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
expr
|
|
154
|
+
Decomposed expression node.
|
|
155
|
+
input_ir
|
|
156
|
+
The rewritten ``input_ir`` to be evaluated by ``expr``.
|
|
157
|
+
partition_info
|
|
158
|
+
A mapping from unique nodes in the new graph to associated
|
|
159
|
+
partitioning information.
|
|
160
|
+
"""
|
|
161
|
+
from cudf_polars.experimental.distinct import lower_distinct
|
|
162
|
+
|
|
163
|
+
(child,) = unique.children
|
|
164
|
+
(maintain_order,) = unique.options
|
|
165
|
+
columns, input_ir, partition_info = select(
|
|
166
|
+
[child],
|
|
167
|
+
input_ir,
|
|
168
|
+
partition_info,
|
|
169
|
+
names=names,
|
|
170
|
+
)
|
|
171
|
+
(column,) = columns
|
|
172
|
+
|
|
173
|
+
assert config_options.executor.name == "streaming", (
|
|
174
|
+
"'in-memory' executor not supported in '_decompose_unique'"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
cardinality: float | None = None
|
|
178
|
+
if cardinality_factor := {
|
|
179
|
+
max(min(v, 1.0), 0.00001)
|
|
180
|
+
for k, v in config_options.executor.cardinality_factor.items()
|
|
181
|
+
if k in _leaf_column_names(child)
|
|
182
|
+
}:
|
|
183
|
+
cardinality = max(cardinality_factor)
|
|
184
|
+
|
|
185
|
+
input_ir, partition_info = lower_distinct(
|
|
186
|
+
Distinct(
|
|
187
|
+
{column.name: column.dtype},
|
|
188
|
+
plc.stream_compaction.DuplicateKeepOption.KEEP_ANY,
|
|
189
|
+
None,
|
|
190
|
+
None,
|
|
191
|
+
maintain_order,
|
|
192
|
+
input_ir,
|
|
193
|
+
),
|
|
194
|
+
input_ir,
|
|
195
|
+
partition_info,
|
|
196
|
+
config_options,
|
|
197
|
+
cardinality=cardinality,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
return column, input_ir, partition_info
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _decompose_agg_node(
|
|
204
|
+
agg: Agg,
|
|
205
|
+
input_ir: IR,
|
|
206
|
+
partition_info: MutableMapping[IR, PartitionInfo],
|
|
207
|
+
config_options: ConfigOptions,
|
|
208
|
+
*,
|
|
209
|
+
names: Generator[str, None, None],
|
|
210
|
+
) -> tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]:
|
|
211
|
+
"""
|
|
212
|
+
Decompose an agg expression into partition-wise stages.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
agg
|
|
217
|
+
The Agg node to decompose.
|
|
218
|
+
input_ir
|
|
219
|
+
The original input-IR node that ``agg`` will evaluate.
|
|
220
|
+
partition_info
|
|
221
|
+
A mapping from all unique IR nodes to the
|
|
222
|
+
associated partitioning information.
|
|
223
|
+
config_options
|
|
224
|
+
GPUEngine configuration options.
|
|
225
|
+
names
|
|
226
|
+
Generator of unique names for temporaries.
|
|
227
|
+
|
|
228
|
+
Returns
|
|
229
|
+
-------
|
|
230
|
+
expr
|
|
231
|
+
Decomposed Agg node.
|
|
232
|
+
input_ir
|
|
233
|
+
The rewritten ``input_ir`` to be evaluated by ``expr``.
|
|
234
|
+
partition_info
|
|
235
|
+
A mapping from unique nodes in the new graph to associated
|
|
236
|
+
partitioning information.
|
|
237
|
+
"""
|
|
238
|
+
expr: Expr
|
|
239
|
+
exprs: list[Expr]
|
|
240
|
+
if agg.name == "count":
|
|
241
|
+
# Chunkwise stage
|
|
242
|
+
columns, input_ir, partition_info = select(
|
|
243
|
+
[agg],
|
|
244
|
+
input_ir,
|
|
245
|
+
partition_info,
|
|
246
|
+
names=names,
|
|
247
|
+
repartition=True,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Combined stage
|
|
251
|
+
(column,) = columns
|
|
252
|
+
columns, input_ir, partition_info = select(
|
|
253
|
+
[Agg(agg.dtype, "sum", None, column)],
|
|
254
|
+
input_ir,
|
|
255
|
+
partition_info,
|
|
256
|
+
names=names,
|
|
257
|
+
)
|
|
258
|
+
(expr,) = columns
|
|
259
|
+
elif agg.name == "mean":
|
|
260
|
+
# Chunkwise stage
|
|
261
|
+
exprs = [
|
|
262
|
+
Agg(agg.dtype, "sum", None, *agg.children),
|
|
263
|
+
Agg(agg.dtype, "count", None, *agg.children),
|
|
264
|
+
]
|
|
265
|
+
columns, input_ir, partition_info = select(
|
|
266
|
+
exprs,
|
|
267
|
+
input_ir,
|
|
268
|
+
partition_info,
|
|
269
|
+
names=names,
|
|
270
|
+
repartition=True,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Combined stage
|
|
274
|
+
exprs = [
|
|
275
|
+
BinOp(
|
|
276
|
+
agg.dtype,
|
|
277
|
+
plc.binaryop.BinaryOperator.DIV,
|
|
278
|
+
*(Agg(agg.dtype, "sum", None, column) for column in columns),
|
|
279
|
+
)
|
|
280
|
+
]
|
|
281
|
+
columns, input_ir, partition_info = select(
|
|
282
|
+
exprs,
|
|
283
|
+
input_ir,
|
|
284
|
+
partition_info,
|
|
285
|
+
names=names,
|
|
286
|
+
repartition=True,
|
|
287
|
+
)
|
|
288
|
+
(expr,) = columns
|
|
289
|
+
elif agg.name == "n_unique":
|
|
290
|
+
# Get uniques and shuffle (if necessary)
|
|
291
|
+
# TODO: Should this be a tree reduction by default?
|
|
292
|
+
(child,) = agg.children
|
|
293
|
+
pi = partition_info[input_ir]
|
|
294
|
+
if pi.count > 1 and [ne.value for ne in pi.partitioned_on] != [input_ir]:
|
|
295
|
+
from cudf_polars.experimental.shuffle import Shuffle
|
|
296
|
+
|
|
297
|
+
children, input_ir, partition_info = select(
|
|
298
|
+
[UnaryFunction(agg.dtype, "unique", (False,), child)],
|
|
299
|
+
input_ir,
|
|
300
|
+
partition_info,
|
|
301
|
+
names=names,
|
|
302
|
+
)
|
|
303
|
+
(child,) = children
|
|
304
|
+
agg = agg.reconstruct([child])
|
|
305
|
+
shuffle_on = (NamedExpr(next(names), child),)
|
|
306
|
+
input_ir = Shuffle(
|
|
307
|
+
input_ir.schema,
|
|
308
|
+
shuffle_on,
|
|
309
|
+
config_options,
|
|
310
|
+
input_ir,
|
|
311
|
+
)
|
|
312
|
+
partition_info[input_ir] = PartitionInfo(
|
|
313
|
+
count=pi.count,
|
|
314
|
+
partitioned_on=shuffle_on,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Chunkwise stage
|
|
318
|
+
columns, input_ir, partition_info = select(
|
|
319
|
+
[Cast(agg.dtype, agg)],
|
|
320
|
+
input_ir,
|
|
321
|
+
partition_info,
|
|
322
|
+
names=names,
|
|
323
|
+
repartition=True,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Combined stage
|
|
327
|
+
(column,) = columns
|
|
328
|
+
columns, input_ir, partition_info = select(
|
|
329
|
+
[Agg(agg.dtype, "sum", None, column)],
|
|
330
|
+
input_ir,
|
|
331
|
+
partition_info,
|
|
332
|
+
names=names,
|
|
333
|
+
)
|
|
334
|
+
(expr,) = columns
|
|
335
|
+
else:
|
|
336
|
+
# Chunkwise stage
|
|
337
|
+
columns, input_ir, partition_info = select(
|
|
338
|
+
[agg],
|
|
339
|
+
input_ir,
|
|
340
|
+
partition_info,
|
|
341
|
+
names=names,
|
|
342
|
+
repartition=True,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Combined stage
|
|
346
|
+
(column,) = columns
|
|
347
|
+
columns, input_ir, partition_info = select(
|
|
348
|
+
[Agg(agg.dtype, agg.name, agg.options, column)],
|
|
349
|
+
input_ir,
|
|
350
|
+
partition_info,
|
|
351
|
+
names=names,
|
|
352
|
+
)
|
|
353
|
+
(expr,) = columns
|
|
354
|
+
|
|
355
|
+
return expr, input_ir, partition_info
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
_SUPPORTED_AGGS = ("count", "min", "max", "sum", "mean", "n_unique")
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _decompose_expr_node(
|
|
362
|
+
expr: Expr,
|
|
363
|
+
input_ir: IR,
|
|
364
|
+
partition_info: MutableMapping[IR, PartitionInfo],
|
|
365
|
+
config_options: ConfigOptions,
|
|
366
|
+
*,
|
|
367
|
+
names: Generator[str, None, None],
|
|
368
|
+
) -> tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]:
|
|
369
|
+
"""
|
|
370
|
+
Decompose an expression into partition-wise stages.
|
|
371
|
+
|
|
372
|
+
Parameters
|
|
373
|
+
----------
|
|
374
|
+
expr
|
|
375
|
+
The Expr node to decompose.
|
|
376
|
+
input_ir
|
|
377
|
+
The input IR node that ``expr`` will evaluate.
|
|
378
|
+
partition_info
|
|
379
|
+
A mapping from all unique IR nodes to the
|
|
380
|
+
associated partitioning information.
|
|
381
|
+
config_options
|
|
382
|
+
GPUEngine configuration options.
|
|
383
|
+
names
|
|
384
|
+
Generator of unique names for temporaries.
|
|
385
|
+
|
|
386
|
+
Returns
|
|
387
|
+
-------
|
|
388
|
+
expr
|
|
389
|
+
Decomposed Expr node.
|
|
390
|
+
input_ir
|
|
391
|
+
The rewritten ``input_ir`` to be evaluated by ``expr``.
|
|
392
|
+
partition_info
|
|
393
|
+
A mapping from unique nodes in the new graph to associated
|
|
394
|
+
partitioning information.
|
|
395
|
+
"""
|
|
396
|
+
if isinstance(expr, Literal):
|
|
397
|
+
# For Literal nodes, we don't actually want an
|
|
398
|
+
# input IR with real columns, because it will
|
|
399
|
+
# mess up the result of ``HConcat``.
|
|
400
|
+
input_ir = Empty()
|
|
401
|
+
partition_info[input_ir] = PartitionInfo(count=1)
|
|
402
|
+
|
|
403
|
+
partition_count = partition_info[input_ir].count
|
|
404
|
+
if partition_count == 1 or expr.is_pointwise:
|
|
405
|
+
# Single-partition and pointwise expressions are always supported.
|
|
406
|
+
return expr, input_ir, partition_info
|
|
407
|
+
elif isinstance(expr, Agg) and expr.name in _SUPPORTED_AGGS:
|
|
408
|
+
# This is a supported Agg expression.
|
|
409
|
+
return _decompose_agg_node(
|
|
410
|
+
expr, input_ir, partition_info, config_options, names=names
|
|
411
|
+
)
|
|
412
|
+
elif isinstance(expr, UnaryFunction) and expr.name == "unique":
|
|
413
|
+
return _decompose_unique(
|
|
414
|
+
expr, input_ir, partition_info, config_options, names=names
|
|
415
|
+
)
|
|
416
|
+
else:
|
|
417
|
+
# This is an un-supported expression - raise.
|
|
418
|
+
raise NotImplementedError(
|
|
419
|
+
f"{type(expr)} not supported for multiple partitions."
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def _decompose(
|
|
424
|
+
expr: Expr, rec: ExprDecomposer
|
|
425
|
+
) -> tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]:
|
|
426
|
+
# Used by `decompose_expr_graph``
|
|
427
|
+
|
|
428
|
+
if not expr.children:
|
|
429
|
+
# Leaf node
|
|
430
|
+
return _decompose_expr_node(
|
|
431
|
+
expr,
|
|
432
|
+
rec.state["input_ir"],
|
|
433
|
+
{rec.state["input_ir"]: rec.state["input_partition_info"]},
|
|
434
|
+
rec.state["config_options"],
|
|
435
|
+
names=rec.state["unique_names"],
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Process child Exprs first
|
|
439
|
+
children, input_irs, _partition_info = zip(
|
|
440
|
+
*(rec(c) for c in expr.children), strict=True
|
|
441
|
+
)
|
|
442
|
+
partition_info = reduce(operator.or_, _partition_info)
|
|
443
|
+
|
|
444
|
+
# Assume the partition count is the maximum input-IR partition count
|
|
445
|
+
input_ir: IR
|
|
446
|
+
assert len(input_irs) > 0 # Must have at least one input IR
|
|
447
|
+
partition_count = max(partition_info[ir].count for ir in input_irs)
|
|
448
|
+
unique_input_irs = [k for k in dict.fromkeys(input_irs) if not isinstance(k, Empty)]
|
|
449
|
+
if len(unique_input_irs) > 1:
|
|
450
|
+
# Need to make sure we only have a single input IR
|
|
451
|
+
# TODO: Check that we aren't concatenating misaligned
|
|
452
|
+
# columns that cannot be broadcasted. For example, what
|
|
453
|
+
# if one of the columns is sorted?
|
|
454
|
+
schema: Schema = {}
|
|
455
|
+
for ir in unique_input_irs:
|
|
456
|
+
schema.update(ir.schema)
|
|
457
|
+
input_ir = HConcat(
|
|
458
|
+
schema,
|
|
459
|
+
True, # noqa: FBT003
|
|
460
|
+
*unique_input_irs,
|
|
461
|
+
)
|
|
462
|
+
partition_info[input_ir] = PartitionInfo(count=partition_count)
|
|
463
|
+
else:
|
|
464
|
+
input_ir = unique_input_irs[0]
|
|
465
|
+
|
|
466
|
+
# Call into class-specific logic to decompose ``expr``
|
|
467
|
+
return _decompose_expr_node(
|
|
468
|
+
expr.reconstruct(children),
|
|
469
|
+
input_ir,
|
|
470
|
+
partition_info,
|
|
471
|
+
rec.state["config_options"],
|
|
472
|
+
names=rec.state["unique_names"],
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def decompose_expr_graph(
|
|
477
|
+
named_expr: NamedExpr,
|
|
478
|
+
input_ir: IR,
|
|
479
|
+
partition_info: MutableMapping[IR, PartitionInfo],
|
|
480
|
+
config_options: ConfigOptions,
|
|
481
|
+
) -> tuple[NamedExpr, IR, MutableMapping[IR, PartitionInfo]]:
|
|
482
|
+
"""
|
|
483
|
+
Decompose a NamedExpr into stages.
|
|
484
|
+
|
|
485
|
+
Parameters
|
|
486
|
+
----------
|
|
487
|
+
named_expr
|
|
488
|
+
The original NamedExpr to decompose.
|
|
489
|
+
input_ir
|
|
490
|
+
The input-IR node that ``named_expr`` will be
|
|
491
|
+
evaluated on.
|
|
492
|
+
partition_info
|
|
493
|
+
A mapping from all unique IR nodes to the
|
|
494
|
+
associated partitioning information.
|
|
495
|
+
config_options
|
|
496
|
+
GPUEngine configuration options.
|
|
497
|
+
|
|
498
|
+
Returns
|
|
499
|
+
-------
|
|
500
|
+
named_expr
|
|
501
|
+
Decomposed NamedExpr object.
|
|
502
|
+
input_ir
|
|
503
|
+
The rewritten ``input_ir`` to be evaluated by ``named_expr``.
|
|
504
|
+
partition_info
|
|
505
|
+
A mapping from unique nodes in the new graph to associated
|
|
506
|
+
partitioning information.
|
|
507
|
+
|
|
508
|
+
Notes
|
|
509
|
+
-----
|
|
510
|
+
This function recursively decomposes ``named_expr.value`` and
|
|
511
|
+
``input_ir`` into multiple partition-wise stages.
|
|
512
|
+
"""
|
|
513
|
+
state = {
|
|
514
|
+
"input_ir": input_ir,
|
|
515
|
+
"input_partition_info": partition_info[input_ir],
|
|
516
|
+
"config_options": config_options,
|
|
517
|
+
"unique_names": unique_names((named_expr.name, *input_ir.schema.keys())),
|
|
518
|
+
}
|
|
519
|
+
mapper = CachingVisitor(_decompose, state=state)
|
|
520
|
+
expr, input_ir, partition_info = mapper(named_expr.value)
|
|
521
|
+
return named_expr.reconstruct(expr), input_ir, partition_info
|