cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/callback.py +85 -53
  3. cudf_polars/containers/column.py +100 -7
  4. cudf_polars/containers/dataframe.py +16 -24
  5. cudf_polars/dsl/expr.py +3 -1
  6. cudf_polars/dsl/expressions/aggregation.py +3 -3
  7. cudf_polars/dsl/expressions/binaryop.py +2 -2
  8. cudf_polars/dsl/expressions/boolean.py +4 -4
  9. cudf_polars/dsl/expressions/datetime.py +39 -1
  10. cudf_polars/dsl/expressions/literal.py +3 -9
  11. cudf_polars/dsl/expressions/selection.py +2 -2
  12. cudf_polars/dsl/expressions/slicing.py +53 -0
  13. cudf_polars/dsl/expressions/sorting.py +1 -1
  14. cudf_polars/dsl/expressions/string.py +4 -4
  15. cudf_polars/dsl/expressions/unary.py +3 -2
  16. cudf_polars/dsl/ir.py +222 -93
  17. cudf_polars/dsl/nodebase.py +8 -1
  18. cudf_polars/dsl/translate.py +66 -38
  19. cudf_polars/experimental/base.py +18 -12
  20. cudf_polars/experimental/dask_serialize.py +22 -8
  21. cudf_polars/experimental/groupby.py +346 -0
  22. cudf_polars/experimental/io.py +13 -11
  23. cudf_polars/experimental/join.py +318 -0
  24. cudf_polars/experimental/parallel.py +57 -6
  25. cudf_polars/experimental/shuffle.py +194 -0
  26. cudf_polars/testing/plugin.py +23 -34
  27. cudf_polars/typing/__init__.py +33 -2
  28. cudf_polars/utils/config.py +138 -0
  29. cudf_polars/utils/conversion.py +40 -0
  30. cudf_polars/utils/dtypes.py +14 -4
  31. cudf_polars/utils/timer.py +39 -0
  32. cudf_polars/utils/versions.py +4 -3
  33. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info}/METADATA +8 -7
  34. cudf_polars_cu12-25.4.0.dist-info/RECORD +55 -0
  35. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info}/WHEEL +1 -1
  36. cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
  37. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info/licenses}/LICENSE +0 -0
  38. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  """Base class for IR nodes, and utilities."""
@@ -58,6 +58,13 @@ class Node(Generic[T]):
58
58
  """
59
59
  return type(self)(*self._ctor_arguments(children))
60
60
 
61
+ def __reduce__(self):
62
+ """Pickle a Node object."""
63
+ return (
64
+ type(self),
65
+ self._ctor_arguments(self.children),
66
+ )
67
+
61
68
  def get_hashable(self) -> Hashable:
62
69
  """
63
70
  Return a hashable object for the node.
@@ -5,6 +5,7 @@
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
+ import copy
8
9
  import functools
9
10
  import json
10
11
  from contextlib import AbstractContextManager, nullcontext
@@ -23,7 +24,7 @@ import pylibcudf as plc
23
24
  from cudf_polars.dsl import expr, ir
24
25
  from cudf_polars.dsl.to_ast import insert_colrefs
25
26
  from cudf_polars.typing import NodeTraverser
26
- from cudf_polars.utils import dtypes, sorting
27
+ from cudf_polars.utils import config, dtypes, sorting
27
28
 
28
29
  if TYPE_CHECKING:
29
30
  from polars import GPUEngine
@@ -41,13 +42,13 @@ class Translator:
41
42
  ----------
42
43
  visitor
43
44
  Polars NodeTraverser object
44
- config
45
+ engine
45
46
  GPU engine configuration.
