cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,224 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Traversal and visitor utilities for nodes."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from collections import deque
9
+ from typing import TYPE_CHECKING, Generic
10
+
11
+ from cudf_polars.typing import (
12
+ StateT_co,
13
+ U_contra,
14
+ V_co,
15
+ )
16
+
17
+ if TYPE_CHECKING:
18
+ from collections.abc import Callable, Generator, MutableMapping, Sequence
19
+
20
+ from cudf_polars.typing import GenericTransformer, NodeT
21
+
22
+
23
+ __all__: list[str] = [
24
+ "CachingVisitor",
25
+ "make_recursive",
26
+ "reuse_if_unchanged",
27
+ "traversal",
28
+ ]
29
+
30
+
31
+ def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
32
+ """
33
+ Pre-order traversal of nodes in an expression.
34
+
35
+ Parameters
36
+ ----------
37
+ nodes
38
+ Roots of expressions to traverse.
39
+
40
+ Yields
41
+ ------
42
+ Unique nodes in the expressions, parent before child, children
43
+ in-order from left to right.
44
+ """
45
+ seen: set[NodeT] = set()
46
+ lifo: deque[NodeT] = deque()
47
+
48
+ for node in nodes:
49
+ if node not in seen:
50
+ lifo.append(node)
51
+ seen.add(node)
52
+
53
+ while lifo:
54
+ node = lifo.pop()
55
+ yield node
56
+ for child in reversed(node.children):
57
+ if child not in seen:
58
+ seen.add(child)
59
+ lifo.append(child)
60
+
61
+
62
+ def post_traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
63
+ """
64
+ Post-order traversal of nodes in an expression.
65
+
66
+ Parameters
67
+ ----------
68
+ nodes
69
+ Roots of expressions to traverse.
70
+
71
+ Yields
72
+ ------
73
+ Unique nodes in the expressions, child before parent, children
74
+ in-order from left to right.
75
+ """
76
+ seen: set[NodeT] = set()
77
+ lifo: deque[NodeT] = deque()
78
+
79
+ for node in nodes:
80
+ if node not in seen:
81
+ lifo.append(node)
82
+ seen.add(node)
83
+
84
+ while lifo:
85
+ node = lifo[-1]
86
+ for child in node.children:
87
+ if child not in seen:
88
+ lifo.append(child)
89
+ seen.add(child)
90
+ break
91
+ else:
92
+ yield node
93
+ lifo.pop()
94
+
95
+
96
+ def reuse_if_unchanged(
97
+ node: NodeT, fn: GenericTransformer[NodeT, NodeT, StateT_co]
98
+ ) -> NodeT:
99
+ """
100
+ Recipe for transforming nodes that returns the old object if unchanged.
101
+
102
+ Parameters
103
+ ----------
104
+ node
105
+ Node to recurse on
106
+ fn
107
+ Function to transform children
108
+
109
+ Notes
110
+ -----
111
+ This can be used as a generic "base case" handler when
112
+ writing transforms that take nodes and produce new nodes.
113
+
114
+ Returns
115
+ -------
116
+ Existing node `e` if transformed children are unchanged, otherwise
117
+ reconstructed node with new children.
118
+ """
119
+ new_children = [fn(c) for c in node.children]
120
+ if all(new == old for new, old in zip(new_children, node.children, strict=True)):
121
+ return node
122
+ return node.reconstruct(new_children)
123
+
124
+
125
+ def make_recursive(
126
+ fn: Callable[[U_contra, GenericTransformer[U_contra, V_co, StateT_co]], V_co],
127
+ *,
128
+ # make_recursive is a type constructor with covariant state parameter
129
+ # not a normal function for which the parameter would be contravariant
130
+ # hence the type ignore
131
+ state: StateT_co, # type: ignore[misc]
132
+ ) -> GenericTransformer[U_contra, V_co, StateT_co]:
133
+ """
134
+ No-op wrapper for recursive visitors.
135
+
136
+ Facilitates using visitors that don't need caching but are written
137
+ in the same style.
138
+
139
+ Parameters
140
+ ----------
141
+ fn
142
+ Function to transform inputs to outputs. Should take as its
143
+ second argument a callable from input to output.
144
+ state
145
+ Arbitrary *immutable* state that should be accessible to the
146
+ visitor through the `state` property.
147
+
148
+ Notes
149
+ -----
150
+ All transformation functions *must* be free of side-effects.
151
+
152
+ Usually, prefer a :class:`CachingVisitor`, but if we know that we
153
+ don't need caching in a transformation and then this no-op
154
+ approach is slightly cheaper.
155
+
156
+ Returns
157
+ -------
158
+ Recursive function without caching.
159
+
160
+ See Also
161
+ --------
162
+ CachingVisitor
163
+ """
164
+
165
+ def rec(node: U_contra) -> V_co:
166
+ return fn(node, rec) # type: ignore[arg-type]
167
+
168
+ rec.state = state # type: ignore[attr-defined]
169
+ return rec # type: ignore[return-value]
170
+
171
+
172
+ class CachingVisitor(Generic[U_contra, V_co, StateT_co]):
173
+ """
174
+ Caching wrapper for recursive visitors.
175
+
176
+ Facilitates writing visitors where already computed results should
177
+ be cached and reused. The cache is managed automatically, and is
178
+ tied to the lifetime of the wrapper.
179
+
180
+ Parameters
181
+ ----------
182
+ fn
183
+ Function to transform inputs to outputs. Should take as its
184
+ second argument the recursive cache manager.
185
+ state
186
+ Arbitrary *immutable* state that should be accessible to the
187
+ visitor through the `state` property.
188
+
189
+ Notes
190
+ -----
191
+ All transformation functions *must* be free of side-effects.
192
+
193
+ Returns
194
+ -------
195
+ Recursive function with caching.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ fn: Callable[[U_contra, GenericTransformer[U_contra, V_co, StateT_co]], V_co],
201
+ *,
202
+ state: StateT_co,
203
+ ) -> None:
204
+ self.fn = fn
205
+ self.cache: MutableMapping[U_contra, V_co] = {}
206
+ self.state = state
207
+
208
+ def __call__(self, value: U_contra) -> V_co:
209
+ """
210
+ Apply the function to a value.
211
+
212
+ Parameters
213
+ ----------
214
+ value
215
+ The value to transform.
216
+
217
+ Returns
218
+ -------
219
+ A transformed value.
220
+ """
221
+ try:
222
+ return self.cache[value]
223
+ except KeyError:
224
+ return self.cache.setdefault(value, self.fn(value, self))
@@ -0,0 +1,8 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """DSL utilities."""
5
+
6
+ from __future__ import annotations
7
+
8
+ __all__: list[str] = []
@@ -0,0 +1,481 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Utilities for rewriting aggregations."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import itertools
9
+ from functools import partial
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ import polars as pl
13
+
14
+ import pylibcudf as plc
15
+
16
+ from cudf_polars.containers import DataType
17
+ from cudf_polars.dsl import expr, ir
18
+ from cudf_polars.dsl.expressions.base import ExecutionContext
19
+ from cudf_polars.utils.versions import POLARS_VERSION_LT_1323
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import Callable, Generator, Iterable, Sequence
23
+
24
+ from cudf_polars.typing import Schema
25
+
26
+ __all__ = ["apply_pre_evaluation", "decompose_aggs", "decompose_single_agg"]
27
+
28
+
29
+ def replace_nulls(col: expr.Expr, value: Any, *, is_top: bool) -> expr.Expr:
30
+ """
31
+ Replace nulls with the given scalar if at top level.
32
+
33
+ Parameters
34
+ ----------
35
+ col
36
+ Expression to replace nulls in.
37
+ value
38
+ Scalar replacement
39
+ is_top
40
+ Is this top-level (should replacement be performed).
41
+
42
+ Returns
43
+ -------
44
+ Massaged expression.
45
+ """
46
+ if not is_top:
47
+ return col
48
+ return expr.UnaryFunction(
49
+ col.dtype, "fill_null", (), col, expr.Literal(col.dtype, value)
50
+ )
51
+
52
+
53
+ def decompose_single_agg(
54
+ named_expr: expr.NamedExpr,
55
+ name_generator: Generator[str, None, None],
56
+ *,
57
+ is_top: bool,
58
+ context: ExecutionContext,
59
+ ) -> tuple[list[tuple[expr.NamedExpr, bool]], expr.NamedExpr]:
60
+ """
61
+ Decompose a single named aggregation.
62
+
63
+ Parameters
64
+ ----------
65
+ named_expr
66
+ The named aggregation to decompose
67
+ name_generator
68
+ Generator of unique names for temporaries introduced during decomposition.
69
+ is_top
70
+ Is this the top of an aggregation expression?
71
+ context
72
+ ExecutionContext in which the aggregation will run.
73
+
74
+ Returns
75
+ -------
76
+ aggregations
77
+ Pairs of expressions to apply as grouped aggregations (whose children
78
+ may be evaluated pointwise) and flags indicating if the
79
+ expression contained nested aggregations.
80
+ post_aggregate
81
+ Single expression to apply to post-process the grouped
82
+ aggregations.
83
+
84
+ Raises
85
+ ------
86
+ NotImplementedError
87
+ If the expression contains nested aggregations or unsupported
88
+ operations in a grouped aggregation context.
89
+ """
90
+ agg = named_expr.value
91
+ name = named_expr.name
92
+ if isinstance(agg, expr.UnaryFunction) and agg.name in {
93
+ "rank",
94
+ }:
95
+ if context != ExecutionContext.WINDOW:
96
+ raise NotImplementedError(
97
+ f"{agg.name} is not supported in groupby or rolling context"
98
+ )
99
+ # Ensure Polars semantics for dtype:
100
+ # - average -> Float64
101
+ # - min/max/dense/ordinal -> IDX_DTYPE (UInt32/UInt64)
102
+ post_col: expr.Expr = expr.Col(agg.dtype, name)
103
+ if agg.name == "rank":
104
+ post_col = expr.Cast(agg.dtype, post_col)
105
+
106
+ return [(named_expr, True)], named_expr.reconstruct(post_col)
107
+ if isinstance(agg, expr.UnaryFunction) and agg.name == "null_count":
108
+ (child,) = agg.children
109
+
110
+ is_null_bool = expr.BooleanFunction(
111
+ DataType(pl.Boolean()),
112
+ expr.BooleanFunction.Name.IsNull,
113
+ (),
114
+ child,
115
+ )
116
+ u32 = DataType(pl.UInt32())
117
+ sum_name = next(name_generator)
118
+ sum_agg = expr.NamedExpr(
119
+ sum_name,
120
+ expr.Agg(u32, "sum", (), expr.Cast(u32, is_null_bool)),
121
+ )
122
+ return [(sum_agg, True)], named_expr.reconstruct(
123
+ expr.Cast(u32, expr.Col(u32, sum_name))
124
+ )
125
+ if isinstance(agg, expr.Col):
126
+ # TODO: collect_list produces null for empty group in libcudf, empty list in polars.
127
+ # But we need the nested value type, so need to track proper dtypes in our DSL.
128
+ return [(named_expr, False)], named_expr.reconstruct(expr.Col(agg.dtype, name))
129
+ if is_top and isinstance(agg, expr.Cast) and isinstance(agg.children[0], expr.Len):
130
+ # Special case to fill nulls with zeros for empty group length calculations
131
+ (child,) = agg.children
132
+ child_agg, post = decompose_single_agg(
133
+ expr.NamedExpr(next(name_generator), child),
134
+ name_generator,
135
+ is_top=True,
136
+ context=context,
137
+ )
138
+ return child_agg, named_expr.reconstruct(
139
+ replace_nulls(
140
+ agg.reconstruct([post.value]),
141
+ 0,
142
+ is_top=True,
143
+ )
144
+ )
145
+ if isinstance(agg, expr.Len):
146
+ return [(named_expr, True)], named_expr.reconstruct(expr.Col(agg.dtype, name))
147
+ if isinstance(agg, (expr.Literal, expr.LiteralColumn)):
148
+ return [], named_expr
149
+ if (
150
+ is_top
151
+ and isinstance(agg, expr.UnaryFunction)
152
+ and agg.name == "fill_null_with_strategy"
153
+ ):
154
+ strategy, _ = agg.options
155
+ raise NotImplementedError(
156
+ f"fill_null_with_strategy({strategy!r}) is not supported in groupby aggregations"
157
+ )
158
+ if isinstance(agg, expr.Agg):
159
+ if agg.name == "quantile":
160
+ # Second child the requested quantile (which is asserted
161
+ # to be a literal on construction)
162
+ child = agg.children[0]
163
+ else:
164
+ (child,) = agg.children
165
+ needs_masking = agg.name in {"min", "max"} and plc.traits.is_floating_point(
166
+ child.dtype.plc
167
+ )
168
+ if needs_masking and agg.options:
169
+ # pl.col("a").nan_max or nan_min
170
+ raise NotImplementedError("Nan propagation in groupby for min/max")
171
+ aggs, _ = decompose_single_agg(
172
+ expr.NamedExpr(next(name_generator), child),
173
+ name_generator,
174
+ is_top=False,
175
+ context=context,
176
+ )
177
+ if any(has_agg for _, has_agg in aggs):
178
+ raise NotImplementedError("Nested aggs in groupby not supported")
179
+
180
+ child_dtype = child.dtype.plc
181
+ req = agg.agg_request
182
+ is_median = agg.name == "median"
183
+ is_quantile = agg.name == "quantile"
184
+
185
+ # quantile agg on decimal: unsupported -> keep dtype Decimal
186
+ # mean/median on decimal: Polars returns float -> pre-cast
187
+ decimal_unsupported = False
188
+ if plc.traits.is_fixed_point(child_dtype):
189
+ if is_quantile:
190
+ decimal_unsupported = True
191
+ elif agg.name in {"mean", "median"}:
192
+ tid = agg.dtype.plc.id()
193
+ if tid in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}:
194
+ cast_to = (
195
+ DataType(pl.Float64)
196
+ if tid == plc.TypeId.FLOAT64
197
+ else DataType(pl.Float32)
198
+ )
199
+ child = expr.Cast(cast_to, child)
200
+ child_dtype = child.dtype.plc
201
+
202
+ is_group_quantile_supported = plc.traits.is_integral(
203
+ child_dtype
204
+ ) or plc.traits.is_floating_point(child_dtype)
205
+
206
+ unsupported = (
207
+ decimal_unsupported
208
+ or ((is_median or is_quantile) and not is_group_quantile_supported)
209
+ ) or (not plc.aggregation.is_valid_aggregation(child_dtype, req))
210
+ if unsupported:
211
+ return [], named_expr.reconstruct(expr.Literal(child.dtype, None))
212
+ if needs_masking:
213
+ child = expr.UnaryFunction(child.dtype, "mask_nans", (), child)
214
+ # The aggregation is just reconstructed with the new
215
+ # (potentially masked) child. This is safe because we recursed
216
+ # to ensure there are no nested aggregations.
217
+
218
+ # rebuild the agg with the transformed child
219
+ new_children = [child] if not is_quantile else [child, agg.children[1]]
220
+ named_expr = named_expr.reconstruct(agg.reconstruct(new_children))
221
+
222
+ if agg.name == "sum":
223
+ col = (
224
+ expr.Cast(agg.dtype, expr.Col(DataType(pl.datatypes.Int64()), name))
225
+ if (
226
+ plc.traits.is_integral(agg.dtype.plc)
227
+ and agg.dtype.id() != plc.TypeId.INT64
228
+ )
229
+ else expr.Col(agg.dtype, name)
230
+ )
231
+ # Polars semantics for sum differ by context:
232
+ # - GROUPBY: sum(all-null group) => 0; sum(empty group) => 0 (fill-null)
233
+ # - ROLLING: sum(all-null window) => null; sum(empty window) => 0 (fill only if empty)
234
+ #
235
+ # Must post-process because libcudf returns null for both empty and all-null windows/groups
236
+ if not POLARS_VERSION_LT_1323 or context in {
237
+ ExecutionContext.GROUPBY,
238
+ ExecutionContext.WINDOW,
239
+ }:
240
+ # GROUPBY: always fill top-level nulls with 0
241
+ return [(named_expr, True)], expr.NamedExpr(
242
+ name, replace_nulls(col, 0, is_top=is_top)
243
+ )
244
+ else: # pragma: no cover
245
+ # ROLLING:
246
+ # Add a second rolling agg to compute the window size, then only
247
+ # replace nulls with 0 when the window size is 0 (ie. empty window).
248
+ win_len_name = next(name_generator)
249
+ win_len = expr.NamedExpr(
250
+ win_len_name,
251
+ expr.Len(DataType(pl.Int32())),
252
+ )
253
+
254
+ win_len_col = expr.Col(DataType(pl.Int32()), win_len_name)
255
+ win_len_filled = replace_nulls(win_len_col, 0, is_top=True)
256
+
257
+ is_empty = expr.BinOp(
258
+ DataType(pl.Boolean()),
259
+ plc.binaryop.BinaryOperator.EQUAL,
260
+ win_len_filled,
261
+ expr.Literal(DataType(pl.Int32()), 0),
262
+ )
263
+
264
+ # If empty -> fill 0; else keep libcudf's semantics for all-null windows.
265
+ filled = replace_nulls(col, 0, is_top=is_top)
266
+ post_ternary_expr = expr.Ternary(agg.dtype, is_empty, filled, col)
267
+
268
+ return [(named_expr, True), (win_len, True)], expr.NamedExpr(
269
+ name, post_ternary_expr
270
+ )
271
+ elif agg.name in {"mean", "median", "quantile", "std", "var"}:
272
+ post_agg_col: expr.Expr = expr.Col(
273
+ DataType(pl.Float64()), name
274
+ ) # libcudf promotes to float64
275
+ if agg.dtype.plc.id() == plc.TypeId.FLOAT32:
276
+ # Cast back to float32 to match Polars
277
+ post_agg_col = expr.Cast(agg.dtype, post_agg_col)
278
+ return [(named_expr, True)], named_expr.reconstruct(post_agg_col)
279
+ else:
280
+ return [(named_expr, True)], named_expr.reconstruct(
281
+ expr.Col(agg.dtype, name)
282
+ )
283
+ if isinstance(agg, expr.Ternary):
284
+ when, then, otherwise = agg.children
285
+
286
+ when_aggs, when_post = decompose_single_agg(
287
+ expr.NamedExpr(next(name_generator), when),
288
+ name_generator,
289
+ is_top=False,
290
+ context=context,
291
+ )
292
+ then_aggs, then_post = decompose_single_agg(
293
+ expr.NamedExpr(next(name_generator), then),
294
+ name_generator,
295
+ is_top=False,
296
+ context=context,
297
+ )
298
+ otherwise_aggs, otherwise_post = decompose_single_agg(
299
+ expr.NamedExpr(next(name_generator), otherwise),
300
+ name_generator,
301
+ is_top=False,
302
+ context=context,
303
+ )
304
+
305
+ when_has = any(h for _, h in when_aggs)
306
+ then_has = any(h for _, h in then_aggs)
307
+ otherwise_has = any(h for _, h in otherwise_aggs)
308
+
309
+ if is_top and not (when_has or then_has or otherwise_has):
310
+ raise NotImplementedError(
311
+ "Broadcasted ternary with list output in groupby is not supported"
312
+ )
313
+
314
+ for post, has in (
315
+ (when_post, when_has),
316
+ (then_post, then_has),
317
+ (otherwise_post, otherwise_has),
318
+ ):
319
+ if is_top and not has and not isinstance(post.value, expr.Literal):
320
+ raise NotImplementedError(
321
+ "Broadcasting aggregated expressions in groupby/rolling"
322
+ )
323
+
324
+ return [*when_aggs, *then_aggs, *otherwise_aggs], named_expr.reconstruct(
325
+ agg.reconstruct([when_post.value, then_post.value, otherwise_post.value])
326
+ )
327
+ if not agg.is_pointwise and isinstance(agg, expr.BooleanFunction):
328
+ raise NotImplementedError(
329
+ f"Non pointwise boolean function {agg.name!r} not supported in groupby or rolling context"
330
+ )
331
+ if agg.is_pointwise:
332
+ aggs, posts = _decompose_aggs(
333
+ (expr.NamedExpr(next(name_generator), child) for child in agg.children),
334
+ name_generator,
335
+ is_top=False,
336
+ context=context,
337
+ )
338
+ if any(has_agg for _, has_agg in aggs):
339
+ if not all(
340
+ has_agg or isinstance(agg.value, expr.Literal) for agg, has_agg in aggs
341
+ ):
342
+ raise NotImplementedError(
343
+ "Broadcasting aggregated expressions in groupby/rolling"
344
+ )
345
+ # Any pointwise expression can be handled either by
346
+ # post-evaluation (if outside an aggregation).
347
+ return (
348
+ aggs,
349
+ named_expr.reconstruct(agg.reconstruct([p.value for p in posts])),
350
+ )
351
+ else:
352
+ # Or pre-evaluation if inside an aggregation.
353
+ return (
354
+ [(named_expr, False)],
355
+ named_expr.reconstruct(expr.Col(agg.dtype, name)),
356
+ )
357
+ raise NotImplementedError(f"No support for {type(agg)} in groupby/rolling")
358
+
359
+
360
+ def _decompose_aggs(
361
+ aggs: Iterable[expr.NamedExpr],
362
+ name_generator: Generator[str, None, None],
363
+ *,
364
+ is_top: bool,
365
+ context: ExecutionContext,
366
+ ) -> tuple[list[tuple[expr.NamedExpr, bool]], Sequence[expr.NamedExpr]]:
367
+ new_aggs, post = zip(
368
+ *(
369
+ decompose_single_agg(agg, name_generator, is_top=is_top, context=context)
370
+ for agg in aggs
371
+ ),
372
+ strict=True,
373
+ )
374
+ return list(itertools.chain.from_iterable(new_aggs)), post
375
+
376
+
377
+ def decompose_aggs(
378
+ aggs: Iterable[expr.NamedExpr],
379
+ name_generator: Generator[str, None, None],
380
+ *,
381
+ context: ExecutionContext,
382
+ ) -> tuple[list[expr.NamedExpr], Sequence[expr.NamedExpr]]:
383
+ """
384
+ Process arbitrary aggregations into a form we can handle in grouped aggregations.
385
+
386
+ Parameters
387
+ ----------
388
+ aggs
389
+ List of aggregation expressions
390
+ name_generator
391
+ Generator of unique names for temporaries introduced during decomposition.
392
+ context
393
+ ExecutionContext in which the aggregation will run.
394
+
395
+ Returns
396
+ -------
397
+ aggregations
398
+ Aggregations to apply in the groupby node.
399
+ post_aggregations
400
+ Expressions to apply after aggregating (as a ``Select``).
401
+
402
+ Notes
403
+ -----
404
+ The aggregation expressions are guaranteed to either be
405
+ expressions that can be pointwise evaluated before the groupby
406
+ operation, or aggregations of such expressions.
407
+
408
+ Raises
409
+ ------
410
+ NotImplementedError
411
+ For unsupported aggregation combinations.
412
+ """
413
+ new_aggs, post = _decompose_aggs(aggs, name_generator, is_top=True, context=context)
414
+ return [agg for agg, _ in new_aggs], post
415
+
416
+
417
+ def apply_pre_evaluation(
418
+ output_schema: Schema,
419
+ keys: Sequence[expr.NamedExpr],
420
+ original_aggs: Sequence[expr.NamedExpr],
421
+ name_generator: Generator[str, None, None],
422
+ context: ExecutionContext,
423
+ *extra_columns: expr.NamedExpr,
424
+ ) -> tuple[Sequence[expr.NamedExpr], Schema, Callable[[ir.IR], ir.IR]]:
425
+ """
426
+ Apply pre-evaluation to aggregations in a grouped or rolling context.
427
+
428
+ Parameters
429
+ ----------
430
+ output_schema
431
+ Schema of the plan node we're rewriting.
432
+ keys
433
+ Grouping keys (may be empty).
434
+ original_aggs
435
+ Aggregation expressions to rewrite.
436
+ name_generator
437
+ Generator of unique names for temporaries introduced during decomposition.
438
+ context
439
+ ExecutionContext in which the aggregation will run.
440
+ extra_columns
441
+ Any additional columns to be included in the output (only
442
+ relevant for rolling aggregations). Columns will appear in the
443
+ order `keys, extra_columns, original_aggs`.
444
+
445
+ Returns
446
+ -------
447
+ aggregations
448
+ The required aggregations.
449
+ schema
450
+ The new schema of the aggregation node
451
+ post_process
452
+ Function to apply to the aggregation node to apply any
453
+ post-processing.
454
+
455
+ Raises
456
+ ------
457
+ NotImplementedError
458
+ If the aggregations are somehow unsupported.
459
+ """
460
+ aggs, post = decompose_aggs(original_aggs, name_generator, context=context)
461
+ assert len(post) == len(original_aggs), (
462
+ f"Unexpected number of post-aggs {len(post)=} {len(original_aggs)=}"
463
+ )
464
+ # Order-preserving unique
465
+ aggs = list(dict.fromkeys(aggs).keys())
466
+ if any(not isinstance(e.value, expr.Col) for e in post):
467
+ selection = [
468
+ *(key.reconstruct(expr.Col(key.value.dtype, key.name)) for key in keys),
469
+ *extra_columns,
470
+ *post,
471
+ ]
472
+ inter_schema = {
473
+ e.name: e.value.dtype for e in itertools.chain(keys, extra_columns, aggs)
474
+ }
475
+ return (
476
+ aggs,
477
+ inter_schema,
478
+ partial(ir.Select, output_schema, selection, True), # noqa: FBT003
479
+ )
480
+ else:
481
+ return aggs, output_schema, lambda inp: inp