cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +82 -65
- cudf_polars/containers/column.py +138 -7
- cudf_polars/containers/dataframe.py +26 -39
- cudf_polars/dsl/expr.py +3 -1
- cudf_polars/dsl/expressions/aggregation.py +27 -63
- cudf_polars/dsl/expressions/base.py +40 -72
- cudf_polars/dsl/expressions/binaryop.py +5 -41
- cudf_polars/dsl/expressions/boolean.py +25 -53
- cudf_polars/dsl/expressions/datetime.py +97 -17
- cudf_polars/dsl/expressions/literal.py +27 -33
- cudf_polars/dsl/expressions/rolling.py +110 -9
- cudf_polars/dsl/expressions/selection.py +8 -26
- cudf_polars/dsl/expressions/slicing.py +47 -0
- cudf_polars/dsl/expressions/sorting.py +5 -18
- cudf_polars/dsl/expressions/string.py +33 -36
- cudf_polars/dsl/expressions/ternary.py +3 -10
- cudf_polars/dsl/expressions/unary.py +35 -75
- cudf_polars/dsl/ir.py +749 -212
- cudf_polars/dsl/nodebase.py +8 -1
- cudf_polars/dsl/to_ast.py +5 -3
- cudf_polars/dsl/translate.py +319 -171
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +292 -0
- cudf_polars/dsl/utils/groupby.py +97 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +46 -0
- cudf_polars/dsl/utils/rolling.py +113 -0
- cudf_polars/dsl/utils/windows.py +186 -0
- cudf_polars/experimental/base.py +17 -19
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsh.py +1279 -0
- cudf_polars/experimental/dask_registers.py +196 -0
- cudf_polars/experimental/distinct.py +174 -0
- cudf_polars/experimental/explain.py +127 -0
- cudf_polars/experimental/expressions.py +521 -0
- cudf_polars/experimental/groupby.py +288 -0
- cudf_polars/experimental/io.py +58 -29
- cudf_polars/experimental/join.py +353 -0
- cudf_polars/experimental/parallel.py +166 -93
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +92 -7
- cudf_polars/experimental/shuffle.py +294 -0
- cudf_polars/experimental/sort.py +45 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/utils.py +100 -0
- cudf_polars/testing/asserts.py +146 -6
- cudf_polars/testing/io.py +72 -0
- cudf_polars/testing/plugin.py +78 -76
- cudf_polars/typing/__init__.py +59 -6
- cudf_polars/utils/config.py +353 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +22 -5
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +5 -4
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/METADATA +10 -7
- cudf_polars_cu12-25.6.0.dist-info/RECORD +73 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/WHEEL +1 -1
- cudf_polars/experimental/dask_serialize.py +0 -59
- cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info/licenses}/LICENSE +0 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Utilities for rewriting aggregations."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import itertools
|
|
9
|
+
from functools import partial
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
import pylibcudf as plc
|
|
13
|
+
|
|
14
|
+
from cudf_polars.dsl import expr, ir
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import Callable, Generator, Iterable, Sequence
|
|
18
|
+
|
|
19
|
+
from cudf_polars.typing import Schema
|
|
20
|
+
|
|
21
|
+
__all__ = ["apply_pre_evaluation", "decompose_aggs", "decompose_single_agg"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def replace_nulls(col: expr.Expr, value: Any, *, is_top: bool) -> expr.Expr:
|
|
25
|
+
"""
|
|
26
|
+
Replace nulls with the given scalar if at top level.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
col
|
|
31
|
+
Expression to replace nulls in.
|
|
32
|
+
value
|
|
33
|
+
Scalar replacement
|
|
34
|
+
is_top
|
|
35
|
+
Is this top-level (should replacement be performed).
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
Massaged expression.
|
|
40
|
+
"""
|
|
41
|
+
if not is_top:
|
|
42
|
+
return col
|
|
43
|
+
return expr.UnaryFunction(
|
|
44
|
+
col.dtype, "fill_null", (), col, expr.Literal(col.dtype, value)
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def decompose_single_agg(
|
|
49
|
+
named_expr: expr.NamedExpr,
|
|
50
|
+
name_generator: Generator[str, None, None],
|
|
51
|
+
*,
|
|
52
|
+
is_top: bool,
|
|
53
|
+
) -> tuple[list[tuple[expr.NamedExpr, bool]], expr.NamedExpr]:
|
|
54
|
+
"""
|
|
55
|
+
Decompose a single named aggregation.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
named_expr
|
|
60
|
+
The named aggregation to decompose
|
|
61
|
+
name_generator
|
|
62
|
+
Generator of unique names for temporaries introduced during decomposition.
|
|
63
|
+
is_top
|
|
64
|
+
Is this the top of an aggregation expression?
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
aggregations
|
|
69
|
+
Pairs of expressions to apply as grouped aggregations (whose children
|
|
70
|
+
may be evaluated pointwise) and flags indicating if the
|
|
71
|
+
expression contained nested aggregations.
|
|
72
|
+
post_aggregate
|
|
73
|
+
Single expression to apply to post-process the grouped
|
|
74
|
+
aggregations.
|
|
75
|
+
|
|
76
|
+
Raises
|
|
77
|
+
------
|
|
78
|
+
NotImplementedError
|
|
79
|
+
If the expression contains nested aggregations or unsupported
|
|
80
|
+
operations in a grouped aggregation context.
|
|
81
|
+
"""
|
|
82
|
+
agg = named_expr.value
|
|
83
|
+
name = named_expr.name
|
|
84
|
+
if isinstance(agg, expr.Col):
|
|
85
|
+
# TODO: collect_list produces null for empty group in libcudf, empty list in polars.
|
|
86
|
+
# But we need the nested value type, so need to track proper dtypes in our DSL.
|
|
87
|
+
return [(named_expr, False)], named_expr.reconstruct(expr.Col(agg.dtype, name))
|
|
88
|
+
if is_top and isinstance(agg, expr.Cast) and isinstance(agg.children[0], expr.Len):
|
|
89
|
+
# Special case to fill nulls with zeros for empty group length calculations
|
|
90
|
+
(child,) = agg.children
|
|
91
|
+
child_agg, post = decompose_single_agg(
|
|
92
|
+
expr.NamedExpr(next(name_generator), child), name_generator, is_top=True
|
|
93
|
+
)
|
|
94
|
+
return child_agg, named_expr.reconstruct(
|
|
95
|
+
replace_nulls(
|
|
96
|
+
agg.reconstruct([post.value]),
|
|
97
|
+
0,
|
|
98
|
+
is_top=True,
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
if isinstance(agg, expr.Len):
|
|
102
|
+
return [(named_expr, True)], named_expr.reconstruct(expr.Col(agg.dtype, name))
|
|
103
|
+
if isinstance(agg, (expr.Literal, expr.LiteralColumn)):
|
|
104
|
+
return [], named_expr
|
|
105
|
+
if isinstance(agg, expr.Agg):
|
|
106
|
+
if agg.name == "quantile":
|
|
107
|
+
# Second child the requested quantile (which is asserted
|
|
108
|
+
# to be a literal on construction)
|
|
109
|
+
child = agg.children[0]
|
|
110
|
+
else:
|
|
111
|
+
(child,) = agg.children
|
|
112
|
+
needs_masking = agg.name in {"min", "max"} and plc.traits.is_floating_point(
|
|
113
|
+
child.dtype
|
|
114
|
+
)
|
|
115
|
+
if needs_masking and agg.options:
|
|
116
|
+
# pl.col("a").nan_max or nan_min
|
|
117
|
+
raise NotImplementedError("Nan propagation in groupby for min/max")
|
|
118
|
+
aggs, _ = decompose_single_agg(
|
|
119
|
+
expr.NamedExpr(next(name_generator), child), name_generator, is_top=False
|
|
120
|
+
)
|
|
121
|
+
if any(has_agg for _, has_agg in aggs):
|
|
122
|
+
raise NotImplementedError("Nested aggs in groupby not supported")
|
|
123
|
+
if needs_masking:
|
|
124
|
+
child = expr.UnaryFunction(child.dtype, "mask_nans", (), child)
|
|
125
|
+
# The aggregation is just reconstructed with the new
|
|
126
|
+
# (potentially masked) child. This is safe because we recursed
|
|
127
|
+
# to ensure there are no nested aggregations.
|
|
128
|
+
return (
|
|
129
|
+
[(named_expr.reconstruct(agg.reconstruct([child])), True)],
|
|
130
|
+
named_expr.reconstruct(expr.Col(agg.dtype, name)),
|
|
131
|
+
)
|
|
132
|
+
elif agg.name == "sum":
|
|
133
|
+
col = (
|
|
134
|
+
expr.Cast(agg.dtype, expr.Col(plc.DataType(plc.TypeId.INT64), name))
|
|
135
|
+
if (
|
|
136
|
+
plc.traits.is_integral(agg.dtype)
|
|
137
|
+
and agg.dtype.id() != plc.TypeId.INT64
|
|
138
|
+
)
|
|
139
|
+
else expr.Col(agg.dtype, name)
|
|
140
|
+
)
|
|
141
|
+
return [(named_expr, True)], expr.NamedExpr(
|
|
142
|
+
name,
|
|
143
|
+
# In polars sum(empty_group) => 0, but in libcudf
|
|
144
|
+
# sum(empty_group) => null So must post-process by
|
|
145
|
+
# replacing nulls, but only if we're a "top-level"
|
|
146
|
+
# agg.
|
|
147
|
+
replace_nulls(col, 0, is_top=is_top),
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
return [(named_expr, True)], named_expr.reconstruct(
|
|
151
|
+
expr.Col(agg.dtype, name)
|
|
152
|
+
)
|
|
153
|
+
if isinstance(agg, expr.Ternary):
|
|
154
|
+
raise NotImplementedError("Ternary inside groupby")
|
|
155
|
+
if agg.is_pointwise:
|
|
156
|
+
aggs, posts = _decompose_aggs(
|
|
157
|
+
(expr.NamedExpr(next(name_generator), child) for child in agg.children),
|
|
158
|
+
name_generator,
|
|
159
|
+
is_top=False,
|
|
160
|
+
)
|
|
161
|
+
if any(has_agg for _, has_agg in aggs):
|
|
162
|
+
if not all(
|
|
163
|
+
has_agg or isinstance(agg.value, expr.Literal) for agg, has_agg in aggs
|
|
164
|
+
):
|
|
165
|
+
raise NotImplementedError(
|
|
166
|
+
"Broadcasting aggregated expressions in groupby/rolling"
|
|
167
|
+
)
|
|
168
|
+
# Any pointwise expression can be handled either by
|
|
169
|
+
# post-evaluation (if outside an aggregation).
|
|
170
|
+
return (
|
|
171
|
+
aggs,
|
|
172
|
+
named_expr.reconstruct(agg.reconstruct([p.value for p in posts])),
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
# Or pre-evaluation if inside an aggregation.
|
|
176
|
+
return (
|
|
177
|
+
[(named_expr, False)],
|
|
178
|
+
named_expr.reconstruct(expr.Col(agg.dtype, name)),
|
|
179
|
+
)
|
|
180
|
+
raise NotImplementedError(f"No support for {type(agg)} in groupby/rolling")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _decompose_aggs(
|
|
184
|
+
aggs: Iterable[expr.NamedExpr],
|
|
185
|
+
name_generator: Generator[str, None, None],
|
|
186
|
+
*,
|
|
187
|
+
is_top: bool,
|
|
188
|
+
) -> tuple[list[tuple[expr.NamedExpr, bool]], Sequence[expr.NamedExpr]]:
|
|
189
|
+
new_aggs, post = zip(
|
|
190
|
+
*(decompose_single_agg(agg, name_generator, is_top=is_top) for agg in aggs),
|
|
191
|
+
strict=True,
|
|
192
|
+
)
|
|
193
|
+
return list(itertools.chain.from_iterable(new_aggs)), post
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def decompose_aggs(
|
|
197
|
+
aggs: Iterable[expr.NamedExpr], name_generator: Generator[str, None, None]
|
|
198
|
+
) -> tuple[list[expr.NamedExpr], Sequence[expr.NamedExpr]]:
|
|
199
|
+
"""
|
|
200
|
+
Process arbitrary aggregations into a form we can handle in grouped aggregations.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
aggs
|
|
205
|
+
List of aggregation expressions
|
|
206
|
+
name_generator
|
|
207
|
+
Generator of unique names for temporaries introduced during decomposition.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
aggregations
|
|
212
|
+
Aggregations to apply in the groupby node.
|
|
213
|
+
post_aggregations
|
|
214
|
+
Expressions to apply after aggregating (as a ``Select``).
|
|
215
|
+
|
|
216
|
+
Notes
|
|
217
|
+
-----
|
|
218
|
+
The aggregation expressions are guaranteed to either be
|
|
219
|
+
expressions that can be pointwise evaluated before the groupby
|
|
220
|
+
operation, or aggregations of such expressions.
|
|
221
|
+
|
|
222
|
+
Raises
|
|
223
|
+
------
|
|
224
|
+
NotImplementedError
|
|
225
|
+
For unsupported aggregation combinations.
|
|
226
|
+
"""
|
|
227
|
+
new_aggs, post = _decompose_aggs(aggs, name_generator, is_top=True)
|
|
228
|
+
return [agg for agg, _ in new_aggs], post
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def apply_pre_evaluation(
|
|
232
|
+
output_schema: Schema,
|
|
233
|
+
keys: Sequence[expr.NamedExpr],
|
|
234
|
+
original_aggs: Sequence[expr.NamedExpr],
|
|
235
|
+
name_generator: Generator[str, None, None],
|
|
236
|
+
*extra_columns: expr.NamedExpr,
|
|
237
|
+
) -> tuple[Sequence[expr.NamedExpr], Schema, Callable[[ir.IR], ir.IR]]:
|
|
238
|
+
"""
|
|
239
|
+
Apply pre-evaluation to aggregations in a grouped or rolling context.
|
|
240
|
+
|
|
241
|
+
Parameters
|
|
242
|
+
----------
|
|
243
|
+
output_schema
|
|
244
|
+
Schema of the plan node we're rewriting.
|
|
245
|
+
keys
|
|
246
|
+
Grouping keys (may be empty).
|
|
247
|
+
original_aggs
|
|
248
|
+
Aggregation expressions to rewrite.
|
|
249
|
+
name_generator
|
|
250
|
+
Generator of unique names for temporaries introduced during decomposition.
|
|
251
|
+
extra_columns
|
|
252
|
+
Any additional columns to be included in the output (only
|
|
253
|
+
relevant for rolling aggregations). Columns will appear in the
|
|
254
|
+
order `keys, extra_columns, original_aggs`.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
aggregations
|
|
259
|
+
The required aggregations.
|
|
260
|
+
schema
|
|
261
|
+
The new schema of the aggregation node
|
|
262
|
+
post_process
|
|
263
|
+
Function to apply to the aggregation node to apply any
|
|
264
|
+
post-processing.
|
|
265
|
+
|
|
266
|
+
Raises
|
|
267
|
+
------
|
|
268
|
+
NotImplementedError
|
|
269
|
+
If the aggregations are somehow unsupported.
|
|
270
|
+
"""
|
|
271
|
+
aggs, post = decompose_aggs(original_aggs, name_generator)
|
|
272
|
+
assert len(post) == len(original_aggs), (
|
|
273
|
+
f"Unexpected number of post-aggs {len(post)=} {len(original_aggs)=}"
|
|
274
|
+
)
|
|
275
|
+
# Order-preserving unique
|
|
276
|
+
aggs = list(dict.fromkeys(aggs).keys())
|
|
277
|
+
if any(not isinstance(e.value, expr.Col) for e in post):
|
|
278
|
+
selection = [
|
|
279
|
+
*(key.reconstruct(expr.Col(key.value.dtype, key.name)) for key in keys),
|
|
280
|
+
*extra_columns,
|
|
281
|
+
*post,
|
|
282
|
+
]
|
|
283
|
+
inter_schema = {
|
|
284
|
+
e.name: e.value.dtype for e in itertools.chain(keys, extra_columns, aggs)
|
|
285
|
+
}
|
|
286
|
+
return (
|
|
287
|
+
aggs,
|
|
288
|
+
inter_schema,
|
|
289
|
+
partial(ir.Select, output_schema, selection, True), # noqa: FBT003
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
return aggs, output_schema, lambda inp: inp
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Utilities for grouped aggregations."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import pylibcudf as plc
|
|
11
|
+
|
|
12
|
+
from cudf_polars.dsl import ir
|
|
13
|
+
from cudf_polars.dsl.utils.aggregations import apply_pre_evaluation
|
|
14
|
+
from cudf_polars.dsl.utils.naming import unique_names
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import Sequence
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
from cudf_polars.dsl import expr
|
|
21
|
+
from cudf_polars.utils import config
|
|
22
|
+
|
|
23
|
+
__all__ = ["rewrite_groupby"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def rewrite_groupby(
|
|
27
|
+
node: Any,
|
|
28
|
+
schema: dict[str, plc.DataType],
|
|
29
|
+
keys: Sequence[expr.NamedExpr],
|
|
30
|
+
aggs: Sequence[expr.NamedExpr],
|
|
31
|
+
config_options: config.ConfigOptions,
|
|
32
|
+
inp: ir.IR,
|
|
33
|
+
) -> ir.IR:
|
|
34
|
+
"""
|
|
35
|
+
Rewrite a groupby plan node into something we can handle.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
node
|
|
40
|
+
The polars groupby plan node.
|
|
41
|
+
schema
|
|
42
|
+
Schema of the groupby plan node.
|
|
43
|
+
keys
|
|
44
|
+
Grouping keys.
|
|
45
|
+
aggs
|
|
46
|
+
Originally requested aggregations.
|
|
47
|
+
config_options
|
|
48
|
+
Configuration options.
|
|
49
|
+
inp
|
|
50
|
+
Input plan node to the groupby.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
New plan node representing the grouped aggregations.
|
|
55
|
+
|
|
56
|
+
Raises
|
|
57
|
+
------
|
|
58
|
+
NotImplementedError
|
|
59
|
+
If any of the requested aggregations are unsupported.
|
|
60
|
+
|
|
61
|
+
Notes
|
|
62
|
+
-----
|
|
63
|
+
Since libcudf can only perform grouped aggregations on columns
|
|
64
|
+
(not arbitrary expressions), the approach is to split each
|
|
65
|
+
aggregation into a pre-selection phase (evaluating expressions
|
|
66
|
+
that live within an aggregation), the aggregation phase (now
|
|
67
|
+
acting on columns only), and a post-selection phase (evaluating
|
|
68
|
+
expressions of aggregated results).
|
|
69
|
+
|
|
70
|
+
This does scheme does not permit nested aggregations, so those are
|
|
71
|
+
unsupported.
|
|
72
|
+
"""
|
|
73
|
+
if len(aggs) == 0:
|
|
74
|
+
return ir.Distinct(
|
|
75
|
+
schema,
|
|
76
|
+
plc.stream_compaction.DuplicateKeepOption.KEEP_ANY,
|
|
77
|
+
None,
|
|
78
|
+
node.options.slice,
|
|
79
|
+
node.maintain_order,
|
|
80
|
+
ir.Select(schema, keys, True, inp), # noqa: FBT003
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
aggs, group_schema, apply_post_evaluation = apply_pre_evaluation(
|
|
84
|
+
schema, keys, aggs, unique_names(schema.keys())
|
|
85
|
+
)
|
|
86
|
+
# TODO: use Distinct when the partitioned executor supports it if
|
|
87
|
+
# the requested aggregations are empty
|
|
88
|
+
inp = ir.GroupBy(
|
|
89
|
+
group_schema,
|
|
90
|
+
keys,
|
|
91
|
+
aggs,
|
|
92
|
+
node.maintain_order,
|
|
93
|
+
node.options.slice,
|
|
94
|
+
config_options,
|
|
95
|
+
inp,
|
|
96
|
+
)
|
|
97
|
+
return apply_post_evaluation(inp)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Name generation utilities."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from collections.abc import Generator, Iterable
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["unique_names"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def unique_names(names: Iterable[str]) -> Generator[str, None, None]:
|
|
18
|
+
"""
|
|
19
|
+
Generate unique names relative to some known names.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
names
|
|
24
|
+
Names we should be unique with respect to.
|
|
25
|
+
|
|
26
|
+
Yields
|
|
27
|
+
------
|
|
28
|
+
Unique names (just using sequence numbers)
|
|
29
|
+
"""
|
|
30
|
+
prefix = "_" * max(map(len, names))
|
|
31
|
+
i = 0
|
|
32
|
+
while True:
|
|
33
|
+
yield f"{prefix}{i}"
|
|
34
|
+
i += 1
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Utilities for replacing nodes in a DAG."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from cudf_polars.dsl.traversal import CachingVisitor, reuse_if_unchanged
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import Mapping, Sequence
|
|
14
|
+
|
|
15
|
+
from cudf_polars.typing import GenericTransformer, NodeT
|
|
16
|
+
|
|
17
|
+
__all__ = ["replace"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _replace(node: NodeT, fn: GenericTransformer[NodeT, NodeT]) -> NodeT:
|
|
21
|
+
try:
|
|
22
|
+
return fn.state["replacements"][node]
|
|
23
|
+
except KeyError:
|
|
24
|
+
return reuse_if_unchanged(node, fn)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def replace(nodes: Sequence[NodeT], replacements: Mapping[NodeT, NodeT]) -> list[NodeT]:
|
|
28
|
+
"""
|
|
29
|
+
Replace nodes in expressions.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
nodes
|
|
34
|
+
Sequence of nodes to perform replacements in.
|
|
35
|
+
replacements
|
|
36
|
+
Mapping from nodes to be replaced to their replacements.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
list
|
|
41
|
+
Of nodes with replacements performed.
|
|
42
|
+
"""
|
|
43
|
+
mapper: GenericTransformer[NodeT, NodeT] = CachingVisitor(
|
|
44
|
+
_replace, state={"replacements": replacements}
|
|
45
|
+
)
|
|
46
|
+
return [mapper(node) for node in nodes]
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Utilities for rolling window aggregations."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import pylibcudf as plc
|
|
11
|
+
|
|
12
|
+
from cudf_polars.dsl import expr, ir
|
|
13
|
+
from cudf_polars.dsl.utils.aggregations import apply_pre_evaluation
|
|
14
|
+
from cudf_polars.dsl.utils.naming import unique_names
|
|
15
|
+
from cudf_polars.dsl.utils.windows import offsets_to_windows
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Sequence
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from cudf_polars.typing import Schema
|
|
22
|
+
from cudf_polars.utils import config
|
|
23
|
+
|
|
24
|
+
__all__ = ["rewrite_rolling"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def rewrite_rolling(
|
|
28
|
+
options: Any,
|
|
29
|
+
schema: Schema,
|
|
30
|
+
keys: Sequence[expr.NamedExpr],
|
|
31
|
+
aggs: Sequence[expr.NamedExpr],
|
|
32
|
+
config_options: config.ConfigOptions,
|
|
33
|
+
inp: ir.IR,
|
|
34
|
+
) -> ir.IR:
|
|
35
|
+
"""
|
|
36
|
+
Rewrite a rolling plan node into something we can handle.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
options
|
|
41
|
+
Rolling-specific group options.
|
|
42
|
+
schema
|
|
43
|
+
Schema of the rolling plan node.
|
|
44
|
+
keys
|
|
45
|
+
Grouping keys for the rolling node (may be empty).
|
|
46
|
+
aggs
|
|
47
|
+
Originally requested rolling aggregations.
|
|
48
|
+
config_options
|
|
49
|
+
Configuration options (currently unused).
|
|
50
|
+
inp
|
|
51
|
+
Input plan node to the rolling aggregation.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
New plan node representing the rolling aggregations
|
|
56
|
+
|
|
57
|
+
Raises
|
|
58
|
+
------
|
|
59
|
+
NotImplementedError
|
|
60
|
+
If any of the requested aggregations are unsupported.
|
|
61
|
+
|
|
62
|
+
Notes
|
|
63
|
+
-----
|
|
64
|
+
Since libcudf can only perform rolling aggregations on columns
|
|
65
|
+
(not arbitrary expressions), the approach is to split each
|
|
66
|
+
aggregation into a pre-selection phase (evaluating expressions
|
|
67
|
+
that live within an aggregation), the aggregation phase (now
|
|
68
|
+
acting on columns only), and a post-selection phase (evaluating
|
|
69
|
+
expressions of aggregated results).
|
|
70
|
+
This scheme does not permit nested aggregations, so those are
|
|
71
|
+
unsupported.
|
|
72
|
+
"""
|
|
73
|
+
index_name = options.rolling.index_column
|
|
74
|
+
index_dtype = schema[index_name]
|
|
75
|
+
index_col = expr.Col(index_dtype, index_name)
|
|
76
|
+
if plc.traits.is_integral(index_dtype) and index_dtype.id() != plc.TypeId.INT64:
|
|
77
|
+
index_dtype = plc.DataType(plc.TypeId.INT64)
|
|
78
|
+
index = expr.NamedExpr(index_name, index_col)
|
|
79
|
+
temp_prefix = "_" * max(map(len, schema))
|
|
80
|
+
if len(aggs) > 0:
|
|
81
|
+
aggs, rolling_schema, apply_post_evaluation = apply_pre_evaluation(
|
|
82
|
+
schema, keys, aggs, unique_names(temp_prefix), index
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
rolling_schema = schema
|
|
86
|
+
apply_post_evaluation = lambda inp: inp # noqa: E731
|
|
87
|
+
preceding, following = offsets_to_windows(
|
|
88
|
+
index_dtype, options.rolling.offset, options.rolling.period
|
|
89
|
+
)
|
|
90
|
+
if (n := len(keys)) > 0:
|
|
91
|
+
# Grouped rolling in polars sorts the output by the groups.
|
|
92
|
+
inp = ir.Sort(
|
|
93
|
+
inp.schema,
|
|
94
|
+
keys,
|
|
95
|
+
[plc.types.Order.ASCENDING] * n,
|
|
96
|
+
[plc.types.NullOrder.BEFORE] * n,
|
|
97
|
+
True, # noqa: FBT003
|
|
98
|
+
None,
|
|
99
|
+
inp,
|
|
100
|
+
)
|
|
101
|
+
return apply_post_evaluation(
|
|
102
|
+
ir.Rolling(
|
|
103
|
+
rolling_schema,
|
|
104
|
+
index,
|
|
105
|
+
preceding,
|
|
106
|
+
following,
|
|
107
|
+
options.rolling.closed_window,
|
|
108
|
+
keys,
|
|
109
|
+
aggs,
|
|
110
|
+
options.slice,
|
|
111
|
+
inp,
|
|
112
|
+
)
|
|
113
|
+
)
|