46
47
  """
47
48
 
48
- def __init__(self, visitor: NodeTraverser, config: GPUEngine):
49
+ def __init__(self, visitor: NodeTraverser, engine: GPUEngine):
49
50
  self.visitor = visitor
50
- self.config = config
51
+ self.config_options = config.ConfigOptions(copy.deepcopy(engine.config))
51
52
  self.errors: list[Exception] = []
52
53
 
53
54
  def translate_ir(self, *, n: int | None = None) -> ir.IR:
@@ -84,7 +85,7 @@ class Translator:
84
85
  # IR is versioned with major.minor, minor is bumped for backwards
85
86
  # compatible changes (e.g. adding new nodes), major is bumped for
86
87
  # incompatible changes (e.g. renaming nodes).
87
- if (version := self.visitor.version()) >= (5, 1):
88
+ if (version := self.visitor.version()) >= (6, 1):
88
89
  e = NotImplementedError(
89
90
  f"No support for polars IR {version=}"
90
91
  ) # pragma: no cover; no such version for now.
@@ -227,13 +228,15 @@ def _(
227
228
  # TODO: with versioning, rename on the rust side
228
229
  skip_rows, n_rows = n_rows
229
230
 
231
+ if file_options.include_file_paths is not None:
232
+ raise NotImplementedError("No support for including file path in scan")
230
233
  row_index = file_options.row_index
231
234
  return ir.Scan(
232
235
  schema,
233
236
  typ,
234
237
  reader_options,
235
238
  cloud_options,
236
- translator.config.config.copy(),
239
+ translator.config_options,
237
240
  node.paths,
238
241
  with_columns,
239
242
  skip_rows,
@@ -260,7 +263,7 @@ def _(
260
263
  schema,
261
264
  node.df,
262
265
  node.projection,
263
- translator.config.config.copy(),
266
+ translator.config_options,
264
267
  )
265
268
 
266
269
 
@@ -288,6 +291,7 @@ def _(
288
291
  aggs,
289
292
  node.maintain_order,
290
293
  node.options,
294
+ translator.config_options,
291
295
  inp,
292
296
  )
293
297
 
@@ -299,38 +303,12 @@ def _(
299
303
  # Join key dtypes are dependent on the schema of the left and
300
304
  # right inputs, so these must be translated with the relevant
301
305
  # input active.
302
- def adjust_literal_dtype(literal: expr.Literal) -> expr.Literal:
303
- if literal.dtype.id() == plc.types.TypeId.INT32:
304
- plc_int64 = plc.types.DataType(plc.types.TypeId.INT64)
305
- return expr.Literal(
306
- plc_int64,
307
- pa.scalar(literal.value.as_py(), type=plc.interop.to_arrow(plc_int64)),
308
- )
309
- return literal
310
-
311
- def maybe_adjust_binop(e) -> expr.Expr:
312
- if isinstance(e.value, expr.BinOp):
313
- left, right = e.value.children
314
- if isinstance(left, expr.Col) and isinstance(right, expr.Literal):
315
- e.value.children = (left, adjust_literal_dtype(right))
316
- elif isinstance(left, expr.Literal) and isinstance(right, expr.Col):
317
- e.value.children = (adjust_literal_dtype(left), right)
318
- return e
319
-
320
- def translate_expr_and_maybe_fix_binop_args(translator, exprs):
321
- return [
322
- maybe_adjust_binop(translate_named_expr(translator, n=e)) for e in exprs
323
- ]
324
-
325
306
  with set_node(translator.visitor, node.input_left):
326
307
  inp_left = translator.translate_ir(n=None)
327
- # TODO: There's bug in the polars type coercion phase. Use
328
- # translate_named_expr directly once it is resolved.
329
- # Tracking issue: https://github.com/pola-rs/polars/issues/20935
330
- left_on = translate_expr_and_maybe_fix_binop_args(translator, node.left_on)
308
+ left_on = [translate_named_expr(translator, n=e) for e in node.left_on]
331
309
  with set_node(translator.visitor, node.input_right):
332
310
  inp_right = translator.translate_ir(n=None)
333
- right_on = translate_expr_and_maybe_fix_binop_args(translator, node.right_on)
311
+ right_on = [translate_named_expr(translator, n=e) for e in node.right_on]
334
312
 
335
313
  if (how := node.options[0]) in {
336
314
  "Inner",
@@ -341,7 +319,15 @@ def _(
341
319
  "Semi",
342
320
  "Anti",
343
321
  }:
344
- return ir.Join(schema, left_on, right_on, node.options, inp_left, inp_right)
322
+ return ir.Join(
323
+ schema,
324
+ left_on,
325
+ right_on,
326
+ node.options,
327
+ translator.config_options,
328
+ inp_left,
329
+ inp_right,
330
+ )
345
331
  else:
346
332
  how, op1, op2 = node.options[0]
347
333
  if how != "IEJoin":
@@ -463,6 +449,21 @@ def _(
463
449
  return ir.Projection(schema, translator.translate_ir(n=node.input))
464
450
 
465
451
 
452
+ @_translate_ir.register
453
+ def _(
454
+ node: pl_ir.MergeSorted, translator: Translator, schema: dict[str, plc.DataType]
455
+ ) -> ir.IR:
456
+ key = node.key
457
+ inp_left = translator.translate_ir(n=node.input_left)
458
+ inp_right = translator.translate_ir(n=node.input_right)
459
+ return ir.MergeSorted(
460
+ schema,
461
+ key,
462
+ inp_left,
463
+ inp_right,
464
+ )
465
+
466
+
466
467
  @_translate_ir.register
467
468
  def _(
468
469
  node: pl_ir.MapFunction, translator: Translator, schema: dict[str, plc.DataType]
@@ -472,7 +473,6 @@ def _(
472
473
  schema,
473
474
  name,
474
475
  options,
475
- # TODO: merge_sorted breaks this pattern
476
476
  translator.translate_ir(n=node.input),
477
477
  )
478
478
 
@@ -623,6 +623,17 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
623
623
  )
624
624
  elif name == "pow":
625
625
  return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children)
626
+ elif name in "top_k":
627
+ (col, k) = children
628
+ assert isinstance(k, expr.Literal)
629
+ (descending,) = options
630
+ return expr.Slice(
631
+ dtype,
632
+ 0,
633
+ k.value.as_py(),
634
+ expr.Sort(dtype, (False, True, not descending), col),
635
+ )
636
+
626
637
  return expr.UnaryFunction(dtype, name, options, *children)
627
638
  raise NotImplementedError(
628
639
  f"No handler for Expr function node with {name=}"
@@ -651,7 +662,10 @@ def _(node: pl_expr.Window, translator: Translator, dtype: plc.DataType) -> expr
651
662
  @_translate_expr.register
652
663
  def _(node: pl_expr.Literal, translator: Translator, dtype: plc.DataType) -> expr.Expr:
653
664
  if isinstance(node.value, plrs.PySeries):
654
- return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value))
665
+ data = pl.Series._from_pyseries(node.value).to_arrow()
666
+ return expr.LiteralColumn(
667
+ dtype, data.cast(dtypes.downcast_arrow_lists(data.type))
668
+ )
655
669
  value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype))
656
670
  return expr.Literal(dtype, value)
657
671
 
@@ -673,6 +687,20 @@ def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr
673
687
  )
674
688
 
675
689
 
690
+ @_translate_expr.register
691
+ def _(node: pl_expr.Slice, translator: Translator, dtype: plc.DataType) -> expr.Expr:
692
+ offset = translator.translate_expr(n=node.offset)
693
+ length = translator.translate_expr(n=node.length)
694
+ assert isinstance(offset, expr.Literal)
695
+ assert isinstance(length, expr.Literal)
696
+ return expr.Slice(
697
+ dtype,
698
+ offset.value.as_py(),
699
+ length.value.as_py(),
700
+ translator.translate_expr(n=node.input),
701
+ )
702
+
703
+
676
704
  @_translate_expr.register
677
705
  def _(node: pl_expr.Gather, translator: Translator, dtype: plc.DataType) -> expr.Expr:
678
706
  return expr.Gather(
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Multi-partition base classes."""
4
4
 
