cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,590 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """
4
+ Multi-partition Expr classes and utilities.
5
+
6
+ This module includes the necessary functionality to
7
+ decompose a non-pointwise expression graph into stages
8
+ that can each be mapped onto a simple partition-wise
9
+ task graph at execution time.
10
+
11
+ For example, if ``Select.exprs`` contains an ``expr.Agg``
12
+ node, ``decompose_expr_graph`` will decompose the complex
13
+ NamedExpr node into a sequence of three new IR nodes::
14
+
15
+ - Select: Partition-wise aggregation logic.
16
+ - Repartition: Concatenate the results of each partition.
17
+ - Select: Final aggregation on the combined results.
18
+
19
+ In this example, the Select stages are mapped onto a simple
20
+ partition-wise task graph at execution time, and the Repartition
21
+ stage is used to capture the data-movement required for a global
22
+ aggregation. At the moment, data movement is always introduced
23
+ by either repartitioning or shuffling the data.
24
+
25
+ Since we are introducing intermediate IR nodes, we are also
26
+ introducing a temporary column for each intermediate result.
27
+ In order to avoid column-name collisions with the original
28
+ input-IR node, we generate unique names for temporary columns
29
+ and concatenate them to the input-IR node using ``HConcat``.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import operator
35
+ from functools import reduce
36
+ from typing import TYPE_CHECKING, TypeAlias, TypedDict
37
+
38
+ import pylibcudf as plc
39
+
40
+ from cudf_polars.dsl.expressions.aggregation import Agg
41
+ from cudf_polars.dsl.expressions.base import Col, Expr, NamedExpr
42
+ from cudf_polars.dsl.expressions.binaryop import BinOp
43
+ from cudf_polars.dsl.expressions.literal import Literal
44
+ from cudf_polars.dsl.expressions.unary import Cast, UnaryFunction
45
+ from cudf_polars.dsl.ir import IR, Distinct, Empty, HConcat, Select
46
+ from cudf_polars.dsl.traversal import (
47
+ CachingVisitor,
48
+ )
49
+ from cudf_polars.dsl.utils.naming import unique_names
50
+ from cudf_polars.experimental.base import PartitionInfo
51
+ from cudf_polars.experimental.repartition import Repartition
52
+ from cudf_polars.experimental.utils import _get_unique_fractions, _leaf_column_names
53
+
54
+ if TYPE_CHECKING:
55
+ from collections.abc import Generator, MutableMapping, Sequence
56
+
57
+ from cudf_polars.dsl.expressions.base import Expr
58
+ from cudf_polars.dsl.ir import IR
59
+ from cudf_polars.experimental.base import ColumnStat, ColumnStats
60
+ from cudf_polars.typing import GenericTransformer, Schema
61
+ from cudf_polars.utils.config import ConfigOptions
62
+
63
+
64
+ class State(TypedDict):
65
+ """
66
+ State for decomposing expressions.
67
+
68
+ Parameters
69
+ ----------
70
+ input_ir
71
+ IR of the input expression.
72
+ input_partition_info
73
+ Partition info of the input expression.
74
+ config_options
75
+ GPUEngine configuration options.
76
+ unique_names
77
+ Generator of unique names for temporaries.
78
+ row_count_estimate
79
+ row-count estimate for the input IR.
80
+ column_stats
81
+ Column statistics for the input IR.
82
+ """
83
+
84
+ input_ir: IR
85
+ input_partition_info: PartitionInfo
86
+ config_options: ConfigOptions
87
+ unique_names: Generator[str, None, None]
88
+ row_count_estimate: ColumnStat[int]
89
+ column_stats: dict[str, ColumnStats]
90
+
91
+
92
+ ExprDecomposer: TypeAlias = "GenericTransformer[Expr, tuple[Expr, IR, MutableMapping[IR, PartitionInfo]], State]"
93
+ """Protocol for decomposing expressions."""
94
+
95
+
96
+ def select(
97
+ exprs: Sequence[Expr],
98
+ input_ir: IR,
99
+ partition_info: MutableMapping[IR, PartitionInfo],
100
+ *,
101
+ names: Generator[str, None, None],
102
+ repartition: bool = False,
103
+ ) -> tuple[list[Col], IR, MutableMapping[IR, PartitionInfo]]:
104
+ """
105
+ Select expressions from an IR node, introducing temporaries.
106
+
107
+ Parameters
108
+ ----------
109
+ exprs
110
+ Expressions to select.
111
+ input_ir
112
+ The input IR node to select from.
113
+ partition_info
114
+ A mapping from all unique IR nodes to the
115
+ associated partitioning information.
116
+ names
117
+ Generator of unique names for temporaries.
118
+ repartition
119
+ Whether to add a Repartition node after the
120
+ new selection.
121
+
122
+ Returns
123
+ -------
124
+ columns
125
+ Expressions to select from the new IR output.
126
+ new_ir
127
+ The new IR node that will introduce temporaries.
128
+ partition_info
129
+ A mapping from unique nodes in the new graph to associated
130
+ partitioning information.
131
+ """
132
+ output_names = [next(names) for _ in range(len(exprs))]
133
+ named_exprs = [
134
+ NamedExpr(name, expr) for name, expr in zip(output_names, exprs, strict=True)
135
+ ]
136
+ new_ir: IR = Select(
137
+ {ne.name: ne.value.dtype for ne in named_exprs},
138
+ named_exprs,
139
+ True, # noqa: FBT003
140
+ input_ir,
141
+ )
142
+ partition_info[new_ir] = PartitionInfo(count=partition_info[input_ir].count)
143
+
144
+ # Optionally collapse into one output partition
145
+ if repartition:
146
+ new_ir = Repartition(new_ir.schema, new_ir)
147
+ partition_info[new_ir] = PartitionInfo(count=1)
148
+
149
+ columns = [Col(ne.value.dtype, ne.name) for ne in named_exprs]
150
+ return columns, new_ir, partition_info
151
+
152
+
153
+ def _decompose_unique(
154
+ unique: UnaryFunction,
155
+ input_ir: IR,
156
+ partition_info: MutableMapping[IR, PartitionInfo],
157
+ config_options: ConfigOptions,
158
+ row_count_estimate: ColumnStat[int],
159
+ column_stats: dict[str, ColumnStats],
160
+ *,
161
+ names: Generator[str, None, None],
162
+ ) -> tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]:
163
+ """
164
+ Decompose a 'unique' UnaryFunction into partition-wise stages.
165
+
166
+ Parameters
167
+ ----------
168
+ unique
169
+ The expression node to decompose.
170
+ input_ir
171
+ The original input-IR node that ``unique`` will evaluate.
172
+ partition_info
173
+ A mapping from all unique IR nodes to the
174
+ associated partitioning information.
175
+ config_options
176
+ GPUEngine configuration options.
177
+ row_count_estimate
178
+ row-count estimate for the input IR.
179
+ column_stats
180
+ Column statistics for the input IR.
181
+ names
182
+ Generator of unique names for temporaries.
183
+
184
+ Returns
185
+ -------
186
+ expr
187
+ Decomposed expression node.
188
+ input_ir
189
+ The rewritten ``input_ir`` to be evaluated by ``expr``.
190
+ partition_info
191
+ A mapping from unique nodes in the new graph to associated
192
+ partitioning information.
193
+ """
194
+ from cudf_polars.experimental.distinct import lower_distinct
195
+
196
+ (child,) = unique.children
197
+ (maintain_order,) = unique.options
198
+ columns, input_ir, partition_info = select(
199
+ [child],
200
+ input_ir,
201
+ partition_info,
202
+ names=names,
203
+ )
204
+ (column,) = columns
205
+
206
+ assert config_options.executor.name == "streaming", (
207
+ "'in-memory' executor not supported in '_decompose_unique'"
208
+ )
209
+
210
+ unique_fraction_dict = _get_unique_fractions(
211
+ _leaf_column_names(child),
212
+ config_options.executor.unique_fraction,
213
+ row_count=row_count_estimate,
214
+ column_stats=column_stats,
215
+ )
216
+
217
+ unique_fraction = (
218
+ max(unique_fraction_dict.values()) if unique_fraction_dict else None
219
+ )
220
+
221
+ input_ir, partition_info = lower_distinct(
222
+ Distinct(
223
+ {column.name: column.dtype},
224
+ plc.stream_compaction.DuplicateKeepOption.KEEP_ANY,
225
+ None,
226
+ None,
227
+ maintain_order,
228
+ input_ir,
229
+ ),
230
+ input_ir,
231
+ partition_info,
232
+ config_options,
233
+ unique_fraction=unique_fraction,
234
+ )
235
+
236
+ return column, input_ir, partition_info
237
+
238
+
239
+ def _decompose_agg_node(
240
+ agg: Agg,
241
+ input_ir: IR,
242
+ partition_info: MutableMapping[IR, PartitionInfo],
243
+ config_options: ConfigOptions,
244
+ *,
245
+ names: Generator[str, None, None],
246
+ ) -> tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]:
247
+ """
248
+ Decompose an agg expression into partition-wise stages.
249
+
250
+ Parameters
251
+ ----------
252
+ agg
253
+ The Agg node to decompose.
254
+ input_ir
255
+ The original input-IR node that ``agg`` will evaluate.
256
+ partition_info
257
+ A mapping from all unique IR nodes to the
258
+ associated partitioning information.
259
+ config_options
260
+ GPUEngine configuration options.
261
+ names
262
+ Generator of unique names for temporaries.
263
+
264
+ Returns
265
+ -------
266
+ expr
267
+ Decomposed Agg node.
268
+ input_ir
269
+ The rewritten ``input_ir`` to be evaluated by ``expr``.
270
+ partition_info
271
+ A mapping from unique nodes in the new graph to associated
272
+ partitioning information.
273
+ """
274
+ expr: Expr
275
+ exprs: list[Expr]
276
+ if agg.name == "count":
277
+ # Chunkwise stage
278
+ columns, input_ir, partition_info = select(
279
+ [agg],
280
+ input_ir,
281
+ partition_info,
282
+ names=names,
283
+ repartition=True,
284
+ )
285
+
286
+ # Combined stage
287
+ (column,) = columns
288
+ columns, input_ir, partition_info = select(
289
+ [Agg(agg.dtype, "sum", None, column)],
290
+ input_ir,
291
+ partition_info,
292
+ names=names,
293
+ )
294
+ (expr,) = columns
295
+ elif agg.name == "mean":
296
+ # Chunkwise stage
297
+ exprs = [
298
+ Agg(agg.dtype, "sum", None, *agg.children),
299
+ Agg(agg.dtype, "count", None, *agg.children),
300
+ ]
301
+ columns, input_ir, partition_info = select(
302
+ exprs,
303
+ input_ir,
304
+ partition_info,
305
+ names=names,
306
+ repartition=True,
307
+ )
308
+
309
+ # Combined stage
310
+ exprs = [
311
+ BinOp(
312
+ agg.dtype,
313
+ plc.binaryop.BinaryOperator.DIV,
314
+ *(Agg(agg.dtype, "sum", None, column) for column in columns),
315
+ )
316
+ ]
317
+ columns, input_ir, partition_info = select(
318
+ exprs,
319
+ input_ir,
320
+ partition_info,
321
+ names=names,
322
+ repartition=True,
323
+ )
324
+ (expr,) = columns
325
+ elif agg.name == "n_unique":
326
+ # Get uniques and shuffle (if necessary)
327
+ # TODO: Should this be a tree reduction by default?
328
+ (child,) = agg.children
329
+ pi = partition_info[input_ir]
330
+ if pi.count > 1 and [ne.value for ne in pi.partitioned_on] != [input_ir]:
331
+ from cudf_polars.experimental.shuffle import Shuffle
332
+
333
+ children, input_ir, partition_info = select(
334
+ [UnaryFunction(agg.dtype, "unique", (False,), child)],
335
+ input_ir,
336
+ partition_info,
337
+ names=names,
338
+ )
339
+ (child,) = children
340
+ agg = agg.reconstruct([child])
341
+ shuffle_on = (NamedExpr(next(names), child),)
342
+
343
+ assert config_options.executor.name == "streaming", (
344
+ "'in-memory' executor not supported in '_decompose_agg_node'"
345
+ )
346
+
347
+ input_ir = Shuffle(
348
+ input_ir.schema,
349
+ shuffle_on,
350
+ config_options.executor.shuffle_method,
351
+ input_ir,
352
+ )
353
+ partition_info[input_ir] = PartitionInfo(
354
+ count=pi.count,
355
+ partitioned_on=shuffle_on,
356
+ )
357
+
358
+ # Chunkwise stage
359
+ columns, input_ir, partition_info = select(
360
+ [Cast(agg.dtype, agg)],
361
+ input_ir,
362
+ partition_info,
363
+ names=names,
364
+ repartition=True,
365
+ )
366
+
367
+ # Combined stage
368
+ (column,) = columns
369
+ columns, input_ir, partition_info = select(
370
+ [Agg(agg.dtype, "sum", None, column)],
371
+ input_ir,
372
+ partition_info,
373
+ names=names,
374
+ )
375
+ (expr,) = columns
376
+ else:
377
+ # Chunkwise stage
378
+ columns, input_ir, partition_info = select(
379
+ [agg],
380
+ input_ir,
381
+ partition_info,
382
+ names=names,
383
+ repartition=True,
384
+ )
385
+
386
+ # Combined stage
387
+ (column,) = columns
388
+ columns, input_ir, partition_info = select(
389
+ [Agg(agg.dtype, agg.name, agg.options, column)],
390
+ input_ir,
391
+ partition_info,
392
+ names=names,
393
+ )
394
+ (expr,) = columns
395
+
396
+ return expr, input_ir, partition_info
397
+
398
+
399
+ _SUPPORTED_AGGS = ("count", "min", "max", "sum", "mean", "n_unique")
400
+
401
+
402
+ def _decompose_expr_node(
403
+ expr: Expr,
404
+ input_ir: IR,
405
+ partition_info: MutableMapping[IR, PartitionInfo],
406
+ config_options: ConfigOptions,
407
+ row_count_estimate: ColumnStat[int],
408
+ column_stats: dict[str, ColumnStats],
409
+ *,
410
+ names: Generator[str, None, None],
411
+ ) -> tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]:
412
+ """
413
+ Decompose an expression into partition-wise stages.
414
+
415
+ Parameters
416
+ ----------
417
+ expr
418
+ The Expr node to decompose.
419
+ input_ir
420
+ The input IR node that ``expr`` will evaluate.
421
+ partition_info
422
+ A mapping from all unique IR nodes to the
423
+ associated partitioning information.
424
+ config_options
425
+ GPUEngine configuration options.
426
+ row_count_estimate
427
+ row-count estimate for the input IR.
428
+ column_stats
429
+ Column statistics for the input IR.
430
+ names
431
+ Generator of unique names for temporaries.
432
+
433
+ Returns
434
+ -------
435
+ expr
436
+ Decomposed Expr node.
437
+ input_ir
438
+ The rewritten ``input_ir`` to be evaluated by ``expr``.
439
+ partition_info
440
+ A mapping from unique nodes in the new graph to associated
441
+ partitioning information.
442
+ """
443
+ if isinstance(expr, Literal):
444
+ # For Literal nodes, we don't actually want an
445
+ # input IR with real columns, because it will
446
+ # mess up the result of ``HConcat``.
447
+ input_ir = Empty({})
448
+ partition_info[input_ir] = PartitionInfo(count=1)
449
+
450
+ partition_count = partition_info[input_ir].count
451
+ if partition_count == 1 or expr.is_pointwise:
452
+ # Single-partition and pointwise expressions are always supported.
453
+ return expr, input_ir, partition_info
454
+ elif isinstance(expr, Agg) and expr.name in _SUPPORTED_AGGS:
455
+ # This is a supported Agg expression.
456
+ return _decompose_agg_node(
457
+ expr, input_ir, partition_info, config_options, names=names
458
+ )
459
+ elif isinstance(expr, UnaryFunction) and expr.name == "unique":
460
+ return _decompose_unique(
461
+ expr,
462
+ input_ir,
463
+ partition_info,
464
+ config_options,
465
+ row_count_estimate,
466
+ column_stats,
467
+ names=names,
468
+ )
469
+ else:
470
+ # This is an un-supported expression - raise.
471
+ raise NotImplementedError(
472
+ f"{type(expr)} not supported for multiple partitions."
473
+ )
474
+
475
+
476
+ def _decompose(
477
+ expr: Expr, rec: ExprDecomposer
478
+ ) -> tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]:
479
+ # Used by `decompose_expr_graph``
480
+
481
+ if not expr.children:
482
+ # Leaf node
483
+ return _decompose_expr_node(
484
+ expr,
485
+ rec.state["input_ir"],
486
+ {rec.state["input_ir"]: rec.state["input_partition_info"]},
487
+ rec.state["config_options"],
488
+ rec.state["row_count_estimate"],
489
+ rec.state["column_stats"],
490
+ names=rec.state["unique_names"],
491
+ )
492
+
493
+ # Process child Exprs first
494
+ children, input_irs, _partition_info = zip(
495
+ *(rec(c) for c in expr.children), strict=True
496
+ )
497
+ partition_info = reduce(operator.or_, _partition_info)
498
+
499
+ # Assume the partition count is the maximum input-IR partition count
500
+ input_ir: IR
501
+ assert len(input_irs) > 0 # Must have at least one input IR
502
+ partition_count = max(partition_info[ir].count for ir in input_irs)
503
+ unique_input_irs = [k for k in dict.fromkeys(input_irs) if not isinstance(k, Empty)]
504
+ if len(unique_input_irs) > 1:
505
+ # Need to make sure we only have a single input IR
506
+ # TODO: Check that we aren't concatenating misaligned
507
+ # columns that cannot be broadcasted. For example, what
508
+ # if one of the columns is sorted?
509
+ schema: Schema = {}
510
+ for ir in unique_input_irs:
511
+ schema.update(ir.schema)
512
+ input_ir = HConcat(
513
+ schema,
514
+ True, # noqa: FBT003
515
+ *unique_input_irs,
516
+ )
517
+ partition_info[input_ir] = PartitionInfo(count=partition_count)
518
+ else:
519
+ input_ir = unique_input_irs[0]
520
+
521
+ # Call into class-specific logic to decompose ``expr``
522
+ return _decompose_expr_node(
523
+ expr.reconstruct(children),
524
+ input_ir,
525
+ partition_info,
526
+ rec.state["config_options"],
527
+ rec.state["row_count_estimate"],
528
+ rec.state["column_stats"],
529
+ names=rec.state["unique_names"],
530
+ )
531
+
532
+
533
+ def decompose_expr_graph(
534
+ named_expr: NamedExpr,
535
+ input_ir: IR,
536
+ partition_info: MutableMapping[IR, PartitionInfo],
537
+ config_options: ConfigOptions,
538
+ row_count_estimate: ColumnStat[int],
539
+ column_stats: dict[str, ColumnStats],
540
+ ) -> tuple[NamedExpr, IR, MutableMapping[IR, PartitionInfo]]:
541
+ """
542
+ Decompose a NamedExpr into stages.
543
+
544
+ Parameters
545
+ ----------
546
+ named_expr
547
+ The original NamedExpr to decompose.
548
+ input_ir
549
+ The input-IR node that ``named_expr`` will be
550
+ evaluated on.
551
+ partition_info
552
+ A mapping from all unique IR nodes to the
553
+ associated partitioning information.
554
+ config_options
555
+ GPUEngine configuration options.
556
+ row_count_estimate
557
+ Row-count estimate for the input IR.
558
+ column_stats
559
+ Column statistics for the input IR.
560
+
561
+ Returns
562
+ -------
563
+ named_expr
564
+ Decomposed NamedExpr object.
565
+ input_ir
566
+ The rewritten ``input_ir`` to be evaluated by ``named_expr``.
567
+ partition_info
568
+ A mapping from unique nodes in the new graph to associated
569
+ partitioning information.
570
+
571
+ Notes
572
+ -----
573
+ This function recursively decomposes ``named_expr.value`` and
574
+ ``input_ir`` into multiple partition-wise stages.
575
+
576
+ The state dictionary is an instance of :class:`State`.
577
+ """
578
+ mapper: ExprDecomposer = CachingVisitor(
579
+ _decompose,
580
+ state={
581
+ "input_ir": input_ir,
582
+ "input_partition_info": partition_info[input_ir],
583
+ "config_options": config_options,
584
+ "unique_names": unique_names((named_expr.name, *input_ir.schema.keys())),
585
+ "row_count_estimate": row_count_estimate,
586
+ "column_stats": column_stats,
587
+ },
588
+ )
589
+ expr, input_ir, partition_info = mapper(named_expr.value)
590
+ return named_expr.reconstruct(expr), input_ir, partition_info