cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import rmm
18
18
 
19
19
  from cudf_polars.containers import Column, DataFrame, DataType
20
20
  from cudf_polars.dsl.expressions.base import NamedExpr
21
+ from cudf_polars.utils.cuda_stream import get_dask_cuda_stream
21
22
 
22
23
  if TYPE_CHECKING:
23
24
  from collections.abc import Hashable, Mapping
@@ -33,7 +34,7 @@ if TYPE_CHECKING:
33
34
  __all__ = ["DaskRegisterManager", "register"]
34
35
 
35
36
 
36
- class DaskRegisterManager: # pragma: no cover; Only used with Distributed scheduler
37
+ class DaskRegisterManager: # pragma: no cover; Only used with Distributed cluster
37
38
  """Manager to ensure ensure serializer is only registered once."""
38
39
 
39
40
  _registered: bool = False
@@ -73,41 +74,57 @@ def register() -> None:
73
74
  @cuda_serialize.register((Column, DataFrame))
74
75
  def serialize_column_or_frame(
75
76
  x: DataFrame | Column,
76
- ) -> tuple[DataFrameHeader | ColumnHeader, list[memoryview]]:
77
+ ) -> tuple[
78
+ DataFrameHeader | ColumnHeader, list[memoryview[bytes] | plc.gpumemoryview]
79
+ ]:
77
80
  with log_errors():
78
- header, frames = x.serialize()
79
- return header, list(frames) # Dask expect a list of frames
81
+ header, frames = x.serialize(stream=get_dask_cuda_stream())
82
+ # Dask expect a list of frames
83
+ return header, list(frames)
80
84
 
81
85
  @cuda_deserialize.register(DataFrame)
82
86
  def _(
83
- header: DataFrameHeader, frames: tuple[memoryview, plc.gpumemoryview]
87
+ header: DataFrameHeader, frames: tuple[memoryview[bytes], plc.gpumemoryview]
84
88
  ) -> DataFrame:
85
89
  with log_errors():
86
90
  metadata, gpudata = frames # TODO: check if this is a length-2 list...
87
- return DataFrame.deserialize(header, (metadata, plc.gpumemoryview(gpudata)))
91
+ return DataFrame.deserialize(
92
+ header,
93
+ (metadata, plc.gpumemoryview(gpudata)),
94
+ stream=get_dask_cuda_stream(),
95
+ )
88
96
 
89
97
  @cuda_deserialize.register(Column)
90
- def _(header: ColumnHeader, frames: tuple[memoryview, plc.gpumemoryview]) -> Column:
98
+ def _(
99
+ header: ColumnHeader, frames: tuple[memoryview[bytes], plc.gpumemoryview]
100
+ ) -> Column:
91
101
  with log_errors():
92
102
  metadata, gpudata = frames
93
- return Column.deserialize(header, (metadata, plc.gpumemoryview(gpudata)))
103
+ return Column.deserialize(
104
+ header,
105
+ (metadata, plc.gpumemoryview(gpudata)),
106
+ stream=get_dask_cuda_stream(),
107
+ )
94
108
 
95
109
  @overload
96
110
  def dask_serialize_column_or_frame(
97
111
  x: DataFrame,
98
- ) -> tuple[DataFrameHeader, tuple[memoryview, memoryview]]: ...
112
+ ) -> tuple[DataFrameHeader, tuple[memoryview[bytes], memoryview[bytes]]]: ...
99
113
 
100
114
  @overload
101
115
  def dask_serialize_column_or_frame(
102
116
  x: Column,
103
- ) -> tuple[ColumnHeader, tuple[memoryview, memoryview]]: ...
117
+ ) -> tuple[ColumnHeader, tuple[memoryview[bytes], memoryview[bytes]]]: ...
104
118
 
105
119
  @dask_serialize.register(Column)
106
120
  def dask_serialize_column_or_frame(
107
121
  x: DataFrame | Column,
108
- ) -> tuple[DataFrameHeader | ColumnHeader, tuple[memoryview, memoryview]]:
122
+ ) -> tuple[
123
+ DataFrameHeader | ColumnHeader, tuple[memoryview[bytes], memoryview[bytes]]
124
+ ]:
125
+ stream = get_dask_cuda_stream()
109
126
  with log_errors():
110
- header, (metadata, gpudata) = x.serialize()
127
+ header, (metadata, gpudata) = x.serialize(stream=stream)
111
128
 
112
129
  # For robustness, we check that the gpu data is contiguous