@@ -9,23 +9,29 @@ from typing import TYPE_CHECKING
9
9
  from cudf_polars.dsl.ir import Union
10
10
 
11
11
  if TYPE_CHECKING:
12
- from collections.abc import Iterator, Sequence
12
+ from collections.abc import Iterator
13
13
 
14
14
  from cudf_polars.containers import DataFrame
15
+ from cudf_polars.dsl.expr import NamedExpr
15
16
  from cudf_polars.dsl.nodebase import Node
16
17
 
17
18
 
18
19
  class PartitionInfo:
19
- """
20
- Partitioning information.
21
-
22
- This class only tracks the partition count (for now).
23
- """
24
-
25
- __slots__ = ("count",)
26
-
27
- def __init__(self, count: int):
20
+ """Partitioning information."""
21
+
22
+ __slots__ = ("count", "partitioned_on")
23
+ count: int
24
+ """Partition count."""
25
+ partitioned_on: tuple[NamedExpr, ...]
26
+ """Columns the data is hash-partitioned on."""
27
+
28
+ def __init__(
29
+ self,
30
+ count: int,
31
+ partitioned_on: tuple[NamedExpr, ...] = (),
32
+ ):
28
33
  self.count = count
34
+ self.partitioned_on = partitioned_on
29
35
 
