cudf-polars-cu13 25.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -0
- cudf_polars/VERSION +1 -0
- cudf_polars/__init__.py +28 -0
- cudf_polars/_version.py +21 -0
- cudf_polars/callback.py +318 -0
- cudf_polars/containers/__init__.py +13 -0
- cudf_polars/containers/column.py +495 -0
- cudf_polars/containers/dataframe.py +361 -0
- cudf_polars/containers/datatype.py +137 -0
- cudf_polars/dsl/__init__.py +8 -0
- cudf_polars/dsl/expr.py +66 -0
- cudf_polars/dsl/expressions/__init__.py +8 -0
- cudf_polars/dsl/expressions/aggregation.py +226 -0
- cudf_polars/dsl/expressions/base.py +272 -0
- cudf_polars/dsl/expressions/binaryop.py +120 -0
- cudf_polars/dsl/expressions/boolean.py +326 -0
- cudf_polars/dsl/expressions/datetime.py +271 -0
- cudf_polars/dsl/expressions/literal.py +97 -0
- cudf_polars/dsl/expressions/rolling.py +643 -0
- cudf_polars/dsl/expressions/selection.py +74 -0
- cudf_polars/dsl/expressions/slicing.py +46 -0
- cudf_polars/dsl/expressions/sorting.py +85 -0
- cudf_polars/dsl/expressions/string.py +1002 -0
- cudf_polars/dsl/expressions/struct.py +137 -0
- cudf_polars/dsl/expressions/ternary.py +49 -0
- cudf_polars/dsl/expressions/unary.py +517 -0
- cudf_polars/dsl/ir.py +2607 -0
- cudf_polars/dsl/nodebase.py +164 -0
- cudf_polars/dsl/to_ast.py +359 -0
- cudf_polars/dsl/tracing.py +16 -0
- cudf_polars/dsl/translate.py +939 -0
- cudf_polars/dsl/traversal.py +224 -0
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +481 -0
- cudf_polars/dsl/utils/groupby.py +98 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +61 -0
- cudf_polars/dsl/utils/reshape.py +74 -0
- cudf_polars/dsl/utils/rolling.py +121 -0
- cudf_polars/dsl/utils/windows.py +192 -0
- cudf_polars/experimental/__init__.py +8 -0
- cudf_polars/experimental/base.py +386 -0
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds.py +220 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
- cudf_polars/experimental/benchmarks/pdsh.py +814 -0
- cudf_polars/experimental/benchmarks/utils.py +832 -0
- cudf_polars/experimental/dask_registers.py +200 -0
- cudf_polars/experimental/dispatch.py +156 -0
- cudf_polars/experimental/distinct.py +197 -0
- cudf_polars/experimental/explain.py +157 -0
- cudf_polars/experimental/expressions.py +590 -0
- cudf_polars/experimental/groupby.py +327 -0
- cudf_polars/experimental/io.py +943 -0
- cudf_polars/experimental/join.py +391 -0
- cudf_polars/experimental/parallel.py +423 -0
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +188 -0
- cudf_polars/experimental/shuffle.py +354 -0
- cudf_polars/experimental/sort.py +609 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/statistics.py +795 -0
- cudf_polars/experimental/utils.py +169 -0
- cudf_polars/py.typed +0 -0
- cudf_polars/testing/__init__.py +8 -0
- cudf_polars/testing/asserts.py +448 -0
- cudf_polars/testing/io.py +122 -0
- cudf_polars/testing/plugin.py +236 -0
- cudf_polars/typing/__init__.py +219 -0
- cudf_polars/utils/__init__.py +8 -0
- cudf_polars/utils/config.py +741 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +118 -0
- cudf_polars/utils/sorting.py +53 -0
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +27 -0
- cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
- cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
- cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
- cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
- cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Traversal and visitor utilities for nodes."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from collections import deque
|
|
9
|
+
from typing import TYPE_CHECKING, Generic
|
|
10
|
+
|
|
11
|
+
from cudf_polars.typing import (
|
|
12
|
+
StateT_co,
|
|
13
|
+
U_contra,
|
|
14
|
+
V_co,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Callable, Generator, MutableMapping, Sequence
|
|
19
|
+
|
|
20
|
+
from cudf_polars.typing import GenericTransformer, NodeT
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__: list[str] = [
|
|
24
|
+
"CachingVisitor",
|
|
25
|
+
"make_recursive",
|
|
26
|
+
"reuse_if_unchanged",
|
|
27
|
+
"traversal",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
|
|
32
|
+
"""
|
|
33
|
+
Pre-order traversal of nodes in an expression.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
nodes
|
|
38
|
+
Roots of expressions to traverse.
|
|
39
|
+
|
|
40
|
+
Yields
|
|
41
|
+
------
|
|
42
|
+
Unique nodes in the expressions, parent before child, children
|
|
43
|
+
in-order from left to right.
|
|
44
|
+
"""
|
|
45
|
+
seen: set[NodeT] = set()
|
|
46
|
+
lifo: deque[NodeT] = deque()
|
|
47
|
+
|
|
48
|
+
for node in nodes:
|
|
49
|
+
if node not in seen:
|
|
50
|
+
lifo.append(node)
|
|
51
|
+
seen.add(node)
|
|
52
|
+
|
|
53
|
+
while lifo:
|
|
54
|
+
node = lifo.pop()
|
|
55
|
+
yield node
|
|
56
|
+
for child in reversed(node.children):
|
|
57
|
+
if child not in seen:
|
|
58
|
+
seen.add(child)
|
|
59
|
+
lifo.append(child)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def post_traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
|
|
63
|
+
"""
|
|
64
|
+
Post-order traversal of nodes in an expression.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
nodes
|
|
69
|
+
Roots of expressions to traverse.
|
|
70
|
+
|
|
71
|
+
Yields
|
|
72
|
+
------
|
|
73
|
+
Unique nodes in the expressions, child before parent, children
|
|
74
|
+
in-order from left to right.
|
|
75
|
+
"""
|
|
76
|
+
seen: set[NodeT] = set()
|
|
77
|
+
lifo: deque[NodeT] = deque()
|
|
78
|
+
|
|
79
|
+
for node in nodes:
|
|
80
|
+
if node not in seen:
|
|
81
|
+
lifo.append(node)
|
|
82
|
+
seen.add(node)
|
|
83
|
+
|
|
84
|
+
while lifo:
|
|
85
|
+
node = lifo[-1]
|
|
86
|
+
for child in node.children:
|
|
87
|
+
if child not in seen:
|
|
88
|
+
lifo.append(child)
|
|
89
|
+
seen.add(child)
|
|
90
|
+
break
|
|
91
|
+
else:
|
|
92
|
+
yield node
|
|
93
|
+
lifo.pop()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def reuse_if_unchanged(
|
|
97
|
+
node: NodeT, fn: GenericTransformer[NodeT, NodeT, StateT_co]
|
|
98
|
+
) -> NodeT:
|
|
99
|
+
"""
|
|
100
|
+
Recipe for transforming nodes that returns the old object if unchanged.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
node
|
|
105
|
+
Node to recurse on
|
|
106
|
+
fn
|
|
107
|
+
Function to transform children
|
|
108
|
+
|
|
109
|
+
Notes
|
|
110
|
+
-----
|
|
111
|
+
This can be used as a generic "base case" handler when
|
|
112
|
+
writing transforms that take nodes and produce new nodes.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
Existing node `e` if transformed children are unchanged, otherwise
|
|
117
|
+
reconstructed node with new children.
|
|
118
|
+
"""
|
|
119
|
+
new_children = [fn(c) for c in node.children]
|
|
120
|
+
if all(new == old for new, old in zip(new_children, node.children, strict=True)):
|
|
121
|
+
return node
|
|
122
|
+
return node.reconstruct(new_children)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def make_recursive(
|
|
126
|
+
fn: Callable[[U_contra, GenericTransformer[U_contra, V_co, StateT_co]], V_co],
|
|
127
|
+
*,
|
|
128
|
+
# make_recursive is a type constructor with covariant state parameter
|
|
129
|
+
# not a normal function for which the parameter would be contravariant
|
|
130
|
+
# hence the type ignore
|
|
131
|
+
state: StateT_co, # type: ignore[misc]
|
|
132
|
+
) -> GenericTransformer[U_contra, V_co, StateT_co]:
|
|
133
|
+
"""
|
|
134
|
+
No-op wrapper for recursive visitors.
|
|
135
|
+
|
|
136
|
+
Facilitates using visitors that don't need caching but are written
|
|
137
|
+
in the same style.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
fn
|
|
142
|
+
Function to transform inputs to outputs. Should take as its
|
|
143
|
+
second argument a callable from input to output.
|
|
144
|
+
state
|
|
145
|
+
Arbitrary *immutable* state that should be accessible to the
|
|
146
|
+
visitor through the `state` property.
|
|
147
|
+
|
|
148
|
+
Notes
|
|
149
|
+
-----
|
|
150
|
+
All transformation functions *must* be free of side-effects.
|
|
151
|
+
|
|
152
|
+
Usually, prefer a :class:`CachingVisitor`, but if we know that we
|
|
153
|
+
don't need caching in a transformation and then this no-op
|
|
154
|
+
approach is slightly cheaper.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
Recursive function without caching.
|
|
159
|
+
|
|
160
|
+
See Also
|
|
161
|
+
--------
|
|
162
|
+
CachingVisitor
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def rec(node: U_contra) -> V_co:
|
|
166
|
+
return fn(node, rec) # type: ignore[arg-type]
|
|
167
|
+
|
|
168
|
+
rec.state = state # type: ignore[attr-defined]
|
|
169
|
+
return rec # type: ignore[return-value]
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class CachingVisitor(Generic[U_contra, V_co, StateT_co]):
|
|
173
|
+
"""
|
|
174
|
+
Caching wrapper for recursive visitors.
|
|
175
|
+
|
|
176
|
+
Facilitates writing visitors where already computed results should
|
|
177
|
+
be cached and reused. The cache is managed automatically, and is
|
|
178
|
+
tied to the lifetime of the wrapper.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
fn
|
|
183
|
+
Function to transform inputs to outputs. Should take as its
|
|
184
|
+
second argument the recursive cache manager.
|
|
185
|
+
state
|
|
186
|
+
Arbitrary *immutable* state that should be accessible to the
|
|
187
|
+
visitor through the `state` property.
|
|
188
|
+
|
|
189
|
+
Notes
|
|
190
|
+
-----
|
|
191
|
+
All transformation functions *must* be free of side-effects.
|
|
192
|
+
|
|
193
|
+
Returns
|
|
194
|
+
-------
|
|
195
|
+
Recursive function with caching.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def __init__(
|
|
199
|
+
self,
|
|
200
|
+
fn: Callable[[U_contra, GenericTransformer[U_contra, V_co, StateT_co]], V_co],
|
|
201
|
+
*,
|
|
202
|
+
state: StateT_co,
|
|
203
|
+
) -> None:
|
|
204
|
+
self.fn = fn
|
|
205
|
+
self.cache: MutableMapping[U_contra, V_co] = {}
|
|
206
|
+
self.state = state
|
|
207
|
+
|
|
208
|
+
def __call__(self, value: U_contra) -> V_co:
|
|
209
|
+
"""
|
|
210
|
+
Apply the function to a value.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
value
|
|
215
|
+
The value to transform.
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
A transformed value.
|
|
220
|
+
"""
|
|
221
|
+
try:
|
|
222
|
+
return self.cache[value]
|
|
223
|
+
except KeyError:
|
|
224
|
+
return self.cache.setdefault(value, self.fn(value, self))
|
|
@@ -0,0 +1,481 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Utilities for rewriting aggregations."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import itertools
|
|
9
|
+
from functools import partial
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
import polars as pl
|
|
13
|
+
|
|
14
|
+
import pylibcudf as plc
|
|
15
|
+
|
|
16
|
+
from cudf_polars.containers import DataType
|
|
17
|
+
from cudf_polars.dsl import expr, ir
|
|
18
|
+
from cudf_polars.dsl.expressions.base import ExecutionContext
|
|
19
|
+
from cudf_polars.utils.versions import POLARS_VERSION_LT_1323
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from collections.abc import Callable, Generator, Iterable, Sequence
|
|
23
|
+
|
|
24
|
+
from cudf_polars.typing import Schema
|
|
25
|
+
|
|
26
|
+
__all__ = ["apply_pre_evaluation", "decompose_aggs", "decompose_single_agg"]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def replace_nulls(col: expr.Expr, value: Any, *, is_top: bool) -> expr.Expr:
|
|
30
|
+
"""
|
|
31
|
+
Replace nulls with the given scalar if at top level.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
col
|
|
36
|
+
Expression to replace nulls in.
|
|
37
|
+
value
|
|
38
|
+
Scalar replacement
|
|
39
|
+
is_top
|
|
40
|
+
Is this top-level (should replacement be performed).
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Massaged expression.
|
|
45
|
+
"""
|
|
46
|
+
if not is_top:
|
|
47
|
+
return col
|
|
48
|
+
return expr.UnaryFunction(
|
|
49
|
+
col.dtype, "fill_null", (), col, expr.Literal(col.dtype, value)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def decompose_single_agg(
|
|
54
|
+
named_expr: expr.NamedExpr,
|
|
55
|
+
name_generator: Generator[str, None, None],
|
|
56
|
+
*,
|
|
57
|
+
is_top: bool,
|
|
58
|
+
context: ExecutionContext,
|
|
59
|
+
) -> tuple[list[tuple[expr.NamedExpr, bool]], expr.NamedExpr]:
|
|
60
|
+
"""
|
|
61
|
+
Decompose a single named aggregation.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
named_expr
|
|
66
|
+
The named aggregation to decompose
|
|
67
|
+
name_generator
|
|
68
|
+
Generator of unique names for temporaries introduced during decomposition.
|
|
69
|
+
is_top
|
|
70
|
+
Is this the top of an aggregation expression?
|
|
71
|
+
context
|
|
72
|
+
ExecutionContext in which the aggregation will run.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
aggregations
|
|
77
|
+
Pairs of expressions to apply as grouped aggregations (whose children
|
|
78
|
+
may be evaluated pointwise) and flags indicating if the
|
|
79
|
+
expression contained nested aggregations.
|
|
80
|
+
post_aggregate
|
|
81
|
+
Single expression to apply to post-process the grouped
|
|
82
|
+
aggregations.
|
|
83
|
+
|
|
84
|
+
Raises
|
|
85
|
+
------
|
|
86
|
+
NotImplementedError
|
|
87
|
+
If the expression contains nested aggregations or unsupported
|
|
88
|
+
operations in a grouped aggregation context.
|
|
89
|
+
"""
|
|
90
|
+
agg = named_expr.value
|
|
91
|
+
name = named_expr.name
|
|
92
|
+
if isinstance(agg, expr.UnaryFunction) and agg.name in {
|
|
93
|
+
"rank",
|
|
94
|
+
}:
|
|
95
|
+
if context != ExecutionContext.WINDOW:
|
|
96
|
+
raise NotImplementedError(
|
|
97
|
+
f"{agg.name} is not supported in groupby or rolling context"
|
|
98
|
+
)
|
|
99
|
+
# Ensure Polars semantics for dtype:
|
|
100
|
+
# - average -> Float64
|
|
101
|
+
# - min/max/dense/ordinal -> IDX_DTYPE (UInt32/UInt64)
|
|
102
|
+
post_col: expr.Expr = expr.Col(agg.dtype, name)
|
|
103
|
+
if agg.name == "rank":
|
|
104
|
+
post_col = expr.Cast(agg.dtype, post_col)
|
|
105
|
+
|
|
106
|
+
return [(named_expr, True)], named_expr.reconstruct(post_col)
|
|
107
|
+
if isinstance(agg, expr.UnaryFunction) and agg.name == "null_count":
|
|
108
|
+
(child,) = agg.children
|
|
109
|
+
|
|
110
|
+
is_null_bool = expr.BooleanFunction(
|
|
111
|
+
DataType(pl.Boolean()),
|
|
112
|
+
expr.BooleanFunction.Name.IsNull,
|
|
113
|
+
(),
|
|
114
|
+
child,
|
|
115
|
+
)
|
|
116
|
+
u32 = DataType(pl.UInt32())
|
|
117
|
+
sum_name = next(name_generator)
|
|
118
|
+
sum_agg = expr.NamedExpr(
|
|
119
|
+
sum_name,
|
|
120
|
+
expr.Agg(u32, "sum", (), expr.Cast(u32, is_null_bool)),
|
|
121
|
+
)
|
|
122
|
+
return [(sum_agg, True)], named_expr.reconstruct(
|
|
123
|
+
expr.Cast(u32, expr.Col(u32, sum_name))
|
|
124
|
+
)
|
|
125
|
+
if isinstance(agg, expr.Col):
|
|
126
|
+
# TODO: collect_list produces null for empty group in libcudf, empty list in polars.
|
|
127
|
+
# But we need the nested value type, so need to track proper dtypes in our DSL.
|
|
128
|
+
return [(named_expr, False)], named_expr.reconstruct(expr.Col(agg.dtype, name))
|
|
129
|
+
if is_top and isinstance(agg, expr.Cast) and isinstance(agg.children[0], expr.Len):
|
|
130
|
+
# Special case to fill nulls with zeros for empty group length calculations
|
|
131
|
+
(child,) = agg.children
|
|
132
|
+
child_agg, post = decompose_single_agg(
|
|
133
|
+
expr.NamedExpr(next(name_generator), child),
|
|
134
|
+
name_generator,
|
|
135
|
+
is_top=True,
|
|
136
|
+
context=context,
|
|
137
|
+
)
|
|
138
|
+
return child_agg, named_expr.reconstruct(
|
|
139
|
+
replace_nulls(
|
|
140
|
+
agg.reconstruct([post.value]),
|
|
141
|
+
0,
|
|
142
|
+
is_top=True,
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
if isinstance(agg, expr.Len):
|
|
146
|
+
return [(named_expr, True)], named_expr.reconstruct(expr.Col(agg.dtype, name))
|
|
147
|
+
if isinstance(agg, (expr.Literal, expr.LiteralColumn)):
|
|
148
|
+
return [], named_expr
|
|
149
|
+
if (
|
|
150
|
+
is_top
|
|
151
|
+
and isinstance(agg, expr.UnaryFunction)
|
|
152
|
+
and agg.name == "fill_null_with_strategy"
|
|
153
|
+
):
|
|
154
|
+
strategy, _ = agg.options
|
|
155
|
+
raise NotImplementedError(
|
|
156
|
+
f"fill_null_with_strategy({strategy!r}) is not supported in groupby aggregations"
|
|
157
|
+
)
|
|
158
|
+
if isinstance(agg, expr.Agg):
|
|
159
|
+
if agg.name == "quantile":
|
|
160
|
+
# Second child the requested quantile (which is asserted
|
|
161
|
+
# to be a literal on construction)
|
|
162
|
+
child = agg.children[0]
|
|
163
|
+
else:
|
|
164
|
+
(child,) = agg.children
|
|
165
|
+
needs_masking = agg.name in {"min", "max"} and plc.traits.is_floating_point(
|
|
166
|
+
child.dtype.plc
|
|
167
|
+
)
|
|
168
|
+
if needs_masking and agg.options:
|
|
169
|
+
# pl.col("a").nan_max or nan_min
|
|
170
|
+
raise NotImplementedError("Nan propagation in groupby for min/max")
|
|
171
|
+
aggs, _ = decompose_single_agg(
|
|
172
|
+
expr.NamedExpr(next(name_generator), child),
|
|
173
|
+
name_generator,
|
|
174
|
+
is_top=False,
|
|
175
|
+
context=context,
|
|
176
|
+
)
|
|
177
|
+
if any(has_agg for _, has_agg in aggs):
|
|
178
|
+
raise NotImplementedError("Nested aggs in groupby not supported")
|
|
179
|
+
|
|
180
|
+
child_dtype = child.dtype.plc
|
|
181
|
+
req = agg.agg_request
|
|
182
|
+
is_median = agg.name == "median"
|
|
183
|
+
is_quantile = agg.name == "quantile"
|
|
184
|
+
|
|
185
|
+
# quantile agg on decimal: unsupported -> keep dtype Decimal
|
|
186
|
+
# mean/median on decimal: Polars returns float -> pre-cast
|
|
187
|
+
decimal_unsupported = False
|
|
188
|
+
if plc.traits.is_fixed_point(child_dtype):
|
|
189
|
+
if is_quantile:
|
|
190
|
+
decimal_unsupported = True
|
|
191
|
+
elif agg.name in {"mean", "median"}:
|
|
192
|
+
tid = agg.dtype.plc.id()
|
|
193
|
+
if tid in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}:
|
|
194
|
+
cast_to = (
|
|
195
|
+
DataType(pl.Float64)
|
|
196
|
+
if tid == plc.TypeId.FLOAT64
|
|
197
|
+
else DataType(pl.Float32)
|
|
198
|
+
)
|
|
199
|
+
child = expr.Cast(cast_to, child)
|
|
200
|
+
child_dtype = child.dtype.plc
|
|
201
|
+
|
|
202
|
+
is_group_quantile_supported = plc.traits.is_integral(
|
|
203
|
+
child_dtype
|
|
204
|
+
) or plc.traits.is_floating_point(child_dtype)
|
|
205
|
+
|
|
206
|
+
unsupported = (
|
|
207
|
+
decimal_unsupported
|
|
208
|
+
or ((is_median or is_quantile) and not is_group_quantile_supported)
|
|
209
|
+
) or (not plc.aggregation.is_valid_aggregation(child_dtype, req))
|
|
210
|
+
if unsupported:
|
|
211
|
+
return [], named_expr.reconstruct(expr.Literal(child.dtype, None))
|
|
212
|
+
if needs_masking:
|
|
213
|
+
child = expr.UnaryFunction(child.dtype, "mask_nans", (), child)
|
|
214
|
+
# The aggregation is just reconstructed with the new
|
|
215
|
+
# (potentially masked) child. This is safe because we recursed
|
|
216
|
+
# to ensure there are no nested aggregations.
|
|
217
|
+
|
|
218
|
+
# rebuild the agg with the transformed child
|
|
219
|
+
new_children = [child] if not is_quantile else [child, agg.children[1]]
|
|
220
|
+
named_expr = named_expr.reconstruct(agg.reconstruct(new_children))
|
|
221
|
+
|
|
222
|
+
if agg.name == "sum":
|
|
223
|
+
col = (
|
|
224
|
+
expr.Cast(agg.dtype, expr.Col(DataType(pl.datatypes.Int64()), name))
|
|
225
|
+
if (
|
|
226
|
+
plc.traits.is_integral(agg.dtype.plc)
|
|
227
|
+
and agg.dtype.id() != plc.TypeId.INT64
|
|
228
|
+
)
|
|
229
|
+
else expr.Col(agg.dtype, name)
|
|
230
|
+
)
|
|
231
|
+
# Polars semantics for sum differ by context:
|
|
232
|
+
# - GROUPBY: sum(all-null group) => 0; sum(empty group) => 0 (fill-null)
|
|
233
|
+
# - ROLLING: sum(all-null window) => null; sum(empty window) => 0 (fill only if empty)
|
|
234
|
+
#
|
|
235
|
+
# Must post-process because libcudf returns null for both empty and all-null windows/groups
|
|
236
|
+
if not POLARS_VERSION_LT_1323 or context in {
|
|
237
|
+
ExecutionContext.GROUPBY,
|
|
238
|
+
ExecutionContext.WINDOW,
|
|
239
|
+
}:
|
|
240
|
+
# GROUPBY: always fill top-level nulls with 0
|
|
241
|
+
return [(named_expr, True)], expr.NamedExpr(
|
|
242
|
+
name, replace_nulls(col, 0, is_top=is_top)
|
|
243
|
+
)
|
|
244
|
+
else: # pragma: no cover
|
|
245
|
+
# ROLLING:
|
|
246
|
+
# Add a second rolling agg to compute the window size, then only
|
|
247
|
+
# replace nulls with 0 when the window size is 0 (ie. empty window).
|
|
248
|
+
win_len_name = next(name_generator)
|
|
249
|
+
win_len = expr.NamedExpr(
|
|
250
|
+
win_len_name,
|
|
251
|
+
expr.Len(DataType(pl.Int32())),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
win_len_col = expr.Col(DataType(pl.Int32()), win_len_name)
|
|
255
|
+
win_len_filled = replace_nulls(win_len_col, 0, is_top=True)
|
|
256
|
+
|
|
257
|
+
is_empty = expr.BinOp(
|
|
258
|
+
DataType(pl.Boolean()),
|
|
259
|
+
plc.binaryop.BinaryOperator.EQUAL,
|
|
260
|
+
win_len_filled,
|
|
261
|
+
expr.Literal(DataType(pl.Int32()), 0),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# If empty -> fill 0; else keep libcudf's semantics for all-null windows.
|
|
265
|
+
filled = replace_nulls(col, 0, is_top=is_top)
|
|
266
|
+
post_ternary_expr = expr.Ternary(agg.dtype, is_empty, filled, col)
|
|
267
|
+
|
|
268
|
+
return [(named_expr, True), (win_len, True)], expr.NamedExpr(
|
|
269
|
+
name, post_ternary_expr
|
|
270
|
+
)
|
|
271
|
+
elif agg.name in {"mean", "median", "quantile", "std", "var"}:
|
|
272
|
+
post_agg_col: expr.Expr = expr.Col(
|
|
273
|
+
DataType(pl.Float64()), name
|
|
274
|
+
) # libcudf promotes to float64
|
|
275
|
+
if agg.dtype.plc.id() == plc.TypeId.FLOAT32:
|
|
276
|
+
# Cast back to float32 to match Polars
|
|
277
|
+
post_agg_col = expr.Cast(agg.dtype, post_agg_col)
|
|
278
|
+
return [(named_expr, True)], named_expr.reconstruct(post_agg_col)
|
|
279
|
+
else:
|
|
280
|
+
return [(named_expr, True)], named_expr.reconstruct(
|
|
281
|
+
expr.Col(agg.dtype, name)
|
|
282
|
+
)
|
|
283
|
+
if isinstance(agg, expr.Ternary):
|
|
284
|
+
when, then, otherwise = agg.children
|
|
285
|
+
|
|
286
|
+
when_aggs, when_post = decompose_single_agg(
|
|
287
|
+
expr.NamedExpr(next(name_generator), when),
|
|
288
|
+
name_generator,
|
|
289
|
+
is_top=False,
|
|
290
|
+
context=context,
|
|
291
|
+
)
|
|
292
|
+
then_aggs, then_post = decompose_single_agg(
|
|
293
|
+
expr.NamedExpr(next(name_generator), then),
|
|
294
|
+
name_generator,
|
|
295
|
+
is_top=False,
|
|
296
|
+
context=context,
|
|
297
|
+
)
|
|
298
|
+
otherwise_aggs, otherwise_post = decompose_single_agg(
|
|
299
|
+
expr.NamedExpr(next(name_generator), otherwise),
|
|
300
|
+
name_generator,
|
|
301
|
+
is_top=False,
|
|
302
|
+
context=context,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
when_has = any(h for _, h in when_aggs)
|
|
306
|
+
then_has = any(h for _, h in then_aggs)
|
|
307
|
+
otherwise_has = any(h for _, h in otherwise_aggs)
|
|
308
|
+
|
|
309
|
+
if is_top and not (when_has or then_has or otherwise_has):
|
|
310
|
+
raise NotImplementedError(
|
|
311
|
+
"Broadcasted ternary with list output in groupby is not supported"
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
for post, has in (
|
|
315
|
+
(when_post, when_has),
|
|
316
|
+
(then_post, then_has),
|
|
317
|
+
(otherwise_post, otherwise_has),
|
|
318
|
+
):
|
|
319
|
+
if is_top and not has and not isinstance(post.value, expr.Literal):
|
|
320
|
+
raise NotImplementedError(
|
|
321
|
+
"Broadcasting aggregated expressions in groupby/rolling"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return [*when_aggs, *then_aggs, *otherwise_aggs], named_expr.reconstruct(
|
|
325
|
+
agg.reconstruct([when_post.value, then_post.value, otherwise_post.value])
|
|
326
|
+
)
|
|
327
|
+
if not agg.is_pointwise and isinstance(agg, expr.BooleanFunction):
|
|
328
|
+
raise NotImplementedError(
|
|
329
|
+
f"Non pointwise boolean function {agg.name!r} not supported in groupby or rolling context"
|
|
330
|
+
)
|
|
331
|
+
if agg.is_pointwise:
|
|
332
|
+
aggs, posts = _decompose_aggs(
|
|
333
|
+
(expr.NamedExpr(next(name_generator), child) for child in agg.children),
|
|
334
|
+
name_generator,
|
|
335
|
+
is_top=False,
|
|
336
|
+
context=context,
|
|
337
|
+
)
|
|
338
|
+
if any(has_agg for _, has_agg in aggs):
|
|
339
|
+
if not all(
|
|
340
|
+
has_agg or isinstance(agg.value, expr.Literal) for agg, has_agg in aggs
|
|
341
|
+
):
|
|
342
|
+
raise NotImplementedError(
|
|
343
|
+
"Broadcasting aggregated expressions in groupby/rolling"
|
|
344
|
+
)
|
|
345
|
+
# Any pointwise expression can be handled either by
|
|
346
|
+
# post-evaluation (if outside an aggregation).
|
|
347
|
+
return (
|
|
348
|
+
aggs,
|
|
349
|
+
named_expr.reconstruct(agg.reconstruct([p.value for p in posts])),
|
|
350
|
+
)
|
|
351
|
+
else:
|
|
352
|
+
# Or pre-evaluation if inside an aggregation.
|
|
353
|
+
return (
|
|
354
|
+
[(named_expr, False)],
|
|
355
|
+
named_expr.reconstruct(expr.Col(agg.dtype, name)),
|
|
356
|
+
)
|
|
357
|
+
raise NotImplementedError(f"No support for {type(agg)} in groupby/rolling")
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _decompose_aggs(
|
|
361
|
+
aggs: Iterable[expr.NamedExpr],
|
|
362
|
+
name_generator: Generator[str, None, None],
|
|
363
|
+
*,
|
|
364
|
+
is_top: bool,
|
|
365
|
+
context: ExecutionContext,
|
|
366
|
+
) -> tuple[list[tuple[expr.NamedExpr, bool]], Sequence[expr.NamedExpr]]:
|
|
367
|
+
new_aggs, post = zip(
|
|
368
|
+
*(
|
|
369
|
+
decompose_single_agg(agg, name_generator, is_top=is_top, context=context)
|
|
370
|
+
for agg in aggs
|
|
371
|
+
),
|
|
372
|
+
strict=True,
|
|
373
|
+
)
|
|
374
|
+
return list(itertools.chain.from_iterable(new_aggs)), post
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def decompose_aggs(
|
|
378
|
+
aggs: Iterable[expr.NamedExpr],
|
|
379
|
+
name_generator: Generator[str, None, None],
|
|
380
|
+
*,
|
|
381
|
+
context: ExecutionContext,
|
|
382
|
+
) -> tuple[list[expr.NamedExpr], Sequence[expr.NamedExpr]]:
|
|
383
|
+
"""
|
|
384
|
+
Process arbitrary aggregations into a form we can handle in grouped aggregations.
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
aggs
|
|
389
|
+
List of aggregation expressions
|
|
390
|
+
name_generator
|
|
391
|
+
Generator of unique names for temporaries introduced during decomposition.
|
|
392
|
+
context
|
|
393
|
+
ExecutionContext in which the aggregation will run.
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
aggregations
|
|
398
|
+
Aggregations to apply in the groupby node.
|
|
399
|
+
post_aggregations
|
|
400
|
+
Expressions to apply after aggregating (as a ``Select``).
|
|
401
|
+
|
|
402
|
+
Notes
|
|
403
|
+
-----
|
|
404
|
+
The aggregation expressions are guaranteed to either be
|
|
405
|
+
expressions that can be pointwise evaluated before the groupby
|
|
406
|
+
operation, or aggregations of such expressions.
|
|
407
|
+
|
|
408
|
+
Raises
|
|
409
|
+
------
|
|
410
|
+
NotImplementedError
|
|
411
|
+
For unsupported aggregation combinations.
|
|
412
|
+
"""
|
|
413
|
+
new_aggs, post = _decompose_aggs(aggs, name_generator, is_top=True, context=context)
|
|
414
|
+
return [agg for agg, _ in new_aggs], post
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def apply_pre_evaluation(
|
|
418
|
+
output_schema: Schema,
|
|
419
|
+
keys: Sequence[expr.NamedExpr],
|
|
420
|
+
original_aggs: Sequence[expr.NamedExpr],
|
|
421
|
+
name_generator: Generator[str, None, None],
|
|
422
|
+
context: ExecutionContext,
|
|
423
|
+
*extra_columns: expr.NamedExpr,
|
|
424
|
+
) -> tuple[Sequence[expr.NamedExpr], Schema, Callable[[ir.IR], ir.IR]]:
|
|
425
|
+
"""
|
|
426
|
+
Apply pre-evaluation to aggregations in a grouped or rolling context.
|
|
427
|
+
|
|
428
|
+
Parameters
|
|
429
|
+
----------
|
|
430
|
+
output_schema
|
|
431
|
+
Schema of the plan node we're rewriting.
|
|
432
|
+
keys
|
|
433
|
+
Grouping keys (may be empty).
|
|
434
|
+
original_aggs
|
|
435
|
+
Aggregation expressions to rewrite.
|
|
436
|
+
name_generator
|
|
437
|
+
Generator of unique names for temporaries introduced during decomposition.
|
|
438
|
+
context
|
|
439
|
+
ExecutionContext in which the aggregation will run.
|
|
440
|
+
extra_columns
|
|
441
|
+
Any additional columns to be included in the output (only
|
|
442
|
+
relevant for rolling aggregations). Columns will appear in the
|
|
443
|
+
order `keys, extra_columns, original_aggs`.
|
|
444
|
+
|
|
445
|
+
Returns
|
|
446
|
+
-------
|
|
447
|
+
aggregations
|
|
448
|
+
The required aggregations.
|
|
449
|
+
schema
|
|
450
|
+
The new schema of the aggregation node
|
|
451
|
+
post_process
|
|
452
|
+
Function to apply to the aggregation node to apply any
|
|
453
|
+
post-processing.
|
|
454
|
+
|
|
455
|
+
Raises
|
|
456
|
+
------
|
|
457
|
+
NotImplementedError
|
|
458
|
+
If the aggregations are somehow unsupported.
|
|
459
|
+
"""
|
|
460
|
+
aggs, post = decompose_aggs(original_aggs, name_generator, context=context)
|
|
461
|
+
assert len(post) == len(original_aggs), (
|
|
462
|
+
f"Unexpected number of post-aggs {len(post)=} {len(original_aggs)=}"
|
|
463
|
+
)
|
|
464
|
+
# Order-preserving unique
|
|
465
|
+
aggs = list(dict.fromkeys(aggs).keys())
|
|
466
|
+
if any(not isinstance(e.value, expr.Col) for e in post):
|
|
467
|
+
selection = [
|
|
468
|
+
*(key.reconstruct(expr.Col(key.value.dtype, key.name)) for key in keys),
|
|
469
|
+
*extra_columns,
|
|
470
|
+
*post,
|
|
471
|
+
]
|
|
472
|
+
inter_schema = {
|
|
473
|
+
e.name: e.value.dtype for e in itertools.chain(keys, extra_columns, aggs)
|
|
474
|
+
}
|
|
475
|
+
return (
|
|
476
|
+
aggs,
|
|
477
|
+
inter_schema,
|
|
478
|
+
partial(ir.Select, output_schema, selection, True), # noqa: FBT003
|
|
479
|
+
)
|
|
480
|
+
else:
|
|
481
|
+
return aggs, output_schema, lambda inp: inp
|