cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -6,23 +6,27 @@
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
- import itertools
9
+ from collections import defaultdict
10
10
  from dataclasses import dataclass
11
11
  from functools import singledispatchmethod
12
12
  from typing import TYPE_CHECKING, Any
13
13
 
14
- import polars as pl
15
-
16
14
  import pylibcudf as plc
17
15
 
18
16
  from cudf_polars.containers import Column, DataFrame, DataType
19
17
  from cudf_polars.dsl import expr
20
18
  from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
21
19
  from cudf_polars.dsl.utils.reshape import broadcast
22
- from cudf_polars.dsl.utils.windows import offsets_to_windows, range_window_bounds
20
+ from cudf_polars.dsl.utils.windows import (
21
+ duration_to_int,
22
+ offsets_to_windows,
23
+ range_window_bounds,
24
+ )
23
25
 
24
26
  if TYPE_CHECKING:
25
- from collections.abc import Generator, Sequence
27
+ from collections.abc import Sequence
28
+
29
+ from rmm.pylibrmm.stream import Stream
26
30
 
27
31
  from cudf_polars.typing import ClosedInterval, Duration
28
32
 
@@ -42,6 +46,16 @@ class RankOp(UnaryOp):
42
46
  pass
43
47
 
44
48
 
49
+ @dataclass(frozen=True)
50
+ class FillNullWithStrategyOp(UnaryOp):
51
+ policy: plc.replace.ReplacePolicy = plc.replace.ReplacePolicy.PRECEDING
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class CumSumOp(UnaryOp):
56
+ pass
57
+
58
+
45
59
  def to_request(
46
60
  value: expr.Expr, orderby: Column, df: DataFrame
47
61
  ) -> plc.rolling.RollingRequest:
