cudf-polars-cu13 25.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -0
- cudf_polars/VERSION +1 -0
- cudf_polars/__init__.py +28 -0
- cudf_polars/_version.py +21 -0
- cudf_polars/callback.py +318 -0
- cudf_polars/containers/__init__.py +13 -0
- cudf_polars/containers/column.py +495 -0
- cudf_polars/containers/dataframe.py +361 -0
- cudf_polars/containers/datatype.py +137 -0
- cudf_polars/dsl/__init__.py +8 -0
- cudf_polars/dsl/expr.py +66 -0
- cudf_polars/dsl/expressions/__init__.py +8 -0
- cudf_polars/dsl/expressions/aggregation.py +226 -0
- cudf_polars/dsl/expressions/base.py +272 -0
- cudf_polars/dsl/expressions/binaryop.py +120 -0
- cudf_polars/dsl/expressions/boolean.py +326 -0
- cudf_polars/dsl/expressions/datetime.py +271 -0
- cudf_polars/dsl/expressions/literal.py +97 -0
- cudf_polars/dsl/expressions/rolling.py +643 -0
- cudf_polars/dsl/expressions/selection.py +74 -0
- cudf_polars/dsl/expressions/slicing.py +46 -0
- cudf_polars/dsl/expressions/sorting.py +85 -0
- cudf_polars/dsl/expressions/string.py +1002 -0
- cudf_polars/dsl/expressions/struct.py +137 -0
- cudf_polars/dsl/expressions/ternary.py +49 -0
- cudf_polars/dsl/expressions/unary.py +517 -0
- cudf_polars/dsl/ir.py +2607 -0
- cudf_polars/dsl/nodebase.py +164 -0
- cudf_polars/dsl/to_ast.py +359 -0
- cudf_polars/dsl/tracing.py +16 -0
- cudf_polars/dsl/translate.py +939 -0
- cudf_polars/dsl/traversal.py +224 -0
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +481 -0
- cudf_polars/dsl/utils/groupby.py +98 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +61 -0
- cudf_polars/dsl/utils/reshape.py +74 -0
- cudf_polars/dsl/utils/rolling.py +121 -0
- cudf_polars/dsl/utils/windows.py +192 -0
- cudf_polars/experimental/__init__.py +8 -0
- cudf_polars/experimental/base.py +386 -0
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds.py +220 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
- cudf_polars/experimental/benchmarks/pdsh.py +814 -0
- cudf_polars/experimental/benchmarks/utils.py +832 -0
- cudf_polars/experimental/dask_registers.py +200 -0
- cudf_polars/experimental/dispatch.py +156 -0
- cudf_polars/experimental/distinct.py +197 -0
- cudf_polars/experimental/explain.py +157 -0
- cudf_polars/experimental/expressions.py +590 -0
- cudf_polars/experimental/groupby.py +327 -0
- cudf_polars/experimental/io.py +943 -0
- cudf_polars/experimental/join.py +391 -0
- cudf_polars/experimental/parallel.py +423 -0
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +188 -0
- cudf_polars/experimental/shuffle.py +354 -0
- cudf_polars/experimental/sort.py +609 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/statistics.py +795 -0
- cudf_polars/experimental/utils.py +169 -0
- cudf_polars/py.typed +0 -0
- cudf_polars/testing/__init__.py +8 -0
- cudf_polars/testing/asserts.py +448 -0
- cudf_polars/testing/io.py +122 -0
- cudf_polars/testing/plugin.py +236 -0
- cudf_polars/typing/__init__.py +219 -0
- cudf_polars/utils/__init__.py +8 -0
- cudf_polars/utils/config.py +741 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +118 -0
- cudf_polars/utils/sorting.py +53 -0
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +27 -0
- cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
- cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
- cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
- cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
- cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,643 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
# TODO: remove need for this
|
|
4
|
+
# ruff: noqa: D101
|
|
5
|
+
"""Rolling DSL nodes."""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import itertools
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from functools import singledispatchmethod
|
|
12
|
+
from typing import TYPE_CHECKING, Any
|
|
13
|
+
|
|
14
|
+
import polars as pl
|
|
15
|
+
|
|
16
|
+
import pylibcudf as plc
|
|
17
|
+
|
|
18
|
+
from cudf_polars.containers import Column, DataFrame, DataType
|
|
19
|
+
from cudf_polars.dsl import expr
|
|
20
|
+
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
|
|
21
|
+
from cudf_polars.dsl.utils.reshape import broadcast
|
|
22
|
+
from cudf_polars.dsl.utils.windows import offsets_to_windows, range_window_bounds
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from collections.abc import Generator, Sequence
|
|
26
|
+
|
|
27
|
+
from cudf_polars.typing import ClosedInterval, Duration
|
|
28
|
+
|
|
29
|
+
__all__ = ["GroupedRollingWindow", "RollingWindow", "to_request"]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class UnaryOp:
|
|
34
|
+
named_exprs: list[expr.NamedExpr]
|
|
35
|
+
order_index: plc.Column | None = None
|
|
36
|
+
by_cols_for_scan: list[Column] | None = None
|
|
37
|
+
local_grouper: plc.groupby.GroupBy | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class RankOp(UnaryOp):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def to_request(
|
|
46
|
+
value: expr.Expr, orderby: Column, df: DataFrame
|
|
47
|
+
) -> plc.rolling.RollingRequest:
|
|
48
|
+
"""
|
|
49
|
+
Produce a rolling request for evaluation with pylibcudf.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
value
|
|
54
|
+
The expression to perform the rolling aggregation on.
|
|
55
|
+
orderby
|
|
56
|
+
Orderby column, used as input to the request when the aggregation is Len.
|
|
57
|
+
df
|
|
58
|
+
DataFrame used to evaluate the inputs to the aggregation.
|
|
59
|
+
"""
|
|
60
|
+
min_periods = 1
|
|
61
|
+
if isinstance(value, expr.Len):
|
|
62
|
+
# A count aggregation, we need a column so use the orderby column
|
|
63
|
+
col = orderby
|
|
64
|
+
elif isinstance(value, expr.Agg):
|
|
65
|
+
child = value.children[0]
|
|
66
|
+
col = child.evaluate(df, context=ExecutionContext.ROLLING)
|
|
67
|
+
if value.name == "var":
|
|
68
|
+
# Polars variance produces null if nvalues <= ddof
|
|
69
|
+
# libcudf produces NaN. However, we can get the polars
|
|
70
|
+
# behaviour by setting the minimum window size to ddof +
|
|
71
|
+
# 1.
|
|
72
|
+
min_periods = value.options + 1
|
|
73
|
+
else:
|
|
74
|
+
col = value.evaluate(
|
|
75
|
+
df, context=ExecutionContext.ROLLING
|
|
76
|
+
) # pragma: no cover; raise before we get here because we
|
|
77
|
+
# don't do correct handling of empty groups
|
|
78
|
+
return plc.rolling.RollingRequest(col.obj, min_periods, value.agg_request)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _by_exprs(b: Expr | tuple) -> Generator[Expr]:
|
|
82
|
+
if isinstance(b, Expr):
|
|
83
|
+
yield b
|
|
84
|
+
elif isinstance(b, tuple): # pragma: no cover; tests cover this path when
|
|
85
|
+
# run with the distributed scheduler only
|
|
86
|
+
for item in b:
|
|
87
|
+
yield from _by_exprs(item)
|
|
88
|
+
else:
|
|
89
|
+
yield expr.Literal(DataType(pl.Int64()), b) # pragma: no cover
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class RollingWindow(Expr):
|
|
93
|
+
__slots__ = (
|
|
94
|
+
"closed_window",
|
|
95
|
+
"following",
|
|
96
|
+
"offset",
|
|
97
|
+
"orderby",
|
|
98
|
+
"orderby_dtype",
|
|
99
|
+
"period",
|
|
100
|
+
"preceding",
|
|
101
|
+
)
|
|
102
|
+
_non_child = (
|
|
103
|
+
"dtype",
|
|
104
|
+
"orderby_dtype",
|
|
105
|
+
"offset",
|
|
106
|
+
"period",
|
|
107
|
+
"closed_window",
|
|
108
|
+
"orderby",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
dtype: DataType,
|
|
114
|
+
orderby_dtype: DataType,
|
|
115
|
+
offset: Duration,
|
|
116
|
+
period: Duration,
|
|
117
|
+
closed_window: ClosedInterval,
|
|
118
|
+
orderby: str,
|
|
119
|
+
agg: Expr,
|
|
120
|
+
) -> None:
|
|
121
|
+
self.dtype = dtype
|
|
122
|
+
self.orderby_dtype = orderby_dtype
|
|
123
|
+
# NOTE: Save original `offset` and `period` args,
|
|
124
|
+
# because the `preceding` and `following` attributes
|
|
125
|
+
# cannot be serialized (and must be reconstructed
|
|
126
|
+
# within `__init__`).
|
|
127
|
+
self.offset = offset
|
|
128
|
+
self.period = period
|
|
129
|
+
self.preceding, self.following = offsets_to_windows(
|
|
130
|
+
orderby_dtype, offset, period
|
|
131
|
+
)
|
|
132
|
+
self.closed_window = closed_window
|
|
133
|
+
self.orderby = orderby
|
|
134
|
+
self.children = (agg,)
|
|
135
|
+
self.is_pointwise = False
|
|
136
|
+
if agg.agg_request.kind() == plc.aggregation.Kind.COLLECT_LIST:
|
|
137
|
+
raise NotImplementedError(
|
|
138
|
+
"Incorrect handling of empty groups for list collection"
|
|
139
|
+
)
|
|
140
|
+
if not plc.rolling.is_valid_rolling_aggregation(agg.dtype.plc, agg.agg_request):
|
|
141
|
+
raise NotImplementedError(f"Unsupported rolling aggregation {agg}")
|
|
142
|
+
|
|
143
|
+
def do_evaluate( # noqa: D102
|
|
144
|
+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
|
|
145
|
+
) -> Column:
|
|
146
|
+
if context != ExecutionContext.FRAME:
|
|
147
|
+
raise RuntimeError(
|
|
148
|
+
"Rolling aggregation inside groupby/over/rolling"
|
|
149
|
+
) # pragma: no cover; translation raises first
|
|
150
|
+
(agg,) = self.children
|
|
151
|
+
orderby = df.column_map[self.orderby]
|
|
152
|
+
# Polars casts integral orderby to int64, but only for calculating window bounds
|
|
153
|
+
if (
|
|
154
|
+
plc.traits.is_integral(orderby.obj.type())
|
|
155
|
+
and orderby.obj.type().id() != plc.TypeId.INT64
|
|
156
|
+
):
|
|
157
|
+
orderby_obj = plc.unary.cast(orderby.obj, plc.DataType(plc.TypeId.INT64))
|
|
158
|
+
else:
|
|
159
|
+
orderby_obj = orderby.obj
|
|
160
|
+
preceding, following = range_window_bounds(
|
|
161
|
+
self.preceding, self.following, self.closed_window
|
|
162
|
+
)
|
|
163
|
+
if orderby.obj.null_count() != 0:
|
|
164
|
+
raise RuntimeError(
|
|
165
|
+
f"Index column '{self.orderby}' in rolling may not contain nulls"
|
|
166
|
+
)
|
|
167
|
+
if not orderby.check_sorted(
|
|
168
|
+
order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE
|
|
169
|
+
):
|
|
170
|
+
raise RuntimeError(
|
|
171
|
+
f"Index column '{self.orderby}' in rolling is not sorted, please sort first"
|
|
172
|
+
)
|
|
173
|
+
(result,) = plc.rolling.grouped_range_rolling_window(
|
|
174
|
+
plc.Table([]),
|
|
175
|
+
orderby_obj,
|
|
176
|
+
plc.types.Order.ASCENDING,
|
|
177
|
+
plc.types.NullOrder.BEFORE,
|
|
178
|
+
preceding,
|
|
179
|
+
following,
|
|
180
|
+
[to_request(agg, orderby, df)],
|
|
181
|
+
).columns()
|
|
182
|
+
return Column(result, dtype=self.dtype)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class GroupedRollingWindow(Expr):
|
|
186
|
+
"""
|
|
187
|
+
Compute a window ``.over(...)`` aggregation and broadcast to rows.
|
|
188
|
+
|
|
189
|
+
Notes
|
|
190
|
+
-----
|
|
191
|
+
- This expression node currently implements **grouped window mapping**
|
|
192
|
+
(aggregate once per group, then broadcast back), not rolling windows.
|
|
193
|
+
- It can be extended later to support `rolling(...).over(...)`
|
|
194
|
+
when polars supports that expression.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
__slots__ = (
|
|
198
|
+
"_order_by_expr",
|
|
199
|
+
"by_count",
|
|
200
|
+
"named_aggs",
|
|
201
|
+
"options",
|
|
202
|
+
"post",
|
|
203
|
+
)
|
|
204
|
+
_non_child = (
|
|
205
|
+
"dtype",
|
|
206
|
+
"options",
|
|
207
|
+
"named_aggs",
|
|
208
|
+
"post",
|
|
209
|
+
"by_count",
|
|
210
|
+
"_order_by_expr",
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
dtype: DataType,
|
|
216
|
+
options: Any,
|
|
217
|
+
named_aggs: Sequence[expr.NamedExpr],
|
|
218
|
+
post: expr.NamedExpr,
|
|
219
|
+
*by: Expr,
|
|
220
|
+
_order_by_expr: Expr | None = None,
|
|
221
|
+
) -> None:
|
|
222
|
+
self.dtype = dtype
|
|
223
|
+
self.options = options
|
|
224
|
+
self.named_aggs = tuple(named_aggs)
|
|
225
|
+
self.post = post
|
|
226
|
+
self.is_pointwise = False
|
|
227
|
+
self._order_by_expr = _order_by_expr
|
|
228
|
+
|
|
229
|
+
unsupported = [
|
|
230
|
+
type(named_expr.value).__name__
|
|
231
|
+
for named_expr in self.named_aggs
|
|
232
|
+
if not (
|
|
233
|
+
isinstance(named_expr.value, (expr.Len, expr.Agg))
|
|
234
|
+
or (
|
|
235
|
+
isinstance(named_expr.value, expr.UnaryFunction)
|
|
236
|
+
and named_expr.value.name in {"rank"}
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
]
|
|
240
|
+
if unsupported:
|
|
241
|
+
kinds = ", ".join(sorted(set(unsupported)))
|
|
242
|
+
raise NotImplementedError(
|
|
243
|
+
f"Unsupported over(...) only expression: {kinds}="
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Ensures every partition-by is an Expr
|
|
247
|
+
# Fixes over(1) cases with the streaming
|
|
248
|
+
# executor and a small blocksize
|
|
249
|
+
by_expr = [e for b in by for e in _by_exprs(b)]
|
|
250
|
+
|
|
251
|
+
# Expose agg dependencies as children so the streaming
|
|
252
|
+
# executor retains required source columns
|
|
253
|
+
child_deps = [
|
|
254
|
+
v.children[0]
|
|
255
|
+
for ne in self.named_aggs
|
|
256
|
+
for v in (ne.value,)
|
|
257
|
+
if isinstance(v, expr.Agg)
|
|
258
|
+
or (isinstance(v, expr.UnaryFunction) and v.name in {"rank"})
|
|
259
|
+
]
|
|
260
|
+
self.by_count = len(by_expr)
|
|
261
|
+
self.children = tuple(
|
|
262
|
+
itertools.chain(
|
|
263
|
+
by_expr,
|
|
264
|
+
(() if self._order_by_expr is None else (self._order_by_expr,)),
|
|
265
|
+
child_deps,
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
@staticmethod
|
|
270
|
+
def _sorted_grouper(by_cols_for_scan: list[Column]) -> plc.groupby.GroupBy:
|
|
271
|
+
return plc.groupby.GroupBy(
|
|
272
|
+
plc.Table([c.obj for c in by_cols_for_scan]),
|
|
273
|
+
null_handling=plc.types.NullPolicy.INCLUDE,
|
|
274
|
+
keys_are_sorted=plc.types.Sorted.YES,
|
|
275
|
+
column_order=[k.order for k in by_cols_for_scan],
|
|
276
|
+
null_precedence=[k.null_order for k in by_cols_for_scan],
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
@singledispatchmethod
|
|
280
|
+
def _apply_unary_op(
|
|
281
|
+
self,
|
|
282
|
+
op: UnaryOp,
|
|
283
|
+
_: DataFrame,
|
|
284
|
+
__: plc.groupby.GroupBy,
|
|
285
|
+
) -> tuple[list[str], list[DataType], list[plc.Table]]:
|
|
286
|
+
raise NotImplementedError(
|
|
287
|
+
f"Unsupported unary op: {type(op).__name__}"
|
|
288
|
+
) # pragma: no cover; translation raises first
|
|
289
|
+
|
|
290
|
+
@_apply_unary_op.register
|
|
291
|
+
def _(
|
|
292
|
+
self,
|
|
293
|
+
op: RankOp,
|
|
294
|
+
df: DataFrame,
|
|
295
|
+
grouper: plc.groupby.GroupBy,
|
|
296
|
+
) -> tuple[list[str], list[DataType], list[plc.Table]]:
|
|
297
|
+
rank_named = op.named_exprs
|
|
298
|
+
order_index = op.order_index
|
|
299
|
+
by_cols_for_scan = op.by_cols_for_scan
|
|
300
|
+
|
|
301
|
+
rank_requests: list[plc.groupby.GroupByRequest] = []
|
|
302
|
+
rank_out_names: list[str] = []
|
|
303
|
+
rank_out_dtypes: list[DataType] = []
|
|
304
|
+
|
|
305
|
+
for ne in rank_named:
|
|
306
|
+
rank_expr = ne.value
|
|
307
|
+
(child_expr,) = rank_expr.children
|
|
308
|
+
val_col = child_expr.evaluate(df, context=ExecutionContext.FRAME).obj
|
|
309
|
+
if order_index is not None:
|
|
310
|
+
val_col = plc.copying.gather(
|
|
311
|
+
plc.Table([val_col]),
|
|
312
|
+
order_index,
|
|
313
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
314
|
+
).columns()[0]
|
|
315
|
+
assert isinstance(rank_expr, expr.UnaryFunction)
|
|
316
|
+
method_str, descending, _ = rank_expr.options
|
|
317
|
+
|
|
318
|
+
rank_method = {
|
|
319
|
+
"average": plc.aggregation.RankMethod.AVERAGE,
|
|
320
|
+
"min": plc.aggregation.RankMethod.MIN,
|
|
321
|
+
"max": plc.aggregation.RankMethod.MAX,
|
|
322
|
+
"dense": plc.aggregation.RankMethod.DENSE,
|
|
323
|
+
"ordinal": plc.aggregation.RankMethod.FIRST,
|
|
324
|
+
}[method_str]
|
|
325
|
+
|
|
326
|
+
order = (
|
|
327
|
+
plc.types.Order.DESCENDING if descending else plc.types.Order.ASCENDING
|
|
328
|
+
)
|
|
329
|
+
# Polars semantics: exclude nulls from domain; nulls get null ranks.
|
|
330
|
+
null_precedence = (
|
|
331
|
+
plc.types.NullOrder.BEFORE if descending else plc.types.NullOrder.AFTER
|
|
332
|
+
)
|
|
333
|
+
agg = plc.aggregation.rank(
|
|
334
|
+
rank_method,
|
|
335
|
+
column_order=order,
|
|
336
|
+
null_handling=plc.types.NullPolicy.EXCLUDE,
|
|
337
|
+
null_precedence=null_precedence,
|
|
338
|
+
percentage=plc.aggregation.RankPercentage.NONE,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
rank_requests.append(plc.groupby.GroupByRequest(val_col, [agg]))
|
|
342
|
+
rank_out_names.append(ne.name)
|
|
343
|
+
rank_out_dtypes.append(rank_expr.dtype)
|
|
344
|
+
|
|
345
|
+
if order_index is not None and by_cols_for_scan is not None:
|
|
346
|
+
# order_by expressions require us order each group
|
|
347
|
+
lg = op.local_grouper
|
|
348
|
+
assert isinstance(lg, plc.groupby.GroupBy)
|
|
349
|
+
_, rank_tables = lg.scan(rank_requests)
|
|
350
|
+
else:
|
|
351
|
+
_, rank_tables = grouper.scan(rank_requests)
|
|
352
|
+
return rank_out_names, rank_out_dtypes, rank_tables
|
|
353
|
+
|
|
354
|
+
def _reorder_to_input(
|
|
355
|
+
self,
|
|
356
|
+
row_id: plc.Column,
|
|
357
|
+
by_cols: list[Column],
|
|
358
|
+
n_rows: int,
|
|
359
|
+
rank_tables: list[plc.Table],
|
|
360
|
+
rank_out_names: list[str],
|
|
361
|
+
rank_out_dtypes: list[DataType],
|
|
362
|
+
*,
|
|
363
|
+
order_index: plc.Column | None = None,
|
|
364
|
+
) -> list[Column]:
|
|
365
|
+
# Reorder scan results from grouped-order back to input row order
|
|
366
|
+
if order_index is None:
|
|
367
|
+
key_orders = [k.order for k in by_cols]
|
|
368
|
+
key_nulls = [k.null_order for k in by_cols]
|
|
369
|
+
order_index = plc.sorting.stable_sorted_order(
|
|
370
|
+
plc.Table([*(c.obj for c in by_cols), row_id]),
|
|
371
|
+
[*key_orders, plc.types.Order.ASCENDING],
|
|
372
|
+
[*key_nulls, plc.types.NullOrder.AFTER],
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
return [
|
|
376
|
+
Column(
|
|
377
|
+
plc.copying.scatter(
|
|
378
|
+
plc.Table([tbl.columns()[0]]),
|
|
379
|
+
order_index,
|
|
380
|
+
plc.Table(
|
|
381
|
+
[
|
|
382
|
+
plc.Column.from_scalar(
|
|
383
|
+
plc.Scalar.from_py(None, tbl.columns()[0].type()),
|
|
384
|
+
n_rows,
|
|
385
|
+
)
|
|
386
|
+
]
|
|
387
|
+
),
|
|
388
|
+
).columns()[0],
|
|
389
|
+
name=name,
|
|
390
|
+
dtype=dtype,
|
|
391
|
+
)
|
|
392
|
+
for name, dtype, tbl in zip(
|
|
393
|
+
rank_out_names, rank_out_dtypes, rank_tables, strict=True
|
|
394
|
+
)
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
def _split_named_expr(
|
|
398
|
+
self,
|
|
399
|
+
) -> tuple[list[expr.NamedExpr], dict[str, list[expr.NamedExpr]]]:
|
|
400
|
+
"""Split into reductions vs unary window operations."""
|
|
401
|
+
reductions: list[expr.NamedExpr] = []
|
|
402
|
+
unary_window_ops: dict[str, list[expr.NamedExpr]] = {
|
|
403
|
+
"rank": [],
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
for ne in self.named_aggs:
|
|
407
|
+
v = ne.value
|
|
408
|
+
if isinstance(v, expr.UnaryFunction) and v.name in unary_window_ops:
|
|
409
|
+
unary_window_ops[v.name].append(ne)
|
|
410
|
+
else:
|
|
411
|
+
reductions.append(ne)
|
|
412
|
+
return reductions, unary_window_ops
|
|
413
|
+
|
|
414
|
+
def _build_window_order_index(
|
|
415
|
+
self,
|
|
416
|
+
by_cols: list[Column],
|
|
417
|
+
*,
|
|
418
|
+
row_id: plc.Column,
|
|
419
|
+
order_by_col: Column | None,
|
|
420
|
+
ob_desc: bool,
|
|
421
|
+
ob_nulls_last: bool,
|
|
422
|
+
value_col: plc.Column | None = None,
|
|
423
|
+
value_desc: bool = False,
|
|
424
|
+
) -> plc.Column:
|
|
425
|
+
"""Compute a stable row ordering for unary operations in a grouped context."""
|
|
426
|
+
cols: list[plc.Column] = [c.obj for c in by_cols]
|
|
427
|
+
orders: list[plc.types.Order] = [k.order for k in by_cols]
|
|
428
|
+
nulls: list[plc.types.NullOrder] = [k.null_order for k in by_cols]
|
|
429
|
+
|
|
430
|
+
if value_col is not None:
|
|
431
|
+
# for rank(...).over(...) the ranked ("sorted") order takes precedence over order_by
|
|
432
|
+
cols.append(value_col)
|
|
433
|
+
orders.append(
|
|
434
|
+
plc.types.Order.DESCENDING if value_desc else plc.types.Order.ASCENDING
|
|
435
|
+
)
|
|
436
|
+
nulls.append(
|
|
437
|
+
plc.types.NullOrder.BEFORE if value_desc else plc.types.NullOrder.AFTER
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if order_by_col is not None:
|
|
441
|
+
cols.append(order_by_col.obj)
|
|
442
|
+
orders.append(
|
|
443
|
+
plc.types.Order.DESCENDING if ob_desc else plc.types.Order.ASCENDING
|
|
444
|
+
)
|
|
445
|
+
nulls.append(
|
|
446
|
+
plc.types.NullOrder.AFTER
|
|
447
|
+
if ob_nulls_last
|
|
448
|
+
else plc.types.NullOrder.BEFORE
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# Use the row id to break ties
|
|
452
|
+
cols.append(row_id)
|
|
453
|
+
orders.append(plc.types.Order.ASCENDING)
|
|
454
|
+
nulls.append(plc.types.NullOrder.AFTER)
|
|
455
|
+
|
|
456
|
+
return plc.sorting.stable_sorted_order(plc.Table(cols), orders, nulls)
|
|
457
|
+
|
|
458
|
+
def _gather_columns(
|
|
459
|
+
self,
|
|
460
|
+
cols: list[Column],
|
|
461
|
+
order_index: plc.Column,
|
|
462
|
+
) -> list[plc.Column] | list[Column]:
|
|
463
|
+
gathered_tbl = plc.copying.gather(
|
|
464
|
+
plc.Table([c.obj for c in cols]),
|
|
465
|
+
order_index,
|
|
466
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
return [
|
|
470
|
+
Column(
|
|
471
|
+
gathered_tbl.columns()[i],
|
|
472
|
+
name=c.name,
|
|
473
|
+
dtype=c.dtype,
|
|
474
|
+
order=c.order,
|
|
475
|
+
null_order=c.null_order,
|
|
476
|
+
is_sorted=True,
|
|
477
|
+
)
|
|
478
|
+
for i, c in enumerate(cols)
|
|
479
|
+
]
|
|
480
|
+
|
|
481
|
+
def do_evaluate( # noqa: D102
|
|
482
|
+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
|
|
483
|
+
) -> Column:
|
|
484
|
+
if context != ExecutionContext.FRAME:
|
|
485
|
+
raise RuntimeError(
|
|
486
|
+
"Window mapping (.over) can only be evaluated at the frame level"
|
|
487
|
+
) # pragma: no cover; translation raises first
|
|
488
|
+
|
|
489
|
+
by_exprs = self.children[: self.by_count]
|
|
490
|
+
order_by_expr = (
|
|
491
|
+
self.children[self.by_count] if self._order_by_expr is not None else None
|
|
492
|
+
)
|
|
493
|
+
by_cols = broadcast(
|
|
494
|
+
*(b.evaluate(df) for b in by_exprs),
|
|
495
|
+
target_length=df.num_rows,
|
|
496
|
+
)
|
|
497
|
+
order_by_col = (
|
|
498
|
+
broadcast(order_by_expr.evaluate(df), target_length=df.num_rows)[0]
|
|
499
|
+
if order_by_expr is not None
|
|
500
|
+
else None
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
by_tbl = plc.Table([c.obj for c in by_cols])
|
|
504
|
+
|
|
505
|
+
sorted_flag = (
|
|
506
|
+
plc.types.Sorted.YES
|
|
507
|
+
if all(k.is_sorted for k in by_cols)
|
|
508
|
+
else plc.types.Sorted.NO
|
|
509
|
+
)
|
|
510
|
+
grouper = plc.groupby.GroupBy(
|
|
511
|
+
by_tbl,
|
|
512
|
+
null_handling=plc.types.NullPolicy.INCLUDE,
|
|
513
|
+
keys_are_sorted=sorted_flag,
|
|
514
|
+
column_order=[k.order for k in by_cols],
|
|
515
|
+
null_precedence=[k.null_order for k in by_cols],
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
scalar_named, unary_window_ops = self._split_named_expr()
|
|
519
|
+
|
|
520
|
+
# Build GroupByRequests for scalar aggregations
|
|
521
|
+
gb_requests: list[plc.groupby.GroupByRequest] = []
|
|
522
|
+
out_names: list[str] = []
|
|
523
|
+
out_dtypes: list[DataType] = []
|
|
524
|
+
for ne in scalar_named:
|
|
525
|
+
val = ne.value
|
|
526
|
+
out_names.append(ne.name)
|
|
527
|
+
out_dtypes.append(val.dtype)
|
|
528
|
+
|
|
529
|
+
if isinstance(val, expr.Len):
|
|
530
|
+
# A count aggregation, we need a column so use a key column
|
|
531
|
+
col = by_cols[0].obj
|
|
532
|
+
gb_requests.append(plc.groupby.GroupByRequest(col, [val.agg_request]))
|
|
533
|
+
elif isinstance(val, expr.Agg):
|
|
534
|
+
(child,) = (
|
|
535
|
+
val.children if val.name != "quantile" else (val.children[0],)
|
|
536
|
+
)
|
|
537
|
+
col = child.evaluate(df, context=ExecutionContext.FRAME).obj
|
|
538
|
+
gb_requests.append(plc.groupby.GroupByRequest(col, [val.agg_request]))
|
|
539
|
+
|
|
540
|
+
group_keys_tbl, value_tables = grouper.aggregate(gb_requests)
|
|
541
|
+
out_cols = (t.columns()[0] for t in value_tables)
|
|
542
|
+
|
|
543
|
+
# We do a left-join between the input keys to group-keys
|
|
544
|
+
# so every input row appears exactly once. left_order is
|
|
545
|
+
# returned un-ordered by libcudf.
|
|
546
|
+
left_order, right_order = plc.join.left_join(
|
|
547
|
+
by_tbl, group_keys_tbl, plc.types.NullEquality.EQUAL
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Scatter the right order indices into an all-null table
|
|
551
|
+
# and at the position of the index in left order. Now we
|
|
552
|
+
# have the map between rows and groups with the correct ordering.
|
|
553
|
+
left_rows = left_order.size()
|
|
554
|
+
target = plc.Column.from_scalar(
|
|
555
|
+
plc.Scalar.from_py(None, plc.types.SIZE_TYPE), left_rows
|
|
556
|
+
)
|
|
557
|
+
aligned_map = plc.copying.scatter(
|
|
558
|
+
plc.Table([right_order]),
|
|
559
|
+
left_order,
|
|
560
|
+
plc.Table([target]),
|
|
561
|
+
).columns()[0]
|
|
562
|
+
|
|
563
|
+
# Broadcast each scalar aggregated result back to row-shape using
|
|
564
|
+
# the aligned mapping between row indices and group indices.
|
|
565
|
+
broadcasted_cols = [
|
|
566
|
+
Column(
|
|
567
|
+
plc.copying.gather(
|
|
568
|
+
plc.Table([col]), aligned_map, plc.copying.OutOfBoundsPolicy.NULLIFY
|
|
569
|
+
).columns()[0],
|
|
570
|
+
name=named_expr.name,
|
|
571
|
+
dtype=dtype,
|
|
572
|
+
)
|
|
573
|
+
for named_expr, dtype, col in zip(
|
|
574
|
+
scalar_named, out_dtypes, out_cols, strict=True
|
|
575
|
+
)
|
|
576
|
+
]
|
|
577
|
+
|
|
578
|
+
row_id = plc.filling.sequence(
|
|
579
|
+
df.num_rows,
|
|
580
|
+
plc.Scalar.from_py(0, plc.types.SIZE_TYPE),
|
|
581
|
+
plc.Scalar.from_py(1, plc.types.SIZE_TYPE),
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
if rank_named := unary_window_ops["rank"]:
|
|
585
|
+
if self._order_by_expr is not None:
|
|
586
|
+
_, _, ob_desc, ob_nulls_last = self.options
|
|
587
|
+
for ne in rank_named:
|
|
588
|
+
rank_expr = ne.value
|
|
589
|
+
assert isinstance(rank_expr, expr.UnaryFunction)
|
|
590
|
+
(child,) = rank_expr.children
|
|
591
|
+
val = child.evaluate(df, context=ExecutionContext.FRAME).obj
|
|
592
|
+
desc = rank_expr.options[1]
|
|
593
|
+
|
|
594
|
+
order_index = self._build_window_order_index(
|
|
595
|
+
by_cols,
|
|
596
|
+
row_id=row_id,
|
|
597
|
+
order_by_col=order_by_col,
|
|
598
|
+
ob_desc=ob_desc,
|
|
599
|
+
ob_nulls_last=ob_nulls_last,
|
|
600
|
+
value_col=val,
|
|
601
|
+
value_desc=desc,
|
|
602
|
+
)
|
|
603
|
+
by_cols_for_scan = self._gather_columns(by_cols, order_index)
|
|
604
|
+
local = GroupedRollingWindow._sorted_grouper(by_cols_for_scan)
|
|
605
|
+
names, dtypes, tables = self._apply_unary_op(
|
|
606
|
+
RankOp(
|
|
607
|
+
named_exprs=[ne],
|
|
608
|
+
order_index=order_index,
|
|
609
|
+
by_cols_for_scan=by_cols_for_scan,
|
|
610
|
+
local_grouper=local,
|
|
611
|
+
),
|
|
612
|
+
df,
|
|
613
|
+
grouper,
|
|
614
|
+
)
|
|
615
|
+
broadcasted_cols.extend(
|
|
616
|
+
self._reorder_to_input(
|
|
617
|
+
row_id,
|
|
618
|
+
by_cols,
|
|
619
|
+
df.num_rows,
|
|
620
|
+
tables,
|
|
621
|
+
names,
|
|
622
|
+
dtypes,
|
|
623
|
+
order_index=order_index,
|
|
624
|
+
)
|
|
625
|
+
)
|
|
626
|
+
else:
|
|
627
|
+
names, dtypes, tables = self._apply_unary_op(
|
|
628
|
+
RankOp(
|
|
629
|
+
named_exprs=rank_named, order_index=None, by_cols_for_scan=None
|
|
630
|
+
),
|
|
631
|
+
df,
|
|
632
|
+
grouper,
|
|
633
|
+
)
|
|
634
|
+
broadcasted_cols.extend(
|
|
635
|
+
self._reorder_to_input(
|
|
636
|
+
row_id, by_cols, df.num_rows, tables, names, dtypes
|
|
637
|
+
)
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
# Create a temporary DataFrame with the broadcasted columns named by their
|
|
641
|
+
# placeholder names from agg decomposition, then evaluate the post-expression.
|
|
642
|
+
df = DataFrame(broadcasted_cols)
|
|
643
|
+
return self.post.value.evaluate(df, context=ExecutionContext.FRAME)
|