30
36
  def keys(self, node: Node) -> Iterator[tuple[str, int]]:
31
37
  """Return the partitioned keys for a given node."""
@@ -38,6 +44,6 @@ def get_key_name(node: Node) -> str:
38
44
  return f"{type(node).__name__.lower()}-{hash(node)}"
39
45
 
40
46
 
41
- def _concat(dfs: Sequence[DataFrame]) -> DataFrame:
47
+ def _concat(*dfs: DataFrame) -> DataFrame:
42
48
  # Concatenate a sequence of DataFrames vertically
43
49
  return Union.do_evaluate(None, *dfs)
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  """Dask serialization."""
@@ -12,7 +12,7 @@ from distributed.utils import log_errors
12
12
  import pylibcudf as plc
13
13
  import rmm
14
14
 
15
- from cudf_polars.containers import DataFrame
15
+ from cudf_polars.containers import Column, DataFrame
16
16
 
17
17
  __all__ = ["register"]
18
18
 
@@ -20,8 +20,8 @@ __all__ = ["register"]
20
20
  def register() -> None:
21
21
  """Register dask serialization routines for DataFrames."""
22
22
 
23
- @cuda_serialize.register(DataFrame)
24
- def _(x: DataFrame):
23
+ @cuda_serialize.register((Column, DataFrame))
24
+ def _(x: DataFrame | Column):
25
25
  with log_errors():
26
26
  header, frames = x.serialize()
27
27
  return header, list(frames) # Dask expect a list of frames
@@ -29,11 +29,17 @@ def register() -> None:
29
29
  @cuda_deserialize.register(DataFrame)
30
30
  def _(header, frames):
31
31
  with log_errors():
32
- assert len(frames) == 2
33
- return DataFrame.deserialize(header, tuple(frames))
32
+ metadata, gpudata = frames
33
+ return DataFrame.deserialize(header, (metadata, plc.gpumemoryview(gpudata)))
34
34
 
35
- @dask_serialize.register(DataFrame)
36
- def _(x: DataFrame):
35
+ @cuda_deserialize.register(Column)
36
+ def _(header, frames):
37
+ with log_errors():
38
+ metadata, gpudata = frames
39
+ return Column.deserialize(header, (metadata, plc.gpumemoryview(gpudata)))
40
+
41
+ @dask_serialize.register((Column, DataFrame))
42
+ def _(x: DataFrame | Column):
37
43
  with log_errors():
38
44
  header, (metadata, gpudata) = x.serialize()
39
45
 
@@ -57,3 +63,11 @@ def register() -> None:
57
63
  # Copy the second frame (the gpudata in host memory) back to the gpu
58
64
  frames = frames[0], plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1]))
59
65
  return DataFrame.deserialize(header, frames)