@@ -78,26 +92,15 @@ def to_request(
78
92
  return plc.rolling.RollingRequest(col.obj, min_periods, value.agg_request)
79
93
 
80
94
 
81
- def _by_exprs(b: Expr | tuple) -> Generator[Expr]:
82
- if isinstance(b, Expr):
83
- yield b
84
- elif isinstance(b, tuple): # pragma: no cover; tests cover this path when
85
- # run with the distributed scheduler only
86
- for item in b:
87
- yield from _by_exprs(item)
88
- else:
89
- yield expr.Literal(DataType(pl.Int64()), b) # pragma: no cover
90
-
91
-
92
95
  class RollingWindow(Expr):
93
96
  __slots__ = (
94
97
  "closed_window",
95
- "following",
98
+ "following_ordinal",
96
99
  "offset",
97
100
  "orderby",
98
101
  "orderby_dtype",
99
102
  "period",
100
- "preceding",
103
+ "preceding_ordinal",
101
104
  )
102
105
  _non_child = (
103
106
  "dtype",
@@ -111,7 +114,7 @@ class RollingWindow(Expr):
111
114
  def __init__(
112
115
  self,
113
116
  dtype: DataType,
114
- orderby_dtype: DataType,
117
+ orderby_dtype: plc.DataType,
115
118
  offset: Duration,
116
119
  period: Duration,
117
120
  closed_window: ClosedInterval,
@@ -126,9 +129,11 @@ class RollingWindow(Expr):
126
129
  # within `__init__`).
127
130
  self.offset = offset
128
131
  self.period = period
129
- self.preceding, self.following = offsets_to_windows(
130
- orderby_dtype, offset, period
131
- )
132
+ self.orderby_dtype = orderby_dtype
133
+ self.offset = offset
134
+ self.period = period
135
+ self.preceding_ordinal = duration_to_int(orderby_dtype, *offset)
136
+ self.following_ordinal = duration_to_int(orderby_dtype, *period)
132
137
  self.closed_window = closed_window
133
138
  self.orderby = orderby
134
139
  self.children = (agg,)
@@ -137,7 +142,9 @@ class RollingWindow(Expr):
137
142
  raise NotImplementedError(
138
143
  "Incorrect handling of empty groups for list collection"
139
144
  )
140
- if not plc.rolling.is_valid_rolling_aggregation(agg.dtype.plc, agg.agg_request):
145
+ if not plc.rolling.is_valid_rolling_aggregation(
146
+ agg.dtype.plc_type, agg.agg_request
147
+ ):
141
148
  raise NotImplementedError(f"Unsupported rolling aggregation {agg}")
142
149
 
143
150
  def do_evaluate( # noqa: D102
@@ -154,18 +161,28 @@ class RollingWindow(Expr):
154
161
  plc.traits.is_integral(orderby.obj.type())
155
162
  and orderby.obj.type().id() != plc.TypeId.INT64
156
163
  ):
157
- orderby_obj = plc.unary.cast(orderby.obj, plc.DataType(plc.TypeId.INT64))
164
+ orderby_obj = plc.unary.cast(
165
+ orderby.obj, plc.DataType(plc.TypeId.INT64), stream=df.stream
166
+ )
158
167
  else:
159
168
  orderby_obj = orderby.obj
169
+ preceding_scalar, following_scalar = offsets_to_windows(
170
+ self.orderby_dtype,
171
+ self.preceding_ordinal,
172
+ self.following_ordinal,
173
+ stream=df.stream,
174
+ )
160
175
  preceding, following = range_window_bounds(
161
- self.preceding, self.following, self.closed_window
176
+ preceding_scalar, following_scalar, self.closed_window
162
177
  )
163
178
  if orderby.obj.null_count() != 0:
164
179
  raise RuntimeError(
165
180
  f"Index column '{self.orderby}' in rolling may not contain nulls"
166
181
  )
167
182
  if not orderby.check_sorted(
168
- order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE
183
+ order=plc.types.Order.ASCENDING,
184
+ null_order=plc.types.NullOrder.BEFORE,
185
+ stream=df.stream,
169
186
  ):
170
187
  raise RuntimeError(
171
188
  f"Index column '{self.orderby}' in rolling is not sorted, please sort first"
@@ -178,6 +195,7 @@ class RollingWindow(Expr):
178
195
  preceding,
179
196
  following,
180
197
  [to_request(agg, orderby, df)],
198
+ stream=df.stream,
181
199
  ).columns()
182
200
  return Column(result, dtype=self.dtype)
183
201
 
@@ -207,7 +225,6 @@ class GroupedRollingWindow(Expr):
207
225
  "named_aggs",
208
226
  "post",
209
227
  "by_count",
210
- "_order_by_expr",
211
228
  )
212
229
 
213
230
  def __init__(
@@ -216,15 +233,18 @@ class GroupedRollingWindow(Expr):
216
233
  options: Any,
217
234
  named_aggs: Sequence[expr.NamedExpr],
218
235
  post: expr.NamedExpr,
219
- *by: Expr,
220
- _order_by_expr: Expr | None = None,
236
+ by_count: int,
237
+ *children: Expr,
221
238
  ) -> None:
222
239
  self.dtype = dtype
223
240
  self.options = options
224
241
  self.named_aggs = tuple(named_aggs)
225
242
  self.post = post
243
+ self.by_count = by_count
244
+ has_order_by = self.options[1]
226
245
  self.is_pointwise = False
227
- self._order_by_expr = _order_by_expr
246
+ self.children = tuple(children)
247
+ self._order_by_expr = children[by_count] if has_order_by else None
228
248
 
229
249
  unsupported = [
230
250
  type(named_expr.value).__name__
@@ -233,7 +253,8 @@ class GroupedRollingWindow(Expr):
233
253
  isinstance(named_expr.value, (expr.Len, expr.Agg))
234
254
  or (
235
255
  isinstance(named_expr.value, expr.UnaryFunction)
236
- and named_expr.value.name in {"rank"}
256
+ and named_expr.value.name
257
+ in {"rank", "fill_null_with_strategy", "cum_sum"}
237
258
  )
238
259
  )
239
260
  ]
@@ -242,29 +263,22 @@ class GroupedRollingWindow(Expr):
242
263
  raise NotImplementedError(
243
264
  f"Unsupported over(...) only expression: {kinds}="
244
265
  )
245
-
246
- # Ensures every partition-by is an Expr
247
- # Fixes over(1) cases with the streaming
248
- # executor and a small blocksize
249
- by_expr = [e for b in by for e in _by_exprs(b)]
250
-
251
- # Expose agg dependencies as children so the streaming
252
- # executor retains required source columns
253
- child_deps = [
254
- v.children[0]
255
- for ne in self.named_aggs
256
- for v in (ne.value,)
257
- if isinstance(v, expr.Agg)
258
- or (isinstance(v, expr.UnaryFunction) and v.name in {"rank"})
259
- ]
260
- self.by_count = len(by_expr)
261
- self.children = tuple(
262
- itertools.chain(
263
- by_expr,
264
- (() if self._order_by_expr is None else (self._order_by_expr,)),
265
- child_deps,
266
+ if has_order_by:
267
+ ob = self._order_by_expr
268
+ is_multi_order_by = (
269
+ isinstance(ob, expr.UnaryFunction)
270
+ and ob.name == "as_struct"
271
+ and len(ob.children) > 1
266
272
  )
267
- )
273
+ has_order_sensitive_agg = any(
274
+ isinstance(ne.value, expr.Agg)
275
+ and ne.value.agg_request.kind() == plc.aggregation.Kind.NTH_ELEMENT
276
+ for ne in self.named_aggs
277
+ )
278
+ if is_multi_order_by and has_order_sensitive_agg:
279
+ raise NotImplementedError(
280
+ "Multiple order_by keys with order-sensitive aggregations"
281
+ )
268
282
 
269
283
  @staticmethod
270
284
  def _sorted_grouper(by_cols_for_scan: list[Column]) -> plc.groupby.GroupBy:
@@ -311,6 +325,7 @@ class GroupedRollingWindow(Expr):
311
325
  plc.Table([val_col]),
312
326
  order_index,
313
327
  plc.copying.OutOfBoundsPolicy.NULLIFY,
328
+ stream=df.stream,
314
329
  ).columns()[0]
315
330
  assert isinstance(rank_expr, expr.UnaryFunction)
316
331
  method_str, descending, _ = rank_expr.options
@@ -351,6 +366,85 @@ class GroupedRollingWindow(Expr):
351
366
  _, rank_tables = grouper.scan(rank_requests)
352
367
  return rank_out_names, rank_out_dtypes, rank_tables
353
368
 
369
+ @_apply_unary_op.register
370
+ def _( # type: ignore[no-untyped-def]
371
+ self,
372
+ op: FillNullWithStrategyOp,
373
+ df: DataFrame,
374
+ _,
375
+ ) -> tuple[list[str], list[DataType], list[plc.Table]]:
376
+ named_exprs = op.named_exprs
377
+
378
+ plc_cols = [
379
+ ne.value.children[0].evaluate(df, context=ExecutionContext.FRAME).obj
380
+ for ne in named_exprs
381
+ ]
382
+ if op.order_index is not None:
383
+ vals_tbl = plc.copying.gather(
384
+ plc.Table(plc_cols),
385
+ op.order_index,
386
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
387
+ stream=df.stream,
388
+ )
389
+ else:
390
+ vals_tbl = plc.Table(plc_cols)
391
+ local_grouper = op.local_grouper
392
+ assert isinstance(local_grouper, plc.groupby.GroupBy)
393
+ _, filled_tbl = local_grouper.replace_nulls(
394
+ vals_tbl,
395
+ [op.policy] * len(plc_cols),
396
+ )
397
+
398
+ tables = [plc.Table([column]) for column in filled_tbl.columns()]
399
+ names = [ne.name for ne in named_exprs]
400
+ dtypes = [ne.value.dtype for ne in named_exprs]
401
+ return names, dtypes, tables
402
+
403
+ @_apply_unary_op.register
404
+ def _( # type: ignore[no-untyped-def]
405
+ self,
406
+ op: CumSumOp,
407
+ df: DataFrame,
408
+ _,
409
+ ) -> tuple[list[str], list[DataType], list[plc.Table]]:
410
+ cum_named = op.named_exprs
411
+ order_index = op.order_index
412
+
413
+ requests: list[plc.groupby.GroupByRequest] = []
414
+ out_names: list[str] = []
415
+ out_dtypes: list[DataType] = []
416
+
417
+ # Instead of calling self._gather_columns, let's call plc.copying.gather directly
418
+ # since we need plc.Column objects, not cudf_polars Column objects
419
+ if order_index is not None:
420
+ plc_cols = [
421
+ ne.value.children[0].evaluate(df, context=ExecutionContext.FRAME).obj
422
+ for ne in cum_named
423
+ ]
424
+ val_cols = plc.copying.gather(
425
+ plc.Table(plc_cols),
426
+ order_index,
427
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
428
+ stream=df.stream,
429
+ ).columns()
430
+ else:
431
+ val_cols = [
432
+ ne.value.children[0].evaluate(df, context=ExecutionContext.FRAME).obj
433
+ for ne in cum_named
434
+ ]
435
+ agg = plc.aggregation.sum()
436
+
437
+ for ne, val_col in zip(cum_named, val_cols, strict=True):
438
+ requests.append(plc.groupby.GroupByRequest(val_col, [agg]))
439
+ out_names.append(ne.name)
440
+ out_dtypes.append(ne.value.dtype)
441
+
442
+ local_grouper = op.local_grouper
443
+ assert isinstance(local_grouper, plc.groupby.GroupBy)
444
+ _, tables = local_grouper.scan(requests)
445
+
446
+ return out_names, out_dtypes, tables
447
+
354
448
  def _reorder_to_input(
355
449
  self,
356
450
  row_id: plc.Column,
@@ -361,6 +455,7 @@ class GroupedRollingWindow(Expr):
361
455
  rank_out_dtypes: list[DataType],
362
456
  *,
363
457
  order_index: plc.Column | None = None,
458
+ stream: Stream,
364
459
  ) -> list[Column]:
365
460
  # Reorder scan results from grouped-order back to input row order
366
461
  if order_index is None:
@@ -370,6 +465,7 @@ class GroupedRollingWindow(Expr):
370
465
  plc.Table([*(c.obj for c in by_cols), row_id]),
371
466
  [*key_orders, plc.types.Order.ASCENDING],
372
467
  [*key_nulls, plc.types.NullOrder.AFTER],
468
+ stream=stream,
373
469
  )
