cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +60 -15
- cudf_polars/containers/column.py +137 -77
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +256 -114
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +33 -3
- cudf_polars/dsl/expressions/unary.py +126 -64
- cudf_polars/dsl/ir.py +1053 -350
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +307 -107
- cudf_polars/dsl/utils/aggregations.py +43 -30
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +55 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +792 -2
- cudf_polars/experimental/benchmarks/utils.py +596 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +30 -15
- cudf_polars/experimental/groupby.py +25 -4
- cudf_polars/experimental/io.py +156 -124
- cudf_polars/experimental/join.py +53 -23
- cudf_polars/experimental/parallel.py +68 -19
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
- cudf_polars/experimental/rapidsmpf/core.py +488 -0
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
- cudf_polars/experimental/rapidsmpf/io.py +696 -0
- cudf_polars/experimental/rapidsmpf/join.py +322 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
- cudf_polars/experimental/rapidsmpf/union.py +115 -0
- cudf_polars/experimental/rapidsmpf/utils.py +374 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +46 -12
- cudf_polars/experimental/sort.py +100 -26
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +93 -17
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +473 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +5 -4
- cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
- cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
|
@@ -6,23 +6,27 @@
|
|
|
6
6
|
|
|
7
7
|
from __future__ import annotations
|
|
8
8
|
|
|
9
|
-
import
|
|
9
|
+
from collections import defaultdict
|
|
10
10
|
from dataclasses import dataclass
|
|
11
11
|
from functools import singledispatchmethod
|
|
12
12
|
from typing import TYPE_CHECKING, Any
|
|
13
13
|
|
|
14
|
-
import polars as pl
|
|
15
|
-
|
|
16
14
|
import pylibcudf as plc
|
|
17
15
|
|
|
18
16
|
from cudf_polars.containers import Column, DataFrame, DataType
|
|
19
17
|
from cudf_polars.dsl import expr
|
|
20
18
|
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
|
|
21
19
|
from cudf_polars.dsl.utils.reshape import broadcast
|
|
22
|
-
from cudf_polars.dsl.utils.windows import
|
|
20
|
+
from cudf_polars.dsl.utils.windows import (
|
|
21
|
+
duration_to_int,
|
|
22
|
+
offsets_to_windows,
|
|
23
|
+
range_window_bounds,
|
|
24
|
+
)
|
|
23
25
|
|
|
24
26
|
if TYPE_CHECKING:
|
|
25
|
-
from collections.abc import
|
|
27
|
+
from collections.abc import Sequence
|
|
28
|
+
|
|
29
|
+
from rmm.pylibrmm.stream import Stream
|
|
26
30
|
|
|
27
31
|
from cudf_polars.typing import ClosedInterval, Duration
|
|
28
32
|
|
|
@@ -42,6 +46,16 @@ class RankOp(UnaryOp):
|
|
|
42
46
|
pass
|
|
43
47
|
|
|
44
48
|
|
|
49
|
+
@dataclass(frozen=True)
|
|
50
|
+
class FillNullWithStrategyOp(UnaryOp):
|
|
51
|
+
policy: plc.replace.ReplacePolicy = plc.replace.ReplacePolicy.PRECEDING
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass(frozen=True)
|
|
55
|
+
class CumSumOp(UnaryOp):
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
|
|
45
59
|
def to_request(
|
|
46
60
|
value: expr.Expr, orderby: Column, df: DataFrame
|
|
47
61
|
) -> plc.rolling.RollingRequest:
|
|
@@ -78,26 +92,15 @@ def to_request(
|
|
|
78
92
|
return plc.rolling.RollingRequest(col.obj, min_periods, value.agg_request)
|
|
79
93
|
|
|
80
94
|
|
|
81
|
-
def _by_exprs(b: Expr | tuple) -> Generator[Expr]:
|
|
82
|
-
if isinstance(b, Expr):
|
|
83
|
-
yield b
|
|
84
|
-
elif isinstance(b, tuple): # pragma: no cover; tests cover this path when
|
|
85
|
-
# run with the distributed scheduler only
|
|
86
|
-
for item in b:
|
|
87
|
-
yield from _by_exprs(item)
|
|
88
|
-
else:
|
|
89
|
-
yield expr.Literal(DataType(pl.Int64()), b) # pragma: no cover
|
|
90
|
-
|
|
91
|
-
|
|
92
95
|
class RollingWindow(Expr):
|
|
93
96
|
__slots__ = (
|
|
94
97
|
"closed_window",
|
|
95
|
-
"
|
|
98
|
+
"following_ordinal",
|
|
96
99
|
"offset",
|
|
97
100
|
"orderby",
|
|
98
101
|
"orderby_dtype",
|
|
99
102
|
"period",
|
|
100
|
-
"
|
|
103
|
+
"preceding_ordinal",
|
|
101
104
|
)
|
|
102
105
|
_non_child = (
|
|
103
106
|
"dtype",
|
|
@@ -111,7 +114,7 @@ class RollingWindow(Expr):
|
|
|
111
114
|
def __init__(
|
|
112
115
|
self,
|
|
113
116
|
dtype: DataType,
|
|
114
|
-
orderby_dtype: DataType,
|
|
117
|
+
orderby_dtype: plc.DataType,
|
|
115
118
|
offset: Duration,
|
|
116
119
|
period: Duration,
|
|
117
120
|
closed_window: ClosedInterval,
|
|
@@ -126,9 +129,11 @@ class RollingWindow(Expr):
|
|
|
126
129
|
# within `__init__`).
|
|
127
130
|
self.offset = offset
|
|
128
131
|
self.period = period
|
|
129
|
-
self.
|
|
130
|
-
|
|
131
|
-
|
|
132
|
+
self.orderby_dtype = orderby_dtype
|
|
133
|
+
self.offset = offset
|
|
134
|
+
self.period = period
|
|
135
|
+
self.preceding_ordinal = duration_to_int(orderby_dtype, *offset)
|
|
136
|
+
self.following_ordinal = duration_to_int(orderby_dtype, *period)
|
|
132
137
|
self.closed_window = closed_window
|
|
133
138
|
self.orderby = orderby
|
|
134
139
|
self.children = (agg,)
|
|
@@ -137,7 +142,9 @@ class RollingWindow(Expr):
|
|
|
137
142
|
raise NotImplementedError(
|
|
138
143
|
"Incorrect handling of empty groups for list collection"
|
|
139
144
|
)
|
|
140
|
-
if not plc.rolling.is_valid_rolling_aggregation(
|
|
145
|
+
if not plc.rolling.is_valid_rolling_aggregation(
|
|
146
|
+
agg.dtype.plc_type, agg.agg_request
|
|
147
|
+
):
|
|
141
148
|
raise NotImplementedError(f"Unsupported rolling aggregation {agg}")
|
|
142
149
|
|
|
143
150
|
def do_evaluate( # noqa: D102
|
|
@@ -154,18 +161,28 @@ class RollingWindow(Expr):
|
|
|
154
161
|
plc.traits.is_integral(orderby.obj.type())
|
|
155
162
|
and orderby.obj.type().id() != plc.TypeId.INT64
|
|
156
163
|
):
|
|
157
|
-
orderby_obj = plc.unary.cast(
|
|
164
|
+
orderby_obj = plc.unary.cast(
|
|
165
|
+
orderby.obj, plc.DataType(plc.TypeId.INT64), stream=df.stream
|
|
166
|
+
)
|
|
158
167
|
else:
|
|
159
168
|
orderby_obj = orderby.obj
|
|
169
|
+
preceding_scalar, following_scalar = offsets_to_windows(
|
|
170
|
+
self.orderby_dtype,
|
|
171
|
+
self.preceding_ordinal,
|
|
172
|
+
self.following_ordinal,
|
|
173
|
+
stream=df.stream,
|
|
174
|
+
)
|
|
160
175
|
preceding, following = range_window_bounds(
|
|
161
|
-
|
|
176
|
+
preceding_scalar, following_scalar, self.closed_window
|
|
162
177
|
)
|
|
163
178
|
if orderby.obj.null_count() != 0:
|
|
164
179
|
raise RuntimeError(
|
|
165
180
|
f"Index column '{self.orderby}' in rolling may not contain nulls"
|
|
166
181
|
)
|
|
167
182
|
if not orderby.check_sorted(
|
|
168
|
-
order=plc.types.Order.ASCENDING,
|
|
183
|
+
order=plc.types.Order.ASCENDING,
|
|
184
|
+
null_order=plc.types.NullOrder.BEFORE,
|
|
185
|
+
stream=df.stream,
|
|
169
186
|
):
|
|
170
187
|
raise RuntimeError(
|
|
171
188
|
f"Index column '{self.orderby}' in rolling is not sorted, please sort first"
|
|
@@ -178,6 +195,7 @@ class RollingWindow(Expr):
|
|
|
178
195
|
preceding,
|
|
179
196
|
following,
|
|
180
197
|
[to_request(agg, orderby, df)],
|
|
198
|
+
stream=df.stream,
|
|
181
199
|
).columns()
|
|
182
200
|
return Column(result, dtype=self.dtype)
|
|
183
201
|
|
|
@@ -207,7 +225,6 @@ class GroupedRollingWindow(Expr):
|
|
|
207
225
|
"named_aggs",
|
|
208
226
|
"post",
|
|
209
227
|
"by_count",
|
|
210
|
-
"_order_by_expr",
|
|
211
228
|
)
|
|
212
229
|
|
|
213
230
|
def __init__(
|
|
@@ -216,15 +233,18 @@ class GroupedRollingWindow(Expr):
|
|
|
216
233
|
options: Any,
|
|
217
234
|
named_aggs: Sequence[expr.NamedExpr],
|
|
218
235
|
post: expr.NamedExpr,
|
|
219
|
-
|
|
220
|
-
|
|
236
|
+
by_count: int,
|
|
237
|
+
*children: Expr,
|
|
221
238
|
) -> None:
|
|
222
239
|
self.dtype = dtype
|
|
223
240
|
self.options = options
|
|
224
241
|
self.named_aggs = tuple(named_aggs)
|
|
225
242
|
self.post = post
|
|
243
|
+
self.by_count = by_count
|
|
244
|
+
has_order_by = self.options[1]
|
|
226
245
|
self.is_pointwise = False
|
|
227
|
-
self.
|
|
246
|
+
self.children = tuple(children)
|
|
247
|
+
self._order_by_expr = children[by_count] if has_order_by else None
|
|
228
248
|
|
|
229
249
|
unsupported = [
|
|
230
250
|
type(named_expr.value).__name__
|
|
@@ -233,7 +253,8 @@ class GroupedRollingWindow(Expr):
|
|
|
233
253
|
isinstance(named_expr.value, (expr.Len, expr.Agg))
|
|
234
254
|
or (
|
|
235
255
|
isinstance(named_expr.value, expr.UnaryFunction)
|
|
236
|
-
and named_expr.value.name
|
|
256
|
+
and named_expr.value.name
|
|
257
|
+
in {"rank", "fill_null_with_strategy", "cum_sum"}
|
|
237
258
|
)
|
|
238
259
|
)
|
|
239
260
|
]
|
|
@@ -242,29 +263,22 @@ class GroupedRollingWindow(Expr):
|
|
|
242
263
|
raise NotImplementedError(
|
|
243
264
|
f"Unsupported over(...) only expression: {kinds}="
|
|
244
265
|
)
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
# Expose agg dependencies as children so the streaming
|
|
252
|
-
# executor retains required source columns
|
|
253
|
-
child_deps = [
|
|
254
|
-
v.children[0]
|
|
255
|
-
for ne in self.named_aggs
|
|
256
|
-
for v in (ne.value,)
|
|
257
|
-
if isinstance(v, expr.Agg)
|
|
258
|
-
or (isinstance(v, expr.UnaryFunction) and v.name in {"rank"})
|
|
259
|
-
]
|
|
260
|
-
self.by_count = len(by_expr)
|
|
261
|
-
self.children = tuple(
|
|
262
|
-
itertools.chain(
|
|
263
|
-
by_expr,
|
|
264
|
-
(() if self._order_by_expr is None else (self._order_by_expr,)),
|
|
265
|
-
child_deps,
|
|
266
|
+
if has_order_by:
|
|
267
|
+
ob = self._order_by_expr
|
|
268
|
+
is_multi_order_by = (
|
|
269
|
+
isinstance(ob, expr.UnaryFunction)
|
|
270
|
+
and ob.name == "as_struct"
|
|
271
|
+
and len(ob.children) > 1
|
|
266
272
|
)
|
|
267
|
-
|
|
273
|
+
has_order_sensitive_agg = any(
|
|
274
|
+
isinstance(ne.value, expr.Agg)
|
|
275
|
+
and ne.value.agg_request.kind() == plc.aggregation.Kind.NTH_ELEMENT
|
|
276
|
+
for ne in self.named_aggs
|
|
277
|
+
)
|
|
278
|
+
if is_multi_order_by and has_order_sensitive_agg:
|
|
279
|
+
raise NotImplementedError(
|
|
280
|
+
"Multiple order_by keys with order-sensitive aggregations"
|
|
281
|
+
)
|
|
268
282
|
|
|
269
283
|
@staticmethod
|
|
270
284
|
def _sorted_grouper(by_cols_for_scan: list[Column]) -> plc.groupby.GroupBy:
|
|
@@ -311,6 +325,7 @@ class GroupedRollingWindow(Expr):
|
|
|
311
325
|
plc.Table([val_col]),
|
|
312
326
|
order_index,
|
|
313
327
|
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
328
|
+
stream=df.stream,
|
|
314
329
|
).columns()[0]
|
|
315
330
|
assert isinstance(rank_expr, expr.UnaryFunction)
|
|
316
331
|
method_str, descending, _ = rank_expr.options
|
|
@@ -351,6 +366,85 @@ class GroupedRollingWindow(Expr):
|
|
|
351
366
|
_, rank_tables = grouper.scan(rank_requests)
|
|
352
367
|
return rank_out_names, rank_out_dtypes, rank_tables
|
|
353
368
|
|
|
369
|
+
@_apply_unary_op.register
|
|
370
|
+
def _( # type: ignore[no-untyped-def]
|
|
371
|
+
self,
|
|
372
|
+
op: FillNullWithStrategyOp,
|
|
373
|
+
df: DataFrame,
|
|
374
|
+
_,
|
|
375
|
+
) -> tuple[list[str], list[DataType], list[plc.Table]]:
|
|
376
|
+
named_exprs = op.named_exprs
|
|
377
|
+
|
|
378
|
+
plc_cols = [
|
|
379
|
+
ne.value.children[0].evaluate(df, context=ExecutionContext.FRAME).obj
|
|
380
|
+
for ne in named_exprs
|
|
381
|
+
]
|
|
382
|
+
if op.order_index is not None:
|
|
383
|
+
vals_tbl = plc.copying.gather(
|
|
384
|
+
plc.Table(plc_cols),
|
|
385
|
+
op.order_index,
|
|
386
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
387
|
+
stream=df.stream,
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
vals_tbl = plc.Table(plc_cols)
|
|
391
|
+
local_grouper = op.local_grouper
|
|
392
|
+
assert isinstance(local_grouper, plc.groupby.GroupBy)
|
|
393
|
+
_, filled_tbl = local_grouper.replace_nulls(
|
|
394
|
+
vals_tbl,
|
|
395
|
+
[op.policy] * len(plc_cols),
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
tables = [plc.Table([column]) for column in filled_tbl.columns()]
|
|
399
|
+
names = [ne.name for ne in named_exprs]
|
|
400
|
+
dtypes = [ne.value.dtype for ne in named_exprs]
|
|
401
|
+
return names, dtypes, tables
|
|
402
|
+
|
|
403
|
+
@_apply_unary_op.register
|
|
404
|
+
def _( # type: ignore[no-untyped-def]
|
|
405
|
+
self,
|
|
406
|
+
op: CumSumOp,
|
|
407
|
+
df: DataFrame,
|
|
408
|
+
_,
|
|
409
|
+
) -> tuple[list[str], list[DataType], list[plc.Table]]:
|
|
410
|
+
cum_named = op.named_exprs
|
|
411
|
+
order_index = op.order_index
|
|
412
|
+
|
|
413
|
+
requests: list[plc.groupby.GroupByRequest] = []
|
|
414
|
+
out_names: list[str] = []
|
|
415
|
+
out_dtypes: list[DataType] = []
|
|
416
|
+
|
|
417
|
+
# Instead of calling self._gather_columns, let's call plc.copying.gather directly
|
|
418
|
+
# since we need plc.Column objects, not cudf_polars Column objects
|
|
419
|
+
if order_index is not None:
|
|
420
|
+
plc_cols = [
|
|
421
|
+
ne.value.children[0].evaluate(df, context=ExecutionContext.FRAME).obj
|
|
422
|
+
for ne in cum_named
|
|
423
|
+
]
|
|
424
|
+
val_cols = plc.copying.gather(
|
|
425
|
+
plc.Table(plc_cols),
|
|
426
|
+
order_index,
|
|
427
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
428
|
+
stream=df.stream,
|
|
429
|
+
).columns()
|
|
430
|
+
else:
|
|
431
|
+
val_cols = [
|
|
432
|
+
ne.value.children[0].evaluate(df, context=ExecutionContext.FRAME).obj
|
|
433
|
+
for ne in cum_named
|
|
434
|
+
]
|
|
435
|
+
agg = plc.aggregation.sum()
|
|
436
|
+
|
|
437
|
+
for ne, val_col in zip(cum_named, val_cols, strict=True):
|
|
438
|
+
requests.append(plc.groupby.GroupByRequest(val_col, [agg]))
|
|
439
|
+
out_names.append(ne.name)
|
|
440
|
+
out_dtypes.append(ne.value.dtype)
|
|
441
|
+
|
|
442
|
+
local_grouper = op.local_grouper
|
|
443
|
+
assert isinstance(local_grouper, plc.groupby.GroupBy)
|
|
444
|
+
_, tables = local_grouper.scan(requests)
|
|
445
|
+
|
|
446
|
+
return out_names, out_dtypes, tables
|
|
447
|
+
|
|
354
448
|
def _reorder_to_input(
|
|
355
449
|
self,
|
|
356
450
|
row_id: plc.Column,
|
|
@@ -361,6 +455,7 @@ class GroupedRollingWindow(Expr):
|
|
|
361
455
|
rank_out_dtypes: list[DataType],
|
|
362
456
|
*,
|
|
363
457
|
order_index: plc.Column | None = None,
|
|
458
|
+
stream: Stream,
|
|
364
459
|
) -> list[Column]:
|
|
365
460
|
# Reorder scan results from grouped-order back to input row order
|
|
366
461
|
if order_index is None:
|
|
@@ -370,6 +465,7 @@ class GroupedRollingWindow(Expr):
|
|
|
370
465
|
plc.Table([*(c.obj for c in by_cols), row_id]),
|
|
371
466
|
[*key_orders, plc.types.Order.ASCENDING],
|
|
372
467
|
[*key_nulls, plc.types.NullOrder.AFTER],
|
|
468
|
+
stream=stream,
|
|
373
469
|
)
|
|
374
470
|
|
|
375
471
|
return [
|
|
@@ -380,11 +476,15 @@ class GroupedRollingWindow(Expr):
|
|
|
380
476
|
plc.Table(
|
|
381
477
|
[
|
|
382
478
|
plc.Column.from_scalar(
|
|
383
|
-
plc.Scalar.from_py(
|
|
479
|
+
plc.Scalar.from_py(
|
|
480
|
+
None, tbl.columns()[0].type(), stream=stream
|
|
481
|
+
),
|
|
384
482
|
n_rows,
|
|
483
|
+
stream=stream,
|
|
385
484
|
)
|
|
386
485
|
]
|
|
387
486
|
),
|
|
487
|
+
stream=stream,
|
|
388
488
|
).columns()[0],
|
|
389
489
|
name=name,
|
|
390
490
|
dtype=dtype,
|
|
@@ -401,6 +501,8 @@ class GroupedRollingWindow(Expr):
|
|
|
401
501
|
reductions: list[expr.NamedExpr] = []
|
|
402
502
|
unary_window_ops: dict[str, list[expr.NamedExpr]] = {
|
|
403
503
|
"rank": [],
|
|
504
|
+
"fill_null_with_strategy": [],
|
|
505
|
+
"cum_sum": [],
|
|
404
506
|
}
|
|
405
507
|
|
|
406
508
|
for ne in self.named_aggs:
|
|
@@ -421,6 +523,7 @@ class GroupedRollingWindow(Expr):
|
|
|
421
523
|
ob_nulls_last: bool,
|
|
422
524
|
value_col: plc.Column | None = None,
|
|
423
525
|
value_desc: bool = False,
|
|
526
|
+
stream: Stream,
|
|
424
527
|
) -> plc.Column:
|
|
425
528
|
"""Compute a stable row ordering for unary operations in a grouped context."""
|
|
426
529
|
cols: list[plc.Column] = [c.obj for c in by_cols]
|
|
@@ -444,7 +547,7 @@ class GroupedRollingWindow(Expr):
|
|
|
444
547
|
)
|
|
445
548
|
nulls.append(
|
|
446
549
|
plc.types.NullOrder.AFTER
|
|
447
|
-
if ob_nulls_last
|
|
550
|
+
if ob_desc ^ ob_nulls_last
|
|
448
551
|
else plc.types.NullOrder.BEFORE
|
|
449
552
|
)
|
|
450
553
|
|
|
@@ -453,31 +556,153 @@ class GroupedRollingWindow(Expr):
|
|
|
453
556
|
orders.append(plc.types.Order.ASCENDING)
|
|
454
557
|
nulls.append(plc.types.NullOrder.AFTER)
|
|
455
558
|
|
|
456
|
-
return plc.sorting.stable_sorted_order(
|
|
559
|
+
return plc.sorting.stable_sorted_order(
|
|
560
|
+
plc.Table(cols), orders, nulls, stream=stream
|
|
561
|
+
)
|
|
457
562
|
|
|
458
563
|
def _gather_columns(
|
|
459
|
-
self,
|
|
460
|
-
|
|
461
|
-
order_index: plc.Column,
|
|
462
|
-
) -> list[plc.Column] | list[Column]:
|
|
564
|
+
self, cols: Sequence[Column], order_index: plc.Column, stream: Stream
|
|
565
|
+
) -> list[Column]:
|
|
463
566
|
gathered_tbl = plc.copying.gather(
|
|
464
567
|
plc.Table([c.obj for c in cols]),
|
|
465
568
|
order_index,
|
|
466
569
|
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
570
|
+
stream=stream,
|
|
467
571
|
)
|
|
468
572
|
|
|
469
573
|
return [
|
|
470
574
|
Column(
|
|
471
|
-
|
|
575
|
+
gathered,
|
|
472
576
|
name=c.name,
|
|
473
577
|
dtype=c.dtype,
|
|
474
578
|
order=c.order,
|
|
475
579
|
null_order=c.null_order,
|
|
476
|
-
is_sorted=
|
|
580
|
+
is_sorted=c.is_sorted,
|
|
477
581
|
)
|
|
478
|
-
for
|
|
582
|
+
for gathered, c in zip(gathered_tbl.columns(), cols, strict=True)
|
|
479
583
|
]
|
|
480
584
|
|
|
585
|
+
def _grouped_window_scan_setup(
|
|
586
|
+
self,
|
|
587
|
+
by_cols: list[Column],
|
|
588
|
+
*,
|
|
589
|
+
row_id: plc.Column,
|
|
590
|
+
order_by_col: Column | None,
|
|
591
|
+
ob_desc: bool,
|
|
592
|
+
ob_nulls_last: bool,
|
|
593
|
+
grouper: plc.groupby.GroupBy,
|
|
594
|
+
stream: Stream,
|
|
595
|
+
) -> tuple[plc.Column | None, list[Column] | None, plc.groupby.GroupBy]:
|
|
596
|
+
if order_by_col is None:
|
|
597
|
+
# keep the original ordering
|
|
598
|
+
return None, None, grouper
|
|
599
|
+
order_index = self._build_window_order_index(
|
|
600
|
+
by_cols,
|
|
601
|
+
row_id=row_id,
|
|
602
|
+
order_by_col=order_by_col,
|
|
603
|
+
ob_desc=ob_desc,
|
|
604
|
+
ob_nulls_last=ob_nulls_last,
|
|
605
|
+
stream=stream,
|
|
606
|
+
)
|
|
607
|
+
by_cols_for_scan = self._gather_columns(by_cols, order_index, stream=stream)
|
|
608
|
+
assert by_cols_for_scan is not None
|
|
609
|
+
local = self._sorted_grouper(by_cols_for_scan)
|
|
610
|
+
return order_index, by_cols_for_scan, local
|
|
611
|
+
|
|
612
|
+
def _broadcast_agg_results(
|
|
613
|
+
self,
|
|
614
|
+
by_tbl: plc.Table,
|
|
615
|
+
group_keys_tbl: plc.Table,
|
|
616
|
+
value_tbls: list[plc.Table],
|
|
617
|
+
names: list[str],
|
|
618
|
+
dtypes: list[DataType],
|
|
619
|
+
stream: Stream,
|
|
620
|
+
) -> list[Column]:
|
|
621
|
+
# We do a left-join between the input keys to group-keys
|
|
622
|
+
# so every input row appears exactly once. left_order is
|
|
623
|
+
# returned un-ordered by libcudf.
|
|
624
|
+
left_order, right_order = plc.join.left_join(
|
|
625
|
+
by_tbl, group_keys_tbl, plc.types.NullEquality.EQUAL, stream
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
# Scatter the right order indices into an all-null table
|
|
629
|
+
# and at the position of the index in left order. Now we
|
|
630
|
+
# have the map between rows and groups with the correct ordering.
|
|
631
|
+
left_rows = left_order.size()
|
|
632
|
+
target = plc.Column.from_scalar(
|
|
633
|
+
plc.Scalar.from_py(None, plc.types.SIZE_TYPE, stream), left_rows, stream
|
|
634
|
+
)
|
|
635
|
+
aligned_map = plc.copying.scatter(
|
|
636
|
+
plc.Table([right_order]),
|
|
637
|
+
left_order,
|
|
638
|
+
plc.Table([target]),
|
|
639
|
+
stream,
|
|
640
|
+
).columns()[0]
|
|
641
|
+
|
|
642
|
+
# Broadcast each scalar aggregated result back to row-shape using
|
|
643
|
+
# the aligned mapping between row indices and group indices.
|
|
644
|
+
out_cols = (t.columns()[0] for t in value_tbls)
|
|
645
|
+
return [
|
|
646
|
+
Column(
|
|
647
|
+
plc.copying.gather(
|
|
648
|
+
plc.Table([col]),
|
|
649
|
+
aligned_map,
|
|
650
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
651
|
+
stream,
|
|
652
|
+
).columns()[0],
|
|
653
|
+
name=name,
|
|
654
|
+
dtype=dtype,
|
|
655
|
+
)
|
|
656
|
+
for name, dtype, col in zip(names, dtypes, out_cols, strict=True)
|
|
657
|
+
]
|
|
658
|
+
|
|
659
|
+
def _build_groupby_requests(
|
|
660
|
+
self,
|
|
661
|
+
named_exprs: list[expr.NamedExpr],
|
|
662
|
+
df: DataFrame,
|
|
663
|
+
order_index: plc.Column | None = None,
|
|
664
|
+
by_cols: list[Column] | None = None,
|
|
665
|
+
) -> tuple[list[plc.groupby.GroupByRequest], list[str], list[DataType]]:
|
|
666
|
+
assert by_cols is not None
|
|
667
|
+
gb_requests: list[plc.groupby.GroupByRequest] = []
|
|
668
|
+
out_names: list[str] = []
|
|
669
|
+
out_dtypes: list[DataType] = []
|
|
670
|
+
|
|
671
|
+
eval_cols: list[plc.Column] = []
|
|
672
|
+
val_nodes: list[tuple[expr.NamedExpr, expr.Agg]] = []
|
|
673
|
+
|
|
674
|
+
for ne in named_exprs:
|
|
675
|
+
val = ne.value
|
|
676
|
+
out_names.append(ne.name)
|
|
677
|
+
out_dtypes.append(val.dtype)
|
|
678
|
+
if isinstance(val, expr.Agg):
|
|
679
|
+
(child,) = (
|
|
680
|
+
val.children
|
|
681
|
+
if getattr(val, "name", None) != "quantile"
|
|
682
|
+
else (val.children[0],)
|
|
683
|
+
)
|
|
684
|
+
eval_cols.append(child.evaluate(df, context=ExecutionContext.FRAME).obj)
|
|
685
|
+
val_nodes.append((ne, val))
|
|
686
|
+
|
|
687
|
+
if order_index is not None and eval_cols:
|
|
688
|
+
eval_cols = plc.copying.gather(
|
|
689
|
+
plc.Table(eval_cols),
|
|
690
|
+
order_index,
|
|
691
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
692
|
+
stream=df.stream,
|
|
693
|
+
).columns()
|
|
694
|
+
|
|
695
|
+
gathered_iter = iter(eval_cols)
|
|
696
|
+
for ne in named_exprs:
|
|
697
|
+
val = ne.value
|
|
698
|
+
if isinstance(val, expr.Len):
|
|
699
|
+
col = by_cols[0].obj
|
|
700
|
+
else:
|
|
701
|
+
col = next(gathered_iter)
|
|
702
|
+
gb_requests.append(plc.groupby.GroupByRequest(col, [val.agg_request]))
|
|
703
|
+
|
|
704
|
+
return gb_requests, out_names, out_dtypes
|
|
705
|
+
|
|
481
706
|
def do_evaluate( # noqa: D102
|
|
482
707
|
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
|
|
483
708
|
) -> Column:
|
|
@@ -493,9 +718,12 @@ class GroupedRollingWindow(Expr):
|
|
|
493
718
|
by_cols = broadcast(
|
|
494
719
|
*(b.evaluate(df) for b in by_exprs),
|
|
495
720
|
target_length=df.num_rows,
|
|
721
|
+
stream=df.stream,
|
|
496
722
|
)
|
|
497
723
|
order_by_col = (
|
|
498
|
-
broadcast(
|
|
724
|
+
broadcast(
|
|
725
|
+
order_by_expr.evaluate(df), target_length=df.num_rows, stream=df.stream
|
|
726
|
+
)[0]
|
|
499
727
|
if order_by_expr is not None
|
|
500
728
|
else None
|
|
501
729
|
)
|
|
@@ -518,67 +746,73 @@ class GroupedRollingWindow(Expr):
|
|
|
518
746
|
scalar_named, unary_window_ops = self._split_named_expr()
|
|
519
747
|
|
|
520
748
|
# Build GroupByRequests for scalar aggregations
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
out_dtypes: list[DataType] = []
|
|
749
|
+
order_sensitive: list[expr.NamedExpr] = []
|
|
750
|
+
other_scalars: list[expr.NamedExpr] = []
|
|
524
751
|
for ne in scalar_named:
|
|
525
752
|
val = ne.value
|
|
526
|
-
|
|
527
|
-
|
|
753
|
+
if (
|
|
754
|
+
self._order_by_expr is not None
|
|
755
|
+
and isinstance(val, expr.Agg)
|
|
756
|
+
and val.agg_request.kind() == plc.aggregation.Kind.NTH_ELEMENT
|
|
757
|
+
):
|
|
758
|
+
order_sensitive.append(ne)
|
|
759
|
+
else:
|
|
760
|
+
other_scalars.append(ne)
|
|
528
761
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
gb_requests.append(plc.groupby.GroupByRequest(col, [val.agg_request]))
|
|
533
|
-
elif isinstance(val, expr.Agg):
|
|
534
|
-
(child,) = (
|
|
535
|
-
val.children if val.name != "quantile" else (val.children[0],)
|
|
536
|
-
)
|
|
537
|
-
col = child.evaluate(df, context=ExecutionContext.FRAME).obj
|
|
538
|
-
gb_requests.append(plc.groupby.GroupByRequest(col, [val.agg_request]))
|
|
762
|
+
gb_requests, out_names, out_dtypes = self._build_groupby_requests(
|
|
763
|
+
other_scalars, df, by_cols=by_cols
|
|
764
|
+
)
|
|
539
765
|
|
|
540
766
|
group_keys_tbl, value_tables = grouper.aggregate(gb_requests)
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
767
|
+
broadcasted_cols = self._broadcast_agg_results(
|
|
768
|
+
by_tbl,
|
|
769
|
+
group_keys_tbl,
|
|
770
|
+
value_tables,
|
|
771
|
+
out_names,
|
|
772
|
+
out_dtypes,
|
|
773
|
+
df.stream,
|
|
548
774
|
)
|
|
549
775
|
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
776
|
+
if order_sensitive:
|
|
777
|
+
row_id = plc.filling.sequence(
|
|
778
|
+
df.num_rows,
|
|
779
|
+
plc.Scalar.from_py(0, plc.types.SIZE_TYPE, stream=df.stream),
|
|
780
|
+
plc.Scalar.from_py(1, plc.types.SIZE_TYPE, stream=df.stream),
|
|
781
|
+
stream=df.stream,
|
|
782
|
+
)
|
|
783
|
+
_, _, ob_desc, ob_nulls_last = self.options
|
|
784
|
+
order_index, _, local = self._grouped_window_scan_setup(
|
|
785
|
+
by_cols,
|
|
786
|
+
row_id=row_id,
|
|
787
|
+
order_by_col=order_by_col,
|
|
788
|
+
ob_desc=ob_desc,
|
|
789
|
+
ob_nulls_last=ob_nulls_last,
|
|
790
|
+
grouper=grouper,
|
|
791
|
+
stream=df.stream,
|
|
792
|
+
)
|
|
793
|
+
assert order_index is not None
|
|
562
794
|
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
broadcasted_cols = [
|
|
566
|
-
Column(
|
|
567
|
-
plc.copying.gather(
|
|
568
|
-
plc.Table([col]), aligned_map, plc.copying.OutOfBoundsPolicy.NULLIFY
|
|
569
|
-
).columns()[0],
|
|
570
|
-
name=named_expr.name,
|
|
571
|
-
dtype=dtype,
|
|
795
|
+
gb_requests, out_names, out_dtypes = self._build_groupby_requests(
|
|
796
|
+
order_sensitive, df, order_index=order_index, by_cols=by_cols
|
|
572
797
|
)
|
|
573
|
-
|
|
574
|
-
|
|
798
|
+
|
|
799
|
+
group_keys_tbl_local, value_tables_local = local.aggregate(gb_requests)
|
|
800
|
+
broadcasted_cols.extend(
|
|
801
|
+
self._broadcast_agg_results(
|
|
802
|
+
by_tbl,
|
|
803
|
+
group_keys_tbl_local,
|
|
804
|
+
value_tables_local,
|
|
805
|
+
out_names,
|
|
806
|
+
out_dtypes,
|
|
807
|
+
df.stream,
|
|
808
|
+
)
|
|
575
809
|
)
|
|
576
|
-
]
|
|
577
810
|
|
|
578
811
|
row_id = plc.filling.sequence(
|
|
579
812
|
df.num_rows,
|
|
580
|
-
plc.Scalar.from_py(0, plc.types.SIZE_TYPE),
|
|
581
|
-
plc.Scalar.from_py(1, plc.types.SIZE_TYPE),
|
|
813
|
+
plc.Scalar.from_py(0, plc.types.SIZE_TYPE, stream=df.stream),
|
|
814
|
+
plc.Scalar.from_py(1, plc.types.SIZE_TYPE, stream=df.stream),
|
|
815
|
+
stream=df.stream,
|
|
582
816
|
)
|
|
583
817
|
|
|
584
818
|
if rank_named := unary_window_ops["rank"]:
|
|
@@ -588,7 +822,6 @@ class GroupedRollingWindow(Expr):
|
|
|
588
822
|
rank_expr = ne.value
|
|
589
823
|
assert isinstance(rank_expr, expr.UnaryFunction)
|
|
590
824
|
(child,) = rank_expr.children
|
|
591
|
-
val = child.evaluate(df, context=ExecutionContext.FRAME).obj
|
|
592
825
|
desc = rank_expr.options[1]
|
|
593
826
|
|
|
594
827
|
order_index = self._build_window_order_index(
|
|
@@ -597,16 +830,21 @@ class GroupedRollingWindow(Expr):
|
|
|
597
830
|
order_by_col=order_by_col,
|
|
598
831
|
ob_desc=ob_desc,
|
|
599
832
|
ob_nulls_last=ob_nulls_last,
|
|
600
|
-
value_col=
|
|
833
|
+
value_col=child.evaluate(
|
|
834
|
+
df, context=ExecutionContext.FRAME
|
|
835
|
+
).obj,
|
|
601
836
|
value_desc=desc,
|
|
837
|
+
stream=df.stream,
|
|
602
838
|
)
|
|
603
|
-
|
|
604
|
-
|
|
839
|
+
rank_by_cols_for_scan = self._gather_columns(
|
|
840
|
+
by_cols, order_index, stream=df.stream
|
|
841
|
+
)
|
|
842
|
+
local = GroupedRollingWindow._sorted_grouper(rank_by_cols_for_scan)
|
|
605
843
|
names, dtypes, tables = self._apply_unary_op(
|
|
606
844
|
RankOp(
|
|
607
845
|
named_exprs=[ne],
|
|
608
846
|
order_index=order_index,
|
|
609
|
-
by_cols_for_scan=
|
|
847
|
+
by_cols_for_scan=rank_by_cols_for_scan,
|
|
610
848
|
local_grouper=local,
|
|
611
849
|
),
|
|
612
850
|
df,
|
|
@@ -621,6 +859,7 @@ class GroupedRollingWindow(Expr):
|
|
|
621
859
|
names,
|
|
622
860
|
dtypes,
|
|
623
861
|
order_index=order_index,
|
|
862
|
+
stream=df.stream,
|
|
624
863
|
)
|
|
625
864
|
)
|
|
626
865
|
else:
|
|
@@ -633,11 +872,113 @@ class GroupedRollingWindow(Expr):
|
|
|
633
872
|
)
|
|
634
873
|
broadcasted_cols.extend(
|
|
635
874
|
self._reorder_to_input(
|
|
636
|
-
row_id,
|
|
875
|
+
row_id,
|
|
876
|
+
by_cols,
|
|
877
|
+
df.num_rows,
|
|
878
|
+
tables,
|
|
879
|
+
names,
|
|
880
|
+
dtypes,
|
|
881
|
+
stream=df.stream,
|
|
882
|
+
)
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
if fill_named := unary_window_ops["fill_null_with_strategy"]:
|
|
886
|
+
order_index, fill_null_by_cols_for_scan, local = (
|
|
887
|
+
self._grouped_window_scan_setup(
|
|
888
|
+
by_cols,
|
|
889
|
+
row_id=row_id,
|
|
890
|
+
order_by_col=order_by_col
|
|
891
|
+
if self._order_by_expr is not None
|
|
892
|
+
else None,
|
|
893
|
+
ob_desc=self.options[2]
|
|
894
|
+
if self._order_by_expr is not None
|
|
895
|
+
else False,
|
|
896
|
+
ob_nulls_last=self.options[3]
|
|
897
|
+
if self._order_by_expr is not None
|
|
898
|
+
else False,
|
|
899
|
+
grouper=grouper,
|
|
900
|
+
stream=df.stream,
|
|
901
|
+
)
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
strategy_exprs: dict[str, list[expr.NamedExpr]] = defaultdict(list)
|
|
905
|
+
for ne in fill_named:
|
|
906
|
+
fill_null_expr = ne.value
|
|
907
|
+
assert isinstance(fill_null_expr, expr.UnaryFunction)
|
|
908
|
+
strategy_exprs[fill_null_expr.options[0]].append(ne)
|
|
909
|
+
|
|
910
|
+
replace_policy = {
|
|
911
|
+
"forward": plc.replace.ReplacePolicy.PRECEDING,
|
|
912
|
+
"backward": plc.replace.ReplacePolicy.FOLLOWING,
|
|
913
|
+
}
|
|
914
|
+
|
|
915
|
+
for strategy, fill_exprs in strategy_exprs.items():
|
|
916
|
+
names, dtypes, tables = self._apply_unary_op(
|
|
917
|
+
FillNullWithStrategyOp(
|
|
918
|
+
named_exprs=fill_exprs,
|
|
919
|
+
order_index=order_index,
|
|
920
|
+
by_cols_for_scan=fill_null_by_cols_for_scan,
|
|
921
|
+
local_grouper=local,
|
|
922
|
+
policy=replace_policy[strategy],
|
|
923
|
+
),
|
|
924
|
+
df,
|
|
925
|
+
grouper,
|
|
926
|
+
)
|
|
927
|
+
broadcasted_cols.extend(
|
|
928
|
+
self._reorder_to_input(
|
|
929
|
+
row_id,
|
|
930
|
+
by_cols,
|
|
931
|
+
df.num_rows,
|
|
932
|
+
tables,
|
|
933
|
+
names,
|
|
934
|
+
dtypes,
|
|
935
|
+
order_index=order_index,
|
|
936
|
+
stream=df.stream,
|
|
637
937
|
)
|
|
638
938
|
)
|
|
639
939
|
|
|
940
|
+
if cum_named := unary_window_ops["cum_sum"]:
|
|
941
|
+
order_index, cum_sum_by_cols_for_scan, local = (
|
|
942
|
+
self._grouped_window_scan_setup(
|
|
943
|
+
by_cols,
|
|
944
|
+
row_id=row_id,
|
|
945
|
+
order_by_col=order_by_col
|
|
946
|
+
if self._order_by_expr is not None
|
|
947
|
+
else None,
|
|
948
|
+
ob_desc=self.options[2]
|
|
949
|
+
if self._order_by_expr is not None
|
|
950
|
+
else False,
|
|
951
|
+
ob_nulls_last=self.options[3]
|
|
952
|
+
if self._order_by_expr is not None
|
|
953
|
+
else False,
|
|
954
|
+
grouper=grouper,
|
|
955
|
+
stream=df.stream,
|
|
956
|
+
)
|
|
957
|
+
)
|
|
958
|
+
names, dtypes, tables = self._apply_unary_op(
|
|
959
|
+
CumSumOp(
|
|
960
|
+
named_exprs=cum_named,
|
|
961
|
+
order_index=order_index,
|
|
962
|
+
by_cols_for_scan=cum_sum_by_cols_for_scan,
|
|
963
|
+
local_grouper=local,
|
|
964
|
+
),
|
|
965
|
+
df,
|
|
966
|
+
grouper,
|
|
967
|
+
)
|
|
968
|
+
broadcasted_cols.extend(
|
|
969
|
+
self._reorder_to_input(
|
|
970
|
+
row_id,
|
|
971
|
+
by_cols,
|
|
972
|
+
df.num_rows,
|
|
973
|
+
tables,
|
|
974
|
+
names,
|
|
975
|
+
dtypes,
|
|
976
|
+
order_index=order_index,
|
|
977
|
+
stream=df.stream,
|
|
978
|
+
)
|
|
979
|
+
)
|
|
980
|
+
|
|
640
981
|
# Create a temporary DataFrame with the broadcasted columns named by their
|
|
641
982
|
# placeholder names from agg decomposition, then evaluate the post-expression.
|
|
642
|
-
df = DataFrame(broadcasted_cols)
|
|
983
|
+
df = DataFrame(broadcasted_cols, stream=df.stream)
|
|
643
984
|
return self.post.value.evaluate(df, context=ExecutionContext.FRAME)
|