113
130
  cai = gpudata.__cuda_array_interface__
@@ -117,23 +134,26 @@ def register() -> None:
117
134
  nbytes = cai["shape"][0]
118
135
 
119
136
  # Copy the gpudata to host memory
120
- gpudata_on_host = memoryview(
137
+ gpudata_on_host: memoryview[bytes] = memoryview(
121
138
  rmm.DeviceBuffer(ptr=gpudata.ptr, size=nbytes).copy_to_host()
122
139
  )
123
140
  return header, (metadata, gpudata_on_host)
124
141
 
125
142
  @dask_deserialize.register(Column)
126
- def _(header: ColumnHeader, frames: tuple[memoryview, memoryview]) -> Column:
143
+ def _(header: ColumnHeader, frames: tuple[memoryview[bytes], memoryview]) -> Column:
127
144
  with log_errors():
128
145
  assert len(frames) == 2
129
146
  # Copy the second frame (the gpudata in host memory) back to the gpu
130
- frames = frames[0], plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1]))
131
- return Column.deserialize(header, frames)
147
+ new_frames = (
148
+ frames[0],
149
+ plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1])),
150
+ )
151
+ return Column.deserialize(header, new_frames, stream=get_dask_cuda_stream())
132
152
 
133
153
  @dask_serialize.register(DataFrame)
134
154
  def _(
135
155
  x: DataFrame, context: Mapping[str, Any] | None = None
136
- ) -> tuple[DataFrameHeader, tuple[memoryview, memoryview]]:
156
+ ) -> tuple[DataFrameHeader, tuple[memoryview[bytes], memoryview[bytes]]]:
137
157
  # Do regular serialization if no staging buffer is provided.
138
158
  if context is None or "staging_device_buffer" not in context:
139
159
  return dask_serialize_column_or_frame(x)
@@ -166,12 +186,19 @@ def register() -> None:
166
186
  return header, frame
167
187
 
168
188
  @dask_deserialize.register(DataFrame)
169
- def _(header: DataFrameHeader, frames: tuple[memoryview, memoryview]) -> DataFrame:
189
+ def _(
190
+ header: DataFrameHeader, frames: tuple[memoryview[bytes], memoryview]
191
+ ) -> DataFrame:
170
192
  with log_errors():
171
193
  assert len(frames) == 2
172
194
  # Copy the second frame (the gpudata in host memory) back to the gpu
173
- frames = frames[0], plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1]))
174
- return DataFrame.deserialize(header, frames)
195
+ new_frames = (
196
+ frames[0],
197
+ plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1])),
198
+ )
199
+ return DataFrame.deserialize(
200
+ header, new_frames, stream=get_dask_cuda_stream()
201
+ )
175
202
 
176
203
  @sizeof_dispatch.register(Column)
177
204
  def _(x: Column) -> int:
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
13
13
  from collections.abc import MutableMapping
14
14
 
15
15
  from cudf_polars.dsl import ir