374
470
 
375
471
  return [
@@ -380,11 +476,15 @@ class GroupedRollingWindow(Expr):
380
476
  plc.Table(
381
477
  [
382
478
  plc.Column.from_scalar(
383
- plc.Scalar.from_py(None, tbl.columns()[0].type()),
479
+ plc.Scalar.from_py(
480
+ None, tbl.columns()[0].type(), stream=stream
481
+ ),
384
482
  n_rows,
483
+ stream=stream,
385
484
  )
386
485
  ]
387
486
  ),
487
+ stream=stream,
388
488
  ).columns()[0],
389
489
  name=name,
390
490
  dtype=dtype,
@@ -401,6 +501,8 @@ class GroupedRollingWindow(Expr):
401
501
  reductions: list[expr.NamedExpr] = []
402
502
  unary_window_ops: dict[str, list[expr.NamedExpr]] = {
403
503
  "rank": [],
504
+ "fill_null_with_strategy": [],
505
+ "cum_sum": [],
404
506
  }
405
507
 
406
508
  for ne in self.named_aggs:
@@ -421,6 +523,7 @@ class GroupedRollingWindow(Expr):
421
523
  ob_nulls_last: bool,
422
524
  value_col: plc.Column | None = None,
423
525
  value_desc: bool = False,
