cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +82 -65
- cudf_polars/containers/column.py +138 -7
- cudf_polars/containers/dataframe.py +26 -39
- cudf_polars/dsl/expr.py +3 -1
- cudf_polars/dsl/expressions/aggregation.py +27 -63
- cudf_polars/dsl/expressions/base.py +40 -72
- cudf_polars/dsl/expressions/binaryop.py +5 -41
- cudf_polars/dsl/expressions/boolean.py +25 -53
- cudf_polars/dsl/expressions/datetime.py +97 -17
- cudf_polars/dsl/expressions/literal.py +27 -33
- cudf_polars/dsl/expressions/rolling.py +110 -9
- cudf_polars/dsl/expressions/selection.py +8 -26
- cudf_polars/dsl/expressions/slicing.py +47 -0
- cudf_polars/dsl/expressions/sorting.py +5 -18
- cudf_polars/dsl/expressions/string.py +33 -36
- cudf_polars/dsl/expressions/ternary.py +3 -10
- cudf_polars/dsl/expressions/unary.py +35 -75
- cudf_polars/dsl/ir.py +749 -212
- cudf_polars/dsl/nodebase.py +8 -1
- cudf_polars/dsl/to_ast.py +5 -3
- cudf_polars/dsl/translate.py +319 -171
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +292 -0
- cudf_polars/dsl/utils/groupby.py +97 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +46 -0
- cudf_polars/dsl/utils/rolling.py +113 -0
- cudf_polars/dsl/utils/windows.py +186 -0
- cudf_polars/experimental/base.py +17 -19
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsh.py +1279 -0
- cudf_polars/experimental/dask_registers.py +196 -0
- cudf_polars/experimental/distinct.py +174 -0
- cudf_polars/experimental/explain.py +127 -0
- cudf_polars/experimental/expressions.py +521 -0
- cudf_polars/experimental/groupby.py +288 -0
- cudf_polars/experimental/io.py +58 -29
- cudf_polars/experimental/join.py +353 -0
- cudf_polars/experimental/parallel.py +166 -93
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +92 -7
- cudf_polars/experimental/shuffle.py +294 -0
- cudf_polars/experimental/sort.py +45 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/utils.py +100 -0
- cudf_polars/testing/asserts.py +146 -6
- cudf_polars/testing/io.py +72 -0
- cudf_polars/testing/plugin.py +78 -76
- cudf_polars/typing/__init__.py +59 -6
- cudf_polars/utils/config.py +353 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +22 -5
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +5 -4
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/METADATA +10 -7
- cudf_polars_cu12-25.6.0.dist-info/RECORD +73 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/WHEEL +1 -1
- cudf_polars/experimental/dask_serialize.py +0 -59
- cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info/licenses}/LICENSE +0 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/top_level.txt +0 -0
|
@@ -21,8 +21,6 @@ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
|
|
|
21
21
|
from cudf_polars.dsl.expressions.literal import Literal, LiteralColumn
|
|
22
22
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
|
-
from collections.abc import Mapping
|
|
25
|
-
|
|
26
24
|
from typing_extensions import Self
|
|
27
25
|
|
|
28
26
|
from polars.polars import _expr_nodes as pl_expr
|
|
@@ -107,10 +105,10 @@ class StringFunction(Expr):
|
|
|
107
105
|
self.options = options
|
|
108
106
|
self.name = name
|
|
109
107
|
self.children = children
|
|
110
|
-
self.is_pointwise =
|
|
108
|
+
self.is_pointwise = self.name != StringFunction.Name.ConcatVertical
|
|
111
109
|
self._validate_input()
|
|
112
110
|
|
|
113
|
-
def _validate_input(self):
|
|
111
|
+
def _validate_input(self) -> None:
|
|
114
112
|
if self.name not in (
|
|
115
113
|
StringFunction.Name.ConcatVertical,
|
|
116
114
|
StringFunction.Name.Contains,
|
|
@@ -138,7 +136,7 @@ class StringFunction(Expr):
|
|
|
138
136
|
raise NotImplementedError(
|
|
139
137
|
"Regex contains only supports a scalar pattern"
|
|
140
138
|
)
|
|
141
|
-
pattern = self.children[1].value
|
|
139
|
+
pattern = self.children[1].value
|
|
142
140
|
try:
|
|
143
141
|
self._regex_program = plc.strings.regex_program.RegexProgram.create(
|
|
144
142
|
pattern,
|
|
@@ -155,7 +153,9 @@ class StringFunction(Expr):
|
|
|
155
153
|
if not all(isinstance(expr, Literal) for expr in self.children[1:]):
|
|
156
154
|
raise NotImplementedError("replace only supports scalar target")
|
|
157
155
|
target = self.children[1]
|
|
158
|
-
if target
|
|
156
|
+
# Above, we raise NotImplementedError if the target is not a Literal,
|
|
157
|
+
# so we can safely access .value here.
|
|
158
|
+
if target.value == "": # type: ignore[attr-defined]
|
|
159
159
|
raise NotImplementedError(
|
|
160
160
|
"libcudf replace does not support empty strings"
|
|
161
161
|
)
|
|
@@ -170,7 +170,14 @@ class StringFunction(Expr):
|
|
|
170
170
|
):
|
|
171
171
|
raise NotImplementedError("replace_many only supports literal inputs")
|
|
172
172
|
target = self.children[1]
|
|
173
|
-
if
|
|
173
|
+
# Above, we raise NotImplementedError if the target is not a Literal,
|
|
174
|
+
# so we can safely access .value here.
|
|
175
|
+
if (isinstance(target, Literal) and target.value == "") or (
|
|
176
|
+
isinstance(target, LiteralColumn)
|
|
177
|
+
and pc.any(
|
|
178
|
+
pc.equal(target.value.cast(pa.string()), "") # type: ignore[attr-defined]
|
|
179
|
+
).as_py()
|
|
180
|
+
):
|
|
174
181
|
raise NotImplementedError(
|
|
175
182
|
"libcudf replace_many is implemented differently from polars "
|
|
176
183
|
"for empty strings"
|
|
@@ -199,36 +206,32 @@ class StringFunction(Expr):
|
|
|
199
206
|
)
|
|
200
207
|
|
|
201
208
|
def do_evaluate(
|
|
202
|
-
self,
|
|
203
|
-
df: DataFrame,
|
|
204
|
-
*,
|
|
205
|
-
context: ExecutionContext = ExecutionContext.FRAME,
|
|
206
|
-
mapping: Mapping[Expr, Column] | None = None,
|
|
209
|
+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
|
|
207
210
|
) -> Column:
|
|
208
211
|
"""Evaluate this expression given a dataframe for context."""
|
|
209
212
|
if self.name is StringFunction.Name.ConcatVertical:
|
|
210
213
|
(child,) = self.children
|
|
211
|
-
column = child.evaluate(df, context=context
|
|
214
|
+
column = child.evaluate(df, context=context)
|
|
212
215
|
delimiter, ignore_nulls = self.options
|
|
213
|
-
if column.
|
|
216
|
+
if column.null_count > 0 and not ignore_nulls:
|
|
214
217
|
return Column(plc.Column.all_null_like(column.obj, 1))
|
|
215
218
|
return Column(
|
|
216
219
|
plc.strings.combine.join_strings(
|
|
217
220
|
column.obj,
|
|
218
|
-
plc.
|
|
219
|
-
plc.
|
|
221
|
+
plc.Scalar.from_py(delimiter, plc.DataType(plc.TypeId.STRING)),
|
|
222
|
+
plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
|
|
220
223
|
)
|
|
221
224
|
)
|
|
222
225
|
elif self.name is StringFunction.Name.Contains:
|
|
223
226
|
child, arg = self.children
|
|
224
|
-
column = child.evaluate(df, context=context
|
|
227
|
+
column = child.evaluate(df, context=context)
|
|
225
228
|
|
|
226
229
|
literal, _ = self.options
|
|
227
230
|
if literal:
|
|
228
|
-
pat = arg.evaluate(df, context=context
|
|
231
|
+
pat = arg.evaluate(df, context=context)
|
|
229
232
|
pattern = (
|
|
230
233
|
pat.obj_scalar
|
|
231
|
-
if pat.is_scalar and pat.
|
|
234
|
+
if pat.is_scalar and pat.size != column.size
|
|
232
235
|
else pat.obj
|
|
233
236
|
)
|
|
234
237
|
return Column(plc.strings.find.contains(column.obj, pattern))
|
|
@@ -241,15 +244,15 @@ class StringFunction(Expr):
|
|
|
241
244
|
assert isinstance(expr_offset, Literal)
|
|
242
245
|
assert isinstance(expr_length, Literal)
|
|
243
246
|
|
|
244
|
-
column = child.evaluate(df, context=context
|
|
247
|
+
column = child.evaluate(df, context=context)
|
|
245
248
|
# libcudf slices via [start,stop).
|
|
246
249
|
# polars slices with offset + length where start == offset
|
|
247
250
|
# stop = start + length. Negative values for start look backward
|
|
248
251
|
# from the last element of the string. If the end index would be
|
|
249
252
|
# below zero, an empty string is returned.
|
|
250
253
|
# Do this maths on the host
|
|
251
|
-
start = expr_offset.value
|
|
252
|
-
length = expr_length.value
|
|
254
|
+
start = expr_offset.value
|
|
255
|
+
length = expr_length.value
|
|
253
256
|
|
|
254
257
|
if length == 0:
|
|
255
258
|
stop = start
|
|
@@ -262,8 +265,8 @@ class StringFunction(Expr):
|
|
|
262
265
|
return Column(
|
|
263
266
|
plc.strings.slice.slice_strings(
|
|
264
267
|
column.obj,
|
|
265
|
-
plc.
|
|
266
|
-
plc.
|
|
268
|
+
plc.Scalar.from_py(start, plc.DataType(plc.TypeId.INT32)),
|
|
269
|
+
plc.Scalar.from_py(stop, plc.DataType(plc.TypeId.INT32)),
|
|
267
270
|
)
|
|
268
271
|
)
|
|
269
272
|
elif self.name in {
|
|
@@ -271,9 +274,7 @@ class StringFunction(Expr):
|
|
|
271
274
|
StringFunction.Name.StripCharsStart,
|
|
272
275
|
StringFunction.Name.StripCharsEnd,
|
|
273
276
|
}:
|
|
274
|
-
column, chars = (
|
|
275
|
-
c.evaluate(df, context=context, mapping=mapping) for c in self.children
|
|
276
|
-
)
|
|
277
|
+
column, chars = (c.evaluate(df, context=context) for c in self.children)
|
|
277
278
|
if self.name is StringFunction.Name.StripCharsStart:
|
|
278
279
|
side = plc.strings.SideType.LEFT
|
|
279
280
|
elif self.name is StringFunction.Name.StripCharsEnd:
|
|
@@ -282,10 +283,7 @@ class StringFunction(Expr):
|
|
|
282
283
|
side = plc.strings.SideType.BOTH
|
|
283
284
|
return Column(plc.strings.strip.strip(column.obj, side, chars.obj_scalar))
|
|
284
285
|
|
|
285
|
-
columns = [
|
|
286
|
-
child.evaluate(df, context=context, mapping=mapping)
|
|
287
|
-
for child in self.children
|
|
288
|
-
]
|
|
286
|
+
columns = [child.evaluate(df, context=context) for child in self.children]
|
|
289
287
|
if self.name is StringFunction.Name.Lowercase:
|
|
290
288
|
(column,) = columns
|
|
291
289
|
return Column(plc.strings.case.to_lower(column.obj))
|
|
@@ -298,7 +296,7 @@ class StringFunction(Expr):
|
|
|
298
296
|
plc.strings.find.ends_with(
|
|
299
297
|
column.obj,
|
|
300
298
|
suffix.obj_scalar
|
|
301
|
-
if column.
|
|
299
|
+
if column.size != suffix.size and suffix.is_scalar
|
|
302
300
|
else suffix.obj,
|
|
303
301
|
)
|
|
304
302
|
)
|
|
@@ -308,14 +306,14 @@ class StringFunction(Expr):
|
|
|
308
306
|
plc.strings.find.starts_with(
|
|
309
307
|
column.obj,
|
|
310
308
|
prefix.obj_scalar
|
|
311
|
-
if column.
|
|
309
|
+
if column.size != prefix.size and prefix.is_scalar
|
|
312
310
|
else prefix.obj,
|
|
313
311
|
)
|
|
314
312
|
)
|
|
315
313
|
elif self.name is StringFunction.Name.Strptime:
|
|
316
314
|
# TODO: ignores ambiguous
|
|
317
315
|
format, strict, exact, cache = self.options
|
|
318
|
-
col = self.children[0].evaluate(df, context=context
|
|
316
|
+
col = self.children[0].evaluate(df, context=context)
|
|
319
317
|
|
|
320
318
|
is_timestamps = plc.strings.convert.convert_datetime.is_timestamp(
|
|
321
319
|
col.obj, format
|
|
@@ -334,8 +332,7 @@ class StringFunction(Expr):
|
|
|
334
332
|
not_timestamps = plc.unary.unary_operation(
|
|
335
333
|
is_timestamps, plc.unary.UnaryOperator.NOT
|
|
336
334
|
)
|
|
337
|
-
|
|
338
|
-
null = plc.interop.from_arrow(pa.scalar(None, type=pa.string()))
|
|
335
|
+
null = plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING))
|
|
339
336
|
res = plc.copying.boolean_mask_scatter(
|
|
340
337
|
[null], plc.Table([col.obj]), not_timestamps
|
|
341
338
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
# TODO: remove need for this
|
|
4
4
|
# ruff: noqa: D101
|
|
@@ -17,8 +17,6 @@ from cudf_polars.dsl.expressions.base import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
|
-
from collections.abc import Mapping
|
|
21
|
-
|
|
22
20
|
from cudf_polars.containers import DataFrame
|
|
23
21
|
|
|
24
22
|
|
|
@@ -37,16 +35,11 @@ class Ternary(Expr):
|
|
|
37
35
|
self.is_pointwise = True
|
|
38
36
|
|
|
39
37
|
def do_evaluate(
|
|
40
|
-
self,
|
|
41
|
-
df: DataFrame,
|
|
42
|
-
*,
|
|
43
|
-
context: ExecutionContext = ExecutionContext.FRAME,
|
|
44
|
-
mapping: Mapping[Expr, Column] | None = None,
|
|
38
|
+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
|
|
45
39
|
) -> Column:
|
|
46
40
|
"""Evaluate this expression given a dataframe for context."""
|
|
47
41
|
when, then, otherwise = (
|
|
48
|
-
child.evaluate(df, context=context
|
|
49
|
-
for child in self.children
|
|
42
|
+
child.evaluate(df, context=context) for child in self.children
|
|
50
43
|
)
|
|
51
44
|
then_obj = then.obj_scalar if then.is_scalar else then.obj
|
|
52
45
|
otherwise_obj = otherwise.obj_scalar if otherwise.is_scalar else otherwise.obj
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
# TODO: remove need for this
|
|
4
4
|
"""DSL nodes for unary operations."""
|
|
@@ -7,18 +7,15 @@ from __future__ import annotations
|
|
|
7
7
|
|
|
8
8
|
from typing import TYPE_CHECKING, Any, ClassVar
|
|
9
9
|
|
|
10
|
-
import pyarrow as pa
|
|
11
|
-
|
|
12
10
|
import pylibcudf as plc
|
|
13
11
|
|
|
14
12
|
from cudf_polars.containers import Column
|
|
15
|
-
from cudf_polars.dsl.expressions.base import
|
|
13
|
+
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
|
|
16
14
|
from cudf_polars.dsl.expressions.literal import Literal
|
|
17
15
|
from cudf_polars.utils import dtypes
|
|
16
|
+
from cudf_polars.utils.versions import POLARS_VERSION_LT_128
|
|
18
17
|
|
|
19
18
|
if TYPE_CHECKING:
|
|
20
|
-
from collections.abc import Mapping
|
|
21
|
-
|
|
22
19
|
from cudf_polars.containers import DataFrame
|
|
23
20
|
|
|
24
21
|
__all__ = ["Cast", "Len", "UnaryFunction"]
|
|
@@ -40,23 +37,13 @@ class Cast(Expr):
|
|
|
40
37
|
)
|
|
41
38
|
|
|
42
39
|
def do_evaluate(
|
|
43
|
-
self,
|
|
44
|
-
df: DataFrame,
|
|
45
|
-
*,
|
|
46
|
-
context: ExecutionContext = ExecutionContext.FRAME,
|
|
47
|
-
mapping: Mapping[Expr, Column] | None = None,
|
|
40
|
+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
|
|
48
41
|
) -> Column:
|
|
49
42
|
"""Evaluate this expression given a dataframe for context."""
|
|
50
43
|
(child,) = self.children
|
|
51
|
-
column = child.evaluate(df, context=context
|
|
44
|
+
column = child.evaluate(df, context=context)
|
|
52
45
|
return column.astype(self.dtype)
|
|
53
46
|
|
|
54
|
-
def collect_agg(self, *, depth: int) -> AggInfo:
|
|
55
|
-
"""Collect information about aggregations in groupbys."""
|
|
56
|
-
# TODO: Could do with sort-based groupby and segmented filter
|
|
57
|
-
(child,) = self.children
|
|
58
|
-
return child.collect_agg(depth=depth)
|
|
59
|
-
|
|
60
47
|
|
|
61
48
|
class Len(Expr):
|
|
62
49
|
"""Class representing the length of an expression."""
|
|
@@ -67,28 +54,19 @@ class Len(Expr):
|
|
|
67
54
|
self.is_pointwise = False
|
|
68
55
|
|
|
69
56
|
def do_evaluate(
|
|
70
|
-
self,
|
|
71
|
-
df: DataFrame,
|
|
72
|
-
*,
|
|
73
|
-
context: ExecutionContext = ExecutionContext.FRAME,
|
|
74
|
-
mapping: Mapping[Expr, Column] | None = None,
|
|
57
|
+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
|
|
75
58
|
) -> Column:
|
|
76
59
|
"""Evaluate this expression given a dataframe for context."""
|
|
77
60
|
return Column(
|
|
78
61
|
plc.Column.from_scalar(
|
|
79
|
-
plc.
|
|
80
|
-
pa.scalar(df.num_rows, type=plc.interop.to_arrow(self.dtype))
|
|
81
|
-
),
|
|
62
|
+
plc.Scalar.from_py(df.num_rows, self.dtype),
|
|
82
63
|
1,
|
|
83
64
|
)
|
|
84
65
|
)
|
|
85
66
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
return AggInfo(
|
|
90
|
-
[(None, plc.aggregation.count(plc.types.NullPolicy.INCLUDE), self)]
|
|
91
|
-
)
|
|
67
|
+
@property
|
|
68
|
+
def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102
|
|
69
|
+
return plc.aggregation.count(plc.types.NullPolicy.INCLUDE)
|
|
92
70
|
|
|
93
71
|
|
|
94
72
|
class UnaryFunction(Expr):
|
|
@@ -119,6 +97,7 @@ class UnaryFunction(Expr):
|
|
|
119
97
|
"abs": plc.unary.UnaryOperator.ABS,
|
|
120
98
|
"bit_invert": plc.unary.UnaryOperator.BIT_INVERT,
|
|
121
99
|
"not": plc.unary.UnaryOperator.NOT,
|
|
100
|
+
"negate": plc.unary.UnaryOperator.NEGATE,
|
|
122
101
|
}
|
|
123
102
|
_supported_misc_fns = frozenset(
|
|
124
103
|
{
|
|
@@ -168,22 +147,15 @@ class UnaryFunction(Expr):
|
|
|
168
147
|
)
|
|
169
148
|
|
|
170
149
|
def do_evaluate(
|
|
171
|
-
self,
|
|
172
|
-
df: DataFrame,
|
|
173
|
-
*,
|
|
174
|
-
context: ExecutionContext = ExecutionContext.FRAME,
|
|
175
|
-
mapping: Mapping[Expr, Column] | None = None,
|
|
150
|
+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
|
|
176
151
|
) -> Column:
|
|
177
152
|
"""Evaluate this expression given a dataframe for context."""
|
|
178
153
|
if self.name == "mask_nans":
|
|
179
154
|
(child,) = self.children
|
|
180
|
-
return child.evaluate(df, context=context
|
|
155
|
+
return child.evaluate(df, context=context).mask_nans()
|
|
181
156
|
if self.name == "round":
|
|
182
157
|
(decimal_places,) = self.options
|
|
183
|
-
(values,) = (
|
|
184
|
-
child.evaluate(df, context=context, mapping=mapping)
|
|
185
|
-
for child in self.children
|
|
186
|
-
)
|
|
158
|
+
(values,) = (child.evaluate(df, context=context) for child in self.children)
|
|
187
159
|
return Column(
|
|
188
160
|
plc.round.round(
|
|
189
161
|
values.obj, decimal_places, plc.round.RoundingMethod.HALF_UP
|
|
@@ -191,10 +163,7 @@ class UnaryFunction(Expr):
|
|
|
191
163
|
).sorted_like(values)
|
|
192
164
|
elif self.name == "unique":
|
|
193
165
|
(maintain_order,) = self.options
|
|
194
|
-
(values,) = (
|
|
195
|
-
child.evaluate(df, context=context, mapping=mapping)
|
|
196
|
-
for child in self.children
|
|
197
|
-
)
|
|
166
|
+
(values,) = (child.evaluate(df, context=context) for child in self.children)
|
|
198
167
|
# Only one column, so keep_any is the same as keep_first
|
|
199
168
|
# for stable distinct
|
|
200
169
|
keep = plc.stream_compaction.DuplicateKeepOption.KEEP_ANY
|
|
@@ -224,10 +193,7 @@ class UnaryFunction(Expr):
|
|
|
224
193
|
return Column(column).sorted_like(values)
|
|
225
194
|
return Column(column)
|
|
226
195
|
elif self.name == "set_sorted":
|
|
227
|
-
(column,) = (
|
|
228
|
-
child.evaluate(df, context=context, mapping=mapping)
|
|
229
|
-
for child in self.children
|
|
230
|
-
)
|
|
196
|
+
(column,) = (child.evaluate(df, context=context) for child in self.children)
|
|
231
197
|
(asc,) = self.options
|
|
232
198
|
order = (
|
|
233
199
|
plc.types.Order.ASCENDING
|
|
@@ -235,7 +201,7 @@ class UnaryFunction(Expr):
|
|
|
235
201
|
else plc.types.Order.DESCENDING
|
|
236
202
|
)
|
|
237
203
|
null_order = plc.types.NullOrder.BEFORE
|
|
238
|
-
if column.
|
|
204
|
+
if column.null_count > 0 and (n := column.size) > 1:
|
|
239
205
|
# PERF: This invokes four stream synchronisations!
|
|
240
206
|
has_nulls_first = not plc.copying.get_element(column.obj, 0).is_valid()
|
|
241
207
|
has_nulls_last = not plc.copying.get_element(
|
|
@@ -251,34 +217,41 @@ class UnaryFunction(Expr):
|
|
|
251
217
|
null_order=null_order,
|
|
252
218
|
)
|
|
253
219
|
elif self.name == "drop_nulls":
|
|
254
|
-
(column,) = (
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
)
|
|
220
|
+
(column,) = (child.evaluate(df, context=context) for child in self.children)
|
|
221
|
+
if column.null_count == 0:
|
|
222
|
+
return column
|
|
258
223
|
return Column(
|
|
259
224
|
plc.stream_compaction.drop_nulls(
|
|
260
225
|
plc.Table([column.obj]), [0], 1
|
|
261
226
|
).columns()[0]
|
|
262
227
|
)
|
|
263
228
|
elif self.name == "fill_null":
|
|
264
|
-
column = self.children[0].evaluate(df, context=context
|
|
229
|
+
column = self.children[0].evaluate(df, context=context)
|
|
230
|
+
if column.null_count == 0:
|
|
231
|
+
return column
|
|
265
232
|
if isinstance(self.children[1], Literal):
|
|
266
|
-
arg = plc.
|
|
233
|
+
arg = plc.Scalar.from_py(self.children[1].value, self.children[1].dtype)
|
|
267
234
|
else:
|
|
268
|
-
evaluated = self.children[1].evaluate(
|
|
269
|
-
df, context=context, mapping=mapping
|
|
270
|
-
)
|
|
235
|
+
evaluated = self.children[1].evaluate(df, context=context)
|
|
271
236
|
arg = evaluated.obj_scalar if evaluated.is_scalar else evaluated.obj
|
|
237
|
+
if (
|
|
238
|
+
not POLARS_VERSION_LT_128
|
|
239
|
+
and isinstance(arg, plc.Scalar)
|
|
240
|
+
and dtypes.can_cast(column.obj.type(), arg.type())
|
|
241
|
+
): # pragma: no cover
|
|
242
|
+
arg = plc.unary.cast(
|
|
243
|
+
plc.Column.from_scalar(arg, 1), column.obj.type()
|
|
244
|
+
).to_scalar()
|
|
272
245
|
return Column(plc.replace.replace_nulls(column.obj, arg))
|
|
273
246
|
elif self.name in self._OP_MAPPING:
|
|
274
|
-
column = self.children[0].evaluate(df, context=context
|
|
247
|
+
column = self.children[0].evaluate(df, context=context)
|
|
275
248
|
if column.obj.type().id() != self.dtype.id():
|
|
276
249
|
arg = plc.unary.cast(column.obj, self.dtype)
|
|
277
250
|
else:
|
|
278
251
|
arg = column.obj
|
|
279
252
|
return Column(plc.unary.unary_operation(arg, self._OP_MAPPING[self.name]))
|
|
280
253
|
elif self.name in UnaryFunction._supported_cum_aggs:
|
|
281
|
-
column = self.children[0].evaluate(df, context=context
|
|
254
|
+
column = self.children[0].evaluate(df, context=context)
|
|
282
255
|
plc_col = column.obj
|
|
283
256
|
col_type = column.obj.type()
|
|
284
257
|
# cum_sum casts
|
|
@@ -324,16 +297,3 @@ class UnaryFunction(Expr):
|
|
|
324
297
|
raise NotImplementedError(
|
|
325
298
|
f"Unimplemented unary function {self.name=}"
|
|
326
299
|
) # pragma: no cover; init trips first
|
|
327
|
-
|
|
328
|
-
def collect_agg(self, *, depth: int) -> AggInfo:
|
|
329
|
-
"""Collect information about aggregations in groupbys."""
|
|
330
|
-
if self.name in {"unique", "drop_nulls"} | self._supported_cum_aggs:
|
|
331
|
-
raise NotImplementedError(f"{self.name} in groupby")
|
|
332
|
-
if depth == 1:
|
|
333
|
-
# inside aggregation, need to pre-evaluate, groupby
|
|
334
|
-
# construction has checked that we don't have nested aggs,
|
|
335
|
-
# so stop the recursion and return ourselves for pre-eval
|
|
336
|
-
return AggInfo([(self, plc.aggregation.collect_list(), self)])
|
|
337
|
-
else:
|
|
338
|
-
(child,) = self.children
|
|
339
|
-
return child.collect_agg(depth=depth)
|