16
- from cudf_polars.dsl.ir import IR
16
+ from cudf_polars.dsl.ir import IR, IRExecutionContext
17
17
  from cudf_polars.experimental.base import (
18
18
  ColumnStats,
19
19
  PartitionInfo,
@@ -77,7 +77,9 @@ def lower_ir_node(
77
77
 
78
78
  @singledispatch
79
79
  def generate_ir_tasks(
80
- ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
80
+ ir: IR,
81
+ partition_info: MutableMapping[IR, PartitionInfo],
82
+ context: IRExecutionContext,
81
83
  ) -> MutableMapping[Any, Any]:
82
84
  """
83
85
  Generate a task graph for evaluation of an IR node.
@@ -88,6 +90,8 @@ def generate_ir_tasks(
88
90
  IR node to generate tasks for.
89
91
  partition_info
90
92
  Partitioning information, obtained from :func:`lower_ir_graph`.
93
+ context
94
+ Runtime context for IR node execution.
91
95
 
92
96
  Returns
93
97
  -------
@@ -139,7 +143,9 @@ def initialize_column_stats(
139
143
 
140
144
  @singledispatch
141
145
  def update_column_stats(
142
- ir: IR, stats: StatsCollector, config_options: ConfigOptions
146
+ ir: IR,
147
+ stats: StatsCollector,
148
+ config_options: ConfigOptions,
143
149
  ) -> None:
144
150
  """
145
151
  Finalize local column statistics for an IR node.
@@ -97,6 +97,7 @@ def lower_distinct(
97
97
  child.schema,
98
98
  shuffle_keys,
99
99
  config_options.executor.shuffle_method,
100
+ config_options.executor.shuffler_insertion_method,
100
101
  child,
101
102
  )
102
103
  partition_info[child] = PartitionInfo(
@@ -150,6 +151,7 @@ def lower_distinct(
150
151
  new_node.schema,
151
152
  shuffle_keys,
152
153
  config_options.executor.shuffle_method,
154
+ config_options.executor.shuffler_insertion_method,
153
155
  new_node,
154
156
  )
155
157
  partition_info[new_node] = PartitionInfo(count=output_count)
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
34
34
 
35
35
 
36
36
  def explain_query(
37
- q: pl.LazyFrame, engine: pl.GPUEngine, *, physical: bool = True
37
+ q: pl.LazyFrame,
38
+ engine: pl.GPUEngine,
39
+ *,
40
+ physical: bool = True,
38
41
  ) -> str:
39
42
  """
40
43
  Return a formatted string representation of the IR plan.
@@ -58,7 +61,17 @@ def explain_query(
58
61
  ir = Translator(q._ldf.visit(), engine).translate_ir()
59
62
 
60
63
  if physical:
61
- lowered_ir, partition_info = lower_ir_graph(ir, config)
64
+ if (
65
+ config.executor.name == "streaming"
66
+ and config.executor.runtime == "rapidsmpf"
67
+ ): # pragma: no cover; rapidsmpf runtime not tested in CI yet
68
+ from cudf_polars.experimental.rapidsmpf.core import (
69
+ lower_ir_graph as rapidsmpf_lower_ir_graph,
70
+ )
71
+
72
+ lowered_ir, partition_info, _ = rapidsmpf_lower_ir_graph(ir, config)
73
+ else:
74
+ lowered_ir, partition_info, _ = lower_ir_graph(ir, config)
62
75
  return _repr_ir_tree(lowered_ir, partition_info)
63
76
  else:
64
77
  if config.executor.name == "streaming":
@@ -38,15 +38,14 @@ from typing import TYPE_CHECKING, TypeAlias, TypedDict
38
38
  import pylibcudf as plc
39
39
 
40
40
  from cudf_polars.dsl.expressions.aggregation import Agg
41
- from cudf_polars.dsl.expressions.base import Col, Expr, NamedExpr
41
+ from cudf_polars.dsl.expressions.base import Col, ExecutionContext, Expr, NamedExpr
42
42
  from cudf_polars.dsl.expressions.binaryop import BinOp
43
43
  from cudf_polars.dsl.expressions.literal import Literal
44
- from cudf_polars.dsl.expressions.unary import Cast, UnaryFunction
44
+ from cudf_polars.dsl.expressions.unary import Cast, Len, UnaryFunction
45
45
  from cudf_polars.dsl.ir import IR, Distinct, Empty, HConcat, Select
46
46
  from cudf_polars.dsl.traversal import (
47
47
  CachingVisitor,
48
48
  )
49
- from cudf_polars.dsl.utils.naming import unique_names
50
49
  from cudf_polars.experimental.base import PartitionInfo
51
50
  from cudf_polars.experimental.repartition import Repartition
52
51
  from cudf_polars.experimental.utils import _get_unique_fractions, _leaf_column_names
@@ -237,7 +236,7 @@ def _decompose_unique(
237
236
 
238
237
 
239
238
  def _decompose_agg_node(
240
- agg: Agg,
239
+ agg: Agg | Len,
241
240
  input_ir: IR,
242
241
  partition_info: MutableMapping[IR, PartitionInfo],
243
242
  config_options: ConfigOptions,
@@ -273,7 +272,7 @@ def _decompose_agg_node(
273
272
  """
274
273
  expr: Expr
275
274
  exprs: list[Expr]
276
- if agg.name == "count":
275
+ if isinstance(agg, Len) or agg.name == "count":
277
276
  # Chunkwise stage
278
277
  columns, input_ir, partition_info = select(
279
278
  [agg],
@@ -286,7 +285,7 @@ def _decompose_agg_node(
286
285
  # Combined stage
287
286
  (column,) = columns
288
287
  columns, input_ir, partition_info = select(
289
- [Agg(agg.dtype, "sum", None, column)],
288
+ [Agg(agg.dtype, "sum", None, ExecutionContext.FRAME, column)],
290
289
  input_ir,
291
290
  partition_info,
292
291
  names=names,
@@ -295,8 +294,8 @@ def _decompose_agg_node(
295
294
  elif agg.name == "mean":
296
295
  # Chunkwise stage
297
296
  exprs = [
298
- Agg(agg.dtype, "sum", None, *agg.children),
299
- Agg(agg.dtype, "count", None, *agg.children),
297
+ Agg(agg.dtype, "sum", None, ExecutionContext.FRAME, *agg.children),
298
+ Agg(agg.dtype, "count", None, ExecutionContext.FRAME, *agg.children),
300
299
  ]
301
300
  columns, input_ir, partition_info = select(
302
301
  exprs,
@@ -311,7 +310,10 @@ def _decompose_agg_node(
311
310
  BinOp(
312
311
  agg.dtype,
313
312
  plc.binaryop.BinaryOperator.DIV,
314
- *(Agg(agg.dtype, "sum", None, column) for column in columns),
313
+ *(
314
+ Agg(agg.dtype, "sum", None, ExecutionContext.FRAME, column)
315
+ for column in columns
316
+ ),
315
317
  )
316
318
  ]
317
319
  columns, input_ir, partition_info = select(
@@ -348,6 +350,7 @@ def _decompose_agg_node(
348
350
  input_ir.schema,
349
351
  shuffle_on,
350
352
  config_options.executor.shuffle_method,
353
+ config_options.executor.shuffler_insertion_method,
351
354
  input_ir,
352
355
  )
353
356
  partition_info[input_ir] = PartitionInfo(
@@ -357,7 +360,7 @@ def _decompose_agg_node(
357
360
 
358
361
  # Chunkwise stage
359
362
  columns, input_ir, partition_info = select(
360
- [Cast(agg.dtype, agg)],
363
+ [Cast(agg.dtype, True, agg)], # noqa: FBT003
361
364
  input_ir,
362
365
  partition_info,
363
366
  names=names,
@@ -367,7 +370,7 @@ def _decompose_agg_node(
367
370
  # Combined stage
368
371
  (column,) = columns
369
372
  columns, input_ir, partition_info = select(
370
- [Agg(agg.dtype, "sum", None, column)],
373
+ [Agg(agg.dtype, "sum", None, ExecutionContext.FRAME, column)],
371
374
  input_ir,
372
375
  partition_info,
373
376
  names=names,
@@ -386,7 +389,7 @@ def _decompose_agg_node(
386
389
  # Combined stage
387
390
  (column,) = columns
388
391
  columns, input_ir, partition_info = select(
389
- [Agg(agg.dtype, agg.name, agg.options, column)],
392
+ [Agg(agg.dtype, agg.name, agg.options, ExecutionContext.FRAME, column)],
390
393
  input_ir,
391
394
  partition_info,
392
395
  names=names,
@@ -451,7 +454,9 @@ def _decompose_expr_node(
451
454
  if partition_count == 1 or expr.is_pointwise:
452
455
  # Single-partition and pointwise expressions are always supported.
453
456
  return expr, input_ir, partition_info
454
- elif isinstance(expr, Agg) and expr.name in _SUPPORTED_AGGS:
457
+ elif isinstance(expr, Len) or (
458
+ isinstance(expr, Agg) and expr.name in _SUPPORTED_AGGS
459
+ ):
455
460
  # This is a supported Agg expression.
456
461
  return _decompose_agg_node(
457
462
  expr, input_ir, partition_info, config_options, names=names
@@ -515,8 +520,15 @@ def _decompose(
515
520
  *unique_input_irs,
516
521
  )
517
522
  partition_info[input_ir] = PartitionInfo(count=partition_count)
518
- else:
523
+ elif len(unique_input_irs) == 1:
519
524
  input_ir = unique_input_irs[0]
525
+ else:
526
+ # All child IRs were Empty. Use an Empty({}) with
527
+ # count=1 to ensure that scalar expressions still
528
+ # produce one output partition with a single row
529
+ # See: https://github.com/rapidsai/cudf/pull/20409
530
+ input_ir = Empty({})
531
+ partition_info[input_ir] = PartitionInfo(count=1)
520
532
 
521
533
  # Call into class-specific logic to decompose ``expr``
522
534
  return _decompose_expr_node(
@@ -537,6 +549,7 @@ def decompose_expr_graph(
537
549
  config_options: ConfigOptions,
538
550
  row_count_estimate: ColumnStat[int],
539
551
  column_stats: dict[str, ColumnStats],
552
+ unique_names: Generator[str, None, None],
540
553
  ) -> tuple[NamedExpr, IR, MutableMapping[IR, PartitionInfo]]:
541
554
  """
542
555
  Decompose a NamedExpr into stages.
@@ -557,6 +570,8 @@ def decompose_expr_graph(
557
570
  Row-count estimate for the input IR.
558
571
  column_stats
559
572
  Column statistics for the input IR.
573
+ unique_names
574
+ Generator of unique names for temporaries.
560
575
 
561
576
  Returns
562
577
  -------
@@ -581,7 +596,7 @@ def decompose_expr_graph(
581
596
  "input_ir": input_ir,
582
597
  "input_partition_info": partition_info[input_ir],
583
598
  "config_options": config_options,
584
- "unique_names": unique_names((named_expr.name, *input_ir.schema.keys())),
599
+ "unique_names": unique_names,
585
600
  "row_count_estimate": row_count_estimate,
586
601
  "column_stats": column_stats,
587
602
  },
@@ -14,6 +14,7 @@ import pylibcudf as plc
14
14
 
15
15
  from cudf_polars.containers import DataType
16
16
  from cudf_polars.dsl.expr import Agg, BinOp, Col, Len, NamedExpr
17
+ from cudf_polars.dsl.expressions.base import ExecutionContext
17
18
  from cudf_polars.dsl.ir import GroupBy, Select, Slice
18
19
  from cudf_polars.dsl.traversal import traversal
19
20
  from cudf_polars.dsl.utils.naming import unique_names
@@ -95,7 +96,12 @@ def decompose(
95
96
  if isinstance(expr, Len):
96
97
  selection = NamedExpr(name, Col(dtype, name))
97
98
  aggregation = [NamedExpr(name, expr)]
98
- reduction = [NamedExpr(name, Agg(dtype, "sum", None, Col(dtype, name)))]
99
+ reduction = [
100
+ NamedExpr(
101
+ name,
102
+ Agg(dtype, "sum", None, ExecutionContext.GROUPBY, Col(dtype, name)),
103
+ )
104
+ ]
99
105
  return selection, aggregation, reduction, False
100
106
  if isinstance(expr, Agg):
101
107
  if expr.name in ("sum", "count", "min", "max", "n_unique"):
@@ -105,19 +111,32 @@ def decompose(
105
111
  aggfunc = expr.name
106
112
  selection = NamedExpr(name, Col(dtype, name))
107
113
  aggregation = [NamedExpr(name, expr)]
108
- reduction = [NamedExpr(name, Agg(dtype, aggfunc, None, Col(dtype, name)))]
114
+ reduction = [
115
+ NamedExpr(
116
+ name,
117
+ Agg(
118
+ dtype, aggfunc, None, ExecutionContext.GROUPBY, Col(dtype, name)
119
+ ),
120
+ )
121
+ ]
109
122
  return selection, aggregation, reduction, expr.name == "n_unique"
110
123
  elif expr.name == "mean":
111
124
  (child,) = expr.children
112
125
  (sum, count), aggregations, reductions, need_preshuffle = combine(
113
126
  decompose(
114
127
  f"{next(names)}__mean_sum",
115
- Agg(dtype, "sum", None, child),
128
+ Agg(dtype, "sum", None, ExecutionContext.GROUPBY, child),
116
129
  names=names,
117
130
  ),
118
131
  decompose(
119
132
  f"{next(names)}__mean_count",
120
- Agg(DataType(pl.Int32()), "count", False, child), # noqa: FBT003
133
+ Agg(
134
+ DataType(pl.Int32()),
135
+ "count",
136
+ False, # noqa: FBT003
137
+ ExecutionContext.GROUPBY,
138
+ child,
139
+ ),
121
140
  names=names,
122
141
  ),
123
142
  )
@@ -230,6 +249,7 @@ def _(
230
249
  child.schema,
231
250
  ir.keys,
232
251
  config_options.executor.shuffle_method,
252
+ config_options.executor.shuffler_insertion_method,
233
253
  child,
234
254
  )
235
255
  partition_info[child] = PartitionInfo(
@@ -272,6 +292,7 @@ def _(
272
292
  gb_pwise.schema,
273
293
  grouped_keys,
274
294
  config_options.executor.shuffle_method,
295
+ config_options.executor.shuffler_insertion_method,
275
296
  gb_pwise,
276
297
  )
277
298
  partition_info[gb_inter] = PartitionInfo(count=post_aggregation_count)