526
+ stream: Stream,
424
527
  ) -> plc.Column:
425
528
  """Compute a stable row ordering for unary operations in a grouped context."""
426
529
  cols: list[plc.Column] = [c.obj for c in by_cols]
@@ -444,7 +547,7 @@ class GroupedRollingWindow(Expr):
444
547
  )
445
548
  nulls.append(
446
549
  plc.types.NullOrder.AFTER
447
- if ob_nulls_last
550
+ if ob_desc ^ ob_nulls_last
448
551
  else plc.types.NullOrder.BEFORE
449
552
  )
450
553
 
@@ -453,31 +556,153 @@ class GroupedRollingWindow(Expr):
453
556
  orders.append(plc.types.Order.ASCENDING)
454
557
  nulls.append(plc.types.NullOrder.AFTER)
455
558
 
456
- return plc.sorting.stable_sorted_order(plc.Table(cols), orders, nulls)
559
+ return plc.sorting.stable_sorted_order(
560
+ plc.Table(cols), orders, nulls, stream=stream
561
+ )
457
562
 
458
563
  def _gather_columns(
459
- self,
460
- cols: list[Column],
461
- order_index: plc.Column,
462
- ) -> list[plc.Column] | list[Column]:
564
+ self, cols: Sequence[Column], order_index: plc.Column, stream: Stream
565
+ ) -> list[Column]:
463
566
  gathered_tbl = plc.copying.gather(
464
567
  plc.Table([c.obj for c in cols]),
465
568
  order_index,
466
569
  plc.copying.OutOfBoundsPolicy.NULLIFY,
570
+ stream=stream,
467
571
  )
468
572
 
