cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,137 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # TODO: Document StructFunction to remove noqa
4
+ # ruff: noqa: D101
5
+ """Struct DSL nodes."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from enum import IntEnum, auto
10
+ from io import StringIO
11
+ from typing import TYPE_CHECKING, Any, ClassVar
12
+
13
+ import pylibcudf as plc
14
+
15
+ from cudf_polars.containers import Column
16
+ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
17
+
18
+ if TYPE_CHECKING:
19
+ from typing_extensions import Self
20
+
21
+ from polars.polars import _expr_nodes as pl_expr
22
+
23
+ from cudf_polars.containers import DataFrame, DataType
24
+
25
+ __all__ = ["StructFunction"]
26
+
27
+
28
+ class StructFunction(Expr):
29
+ class Name(IntEnum):
30
+ """Internal and picklable representation of polars' `StructFunction`."""
31
+
32
+ FieldByName = auto()
33
+ RenameFields = auto()
34
+ PrefixFields = auto()
35
+ SuffixFields = auto()
36
+ JsonEncode = auto()
37
+ WithFields = auto() # TODO: https://github.com/rapidsai/cudf/issues/19284
38
+ MapFieldNames = auto() # TODO: https://github.com/rapidsai/cudf/issues/19285
39
+ FieldByIndex = auto()
40
+ MultipleFields = (
41
+ auto()
42
+ ) # https://github.com/pola-rs/polars/pull/23022#issuecomment-2933910958
43
+
44
+ @classmethod
45
+ def from_polars(cls, obj: pl_expr.StructFunction) -> Self:
46
+ """Convert from polars' `StructFunction`."""
47
+ try:
48
+ function, name = str(obj).split(".", maxsplit=1)
49
+ except ValueError:
50
+ # Failed to unpack string
51
+ function = None
52
+ if function != "StructFunction":
53
+ raise ValueError("StructFunction required")
54
+ return getattr(cls, name)
55
+
56
+ __slots__ = ("name", "options")
57
+ _non_child = ("dtype", "name", "options")
58
+
59
+ _supported_ops: ClassVar[set[Name]] = {
60
+ Name.FieldByName,
61
+ Name.RenameFields,
62
+ Name.PrefixFields,
63
+ Name.SuffixFields,
64
+ Name.JsonEncode,
65
+ }
66
+
67
+ def __init__(
68
+ self,
69
+ dtype: DataType,
70
+ name: StructFunction.Name,
71
+ options: tuple[Any, ...],
72
+ *children: Expr,
73
+ ) -> None:
74
+ self.dtype = dtype
75
+ self.options = options
76
+ self.name = name
77
+ self.children = children
78
+ self.is_pointwise = True
79
+ if self.name not in self._supported_ops:
80
+ raise NotImplementedError(
81
+ f"Struct function {self.name}"
82
+ ) # pragma: no cover
83
+
84
+ def do_evaluate(
85
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
86
+ ) -> Column:
87
+ """Evaluate this expression given a dataframe for context."""
88
+ columns = [child.evaluate(df, context=context) for child in self.children]
89
+ (column,) = columns
90
+ if self.name == StructFunction.Name.FieldByName:
91
+ field_index = next(
92
+ (
93
+ i
94
+ for i, field in enumerate(self.children[0].dtype.polars.fields)
95
+ if field.name == self.options[0]
96
+ ),
97
+ None,
98
+ )
99
+ assert field_index is not None
100
+ return Column(
101
+ column.obj.children()[field_index],
102
+ dtype=self.dtype,
103
+ )
104
+ elif self.name == StructFunction.Name.JsonEncode:
105
+ # Once https://github.com/rapidsai/cudf/issues/19338 is implemented,
106
+ # we can use do this conversion on host.
107
+ buff = StringIO()
108
+ target = plc.io.SinkInfo([buff])
109
+ table = plc.Table(column.obj.children())
110
+ metadata = plc.io.TableWithMetadata(
111
+ table,
112
+ [(field.name, []) for field in self.children[0].dtype.polars.fields],
113
+ )
114
+ options = (
115
+ plc.io.json.JsonWriterOptions.builder(target, table)
116
+ .lines(val=True)
117
+ .na_rep("null")
118
+ .include_nulls(val=True)
119
+ .metadata(metadata)
120
+ .utf8_escaped(val=False)
121
+ .build()
122
+ )
123
+ plc.io.json.write_json(options)
124
+ return Column(
125
+ plc.Column.from_iterable_of_py(buff.getvalue().split()),
126
+ dtype=self.dtype,
127
+ )
128
+ elif self.name in {
129
+ StructFunction.Name.RenameFields,
130
+ StructFunction.Name.PrefixFields,
131
+ StructFunction.Name.SuffixFields,
132
+ }:
133
+ return column
134
+ else:
135
+ raise NotImplementedError(
136
+ f"Struct function {self.name}"
137
+ ) # pragma: no cover
@@ -0,0 +1,49 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # TODO: remove need for this
4
+ # ruff: noqa: D101
5
+ """DSL nodes for ternary operations."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ import pylibcudf as plc
12
+
13
+ from cudf_polars.containers import Column
14
+ from cudf_polars.dsl.expressions.base import (
15
+ ExecutionContext,
16
+ Expr,
17
+ )
18
+
19
+ if TYPE_CHECKING:
20
+ from cudf_polars.containers import DataFrame, DataType
21
+
22
+
23
+ __all__ = ["Ternary"]
24
+
25
+
26
+ class Ternary(Expr):
27
+ __slots__ = ()
28
+ _non_child = ("dtype",)
29
+
30
+ def __init__(
31
+ self, dtype: DataType, when: Expr, then: Expr, otherwise: Expr
32
+ ) -> None:
33
+ self.dtype = dtype
34
+ self.children = (when, then, otherwise)
35
+ self.is_pointwise = True
36
+
37
+ def do_evaluate(
38
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
39
+ ) -> Column:
40
+ """Evaluate this expression given a dataframe for context."""
41
+ when, then, otherwise = (
42
+ child.evaluate(df, context=context) for child in self.children
43
+ )
44
+ then_obj = then.obj_scalar if then.is_scalar else then.obj
45
+ otherwise_obj = otherwise.obj_scalar if otherwise.is_scalar else otherwise.obj
46
+ return Column(
47
+ plc.copying.copy_if_else(then_obj, otherwise_obj, when.obj),
48
+ dtype=self.dtype,
49
+ )
@@ -0,0 +1,517 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # TODO: remove need for this
4
+ """DSL nodes for unary operations."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Any, ClassVar, cast
9
+
10
+ from typing_extensions import assert_never
11
+
12
+ import pylibcudf as plc
13
+
14
+ from cudf_polars.containers import Column
15
+ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
16
+ from cudf_polars.dsl.expressions.literal import Literal
17
+ from cudf_polars.utils import dtypes
18
+ from cudf_polars.utils.versions import POLARS_VERSION_LT_129
19
+
20
+ if TYPE_CHECKING:
21
+ from cudf_polars.containers import DataFrame, DataType
22
+
23
+ __all__ = ["Cast", "Len", "UnaryFunction"]
24
+
25
+
26
+ class Cast(Expr):
27
+ """Class representing a cast of an expression."""
28
+
29
+ __slots__ = ()
30
+ _non_child = ("dtype",)
31
+
32
+ def __init__(self, dtype: DataType, value: Expr) -> None:
33
+ self.dtype = dtype
34
+ self.children = (value,)
35
+ self.is_pointwise = True
36
+ if not dtypes.can_cast(value.dtype.plc, self.dtype.plc):
37
+ raise NotImplementedError(
38
+ f"Can't cast {value.dtype.id().name} to {self.dtype.id().name}"
39
+ )
40
+
41
+ def do_evaluate(
42
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
43
+ ) -> Column:
44
+ """Evaluate this expression given a dataframe for context."""
45
+ (child,) = self.children
46
+ column = child.evaluate(df, context=context)
47
+ return column.astype(self.dtype)
48
+
49
+
50
+ class Len(Expr):
51
+ """Class representing the length of an expression."""
52
+
53
+ def __init__(self, dtype: DataType) -> None:
54
+ self.dtype = dtype
55
+ self.children = ()
56
+ self.is_pointwise = False
57
+
58
+ def do_evaluate(
59
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
60
+ ) -> Column:
61
+ """Evaluate this expression given a dataframe for context."""
62
+ return Column(
63
+ plc.Column.from_scalar(
64
+ plc.Scalar.from_py(df.num_rows, self.dtype.plc),
65
+ 1,
66
+ ),
67
+ dtype=self.dtype,
68
+ )
69
+
70
+ @property
71
+ def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102
72
+ return plc.aggregation.count(plc.types.NullPolicy.INCLUDE)
73
+
74
+
75
+ class UnaryFunction(Expr):
76
+ """Class representing unary functions of an expression."""
77
+
78
+ __slots__ = ("name", "options")
79
+ _non_child = ("dtype", "name", "options")
80
+
81
+ # Note: log, and pow are handled via translation to binops
82
+ _OP_MAPPING: ClassVar[dict[str, plc.unary.UnaryOperator]] = {
83
+ "sin": plc.unary.UnaryOperator.SIN,
84
+ "cos": plc.unary.UnaryOperator.COS,
85
+ "tan": plc.unary.UnaryOperator.TAN,
86
+ "arcsin": plc.unary.UnaryOperator.ARCSIN,
87
+ "arccos": plc.unary.UnaryOperator.ARCCOS,
88
+ "arctan": plc.unary.UnaryOperator.ARCTAN,
89
+ "sinh": plc.unary.UnaryOperator.SINH,
90
+ "cosh": plc.unary.UnaryOperator.COSH,
91
+ "tanh": plc.unary.UnaryOperator.TANH,
92
+ "arcsinh": plc.unary.UnaryOperator.ARCSINH,
93
+ "arccosh": plc.unary.UnaryOperator.ARCCOSH,
94
+ "arctanh": plc.unary.UnaryOperator.ARCTANH,
95
+ "exp": plc.unary.UnaryOperator.EXP,
96
+ "sqrt": plc.unary.UnaryOperator.SQRT,
97
+ "cbrt": plc.unary.UnaryOperator.CBRT,
98
+ "ceil": plc.unary.UnaryOperator.CEIL,
99
+ "floor": plc.unary.UnaryOperator.FLOOR,
100
+ "abs": plc.unary.UnaryOperator.ABS,
101
+ "bit_invert": plc.unary.UnaryOperator.BIT_INVERT,
102
+ "not": plc.unary.UnaryOperator.NOT,
103
+ "negate": plc.unary.UnaryOperator.NEGATE,
104
+ }
105
+ _supported_misc_fns = frozenset(
106
+ {
107
+ "as_struct",
108
+ "drop_nulls",
109
+ "fill_null",
110
+ "fill_null_with_strategy",
111
+ "mask_nans",
112
+ "null_count",
113
+ "rank",
114
+ "round",
115
+ "set_sorted",
116
+ "top_k",
117
+ "unique",
118
+ "value_counts",
119
+ }
120
+ )
121
+ _supported_cum_aggs = frozenset(
122
+ {
123
+ "cum_min",
124
+ "cum_max",
125
+ "cum_prod",
126
+ "cum_sum",
127
+ }
128
+ )
129
+ _supported_fns = frozenset().union(
130
+ _supported_misc_fns, _supported_cum_aggs, _OP_MAPPING.keys()
131
+ )
132
+
133
+ def __init__(
134
+ self, dtype: DataType, name: str, options: tuple[Any, ...], *children: Expr
135
+ ) -> None:
136
+ self.dtype = dtype
137
+ self.name = name
138
+ self.options = options
139
+ self.children = children
140
+ self.is_pointwise = self.name not in (
141
+ "as_struct",
142
+ "cum_max",
143
+ "cum_min",
144
+ "cum_prod",
145
+ "cum_sum",
146
+ "drop_nulls",
147
+ "rank",
148
+ "top_k",
149
+ "unique",
150
+ )
151
+
152
+ if self.name not in UnaryFunction._supported_fns:
153
+ raise NotImplementedError(f"Unary function {name=}")
154
+ if self.name in UnaryFunction._supported_cum_aggs:
155
+ (reverse,) = self.options
156
+ if reverse:
157
+ raise NotImplementedError(
158
+ "reverse=True is not supported for cumulative aggregations"
159
+ )
160
+ if self.name == "fill_null_with_strategy" and self.options[1] not in {0, None}:
161
+ raise NotImplementedError(
162
+ "Filling null values with limit specified is not yet supported."
163
+ )
164
+ if self.name == "rank":
165
+ method, _, _ = self.options
166
+ if method not in {"average", "min", "max", "dense", "ordinal"}:
167
+ raise NotImplementedError(
168
+ f"ranking with {method=} is not yet supported"
169
+ )
170
+
171
+ def do_evaluate(
172
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
173
+ ) -> Column:
174
+ """Evaluate this expression given a dataframe for context."""
175
+ if self.name == "mask_nans":
176
+ (child,) = self.children
177
+ return child.evaluate(df, context=context).mask_nans()
178
+ if self.name == "null_count":
179
+ (column,) = (child.evaluate(df, context=context) for child in self.children)
180
+ return Column(
181
+ plc.Column.from_scalar(
182
+ plc.Scalar.from_py(column.null_count, self.dtype.plc),
183
+ 1,
184
+ ),
185
+ dtype=self.dtype,
186
+ )
187
+ if self.name == "round":
188
+ round_mode = "half_away_from_zero"
189
+ if POLARS_VERSION_LT_129:
190
+ (decimal_places,) = self.options # pragma: no cover
191
+ else:
192
+ # pragma: no cover
193
+ (
194
+ decimal_places,
195
+ round_mode,
196
+ ) = self.options
197
+ (values,) = (child.evaluate(df, context=context) for child in self.children)
198
+ return Column(
199
+ plc.round.round(
200
+ values.obj,
201
+ decimal_places,
202
+ (
203
+ plc.round.RoundingMethod.HALF_EVEN
204
+ if round_mode == "half_to_even"
205
+ else plc.round.RoundingMethod.HALF_UP
206
+ ),
207
+ ),
208
+ dtype=self.dtype,
209
+ ).sorted_like(values) # pragma: no cover
210
+ elif self.name == "unique":
211
+ (maintain_order,) = self.options
212
+ (values,) = (child.evaluate(df, context=context) for child in self.children)
213
+ # Only one column, so keep_any is the same as keep_first
214
+ # for stable distinct
215
+ keep = plc.stream_compaction.DuplicateKeepOption.KEEP_ANY
216
+ if values.is_sorted:
217
+ maintain_order = True
218
+ result = plc.stream_compaction.unique(
219
+ plc.Table([values.obj]),
220
+ [0],
221
+ keep,
222
+ plc.types.NullEquality.EQUAL,
223
+ )
224
+ else:
225
+ distinct = (
226
+ plc.stream_compaction.stable_distinct
227
+ if maintain_order
228
+ else plc.stream_compaction.distinct
229
+ )
230
+ result = distinct(
231
+ plc.Table([values.obj]),
232
+ [0],
233
+ keep,
234
+ plc.types.NullEquality.EQUAL,
235
+ plc.types.NanEquality.ALL_EQUAL,
236
+ )
237
+ (column,) = result.columns()
238
+ result = Column(column, dtype=self.dtype)
239
+ if maintain_order:
240
+ result = result.sorted_like(values)
241
+ return result
242
+ elif self.name == "set_sorted":
243
+ (column,) = (child.evaluate(df, context=context) for child in self.children)
244
+ (asc,) = self.options
245
+ order = (
246
+ plc.types.Order.ASCENDING
247
+ if asc == "ascending"
248
+ else plc.types.Order.DESCENDING
249
+ )
250
+ null_order = plc.types.NullOrder.BEFORE
251
+ if column.null_count > 0 and (n := column.size) > 1:
252
+ # PERF: This invokes four stream synchronisations!
253
+ has_nulls_first = not plc.copying.get_element(column.obj, 0).is_valid()
254
+ has_nulls_last = not plc.copying.get_element(
255
+ column.obj, n - 1
256
+ ).is_valid()
257
+ if (order == plc.types.Order.DESCENDING and has_nulls_first) or (
258
+ order == plc.types.Order.ASCENDING and has_nulls_last
259
+ ):
260
+ null_order = plc.types.NullOrder.AFTER
261
+ return column.set_sorted(
262
+ is_sorted=plc.types.Sorted.YES,
263
+ order=order,
264
+ null_order=null_order,
265
+ )
266
+ elif self.name == "value_counts":
267
+ (sort, _, _, normalize) = self.options
268
+ count_agg = [plc.aggregation.count(plc.types.NullPolicy.INCLUDE)]
269
+ gb_requests = [
270
+ plc.groupby.GroupByRequest(
271
+ child.evaluate(df, context=context).obj, count_agg
272
+ )
273
+ for child in self.children
274
+ ]
275
+ (keys_table, (counts_table,)) = plc.groupby.GroupBy(
276
+ df.table, null_handling=plc.types.NullPolicy.INCLUDE
277
+ ).aggregate(gb_requests)
278
+ if sort:
279
+ sort_indices = plc.sorting.stable_sorted_order(
280
+ counts_table,
281
+ [plc.types.Order.DESCENDING],
282
+ [plc.types.NullOrder.BEFORE],
283
+ )
284
+ counts_table = plc.copying.gather(
285
+ counts_table, sort_indices, plc.copying.OutOfBoundsPolicy.DONT_CHECK
286
+ )
287
+ keys_table = plc.copying.gather(
288
+ keys_table, sort_indices, plc.copying.OutOfBoundsPolicy.DONT_CHECK
289
+ )
290
+ keys_col = keys_table.columns()[0]
291
+ counts_col = counts_table.columns()[0]
292
+ if normalize:
293
+ total_counts = plc.reduce.reduce(
294
+ counts_col, plc.aggregation.sum(), plc.DataType(plc.TypeId.UINT64)
295
+ )
296
+ counts_col = plc.binaryop.binary_operation(
297
+ counts_col,
298
+ total_counts,
299
+ plc.binaryop.BinaryOperator.DIV,
300
+ plc.DataType(plc.TypeId.FLOAT64),
301
+ )
302
+ elif counts_col.type().id() == plc.TypeId.INT32:
303
+ counts_col = plc.unary.cast(counts_col, plc.DataType(plc.TypeId.UINT32))
304
+
305
+ plc_column = plc.Column(
306
+ self.dtype.plc,
307
+ counts_col.size(),
308
+ None,
309
+ None,
310
+ 0,
311
+ 0,
312
+ [keys_col, counts_col],
313
+ )
314
+ return Column(plc_column, dtype=self.dtype)
315
+ elif self.name == "drop_nulls":
316
+ (column,) = (child.evaluate(df, context=context) for child in self.children)
317
+ if column.null_count == 0:
318
+ return column
319
+ return Column(
320
+ plc.stream_compaction.drop_nulls(
321
+ plc.Table([column.obj]), [0], 1
322
+ ).columns()[0],
323
+ dtype=self.dtype,
324
+ )
325
+ elif self.name == "fill_null":
326
+ column = self.children[0].evaluate(df, context=context)
327
+ if column.null_count == 0:
328
+ return column
329
+ fill_value = self.children[1]
330
+ if isinstance(fill_value, Literal):
331
+ arg = plc.Scalar.from_py(fill_value.value, fill_value.dtype.plc)
332
+ else:
333
+ evaluated = fill_value.evaluate(df, context=context)
334
+ arg = evaluated.obj_scalar if evaluated.is_scalar else evaluated.obj
335
+ if isinstance(arg, plc.Scalar) and dtypes.can_cast(
336
+ column.dtype.plc, arg.type()
337
+ ): # pragma: no cover
338
+ arg = (
339
+ Column(plc.Column.from_scalar(arg, 1), dtype=fill_value.dtype)
340
+ .astype(column.dtype)
341
+ .obj.to_scalar()
342
+ )
343
+ return Column(plc.replace.replace_nulls(column.obj, arg), dtype=self.dtype)
344
+ elif self.name == "fill_null_with_strategy":
345
+ column = self.children[0].evaluate(df, context=context)
346
+ strategy, limit = self.options
347
+ if (
348
+ column.null_count == 0
349
+ or limit == 0
350
+ or (
351
+ column.null_count == column.size and strategy not in {"zero", "one"}
352
+ )
353
+ ):
354
+ return column
355
+ if strategy == "forward":
356
+ replacement = plc.replace.ReplacePolicy.PRECEDING
357
+ elif strategy == "backward":
358
+ replacement = plc.replace.ReplacePolicy.FOLLOWING
359
+ elif strategy == "min":
360
+ replacement = plc.reduce.reduce(
361
+ column.obj,
362
+ plc.aggregation.min(),
363
+ column.dtype.plc,
364
+ )
365
+ elif strategy == "max":
366
+ replacement = plc.reduce.reduce(
367
+ column.obj,
368
+ plc.aggregation.max(),
369
+ column.dtype.plc,
370
+ )
371
+ elif strategy == "mean":
372
+ replacement = plc.reduce.reduce(
373
+ column.obj,
374
+ plc.aggregation.mean(),
375
+ plc.DataType(plc.TypeId.FLOAT64),
376
+ )
377
+ elif strategy == "zero":
378
+ replacement = plc.scalar.Scalar.from_py(0, dtype=column.dtype.plc)
379
+ elif strategy == "one":
380
+ replacement = plc.scalar.Scalar.from_py(1, dtype=column.dtype.plc)
381
+ else:
382
+ assert_never(strategy) # pragma: no cover
383
+
384
+ if strategy == "mean":
385
+ return Column(
386
+ plc.replace.replace_nulls(
387
+ plc.unary.cast(column.obj, plc.DataType(plc.TypeId.FLOAT64)),
388
+ replacement,
389
+ ),
390
+ dtype=self.dtype,
391
+ ).astype(self.dtype)
392
+ return Column(
393
+ plc.replace.replace_nulls(column.obj, replacement),
394
+ dtype=self.dtype,
395
+ )
396
+ elif self.name == "as_struct":
397
+ children = [
398
+ child.evaluate(df, context=context).obj for child in self.children
399
+ ]
400
+ return Column(
401
+ plc.Column(
402
+ data_type=self.dtype.plc,
403
+ size=children[0].size(),
404
+ data=None,
405
+ mask=None,
406
+ null_count=0,
407
+ offset=0,
408
+ children=children,
409
+ ),
410
+ dtype=self.dtype,
411
+ )
412
+ elif self.name == "rank":
413
+ (column,) = (child.evaluate(df, context=context) for child in self.children)
414
+ method_str, descending, _ = self.options
415
+
416
+ method = {
417
+ "average": plc.aggregation.RankMethod.AVERAGE,
418
+ "min": plc.aggregation.RankMethod.MIN,
419
+ "max": plc.aggregation.RankMethod.MAX,
420
+ "dense": plc.aggregation.RankMethod.DENSE,
421
+ "ordinal": plc.aggregation.RankMethod.FIRST,
422
+ }[method_str]
423
+
424
+ order = (
425
+ plc.types.Order.DESCENDING if descending else plc.types.Order.ASCENDING
426
+ )
427
+
428
+ ranked: plc.Column = plc.sorting.rank(
429
+ column.obj,
430
+ method,
431
+ order,
432
+ plc.types.NullPolicy.EXCLUDE,
433
+ plc.types.NullOrder.BEFORE if descending else plc.types.NullOrder.AFTER,
434
+ percentage=False,
435
+ )
436
+
437
+ # Min/Max/Dense/Ordinal -> IDX_DTYPE
438
+ # See https://github.com/pola-rs/polars/blob/main/crates/polars-ops/src/series/ops/rank.rs
439
+ if method_str in {"min", "max", "dense", "ordinal"}:
440
+ dest = self.dtype.plc.id()
441
+ src = ranked.type().id()
442
+ if dest == plc.TypeId.UINT32 and src != plc.TypeId.UINT32:
443
+ ranked = plc.unary.cast(ranked, plc.DataType(plc.TypeId.UINT32))
444
+ elif (
445
+ dest == plc.TypeId.UINT64 and src != plc.TypeId.UINT64
446
+ ): # pragma: no cover
447
+ ranked = plc.unary.cast(ranked, plc.DataType(plc.TypeId.UINT64))
448
+
449
+ return Column(ranked, dtype=self.dtype)
450
+ elif self.name == "top_k":
451
+ (column, k) = (
452
+ child.evaluate(df, context=context) for child in self.children
453
+ )
454
+ (reverse,) = self.options
455
+ return Column(
456
+ plc.sorting.top_k(
457
+ column.obj,
458
+ cast(Literal, self.children[1]).value,
459
+ plc.types.Order.ASCENDING
460
+ if reverse
461
+ else plc.types.Order.DESCENDING,
462
+ ),
463
+ dtype=self.dtype,
464
+ )
465
+ elif self.name in self._OP_MAPPING:
466
+ column = self.children[0].evaluate(df, context=context)
467
+ if column.dtype.plc.id() != self.dtype.id():
468
+ arg = plc.unary.cast(column.obj, self.dtype.plc)
469
+ else:
470
+ arg = column.obj
471
+ return Column(
472
+ plc.unary.unary_operation(arg, self._OP_MAPPING[self.name]),
473
+ dtype=self.dtype,
474
+ )
475
+ elif self.name in UnaryFunction._supported_cum_aggs:
476
+ column = self.children[0].evaluate(df, context=context)
477
+ plc_col = column.obj
478
+ col_type = column.dtype.plc
479
+ # cum_sum casts
480
+ # Int8, UInt8, Int16, UInt16 -> Int64 for overflow prevention
481
+ # Bool -> UInt32
482
+ # cum_prod casts integer dtypes < int64 and bool to int64
483
+ # See:
484
+ # https://github.com/pola-rs/polars/blob/main/crates/polars-ops/src/series/ops/cum_agg.rs
485
+ if (
486
+ self.name == "cum_sum"
487
+ and col_type.id()
488
+ in {
489
+ plc.TypeId.INT8,
490
+ plc.TypeId.UINT8,
491
+ plc.TypeId.INT16,
492
+ plc.TypeId.UINT16,
493
+ }
494
+ ) or (
495
+ self.name == "cum_prod"
496
+ and plc.traits.is_integral(col_type)
497
+ and plc.types.size_of(col_type) <= 4
498
+ ):
499
+ plc_col = plc.unary.cast(plc_col, plc.DataType(plc.TypeId.INT64))
500
+ elif self.name == "cum_sum" and column.dtype.plc.id() == plc.TypeId.BOOL8:
501
+ plc_col = plc.unary.cast(plc_col, plc.DataType(plc.TypeId.UINT32))
502
+ if self.name == "cum_sum":
503
+ agg = plc.aggregation.sum()
504
+ elif self.name == "cum_prod":
505
+ agg = plc.aggregation.product()
506
+ elif self.name == "cum_min":
507
+ agg = plc.aggregation.min()
508
+ elif self.name == "cum_max":
509
+ agg = plc.aggregation.max()
510
+
511
+ return Column(
512
+ plc.reduce.scan(plc_col, agg, plc.reduce.ScanType.INCLUSIVE),
513
+ dtype=self.dtype,
514
+ )
515
+ raise NotImplementedError(
516
+ f"Unimplemented unary function {self.name=}"
517
+ ) # pragma: no cover; init trips first