cudf-polars-cu12 25.2.1__py3-none-any.whl → 25.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +85 -53
- cudf_polars/containers/column.py +100 -7
- cudf_polars/containers/dataframe.py +16 -24
- cudf_polars/dsl/expr.py +3 -1
- cudf_polars/dsl/expressions/aggregation.py +3 -3
- cudf_polars/dsl/expressions/binaryop.py +2 -2
- cudf_polars/dsl/expressions/boolean.py +4 -4
- cudf_polars/dsl/expressions/datetime.py +39 -1
- cudf_polars/dsl/expressions/literal.py +3 -9
- cudf_polars/dsl/expressions/selection.py +2 -2
- cudf_polars/dsl/expressions/slicing.py +53 -0
- cudf_polars/dsl/expressions/sorting.py +1 -1
- cudf_polars/dsl/expressions/string.py +4 -4
- cudf_polars/dsl/expressions/unary.py +3 -2
- cudf_polars/dsl/ir.py +222 -93
- cudf_polars/dsl/nodebase.py +8 -1
- cudf_polars/dsl/translate.py +66 -38
- cudf_polars/experimental/base.py +18 -12
- cudf_polars/experimental/dask_serialize.py +22 -8
- cudf_polars/experimental/groupby.py +346 -0
- cudf_polars/experimental/io.py +13 -11
- cudf_polars/experimental/join.py +318 -0
- cudf_polars/experimental/parallel.py +57 -6
- cudf_polars/experimental/shuffle.py +194 -0
- cudf_polars/testing/plugin.py +23 -34
- cudf_polars/typing/__init__.py +33 -2
- cudf_polars/utils/config.py +138 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +14 -4
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +4 -3
- {cudf_polars_cu12-25.2.1.dist-info → cudf_polars_cu12-25.4.0.dist-info}/METADATA +8 -7
- cudf_polars_cu12-25.4.0.dist-info/RECORD +55 -0
- {cudf_polars_cu12-25.2.1.dist-info → cudf_polars_cu12-25.4.0.dist-info}/WHEEL +1 -1
- cudf_polars_cu12-25.2.1.dist-info/RECORD +0 -48
- {cudf_polars_cu12-25.2.1.dist-info → cudf_polars_cu12-25.4.0.dist-info/licenses}/LICENSE +0 -0
- {cudf_polars_cu12-25.2.1.dist-info → cudf_polars_cu12-25.4.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/nodebase.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
"""Base class for IR nodes, and utilities."""
|
|
@@ -58,6 +58,13 @@ class Node(Generic[T]):
|
|
|
58
58
|
"""
|
|
59
59
|
return type(self)(*self._ctor_arguments(children))
|
|
60
60
|
|
|
61
|
+
def __reduce__(self):
|
|
62
|
+
"""Pickle a Node object."""
|
|
63
|
+
return (
|
|
64
|
+
type(self),
|
|
65
|
+
self._ctor_arguments(self.children),
|
|
66
|
+
)
|
|
67
|
+
|
|
61
68
|
def get_hashable(self) -> Hashable:
|
|
62
69
|
"""
|
|
63
70
|
Return a hashable object for the node.
|
cudf_polars/dsl/translate.py
CHANGED
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
+
import copy
|
|
8
9
|
import functools
|
|
9
10
|
import json
|
|
10
11
|
from contextlib import AbstractContextManager, nullcontext
|
|
@@ -23,7 +24,7 @@ import pylibcudf as plc
|
|
|
23
24
|
from cudf_polars.dsl import expr, ir
|
|
24
25
|
from cudf_polars.dsl.to_ast import insert_colrefs
|
|
25
26
|
from cudf_polars.typing import NodeTraverser
|
|
26
|
-
from cudf_polars.utils import dtypes, sorting
|
|
27
|
+
from cudf_polars.utils import config, dtypes, sorting
|
|
27
28
|
|
|
28
29
|
if TYPE_CHECKING:
|
|
29
30
|
from polars import GPUEngine
|
|
@@ -41,13 +42,13 @@ class Translator:
|
|
|
41
42
|
----------
|
|
42
43
|
visitor
|
|
43
44
|
Polars NodeTraverser object
|
|
44
|
-
|
|
45
|
+
engine
|
|
45
46
|
GPU engine configuration.
|
|
46
47
|
"""
|
|
47
48
|
|
|
48
|
-
def __init__(self, visitor: NodeTraverser,
|
|
49
|
+
def __init__(self, visitor: NodeTraverser, engine: GPUEngine):
|
|
49
50
|
self.visitor = visitor
|
|
50
|
-
self.
|
|
51
|
+
self.config_options = config.ConfigOptions(copy.deepcopy(engine.config))
|
|
51
52
|
self.errors: list[Exception] = []
|
|
52
53
|
|
|
53
54
|
def translate_ir(self, *, n: int | None = None) -> ir.IR:
|
|
@@ -84,7 +85,7 @@ class Translator:
|
|
|
84
85
|
# IR is versioned with major.minor, minor is bumped for backwards
|
|
85
86
|
# compatible changes (e.g. adding new nodes), major is bumped for
|
|
86
87
|
# incompatible changes (e.g. renaming nodes).
|
|
87
|
-
if (version := self.visitor.version()) >= (
|
|
88
|
+
if (version := self.visitor.version()) >= (6, 1):
|
|
88
89
|
e = NotImplementedError(
|
|
89
90
|
f"No support for polars IR {version=}"
|
|
90
91
|
) # pragma: no cover; no such version for now.
|
|
@@ -227,13 +228,15 @@ def _(
|
|
|
227
228
|
# TODO: with versioning, rename on the rust side
|
|
228
229
|
skip_rows, n_rows = n_rows
|
|
229
230
|
|
|
231
|
+
if file_options.include_file_paths is not None:
|
|
232
|
+
raise NotImplementedError("No support for including file path in scan")
|
|
230
233
|
row_index = file_options.row_index
|
|
231
234
|
return ir.Scan(
|
|
232
235
|
schema,
|
|
233
236
|
typ,
|
|
234
237
|
reader_options,
|
|
235
238
|
cloud_options,
|
|
236
|
-
translator.
|
|
239
|
+
translator.config_options,
|
|
237
240
|
node.paths,
|
|
238
241
|
with_columns,
|
|
239
242
|
skip_rows,
|
|
@@ -260,7 +263,7 @@ def _(
|
|
|
260
263
|
schema,
|
|
261
264
|
node.df,
|
|
262
265
|
node.projection,
|
|
263
|
-
translator.
|
|
266
|
+
translator.config_options,
|
|
264
267
|
)
|
|
265
268
|
|
|
266
269
|
|
|
@@ -288,6 +291,7 @@ def _(
|
|
|
288
291
|
aggs,
|
|
289
292
|
node.maintain_order,
|
|
290
293
|
node.options,
|
|
294
|
+
translator.config_options,
|
|
291
295
|
inp,
|
|
292
296
|
)
|
|
293
297
|
|
|
@@ -299,38 +303,12 @@ def _(
|
|
|
299
303
|
# Join key dtypes are dependent on the schema of the left and
|
|
300
304
|
# right inputs, so these must be translated with the relevant
|
|
301
305
|
# input active.
|
|
302
|
-
def adjust_literal_dtype(literal: expr.Literal) -> expr.Literal:
|
|
303
|
-
if literal.dtype.id() == plc.types.TypeId.INT32:
|
|
304
|
-
plc_int64 = plc.types.DataType(plc.types.TypeId.INT64)
|
|
305
|
-
return expr.Literal(
|
|
306
|
-
plc_int64,
|
|
307
|
-
pa.scalar(literal.value.as_py(), type=plc.interop.to_arrow(plc_int64)),
|
|
308
|
-
)
|
|
309
|
-
return literal
|
|
310
|
-
|
|
311
|
-
def maybe_adjust_binop(e) -> expr.Expr:
|
|
312
|
-
if isinstance(e.value, expr.BinOp):
|
|
313
|
-
left, right = e.value.children
|
|
314
|
-
if isinstance(left, expr.Col) and isinstance(right, expr.Literal):
|
|
315
|
-
e.value.children = (left, adjust_literal_dtype(right))
|
|
316
|
-
elif isinstance(left, expr.Literal) and isinstance(right, expr.Col):
|
|
317
|
-
e.value.children = (adjust_literal_dtype(left), right)
|
|
318
|
-
return e
|
|
319
|
-
|
|
320
|
-
def translate_expr_and_maybe_fix_binop_args(translator, exprs):
|
|
321
|
-
return [
|
|
322
|
-
maybe_adjust_binop(translate_named_expr(translator, n=e)) for e in exprs
|
|
323
|
-
]
|
|
324
|
-
|
|
325
306
|
with set_node(translator.visitor, node.input_left):
|
|
326
307
|
inp_left = translator.translate_ir(n=None)
|
|
327
|
-
|
|
328
|
-
# translate_named_expr directly once it is resolved.
|
|
329
|
-
# Tracking issue: https://github.com/pola-rs/polars/issues/20935
|
|
330
|
-
left_on = translate_expr_and_maybe_fix_binop_args(translator, node.left_on)
|
|
308
|
+
left_on = [translate_named_expr(translator, n=e) for e in node.left_on]
|
|
331
309
|
with set_node(translator.visitor, node.input_right):
|
|
332
310
|
inp_right = translator.translate_ir(n=None)
|
|
333
|
-
right_on =
|
|
311
|
+
right_on = [translate_named_expr(translator, n=e) for e in node.right_on]
|
|
334
312
|
|
|
335
313
|
if (how := node.options[0]) in {
|
|
336
314
|
"Inner",
|
|
@@ -341,7 +319,15 @@ def _(
|
|
|
341
319
|
"Semi",
|
|
342
320
|
"Anti",
|
|
343
321
|
}:
|
|
344
|
-
return ir.Join(
|
|
322
|
+
return ir.Join(
|
|
323
|
+
schema,
|
|
324
|
+
left_on,
|
|
325
|
+
right_on,
|
|
326
|
+
node.options,
|
|
327
|
+
translator.config_options,
|
|
328
|
+
inp_left,
|
|
329
|
+
inp_right,
|
|
330
|
+
)
|
|
345
331
|
else:
|
|
346
332
|
how, op1, op2 = node.options[0]
|
|
347
333
|
if how != "IEJoin":
|
|
@@ -463,6 +449,21 @@ def _(
|
|
|
463
449
|
return ir.Projection(schema, translator.translate_ir(n=node.input))
|
|
464
450
|
|
|
465
451
|
|
|
452
|
+
@_translate_ir.register
|
|
453
|
+
def _(
|
|
454
|
+
node: pl_ir.MergeSorted, translator: Translator, schema: dict[str, plc.DataType]
|
|
455
|
+
) -> ir.IR:
|
|
456
|
+
key = node.key
|
|
457
|
+
inp_left = translator.translate_ir(n=node.input_left)
|
|
458
|
+
inp_right = translator.translate_ir(n=node.input_right)
|
|
459
|
+
return ir.MergeSorted(
|
|
460
|
+
schema,
|
|
461
|
+
key,
|
|
462
|
+
inp_left,
|
|
463
|
+
inp_right,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
|
|
466
467
|
@_translate_ir.register
|
|
467
468
|
def _(
|
|
468
469
|
node: pl_ir.MapFunction, translator: Translator, schema: dict[str, plc.DataType]
|
|
@@ -472,7 +473,6 @@ def _(
|
|
|
472
473
|
schema,
|
|
473
474
|
name,
|
|
474
475
|
options,
|
|
475
|
-
# TODO: merge_sorted breaks this pattern
|
|
476
476
|
translator.translate_ir(n=node.input),
|
|
477
477
|
)
|
|
478
478
|
|
|
@@ -623,6 +623,17 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
|
|
|
623
623
|
)
|
|
624
624
|
elif name == "pow":
|
|
625
625
|
return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children)
|
|
626
|
+
elif name in "top_k":
|
|
627
|
+
(col, k) = children
|
|
628
|
+
assert isinstance(k, expr.Literal)
|
|
629
|
+
(descending,) = options
|
|
630
|
+
return expr.Slice(
|
|
631
|
+
dtype,
|
|
632
|
+
0,
|
|
633
|
+
k.value.as_py(),
|
|
634
|
+
expr.Sort(dtype, (False, True, not descending), col),
|
|
635
|
+
)
|
|
636
|
+
|
|
626
637
|
return expr.UnaryFunction(dtype, name, options, *children)
|
|
627
638
|
raise NotImplementedError(
|
|
628
639
|
f"No handler for Expr function node with {name=}"
|
|
@@ -651,7 +662,10 @@ def _(node: pl_expr.Window, translator: Translator, dtype: plc.DataType) -> expr
|
|
|
651
662
|
@_translate_expr.register
|
|
652
663
|
def _(node: pl_expr.Literal, translator: Translator, dtype: plc.DataType) -> expr.Expr:
|
|
653
664
|
if isinstance(node.value, plrs.PySeries):
|
|
654
|
-
|
|
665
|
+
data = pl.Series._from_pyseries(node.value).to_arrow()
|
|
666
|
+
return expr.LiteralColumn(
|
|
667
|
+
dtype, data.cast(dtypes.downcast_arrow_lists(data.type))
|
|
668
|
+
)
|
|
655
669
|
value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype))
|
|
656
670
|
return expr.Literal(dtype, value)
|
|
657
671
|
|
|
@@ -673,6 +687,20 @@ def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr
|
|
|
673
687
|
)
|
|
674
688
|
|
|
675
689
|
|
|
690
|
+
@_translate_expr.register
|
|
691
|
+
def _(node: pl_expr.Slice, translator: Translator, dtype: plc.DataType) -> expr.Expr:
|
|
692
|
+
offset = translator.translate_expr(n=node.offset)
|
|
693
|
+
length = translator.translate_expr(n=node.length)
|
|
694
|
+
assert isinstance(offset, expr.Literal)
|
|
695
|
+
assert isinstance(length, expr.Literal)
|
|
696
|
+
return expr.Slice(
|
|
697
|
+
dtype,
|
|
698
|
+
offset.value.as_py(),
|
|
699
|
+
length.value.as_py(),
|
|
700
|
+
translator.translate_expr(n=node.input),
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
|
|
676
704
|
@_translate_expr.register
|
|
677
705
|
def _(node: pl_expr.Gather, translator: Translator, dtype: plc.DataType) -> expr.Expr:
|
|
678
706
|
return expr.Gather(
|
cudf_polars/experimental/base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Multi-partition base classes."""
|
|
4
4
|
|
|
@@ -9,23 +9,29 @@ from typing import TYPE_CHECKING
|
|
|
9
9
|
from cudf_polars.dsl.ir import Union
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
|
-
from collections.abc import Iterator
|
|
12
|
+
from collections.abc import Iterator
|
|
13
13
|
|
|
14
14
|
from cudf_polars.containers import DataFrame
|
|
15
|
+
from cudf_polars.dsl.expr import NamedExpr
|
|
15
16
|
from cudf_polars.dsl.nodebase import Node
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class PartitionInfo:
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
20
|
+
"""Partitioning information."""
|
|
21
|
+
|
|
22
|
+
__slots__ = ("count", "partitioned_on")
|
|
23
|
+
count: int
|
|
24
|
+
"""Partition count."""
|
|
25
|
+
partitioned_on: tuple[NamedExpr, ...]
|
|
26
|
+
"""Columns the data is hash-partitioned on."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
count: int,
|
|
31
|
+
partitioned_on: tuple[NamedExpr, ...] = (),
|
|
32
|
+
):
|
|
28
33
|
self.count = count
|
|
34
|
+
self.partitioned_on = partitioned_on
|
|
29
35
|
|
|
30
36
|
def keys(self, node: Node) -> Iterator[tuple[str, int]]:
|
|
31
37
|
"""Return the partitioned keys for a given node."""
|
|
@@ -38,6 +44,6 @@ def get_key_name(node: Node) -> str:
|
|
|
38
44
|
return f"{type(node).__name__.lower()}-{hash(node)}"
|
|
39
45
|
|
|
40
46
|
|
|
41
|
-
def _concat(dfs:
|
|
47
|
+
def _concat(*dfs: DataFrame) -> DataFrame:
|
|
42
48
|
# Concatenate a sequence of DataFrames vertically
|
|
43
49
|
return Union.do_evaluate(None, *dfs)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
"""Dask serialization."""
|
|
@@ -12,7 +12,7 @@ from distributed.utils import log_errors
|
|
|
12
12
|
import pylibcudf as plc
|
|
13
13
|
import rmm
|
|
14
14
|
|
|
15
|
-
from cudf_polars.containers import DataFrame
|
|
15
|
+
from cudf_polars.containers import Column, DataFrame
|
|
16
16
|
|
|
17
17
|
__all__ = ["register"]
|
|
18
18
|
|
|
@@ -20,8 +20,8 @@ __all__ = ["register"]
|
|
|
20
20
|
def register() -> None:
|
|
21
21
|
"""Register dask serialization routines for DataFrames."""
|
|
22
22
|
|
|
23
|
-
@cuda_serialize.register(DataFrame)
|
|
24
|
-
def _(x: DataFrame):
|
|
23
|
+
@cuda_serialize.register((Column, DataFrame))
|
|
24
|
+
def _(x: DataFrame | Column):
|
|
25
25
|
with log_errors():
|
|
26
26
|
header, frames = x.serialize()
|
|
27
27
|
return header, list(frames) # Dask expect a list of frames
|
|
@@ -29,11 +29,17 @@ def register() -> None:
|
|
|
29
29
|
@cuda_deserialize.register(DataFrame)
|
|
30
30
|
def _(header, frames):
|
|
31
31
|
with log_errors():
|
|
32
|
-
|
|
33
|
-
return DataFrame.deserialize(header,
|
|
32
|
+
metadata, gpudata = frames
|
|
33
|
+
return DataFrame.deserialize(header, (metadata, plc.gpumemoryview(gpudata)))
|
|
34
34
|
|
|
35
|
-
@
|
|
36
|
-
def _(
|
|
35
|
+
@cuda_deserialize.register(Column)
|
|
36
|
+
def _(header, frames):
|
|
37
|
+
with log_errors():
|
|
38
|
+
metadata, gpudata = frames
|
|
39
|
+
return Column.deserialize(header, (metadata, plc.gpumemoryview(gpudata)))
|
|
40
|
+
|
|
41
|
+
@dask_serialize.register((Column, DataFrame))
|
|
42
|
+
def _(x: DataFrame | Column):
|
|
37
43
|
with log_errors():
|
|
38
44
|
header, (metadata, gpudata) = x.serialize()
|
|
39
45
|
|
|
@@ -57,3 +63,11 @@ def register() -> None:
|
|
|
57
63
|
# Copy the second frame (the gpudata in host memory) back to the gpu
|
|
58
64
|
frames = frames[0], plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1]))
|
|
59
65
|
return DataFrame.deserialize(header, frames)
|
|
66
|
+
|
|
67
|
+
@dask_deserialize.register(Column)
|
|
68
|
+
def _(header, frames) -> Column:
|
|
69
|
+
with log_errors():
|
|
70
|
+
assert len(frames) == 2
|
|
71
|
+
# Copy the second frame (the gpudata in host memory) back to the gpu
|
|
72
|
+
frames = frames[0], plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1]))
|
|
73
|
+
return Column.deserialize(header, frames)
|
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Parallel GroupBy Logic."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import itertools
|
|
8
|
+
import uuid
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
import pylibcudf as plc
|
|
12
|
+
|
|
13
|
+
from cudf_polars.dsl.expr import (
|
|
14
|
+
Agg,
|
|
15
|
+
BinOp,
|
|
16
|
+
Cast,
|
|
17
|
+
Col,
|
|
18
|
+
Len,
|
|
19
|
+
Literal,
|
|
20
|
+
NamedExpr,
|
|
21
|
+
UnaryFunction,
|
|
22
|
+
)
|
|
23
|
+
from cudf_polars.dsl.ir import GroupBy, Select
|
|
24
|
+
from cudf_polars.dsl.traversal import traversal
|
|
25
|
+
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
|
|
26
|
+
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
|
|
27
|
+
from cudf_polars.experimental.shuffle import Shuffle
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from collections.abc import MutableMapping
|
|
31
|
+
|
|
32
|
+
from cudf_polars.dsl.expr import Expr
|
|
33
|
+
from cudf_polars.dsl.ir import IR
|
|
34
|
+
from cudf_polars.experimental.parallel import LowerIRTransformer
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Supported multi-partition aggregations
|
|
38
|
+
_GB_AGG_SUPPORTED = ("sum", "count", "mean", "min", "max")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def combine(
|
|
42
|
+
*decompositions: tuple[NamedExpr, list[NamedExpr], list[NamedExpr]],
|
|
43
|
+
) -> tuple[list[NamedExpr], list[NamedExpr], list[NamedExpr]]:
|
|
44
|
+
"""
|
|
45
|
+
Combine multiple groupby-aggregation decompositions.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
decompositions
|
|
50
|
+
Packed sequence of `decompose` results.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
Unified groupby-aggregation decomposition.
|
|
55
|
+
"""
|
|
56
|
+
selections, aggregations, reductions = zip(*decompositions, strict=True)
|
|
57
|
+
assert all(isinstance(ne, NamedExpr) for ne in selections)
|
|
58
|
+
return (
|
|
59
|
+
list(selections),
|
|
60
|
+
list(itertools.chain.from_iterable(aggregations)),
|
|
61
|
+
list(itertools.chain.from_iterable(reductions)),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def decompose(
|
|
66
|
+
name: str, expr: Expr
|
|
67
|
+
) -> tuple[NamedExpr, list[NamedExpr], list[NamedExpr]]:
|
|
68
|
+
"""
|
|
69
|
+
Decompose a groupby-aggregation expression.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
name
|
|
74
|
+
Output schema name.
|
|
75
|
+
expr
|
|
76
|
+
The aggregation expression for a single column.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
NamedExpr
|
|
81
|
+
The expression selecting the *output* column or columns.
|
|
82
|
+
list[NamedExpr]
|
|
83
|
+
The initial aggregation expressions.
|
|
84
|
+
list[NamedExpr]
|
|
85
|
+
The reduction expressions.
|
|
86
|
+
"""
|
|
87
|
+
dtype = expr.dtype
|
|
88
|
+
expr = expr.children[0] if isinstance(expr, Cast) else expr
|
|
89
|
+
|
|
90
|
+
unary_op: list[Any] = []
|
|
91
|
+
if isinstance(expr, UnaryFunction) and expr.is_pointwise:
|
|
92
|
+
# TODO: Handle multiple/sequential unary ops
|
|
93
|
+
unary_op = [expr.name, expr.options]
|
|
94
|
+
expr = expr.children[0]
|
|
95
|
+
|
|
96
|
+
def _wrap_unary(select):
|
|
97
|
+
# Helper function to wrap the final selection
|
|
98
|
+
# in a UnaryFunction (when necessary)
|
|
99
|
+
if unary_op:
|
|
100
|
+
return UnaryFunction(select.dtype, *unary_op, select)
|
|
101
|
+
return select
|
|
102
|
+
|
|
103
|
+
if isinstance(expr, Len):
|
|
104
|
+
selection = NamedExpr(name, _wrap_unary(Col(dtype, name)))
|
|
105
|
+
aggregation = [NamedExpr(name, expr)]
|
|
106
|
+
reduction = [
|
|
107
|
+
NamedExpr(
|
|
108
|
+
name,
|
|
109
|
+
# Sum reduction may require casting.
|
|
110
|
+
# Do it for all cases to be safe (for now)
|
|
111
|
+
Cast(dtype, Agg(dtype, "sum", None, Col(dtype, name))),
|
|
112
|
+
)
|
|
113
|
+
]
|
|
114
|
+
return selection, aggregation, reduction
|
|
115
|
+
if isinstance(expr, Agg):
|
|
116
|
+
if expr.name in ("sum", "count", "min", "max"):
|
|
117
|
+
if expr.name in ("sum", "count"):
|
|
118
|
+
aggfunc = "sum"
|
|
119
|
+
else:
|
|
120
|
+
aggfunc = expr.name
|
|
121
|
+
selection = NamedExpr(name, _wrap_unary(Col(dtype, name)))
|
|
122
|
+
aggregation = [NamedExpr(name, expr)]
|
|
123
|
+
reduction = [
|
|
124
|
+
NamedExpr(
|
|
125
|
+
name,
|
|
126
|
+
# Sum reduction may require casting.
|
|
127
|
+
# Do it for all cases to be safe (for now)
|
|
128
|
+
Cast(dtype, Agg(dtype, aggfunc, None, Col(dtype, name))),
|
|
129
|
+
)
|
|
130
|
+
]
|
|
131
|
+
return selection, aggregation, reduction
|
|
132
|
+
elif expr.name == "mean":
|
|
133
|
+
(child,) = expr.children
|
|
134
|
+
token = str(uuid.uuid4().hex) # prevent collisions with user's names
|
|
135
|
+
(sum, count), aggregations, reductions = combine(
|
|
136
|
+
decompose(f"{name}__mean_sum_{token}", Agg(dtype, "sum", None, child)),
|
|
137
|
+
decompose(f"{name}__mean_count_{token}", Len(dtype)),
|
|
138
|
+
)
|
|
139
|
+
selection = NamedExpr(
|
|
140
|
+
name,
|
|
141
|
+
_wrap_unary(
|
|
142
|
+
BinOp(
|
|
143
|
+
dtype, plc.binaryop.BinaryOperator.DIV, sum.value, count.value
|
|
144
|
+
)
|
|
145
|
+
),
|
|
146
|
+
)
|
|
147
|
+
return selection, aggregations, reductions
|
|
148
|
+
else:
|
|
149
|
+
raise NotImplementedError(
|
|
150
|
+
"GroupBy does not support multiple partitions "
|
|
151
|
+
f"for this aggregation type:\n{type(expr)}\n"
|
|
152
|
+
f"Only {_GB_AGG_SUPPORTED} are supported."
|
|
153
|
+
)
|
|
154
|
+
elif isinstance(expr, BinOp):
|
|
155
|
+
# The expectation is that each operand of the BinOp is decomposable.
|
|
156
|
+
# We can then combine the decompositions of the operands to form the
|
|
157
|
+
# decomposition of the BinOp.
|
|
158
|
+
(left, right) = expr.children
|
|
159
|
+
token = str(uuid.uuid4().hex) # prevent collisions with user's names
|
|
160
|
+
(left_selection, right_selection), aggregations, reductions = combine(
|
|
161
|
+
decompose(f"{name}__left_{token}", left),
|
|
162
|
+
decompose(f"{name}__right_{token}", right),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
selection = NamedExpr(
|
|
166
|
+
name,
|
|
167
|
+
_wrap_unary(
|
|
168
|
+
BinOp(dtype, expr.op, left_selection.value, right_selection.value)
|
|
169
|
+
),
|
|
170
|
+
)
|
|
171
|
+
return selection, aggregations, reductions
|
|
172
|
+
|
|
173
|
+
elif isinstance(expr, Literal):
|
|
174
|
+
selection = NamedExpr(name, _wrap_unary(Col(dtype, name)))
|
|
175
|
+
aggregation = []
|
|
176
|
+
reduction = [NamedExpr(name, expr)]
|
|
177
|
+
return selection, aggregation, reduction
|
|
178
|
+
|
|
179
|
+
else: # pragma: no cover
|
|
180
|
+
# Unsupported expression
|
|
181
|
+
raise NotImplementedError(
|
|
182
|
+
f"GroupBy does not support multiple partitions for this expression:\n{expr}"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@lower_ir_node.register(GroupBy)
|
|
187
|
+
def _(
|
|
188
|
+
ir: GroupBy, rec: LowerIRTransformer
|
|
189
|
+
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
|
|
190
|
+
# Extract child partitioning
|
|
191
|
+
child, partition_info = rec(ir.children[0])
|
|
192
|
+
|
|
193
|
+
# Handle single-partition case
|
|
194
|
+
if partition_info[child].count == 1:
|
|
195
|
+
single_part_node = ir.reconstruct([child])
|
|
196
|
+
partition_info[single_part_node] = partition_info[child]
|
|
197
|
+
return single_part_node, partition_info
|
|
198
|
+
|
|
199
|
+
# Check group-by keys
|
|
200
|
+
if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.keys])):
|
|
201
|
+
raise NotImplementedError(
|
|
202
|
+
f"GroupBy does not support multiple partitions for keys:\n{ir.keys}"
|
|
203
|
+
) # pragma: no cover
|
|
204
|
+
|
|
205
|
+
# Check if we are dealing with any high-cardinality columns
|
|
206
|
+
post_aggregation_count = 1 # Default tree reduction
|
|
207
|
+
groupby_key_columns = [ne.name for ne in ir.keys]
|
|
208
|
+
cardinality_factor = {
|
|
209
|
+
c: min(f, 1.0)
|
|
210
|
+
for c, f in ir.config_options.get(
|
|
211
|
+
"executor_options.cardinality_factor", default={}
|
|
212
|
+
).items()
|
|
213
|
+
if c in groupby_key_columns
|
|
214
|
+
}
|
|
215
|
+
if cardinality_factor:
|
|
216
|
+
# The `cardinality_factor` dictionary can be used
|
|
217
|
+
# to specify a mapping between column names and
|
|
218
|
+
# cardinality "factors". Each factor estimates the
|
|
219
|
+
# fractional number of unique values in the column.
|
|
220
|
+
# Each value should be in the range (0, 1].
|
|
221
|
+
child_count = partition_info[child].count
|
|
222
|
+
post_aggregation_count = max(
|
|
223
|
+
int(max(cardinality_factor.values()) * child_count),
|
|
224
|
+
1,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Decompose the aggregation requests into three distinct phases
|
|
228
|
+
selection_exprs, piecewise_exprs, reduction_exprs = combine(
|
|
229
|
+
*(decompose(agg.name, agg.value) for agg in ir.agg_requests)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Partition-wise groupby operation
|
|
233
|
+
pwise_schema = {k.name: k.value.dtype for k in ir.keys} | {
|
|
234
|
+
k.name: k.value.dtype for k in piecewise_exprs
|
|
235
|
+
}
|
|
236
|
+
gb_pwise = GroupBy(
|
|
237
|
+
pwise_schema,
|
|
238
|
+
ir.keys,
|
|
239
|
+
piecewise_exprs,
|
|
240
|
+
ir.maintain_order,
|
|
241
|
+
ir.options,
|
|
242
|
+
ir.config_options,
|
|
243
|
+
child,
|
|
244
|
+
)
|
|
245
|
+
child_count = partition_info[child].count
|
|
246
|
+
partition_info[gb_pwise] = PartitionInfo(count=child_count)
|
|
247
|
+
|
|
248
|
+
# Add Shuffle node if necessary
|
|
249
|
+
gb_inter: GroupBy | Shuffle = gb_pwise
|
|
250
|
+
if post_aggregation_count > 1:
|
|
251
|
+
if ir.maintain_order: # pragma: no cover
|
|
252
|
+
raise NotImplementedError(
|
|
253
|
+
"maintain_order not supported for multiple output partitions."
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
gb_inter = Shuffle(
|
|
257
|
+
pwise_schema,
|
|
258
|
+
ir.keys,
|
|
259
|
+
ir.config_options,
|
|
260
|
+
gb_pwise,
|
|
261
|
+
)
|
|
262
|
+
partition_info[gb_inter] = PartitionInfo(count=post_aggregation_count)
|
|
263
|
+
|
|
264
|
+
# Tree reduction if post_aggregation_count==1
|
|
265
|
+
# (Otherwise, this is another partition-wise op)
|
|
266
|
+
gb_reduce = GroupBy(
|
|
267
|
+
{k.name: k.value.dtype for k in ir.keys}
|
|
268
|
+
| {k.name: k.value.dtype for k in reduction_exprs},
|
|
269
|
+
ir.keys,
|
|
270
|
+
reduction_exprs,
|
|
271
|
+
ir.maintain_order,
|
|
272
|
+
ir.options,
|
|
273
|
+
ir.config_options,
|
|
274
|
+
gb_inter,
|
|
275
|
+
)
|
|
276
|
+
partition_info[gb_reduce] = PartitionInfo(count=post_aggregation_count)
|
|
277
|
+
|
|
278
|
+
# Final Select phase
|
|
279
|
+
aggregated = {ne.name: ne for ne in selection_exprs}
|
|
280
|
+
new_node = Select(
|
|
281
|
+
ir.schema,
|
|
282
|
+
[
|
|
283
|
+
# Select the aggregated data or the original column
|
|
284
|
+
aggregated.get(name, NamedExpr(name, Col(dtype, name)))
|
|
285
|
+
for name, dtype in ir.schema.items()
|
|
286
|
+
],
|
|
287
|
+
False, # noqa: FBT003
|
|
288
|
+
gb_reduce,
|
|
289
|
+
)
|
|
290
|
+
partition_info[new_node] = PartitionInfo(count=post_aggregation_count)
|
|
291
|
+
return new_node, partition_info
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _tree_node(do_evaluate, nbatch, *args):
|
|
295
|
+
return do_evaluate(*args[nbatch:], _concat(*args[:nbatch]))
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@generate_ir_tasks.register(GroupBy)
|
|
299
|
+
def _(
|
|
300
|
+
ir: GroupBy, partition_info: MutableMapping[IR, PartitionInfo]
|
|
301
|
+
) -> MutableMapping[Any, Any]:
|
|
302
|
+
(child,) = ir.children
|
|
303
|
+
child_count = partition_info[child].count
|
|
304
|
+
child_name = get_key_name(child)
|
|
305
|
+
output_count = partition_info[ir].count
|
|
306
|
+
|
|
307
|
+
if output_count == child_count:
|
|
308
|
+
return {
|
|
309
|
+
key: (
|
|
310
|
+
ir.do_evaluate,
|
|
311
|
+
*ir._non_child_args,
|
|
312
|
+
(child_name, i),
|
|
313
|
+
)
|
|
314
|
+
for i, key in enumerate(partition_info[ir].keys(ir))
|
|
315
|
+
}
|
|
316
|
+
elif output_count != 1: # pragma: no cover
|
|
317
|
+
raise ValueError(f"Expected single partition, got {output_count}")
|
|
318
|
+
|
|
319
|
+
# Simple N-ary tree reduction
|
|
320
|
+
j = 0
|
|
321
|
+
n_ary = ir.config_options.get("executor_options.groupby_n_ary", default=32)
|
|
322
|
+
graph: MutableMapping[Any, Any] = {}
|
|
323
|
+
name = get_key_name(ir)
|
|
324
|
+
keys: list[Any] = [(child_name, i) for i in range(child_count)]
|
|
325
|
+
while len(keys) > n_ary:
|
|
326
|
+
new_keys: list[Any] = []
|
|
327
|
+
for i, k in enumerate(range(0, len(keys), n_ary)):
|
|
328
|
+
batch = keys[k : k + n_ary]
|
|
329
|
+
graph[(name, j, i)] = (
|
|
330
|
+
_tree_node,
|
|
331
|
+
ir.do_evaluate,
|
|
332
|
+
len(batch),
|
|
333
|
+
*batch,
|
|
334
|
+
*ir._non_child_args,
|
|
335
|
+
)
|
|
336
|
+
new_keys.append((name, j, i))
|
|
337
|
+
j += 1
|
|
338
|
+
keys = new_keys
|
|
339
|
+
graph[(name, 0)] = (
|
|
340
|
+
_tree_node,
|
|
341
|
+
ir.do_evaluate,
|
|
342
|
+
len(keys),
|
|
343
|
+
*keys,
|
|
344
|
+
*ir._non_child_args,
|
|
345
|
+
)
|
|
346
|
+
return graph
|