469
573
  return [
470
574
  Column(
471
- gathered_tbl.columns()[i],
575
+ gathered,
472
576
  name=c.name,
473
577
  dtype=c.dtype,
474
578
  order=c.order,
475
579
  null_order=c.null_order,
476
- is_sorted=True,
580
+ is_sorted=c.is_sorted,
477
581
  )
478
- for i, c in enumerate(cols)
582
+ for gathered, c in zip(gathered_tbl.columns(), cols, strict=True)
479
583
  ]
480
584
 
585
+ def _grouped_window_scan_setup(
586
+ self,
587
+ by_cols: list[Column],
588
+ *,
589
+ row_id: plc.Column,
590
+ order_by_col: Column | None,
591
+ ob_desc: bool,
592
+ ob_nulls_last: bool,
593
+ grouper: plc.groupby.GroupBy,
594
+ stream: Stream,
595
+ ) -> tuple[plc.Column | None, list[Column] | None, plc.groupby.GroupBy]:
596
+ if order_by_col is None:
597
+ # keep the original ordering
598
+ return None, None, grouper
599
+ order_index = self._build_window_order_index(
600
+ by_cols,
601
+ row_id=row_id,
602
+ order_by_col=order_by_col,
603
+ ob_desc=ob_desc,
604
+ ob_nulls_last=ob_nulls_last,
605
+ stream=stream,
606
+ )
607
+ by_cols_for_scan = self._gather_columns(by_cols, order_index, stream=stream)
608
+ assert by_cols_for_scan is not None
609
+ local = self._sorted_grouper(by_cols_for_scan)
610
+ return order_index, by_cols_for_scan, local
611
+
612
+ def _broadcast_agg_results(
613
+ self,
614
+ by_tbl: plc.Table,
615
+ group_keys_tbl: plc.Table,
616
+ value_tbls: list[plc.Table],
617
+ names: list[str],
618
+ dtypes: list[DataType],
619
+ stream: Stream,
620
+ ) -> list[Column]:
621
+ # We do a left-join between the input keys to group-keys
622
+ # so every input row appears exactly once. left_order is
623
+ # returned un-ordered by libcudf.
624
+ left_order, right_order = plc.join.left_join(
625
+ by_tbl, group_keys_tbl, plc.types.NullEquality.EQUAL, stream
626
+ )
627
+
628
+ # Scatter the right order indices into an all-null table
629
+ # and at the position of the index in left order. Now we
630
+ # have the map between rows and groups with the correct ordering.
631
+ left_rows = left_order.size()
632
+ target = plc.Column.from_scalar(
633
+ plc.Scalar.from_py(None, plc.types.SIZE_TYPE, stream), left_rows, stream
634
+ )
635
+ aligned_map = plc.copying.scatter(
636
+ plc.Table([right_order]),
637
+ left_order,
638
+ plc.Table([target]),
639
+ stream,
640
+ ).columns()[0]
641
+
642
+ # Broadcast each scalar aggregated result back to row-shape using
643
+ # the aligned mapping between row indices and group indices.
644
+ out_cols = (t.columns()[0] for t in value_tbls)
645
+ return [
646
+ Column(
647
+ plc.copying.gather(
648
+ plc.Table([col]),
649
+ aligned_map,
650
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
651
+ stream,
652
+ ).columns()[0],
653
+ name=name,
654
+ dtype=dtype,
655
+ )
656
+ for name, dtype, col in zip(names, dtypes, out_cols, strict=True)
657
+ ]
658
+
659
+ def _build_groupby_requests(
660
+ self,
661
+ named_exprs: list[expr.NamedExpr],
662
+ df: DataFrame,
663
+ order_index: plc.Column | None = None,
664
+ by_cols: list[Column] | None = None,
665
+ ) -> tuple[list[plc.groupby.GroupByRequest], list[str], list[DataType]]:
666
+ assert by_cols is not None
667
+ gb_requests: list[plc.groupby.GroupByRequest] = []
668
+ out_names: list[str] = []
669
+ out_dtypes: list[DataType] = []
670
+
671
+ eval_cols: list[plc.Column] = []
672
+ val_nodes: list[tuple[expr.NamedExpr, expr.Agg]] = []
673
+
674
+ for ne in named_exprs:
675
+ val = ne.value
676
+ out_names.append(ne.name)
677
+ out_dtypes.append(val.dtype)
678
+ if isinstance(val, expr.Agg):
679
+ (child,) = (
680
+ val.children
681
+ if getattr(val, "name", None) != "quantile"
682
+ else (val.children[0],)
683
+ )
684
+ eval_cols.append(child.evaluate(df, context=ExecutionContext.FRAME).obj)
685
+ val_nodes.append((ne, val))
686
+
687
+ if order_index is not None and eval_cols:
688
+ eval_cols = plc.copying.gather(
689
+ plc.Table(eval_cols),
690
+ order_index,
691
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
692
+ stream=df.stream,
693
+ ).columns()
694
+
695
+ gathered_iter = iter(eval_cols)
696
+ for ne in named_exprs:
697
+ val = ne.value
698
+ if isinstance(val, expr.Len):
699
+ col = by_cols[0].obj
700
+ else:
701
+ col = next(gathered_iter)
702
+ gb_requests.append(plc.groupby.GroupByRequest(col, [val.agg_request]))
703
+
704
+ return gb_requests, out_names, out_dtypes
705
+
481
706
  def do_evaluate( # noqa: D102
482
707
  self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
483
708
  ) -> Column:
@@ -493,9 +718,12 @@ class GroupedRollingWindow(Expr):
493
718
  by_cols = broadcast(
494
719
  *(b.evaluate(df) for b in by_exprs),
495
720
  target_length=df.num_rows,
721
+ stream=df.stream,
496
722
  )
497
723
  order_by_col = (
498
- broadcast(order_by_expr.evaluate(df), target_length=df.num_rows)[0]
724
+ broadcast(
725
+ order_by_expr.evaluate(df), target_length=df.num_rows, stream=df.stream
726
+ )[0]
499
727
  if order_by_expr is not None
500
728
  else None
501
729
  )
@@ -518,67 +746,73 @@ class GroupedRollingWindow(Expr):
518
746
  scalar_named, unary_window_ops = self._split_named_expr()
519
747
 
520
748
  # Build GroupByRequests for scalar aggregations
521
- gb_requests: list[plc.groupby.GroupByRequest] = []
522
- out_names: list[str] = []
523
- out_dtypes: list[DataType] = []
749
+ order_sensitive: list[expr.NamedExpr] = []
750
+ other_scalars: list[expr.NamedExpr] = []
524
751
  for ne in scalar_named:
525
752
  val = ne.value
526
- out_names.append(ne.name)
527
- out_dtypes.append(val.dtype)
753
+ if (
754
+ self._order_by_expr is not None
755
+ and isinstance(val, expr.Agg)
756
+ and val.agg_request.kind() == plc.aggregation.Kind.NTH_ELEMENT
757
+ ):
758
+ order_sensitive.append(ne)
759
+ else:
760
+ other_scalars.append(ne)
528
761
 
529
- if isinstance(val, expr.Len):
530
- # A count aggregation, we need a column so use a key column
531
- col = by_cols[0].obj
532
- gb_requests.append(plc.groupby.GroupByRequest(col, [val.agg_request]))
533
- elif isinstance(val, expr.Agg):
534
- (child,) = (
535
- val.children if val.name != "quantile" else (val.children[0],)
536
- )
537
- col = child.evaluate(df, context=ExecutionContext.FRAME).obj
538
- gb_requests.append(plc.groupby.GroupByRequest(col, [val.agg_request]))
762
+ gb_requests, out_names, out_dtypes = self._build_groupby_requests(
763
+ other_scalars, df, by_cols=by_cols
764
+ )
539
765
 
540
766
  group_keys_tbl, value_tables = grouper.aggregate(gb_requests)