66
+
67
+ @dask_deserialize.register(Column)
68
+ def _(header, frames) -> Column:
69
+ with log_errors():
70
+ assert len(frames) == 2
71
+ # Copy the second frame (the gpudata in host memory) back to the gpu
72
+ frames = frames[0], plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1]))
73
+ return Column.deserialize(header, frames)
@@ -0,0 +1,346 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Parallel GroupBy Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import itertools
8
+ import uuid
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ import pylibcudf as plc
12
+
13
+ from cudf_polars.dsl.expr import (
14
+ Agg,
15
+ BinOp,
16
+ Cast,
17
+ Col,
18
+ Len,
19
+ Literal,
20
+ NamedExpr,
21
+ UnaryFunction,
22
+ )
23
+ from cudf_polars.dsl.ir import GroupBy, Select
24
+ from cudf_polars.dsl.traversal import traversal
25
+ from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
26
+ from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
27
+ from cudf_polars.experimental.shuffle import Shuffle
28
+
29
+ if TYPE_CHECKING:
30
+ from collections.abc import MutableMapping
31
+
32
+ from cudf_polars.dsl.expr import Expr
33
+ from cudf_polars.dsl.ir import IR
34
+ from cudf_polars.experimental.parallel import LowerIRTransformer
35
+
36
+
37
+ # Supported multi-partition aggregations
38
+ _GB_AGG_SUPPORTED = ("sum", "count", "mean", "min", "max")
39
+
40
+
41
+ def combine(
42
+ *decompositions: tuple[NamedExpr, list[NamedExpr], list[NamedExpr]],
43
+ ) -> tuple[list[NamedExpr], list[NamedExpr], list[NamedExpr]]:
44
+ """
45
+ Combine multiple groupby-aggregation decompositions.
46
+
47
+ Parameters
48
+ ----------
49
+ decompositions
50
+ Packed sequence of `decompose` results.
51
+
52
+ Returns
53
+ -------
54
+ Unified groupby-aggregation decomposition.
55
+ """
56
+ selections, aggregations, reductions = zip(*decompositions, strict=True)
57
+ assert all(isinstance(ne, NamedExpr) for ne in selections)
58
+ return (
59
+ list(selections),
60
+ list(itertools.chain.from_iterable(aggregations)),
61
+ list(itertools.chain.from_iterable(reductions)),
62
+ )
63
+
64
+
65
+ def decompose(
66
+ name: str, expr: Expr
67
+ ) -> tuple[NamedExpr, list[NamedExpr], list[NamedExpr]]:
68
+ """
69
+ Decompose a groupby-aggregation expression.
70
+
71
+ Parameters
72
+ ----------
73
+ name
74
+ Output schema name.
75
+ expr
76
+ The aggregation expression for a single column.
77
+
78
+ Returns
79
+ -------
80
+ NamedExpr
81
+ The expression selecting the *output* column or columns.
82
+ list[NamedExpr]
83
+ The initial aggregation expressions.
84
+ list[NamedExpr]
85
+ The reduction expressions.
86
+ """
87
+ dtype = expr.dtype
88
+ expr = expr.children[0] if isinstance(expr, Cast) else expr
89
+
90
+ unary_op: list[Any] = []
91
+ if isinstance(expr, UnaryFunction) and expr.is_pointwise:
92
+ # TODO: Handle multiple/sequential unary ops
93
+ unary_op = [expr.name, expr.options]
94
+ expr = expr.children[0]
95
+
96
+ def _wrap_unary(select):
97
+ # Helper function to wrap the final selection
98
+ # in a UnaryFunction (when necessary)
99
+ if unary_op:
100
+ return UnaryFunction(select.dtype, *unary_op, select)
101
+ return select
102
+
103
+ if isinstance(expr, Len):
104
+ selection = NamedExpr(name, _wrap_unary(Col(dtype, name)))
105
+ aggregation = [NamedExpr(name, expr)]
106
+ reduction = [
107
+ NamedExpr(
108
+ name,
109
+ # Sum reduction may require casting.
110
+ # Do it for all cases to be safe (for now)
111
+ Cast(dtype, Agg(dtype, "sum", None, Col(dtype, name))),
112
+ )
113
+ ]
114
+ return selection, aggregation, reduction
115
+ if isinstance(expr, Agg):
116
+ if expr.name in ("sum", "count", "min", "max"):
117
+ if expr.name in ("sum", "count"):
118
+ aggfunc = "sum"
119
+ else:
120
+ aggfunc = expr.name
121
+ selection = NamedExpr(name, _wrap_unary(Col(dtype, name)))
122
+ aggregation = [NamedExpr(name, expr)]
123
+ reduction = [
124
+ NamedExpr(
125
+ name,
126
+ # Sum reduction may require casting.
127
+ # Do it for all cases to be safe (for now)
128
+ Cast(dtype, Agg(dtype, aggfunc, None, Col(dtype, name))),
129
+ )
130
+ ]
131
+ return selection, aggregation, reduction
132
+ elif expr.name == "mean":
133
+ (child,) = expr.children
134
+ token = str(uuid.uuid4().hex) # prevent collisions with user's names
135
+ (sum, count), aggregations, reductions = combine(
136
+ decompose(f"{name}__mean_sum_{token}", Agg(dtype, "sum", None, child)),
137
+ decompose(f"{name}__mean_count_{token}", Len(dtype)),
138
+ )
139
+ selection = NamedExpr(
140
+ name,
141
+ _wrap_unary(
142
+ BinOp(
143
+ dtype, plc.binaryop.BinaryOperator.DIV, sum.value, count.value
144
+ )
145
+ ),
146
+ )
147
+ return selection, aggregations, reductions
148
+ else:
149
+ raise NotImplementedError(
150
+ "GroupBy does not support multiple partitions "
151
+ f"for this aggregation type:\n{type(expr)}\n"
152
+ f"Only {_GB_AGG_SUPPORTED} are supported."
153
+ )
154
+ elif isinstance(expr, BinOp):
155
+ # The expectation is that each operand of the BinOp is decomposable.
156
+ # We can then combine the decompositions of the operands to form the
157
+ # decomposition of the BinOp.
158
+ (left, right) = expr.children
159
+ token = str(uuid.uuid4().hex) # prevent collisions with user's names
160
+ (left_selection, right_selection), aggregations, reductions = combine(
161
+ decompose(f"{name}__left_{token}", left),
162
+ decompose(f"{name}__right_{token}", right),
163
+ )
164
+
165
+ selection = NamedExpr(
166
+ name,
167
+ _wrap_unary(
168
+ BinOp(dtype, expr.op, left_selection.value, right_selection.value)
169
+ ),
170
+ )
171
+ return selection, aggregations, reductions
172
+
173
+ elif isinstance(expr, Literal):
174
+ selection = NamedExpr(name, _wrap_unary(Col(dtype, name)))
175
+ aggregation = []
176
+ reduction = [NamedExpr(name, expr)]
177
+ return selection, aggregation, reduction
178
+
179
+ else: # pragma: no cover
180
+ # Unsupported expression
181
+ raise NotImplementedError(
182
+ f"GroupBy does not support multiple partitions for this expression:\n{expr}"
183
+ )
184
+
185
+
186
+ @lower_ir_node.register(GroupBy)
187
+ def _(
188
+ ir: GroupBy, rec: LowerIRTransformer
189
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
190
+ # Extract child partitioning
191
+ child, partition_info = rec(ir.children[0])
192
+
193
+ # Handle single-partition case
194
+ if partition_info[child].count == 1:
195
+ single_part_node = ir.reconstruct([child])
196
+ partition_info[single_part_node] = partition_info[child]
197
+ return single_part_node, partition_info
198
+
199
+ # Check group-by keys
200
+ if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.keys])):
201
+ raise NotImplementedError(
202
+ f"GroupBy does not support multiple partitions for keys:\n{ir.keys}"
203
+ ) # pragma: no cover
204
+
205
+ # Check if we are dealing with any high-cardinality columns
206
+ post_aggregation_count = 1 # Default tree reduction
207
+ groupby_key_columns = [ne.name for ne in ir.keys]
208
+ cardinality_factor = {
209
+ c: min(f, 1.0)
210
+ for c, f in ir.config_options.get(
211
+ "executor_options.cardinality_factor", default={}
212
+ ).items()
213
+ if c in groupby_key_columns
214
+ }
215
+ if cardinality_factor:
216
+ # The `cardinality_factor` dictionary can be used
217
+ # to specify a mapping between column names and
218
+ # cardinality "factors". Each factor estimates the
219
+ # fractional number of unique values in the column.
220
+ # Each value should be in the range (0, 1].
221
+ child_count = partition_info[child].count
222
+ post_aggregation_count = max(
223
+ int(max(cardinality_factor.values()) * child_count),
224
+ 1,
225
+ )
226
+
227
+ # Decompose the aggregation requests into three distinct phases
228
+ selection_exprs, piecewise_exprs, reduction_exprs = combine(
229
+ *(decompose(agg.name, agg.value) for agg in ir.agg_requests)
230
+ )
231
+
232
+ # Partition-wise groupby operation
233
+ pwise_schema = {k.name: k.value.dtype for k in ir.keys} | {
234
+ k.name: k.value.dtype for k in piecewise_exprs
235
+ }
236
+ gb_pwise = GroupBy(
237
+ pwise_schema,
238
+ ir.keys,
239
+ piecewise_exprs,
240
+ ir.maintain_order,
241
+ ir.options,
242
+ ir.config_options,
243
+ child,
244
+ )
245
+ child_count = partition_info[child].count
246
+ partition_info[gb_pwise] = PartitionInfo(count=child_count)
247
+
248
+ # Add Shuffle node if necessary
249
+ gb_inter: GroupBy | Shuffle = gb_pwise
250
+ if post_aggregation_count > 1:
251
+ if ir.maintain_order: # pragma: no cover
252
+ raise NotImplementedError(
253
+ "maintain_order not supported for multiple output partitions."
254
+ )
255
+
256
+ gb_inter = Shuffle(
257
+ pwise_schema,
258
+ ir.keys,
259
+ ir.config_options,
260
+ gb_pwise,
261
+ )
262
+ partition_info[gb_inter] = PartitionInfo(count=post_aggregation_count)
263
+
264
+ # Tree reduction if post_aggregation_count==1
265
+ # (Otherwise, this is another partition-wise op)
266
+ gb_reduce = GroupBy(
267
+ {k.name: k.value.dtype for k in ir.keys}
268
+ | {k.name: k.value.dtype for k in reduction_exprs},
269
+ ir.keys,
270
+ reduction_exprs,
271
+ ir.maintain_order,
272
+ ir.options,
273
+ ir.config_options,
274
+ gb_inter,
275
+ )
276
+ partition_info[gb_reduce] = PartitionInfo(count=post_aggregation_count)
277
+
278
+ # Final Select phase
279
+ aggregated = {ne.name: ne for ne in selection_exprs}
280
+ new_node = Select(
281
+ ir.schema,
282
+ [
283
+ # Select the aggregated data or the original column
284
+ aggregated.get(name, NamedExpr(name, Col(dtype, name)))
285
+ for name, dtype in ir.schema.items()
286
+ ],
287
+ False, # noqa: FBT003
288
+ gb_reduce,
289
+ )
290
+ partition_info[new_node] = PartitionInfo(count=post_aggregation_count)
291
+ return new_node, partition_info
292
+
293
+
294
+ def _tree_node(do_evaluate, nbatch, *args):
295
+ return do_evaluate(*args[nbatch:], _concat(*args[:nbatch]))
296
+
297
+
298
+ @generate_ir_tasks.register(GroupBy)
299
+ def _(
300
+ ir: GroupBy, partition_info: MutableMapping[IR, PartitionInfo]
301
+ ) -> MutableMapping[Any, Any]:
302
+ (child,) = ir.children
303
+ child_count = partition_info[child].count
304
+ child_name = get_key_name(child)
305
+ output_count = partition_info[ir].count
306
+
307
+ if output_count == child_count:
308
+ return {
309
+ key: (
310
+ ir.do_evaluate,
311
+ *ir._non_child_args,
312
+ (child_name, i),
313
+ )
314
+ for i, key in enumerate(partition_info[ir].keys(ir))
315
+ }
316
+ elif output_count != 1: # pragma: no cover
317
+ raise ValueError(f"Expected single partition, got {output_count}")
318
+
319
+ # Simple N-ary tree reduction
320
+ j = 0
321
+ n_ary = ir.config_options.get("executor_options.groupby_n_ary", default=32)
322
+ graph: MutableMapping[Any, Any] = {}
323
+ name = get_key_name(ir)
324
+ keys: list[Any] = [(child_name, i) for i in range(child_count)]
325
+ while len(keys) > n_ary:
326
+ new_keys: list[Any] = []
327
+ for i, k in enumerate(range(0, len(keys), n_ary)):
328
+ batch = keys[k : k + n_ary]
329
+ graph[(name, j, i)] = (
330
+ _tree_node,
331
+ ir.do_evaluate,
332
+ len(batch),
333
+ *batch,
334
+ *ir._non_child_args,
335
+ )
336
+ new_keys.append((name, j, i))
337
+ j += 1
338
+ keys = new_keys
339
+ graph[(name, 0)] = (
340
+ _tree_node,
341
+ ir.do_evaluate,
342
+ len(keys),
343
+ *keys,
344
+ *ir._non_child_args,
345
+ )
346
+ return graph