541
- out_cols = (t.columns()[0] for t in value_tables)
542
-
543
- # We do a left-join between the input keys to group-keys
544
- # so every input row appears exactly once. left_order is
545
- # returned un-ordered by libcudf.
546
- left_order, right_order = plc.join.left_join(
547
- by_tbl, group_keys_tbl, plc.types.NullEquality.EQUAL
767
+ broadcasted_cols = self._broadcast_agg_results(
768
+ by_tbl,
769
+ group_keys_tbl,
770
+ value_tables,
771
+ out_names,
772
+ out_dtypes,
773
+ df.stream,
548
774
  )
549
775
 
550
- # Scatter the right order indices into an all-null table
551
- # and at the position of the index in left order. Now we
552
- # have the map between rows and groups with the correct ordering.
553
- left_rows = left_order.size()
554
- target = plc.Column.from_scalar(
555
- plc.Scalar.from_py(None, plc.types.SIZE_TYPE), left_rows
556
- )
557
- aligned_map = plc.copying.scatter(
558
- plc.Table([right_order]),
559
- left_order,
560
- plc.Table([target]),
561
- ).columns()[0]
776
+ if order_sensitive:
777
+ row_id = plc.filling.sequence(
778
+ df.num_rows,
779
+ plc.Scalar.from_py(0, plc.types.SIZE_TYPE, stream=df.stream),
780
+ plc.Scalar.from_py(1, plc.types.SIZE_TYPE, stream=df.stream),
781
+ stream=df.stream,
782
+ )
783
+ _, _, ob_desc, ob_nulls_last = self.options
784
+ order_index, _, local = self._grouped_window_scan_setup(
785
+ by_cols,
786
+ row_id=row_id,
787
+ order_by_col=order_by_col,
788
+ ob_desc=ob_desc,
789
+ ob_nulls_last=ob_nulls_last,
790
+ grouper=grouper,
791
+ stream=df.stream,
792
+ )
793
+ assert order_index is not None
562
794
 
563
- # Broadcast each scalar aggregated result back to row-shape using
564
- # the aligned mapping between row indices and group indices.
565
- broadcasted_cols = [
566
- Column(
567
- plc.copying.gather(
568
- plc.Table([col]), aligned_map, plc.copying.OutOfBoundsPolicy.NULLIFY
569
- ).columns()[0],
570
- name=named_expr.name,
571
- dtype=dtype,
795
+ gb_requests, out_names, out_dtypes = self._build_groupby_requests(
796
+ order_sensitive, df, order_index=order_index, by_cols=by_cols
572
797
  )
573
- for named_expr, dtype, col in zip(
574
- scalar_named, out_dtypes, out_cols, strict=True
798
+
799
+ group_keys_tbl_local, value_tables_local = local.aggregate(gb_requests)
800
+ broadcasted_cols.extend(
801
+ self._broadcast_agg_results(
802
+ by_tbl,
803
+ group_keys_tbl_local,
804
+ value_tables_local,
805
+ out_names,
806
+ out_dtypes,
807
+ df.stream,
808
+ )
575
809
  )
576
- ]
577
810
 
578
811
  row_id = plc.filling.sequence(
579
812
  df.num_rows,
580
- plc.Scalar.from_py(0, plc.types.SIZE_TYPE),
581
- plc.Scalar.from_py(1, plc.types.SIZE_TYPE),
813
+ plc.Scalar.from_py(0, plc.types.SIZE_TYPE, stream=df.stream),
814
+ plc.Scalar.from_py(1, plc.types.SIZE_TYPE, stream=df.stream),
815
+ stream=df.stream,
582
816
  )
583
817
 
584
818
  if rank_named := unary_window_ops["rank"]:
@@ -588,7 +822,6 @@ class GroupedRollingWindow(Expr):
588
822
  rank_expr = ne.value
589
823
  assert isinstance(rank_expr, expr.UnaryFunction)
590
824
  (child,) = rank_expr.children
591
- val = child.evaluate(df, context=ExecutionContext.FRAME).obj
592
825
  desc = rank_expr.options[1]
593
826
 
594
827
  order_index = self._build_window_order_index(
@@ -597,16 +830,21 @@ class GroupedRollingWindow(Expr):
597
830
  order_by_col=order_by_col,
598
831
  ob_desc=ob_desc,
599
832
  ob_nulls_last=ob_nulls_last,
600
- value_col=val,
833
+ value_col=child.evaluate(
834
+ df, context=ExecutionContext.FRAME
835
+ ).obj,
601
836
  value_desc=desc,
837
+ stream=df.stream,
602
838
  )
603
- by_cols_for_scan = self._gather_columns(by_cols, order_index)
604
- local = GroupedRollingWindow._sorted_grouper(by_cols_for_scan)
839
+ rank_by_cols_for_scan = self._gather_columns(
840
+ by_cols, order_index, stream=df.stream
841
+ )
842
+ local = GroupedRollingWindow._sorted_grouper(rank_by_cols_for_scan)
605
843
  names, dtypes, tables = self._apply_unary_op(
606
844
  RankOp(
607
845
  named_exprs=[ne],
608
846
  order_index=order_index,
609
- by_cols_for_scan=by_cols_for_scan,
847
+ by_cols_for_scan=rank_by_cols_for_scan,
610
848
  local_grouper=local,
611
849
  ),
612
850
  df,
@@ -621,6 +859,7 @@ class GroupedRollingWindow(Expr):
621
859
  names,
622
860
  dtypes,
623
861
  order_index=order_index,
862
+ stream=df.stream,
624
863
  )
625
864
  )
626
865
  else:
@@ -633,11 +872,113 @@ class GroupedRollingWindow(Expr):
633
872
  )
634
873
  broadcasted_cols.extend(
635
874
  self._reorder_to_input(
636
- row_id, by_cols, df.num_rows, tables, names, dtypes
875
+ row_id,
876
+ by_cols,
877
+ df.num_rows,
878
+ tables,
879
+ names,
880
+ dtypes,
881
+ stream=df.stream,
882
+ )
883
+ )
884
+
885
+ if fill_named := unary_window_ops["fill_null_with_strategy"]:
886
+ order_index, fill_null_by_cols_for_scan, local = (
887
+ self._grouped_window_scan_setup(
888
+ by_cols,
889
+ row_id=row_id,
890
+ order_by_col=order_by_col
891
+ if self._order_by_expr is not None
892
+ else None,
893
+ ob_desc=self.options[2]
894
+ if self._order_by_expr is not None
895
+ else False,
896
+ ob_nulls_last=self.options[3]
897
+ if self._order_by_expr is not None
898
+ else False,
899
+ grouper=grouper,
900
+ stream=df.stream,
901
+ )
902
+ )
903
+
904
+ strategy_exprs: dict[str, list[expr.NamedExpr]] = defaultdict(list)
905
+ for ne in fill_named:
906
+ fill_null_expr = ne.value
907
+ assert isinstance(fill_null_expr, expr.UnaryFunction)
908
+ strategy_exprs[fill_null_expr.options[0]].append(ne)
909
+
910
+ replace_policy = {
911
+ "forward": plc.replace.ReplacePolicy.PRECEDING,
912
+ "backward": plc.replace.ReplacePolicy.FOLLOWING,
913
+ }
914
+
915
+ for strategy, fill_exprs in strategy_exprs.items():
916
+ names, dtypes, tables = self._apply_unary_op(
917
+ FillNullWithStrategyOp(
918
+ named_exprs=fill_exprs,
919
+ order_index=order_index,
920
+ by_cols_for_scan=fill_null_by_cols_for_scan,
921
+ local_grouper=local,
922
+ policy=replace_policy[strategy],
923
+ ),
924
+ df,
925
+ grouper,
926
+ )
927
+ broadcasted_cols.extend(
928
+ self._reorder_to_input(
929
+ row_id,
930
+ by_cols,
931
+ df.num_rows,
932
+ tables,
933
+ names,
934
+ dtypes,
935
+ order_index=order_index,
936
+ stream=df.stream,
637
937
  )
638
938
  )
639
939
 
940
+ if cum_named := unary_window_ops["cum_sum"]:
941
+ order_index, cum_sum_by_cols_for_scan, local = (
942
+ self._grouped_window_scan_setup(
943
+ by_cols,
944
+ row_id=row_id,
945
+ order_by_col=order_by_col
946
+ if self._order_by_expr is not None
947
+ else None,
948
+ ob_desc=self.options[2]
949
+ if self._order_by_expr is not None
950
+ else False,
951
+ ob_nulls_last=self.options[3]
952
+ if self._order_by_expr is not None
953
+ else False,
954
+ grouper=grouper,
955
+ stream=df.stream,
956
+ )
957
+ )
958
+ names, dtypes, tables = self._apply_unary_op(
959
+ CumSumOp(
960
+ named_exprs=cum_named,
961
+ order_index=order_index,
962
+ by_cols_for_scan=cum_sum_by_cols_for_scan,
963
+ local_grouper=local,
964
+ ),
965
+ df,
966
+ grouper,
967
+ )
968
+ broadcasted_cols.extend(
969
+ self._reorder_to_input(
970
+ row_id,
971
+ by_cols,
972
+ df.num_rows,
973
+ tables,
974
+ names,
975
+ dtypes,
976
+ order_index=order_index,
977
+ stream=df.stream,
978
+ )
979
+ )
980
+
640
981
  # Create a temporary DataFrame with the broadcasted columns named by their
641
982
  # placeholder names from agg decomposition, then evaluate the post-expression.
642
- df = DataFrame(broadcasted_cols)
983
+ df = DataFrame(broadcasted_cols, stream=df.stream)
643
984
  return self.post.value.evaluate(df, context=ExecutionContext.FRAME)