cudf-polars-cu13 25.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -0
- cudf_polars/VERSION +1 -0
- cudf_polars/__init__.py +28 -0
- cudf_polars/_version.py +21 -0
- cudf_polars/callback.py +318 -0
- cudf_polars/containers/__init__.py +13 -0
- cudf_polars/containers/column.py +495 -0
- cudf_polars/containers/dataframe.py +361 -0
- cudf_polars/containers/datatype.py +137 -0
- cudf_polars/dsl/__init__.py +8 -0
- cudf_polars/dsl/expr.py +66 -0
- cudf_polars/dsl/expressions/__init__.py +8 -0
- cudf_polars/dsl/expressions/aggregation.py +226 -0
- cudf_polars/dsl/expressions/base.py +272 -0
- cudf_polars/dsl/expressions/binaryop.py +120 -0
- cudf_polars/dsl/expressions/boolean.py +326 -0
- cudf_polars/dsl/expressions/datetime.py +271 -0
- cudf_polars/dsl/expressions/literal.py +97 -0
- cudf_polars/dsl/expressions/rolling.py +643 -0
- cudf_polars/dsl/expressions/selection.py +74 -0
- cudf_polars/dsl/expressions/slicing.py +46 -0
- cudf_polars/dsl/expressions/sorting.py +85 -0
- cudf_polars/dsl/expressions/string.py +1002 -0
- cudf_polars/dsl/expressions/struct.py +137 -0
- cudf_polars/dsl/expressions/ternary.py +49 -0
- cudf_polars/dsl/expressions/unary.py +517 -0
- cudf_polars/dsl/ir.py +2607 -0
- cudf_polars/dsl/nodebase.py +164 -0
- cudf_polars/dsl/to_ast.py +359 -0
- cudf_polars/dsl/tracing.py +16 -0
- cudf_polars/dsl/translate.py +939 -0
- cudf_polars/dsl/traversal.py +224 -0
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +481 -0
- cudf_polars/dsl/utils/groupby.py +98 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +61 -0
- cudf_polars/dsl/utils/reshape.py +74 -0
- cudf_polars/dsl/utils/rolling.py +121 -0
- cudf_polars/dsl/utils/windows.py +192 -0
- cudf_polars/experimental/__init__.py +8 -0
- cudf_polars/experimental/base.py +386 -0
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds.py +220 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
- cudf_polars/experimental/benchmarks/pdsh.py +814 -0
- cudf_polars/experimental/benchmarks/utils.py +832 -0
- cudf_polars/experimental/dask_registers.py +200 -0
- cudf_polars/experimental/dispatch.py +156 -0
- cudf_polars/experimental/distinct.py +197 -0
- cudf_polars/experimental/explain.py +157 -0
- cudf_polars/experimental/expressions.py +590 -0
- cudf_polars/experimental/groupby.py +327 -0
- cudf_polars/experimental/io.py +943 -0
- cudf_polars/experimental/join.py +391 -0
- cudf_polars/experimental/parallel.py +423 -0
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +188 -0
- cudf_polars/experimental/shuffle.py +354 -0
- cudf_polars/experimental/sort.py +609 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/statistics.py +795 -0
- cudf_polars/experimental/utils.py +169 -0
- cudf_polars/py.typed +0 -0
- cudf_polars/testing/__init__.py +8 -0
- cudf_polars/testing/asserts.py +448 -0
- cudf_polars/testing/io.py +122 -0
- cudf_polars/testing/plugin.py +236 -0
- cudf_polars/typing/__init__.py +219 -0
- cudf_polars/utils/__init__.py +8 -0
- cudf_polars/utils/config.py +741 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +118 -0
- cudf_polars/utils/sorting.py +53 -0
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +27 -0
- cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
- cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
- cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
- cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
- cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,939 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Translate polars IR representation to ours."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import functools
|
|
9
|
+
import json
|
|
10
|
+
from contextlib import AbstractContextManager, nullcontext
|
|
11
|
+
from functools import singledispatch
|
|
12
|
+
from typing import TYPE_CHECKING, Any
|
|
13
|
+
|
|
14
|
+
from typing_extensions import assert_never
|
|
15
|
+
|
|
16
|
+
import polars as pl
|
|
17
|
+
import polars.polars as plrs
|
|
18
|
+
from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir
|
|
19
|
+
|
|
20
|
+
import pylibcudf as plc
|
|
21
|
+
|
|
22
|
+
from cudf_polars.containers import DataType
|
|
23
|
+
from cudf_polars.dsl import expr, ir
|
|
24
|
+
from cudf_polars.dsl.expressions.base import ExecutionContext
|
|
25
|
+
from cudf_polars.dsl.to_ast import insert_colrefs
|
|
26
|
+
from cudf_polars.dsl.utils.aggregations import decompose_single_agg
|
|
27
|
+
from cudf_polars.dsl.utils.groupby import rewrite_groupby
|
|
28
|
+
from cudf_polars.dsl.utils.naming import unique_names
|
|
29
|
+
from cudf_polars.dsl.utils.replace import replace
|
|
30
|
+
from cudf_polars.dsl.utils.rolling import rewrite_rolling
|
|
31
|
+
from cudf_polars.typing import Schema
|
|
32
|
+
from cudf_polars.utils import config, sorting
|
|
33
|
+
from cudf_polars.utils.versions import (
|
|
34
|
+
POLARS_VERSION_LT_131,
|
|
35
|
+
POLARS_VERSION_LT_132,
|
|
36
|
+
POLARS_VERSION_LT_1323,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
from polars import GPUEngine
|
|
41
|
+
|
|
42
|
+
from cudf_polars.typing import NodeTraverser
|
|
43
|
+
|
|
44
|
+
__all__ = ["Translator", "translate_named_expr"]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Translator:
|
|
48
|
+
"""
|
|
49
|
+
Translates polars-internal IR nodes and expressions to our representation.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
visitor
|
|
54
|
+
Polars NodeTraverser object
|
|
55
|
+
engine
|
|
56
|
+
GPU engine configuration.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, visitor: NodeTraverser, engine: GPUEngine):
|
|
60
|
+
self.visitor = visitor
|
|
61
|
+
self.config_options = config.ConfigOptions.from_polars_engine(engine)
|
|
62
|
+
self.errors: list[Exception] = []
|
|
63
|
+
self._cache_nodes: dict[int, ir.Cache] = {}
|
|
64
|
+
|
|
65
|
+
def translate_ir(self, *, n: int | None = None) -> ir.IR:
|
|
66
|
+
"""
|
|
67
|
+
Translate a polars-internal IR node to our representation.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
visitor
|
|
72
|
+
Polars NodeTraverser object
|
|
73
|
+
n
|
|
74
|
+
Optional node to start traversing from, if not provided uses
|
|
75
|
+
current polars-internal node.
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
Translated IR object
|
|
80
|
+
|
|
81
|
+
Raises
|
|
82
|
+
------
|
|
83
|
+
NotImplementedError
|
|
84
|
+
If the version of Polars IR is unsupported.
|
|
85
|
+
|
|
86
|
+
Notes
|
|
87
|
+
-----
|
|
88
|
+
Any expression nodes that cannot be translated are replaced by
|
|
89
|
+
:class:`expr.ErrorNode` nodes and collected in the the `errors` attribute.
|
|
90
|
+
After translation is complete, this list of errors should be inspected
|
|
91
|
+
to determine if the query is supported.
|
|
92
|
+
"""
|
|
93
|
+
ctx: AbstractContextManager[None] = (
|
|
94
|
+
set_node(self.visitor, n) if n is not None else noop_context
|
|
95
|
+
)
|
|
96
|
+
# IR is versioned with major.minor, minor is bumped for backwards
|
|
97
|
+
# compatible changes (e.g. adding new nodes), major is bumped for
|
|
98
|
+
# incompatible changes (e.g. renaming nodes).
|
|
99
|
+
if (version := self.visitor.version()) >= (10, 1):
|
|
100
|
+
e = NotImplementedError(
|
|
101
|
+
f"No support for polars IR {version=}"
|
|
102
|
+
) # pragma: no cover; no such version for now.
|
|
103
|
+
self.errors.append(e) # pragma: no cover
|
|
104
|
+
raise e # pragma: no cover
|
|
105
|
+
|
|
106
|
+
with ctx:
|
|
107
|
+
polars_schema = self.visitor.get_schema()
|
|
108
|
+
try:
|
|
109
|
+
schema = {k: DataType(v) for k, v in polars_schema.items()}
|
|
110
|
+
except Exception as e:
|
|
111
|
+
self.errors.append(NotImplementedError(str(e)))
|
|
112
|
+
return ir.ErrorNode({}, str(e))
|
|
113
|
+
try:
|
|
114
|
+
node = self.visitor.view_current_node()
|
|
115
|
+
except Exception as e:
|
|
116
|
+
self.errors.append(e)
|
|
117
|
+
return ir.ErrorNode(schema, str(e))
|
|
118
|
+
try:
|
|
119
|
+
result = _translate_ir(node, self, schema)
|
|
120
|
+
except Exception as e:
|
|
121
|
+
self.errors.append(e)
|
|
122
|
+
return ir.ErrorNode(schema, str(e))
|
|
123
|
+
if any(
|
|
124
|
+
isinstance(dtype, pl.Null)
|
|
125
|
+
for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values())
|
|
126
|
+
):
|
|
127
|
+
error = NotImplementedError(
|
|
128
|
+
f"No GPU support for {result} with Null column dtype."
|
|
129
|
+
)
|
|
130
|
+
self.errors.append(error)
|
|
131
|
+
return ir.ErrorNode(schema, str(error))
|
|
132
|
+
|
|
133
|
+
return result
|
|
134
|
+
|
|
135
|
+
def translate_expr(self, *, n: int, schema: Schema) -> expr.Expr:
|
|
136
|
+
"""
|
|
137
|
+
Translate a polars-internal expression IR into our representation.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
n
|
|
142
|
+
Node to translate, an integer referencing a polars internal node.
|
|
143
|
+
schema
|
|
144
|
+
Schema of the IR node this expression uses as evaluation context.
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
Translated IR object.
|
|
149
|
+
|
|
150
|
+
Notes
|
|
151
|
+
-----
|
|
152
|
+
Any expression nodes that cannot be translated are replaced by
|
|
153
|
+
:class:`expr.ErrorExpr` nodes and collected in the the `errors` attribute.
|
|
154
|
+
After translation is complete, this list of errors should be inspected
|
|
155
|
+
to determine if the query is supported.
|
|
156
|
+
"""
|
|
157
|
+
node = self.visitor.view_expression(n)
|
|
158
|
+
dtype = DataType(self.visitor.get_dtype(n))
|
|
159
|
+
try:
|
|
160
|
+
return _translate_expr(node, self, dtype, schema)
|
|
161
|
+
except Exception as e:
|
|
162
|
+
self.errors.append(e)
|
|
163
|
+
return expr.ErrorExpr(dtype, str(e))
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class set_node(AbstractContextManager[None]):
|
|
167
|
+
"""
|
|
168
|
+
Run a block with current node set in the visitor.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
visitor
|
|
173
|
+
The internal Rust visitor object
|
|
174
|
+
n
|
|
175
|
+
The node to set as the current root.
|
|
176
|
+
|
|
177
|
+
Notes
|
|
178
|
+
-----
|
|
179
|
+
This is useful for translating expressions with a given node
|
|
180
|
+
active, restoring the node when the block exits.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
__slots__ = ("n", "visitor")
|
|
184
|
+
visitor: NodeTraverser
|
|
185
|
+
|
|
186
|
+
n: int
|
|
187
|
+
|
|
188
|
+
def __init__(self, visitor: NodeTraverser, n: int) -> None:
|
|
189
|
+
self.visitor = visitor
|
|
190
|
+
self.n = n
|
|
191
|
+
|
|
192
|
+
def __enter__(self) -> None:
|
|
193
|
+
n = self.visitor.get_node()
|
|
194
|
+
self.visitor.set_node(self.n)
|
|
195
|
+
self.n = n
|
|
196
|
+
|
|
197
|
+
def __exit__(self, *args: Any) -> None:
|
|
198
|
+
self.visitor.set_node(self.n)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
noop_context: nullcontext[None] = nullcontext()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@singledispatch
|
|
205
|
+
def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR:
|
|
206
|
+
raise NotImplementedError(
|
|
207
|
+
f"Translation for {type(node).__name__}"
|
|
208
|
+
) # pragma: no cover
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@_translate_ir.register
|
|
212
|
+
def _(node: pl_ir.PythonScan, translator: Translator, schema: Schema) -> ir.IR:
|
|
213
|
+
scan_fn, with_columns, source_type, predicate, nrows = node.options
|
|
214
|
+
options = (scan_fn, with_columns, source_type, nrows)
|
|
215
|
+
predicate = (
|
|
216
|
+
translate_named_expr(translator, n=predicate, schema=schema)
|
|
217
|
+
if predicate is not None
|
|
218
|
+
else None
|
|
219
|
+
)
|
|
220
|
+
return ir.PythonScan(schema, options, predicate)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@_translate_ir.register
|
|
224
|
+
def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR:
|
|
225
|
+
typ, *options = node.scan_type
|
|
226
|
+
paths = node.paths
|
|
227
|
+
# Polars can produce a Scan with an empty ``node.paths`` (eg. the native
|
|
228
|
+
# Iceberg reader on a table with no data files yet). In this case, polars returns an
|
|
229
|
+
# empty DataFrame with the declared schema. Mirror that here by
|
|
230
|
+
# replacing the Scan with an Empty IR node.
|
|
231
|
+
if not paths: # pragma: no cover
|
|
232
|
+
return ir.Empty(schema)
|
|
233
|
+
if typ == "ndjson":
|
|
234
|
+
(reader_options,) = map(json.loads, options)
|
|
235
|
+
cloud_options = None
|
|
236
|
+
else:
|
|
237
|
+
reader_options, cloud_options = map(json.loads, options)
|
|
238
|
+
file_options = node.file_options
|
|
239
|
+
with_columns = file_options.with_columns
|
|
240
|
+
row_index = file_options.row_index
|
|
241
|
+
include_file_paths = file_options.include_file_paths
|
|
242
|
+
if not POLARS_VERSION_LT_131:
|
|
243
|
+
deletion_files = file_options.deletion_files # pragma: no cover
|
|
244
|
+
if deletion_files: # pragma: no cover
|
|
245
|
+
raise NotImplementedError(
|
|
246
|
+
"Iceberg format is not supported in cudf-polars. Furthermore, row-level deletions are not supported."
|
|
247
|
+
) # pragma: no cover
|
|
248
|
+
config_options = translator.config_options
|
|
249
|
+
parquet_options = config_options.parquet_options
|
|
250
|
+
|
|
251
|
+
pre_slice = file_options.n_rows
|
|
252
|
+
if pre_slice is None:
|
|
253
|
+
n_rows = -1
|
|
254
|
+
skip_rows = 0
|
|
255
|
+
else:
|
|
256
|
+
skip_rows, n_rows = pre_slice
|
|
257
|
+
|
|
258
|
+
return ir.Scan(
|
|
259
|
+
schema,
|
|
260
|
+
typ,
|
|
261
|
+
reader_options,
|
|
262
|
+
cloud_options,
|
|
263
|
+
paths,
|
|
264
|
+
with_columns,
|
|
265
|
+
skip_rows,
|
|
266
|
+
n_rows,
|
|
267
|
+
row_index,
|
|
268
|
+
include_file_paths,
|
|
269
|
+
translate_named_expr(translator, n=node.predicate, schema=schema)
|
|
270
|
+
if node.predicate is not None
|
|
271
|
+
else None,
|
|
272
|
+
parquet_options,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@_translate_ir.register
|
|
277
|
+
def _(node: pl_ir.Cache, translator: Translator, schema: Schema) -> ir.IR:
|
|
278
|
+
if POLARS_VERSION_LT_1323: # pragma: no cover
|
|
279
|
+
refcount = node.cache_hits
|
|
280
|
+
else:
|
|
281
|
+
refcount = None
|
|
282
|
+
|
|
283
|
+
# Make sure Cache nodes with the same id_
|
|
284
|
+
# are actually the same object.
|
|
285
|
+
if node.id_ not in translator._cache_nodes:
|
|
286
|
+
translator._cache_nodes[node.id_] = ir.Cache(
|
|
287
|
+
schema,
|
|
288
|
+
node.id_,
|
|
289
|
+
refcount,
|
|
290
|
+
translator.translate_ir(n=node.input),
|
|
291
|
+
)
|
|
292
|
+
return translator._cache_nodes[node.id_]
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@_translate_ir.register
|
|
296
|
+
def _(node: pl_ir.DataFrameScan, translator: Translator, schema: Schema) -> ir.IR:
|
|
297
|
+
return ir.DataFrameScan(
|
|
298
|
+
schema,
|
|
299
|
+
node.df,
|
|
300
|
+
node.projection,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@_translate_ir.register
|
|
305
|
+
def _(node: pl_ir.Select, translator: Translator, schema: Schema) -> ir.IR:
|
|
306
|
+
with set_node(translator.visitor, node.input):
|
|
307
|
+
inp = translator.translate_ir(n=None)
|
|
308
|
+
exprs = [
|
|
309
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.expr
|
|
310
|
+
]
|
|
311
|
+
return ir.Select(schema, exprs, node.should_broadcast, inp)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@_translate_ir.register
|
|
315
|
+
def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
|
|
316
|
+
with set_node(translator.visitor, node.input):
|
|
317
|
+
inp = translator.translate_ir(n=None)
|
|
318
|
+
keys = [
|
|
319
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.keys
|
|
320
|
+
]
|
|
321
|
+
original_aggs = [
|
|
322
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.aggs
|
|
323
|
+
]
|
|
324
|
+
is_rolling = node.options.rolling is not None
|
|
325
|
+
is_dynamic = node.options.dynamic is not None
|
|
326
|
+
if is_dynamic:
|
|
327
|
+
raise NotImplementedError("group_by_dynamic")
|
|
328
|
+
elif is_rolling:
|
|
329
|
+
return rewrite_rolling(
|
|
330
|
+
node.options, schema, keys, original_aggs, translator.config_options, inp
|
|
331
|
+
)
|
|
332
|
+
else:
|
|
333
|
+
return rewrite_groupby(node, schema, keys, original_aggs, inp)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@_translate_ir.register
|
|
337
|
+
def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
|
|
338
|
+
# Join key dtypes are dependent on the schema of the left and
|
|
339
|
+
# right inputs, so these must be translated with the relevant
|
|
340
|
+
# input active.
|
|
341
|
+
with set_node(translator.visitor, node.input_left):
|
|
342
|
+
inp_left = translator.translate_ir(n=None)
|
|
343
|
+
left_on = [
|
|
344
|
+
translate_named_expr(translator, n=e, schema=inp_left.schema)
|
|
345
|
+
for e in node.left_on
|
|
346
|
+
]
|
|
347
|
+
with set_node(translator.visitor, node.input_right):
|
|
348
|
+
inp_right = translator.translate_ir(n=None)
|
|
349
|
+
right_on = [
|
|
350
|
+
translate_named_expr(translator, n=e, schema=inp_right.schema)
|
|
351
|
+
for e in node.right_on
|
|
352
|
+
]
|
|
353
|
+
|
|
354
|
+
if (how := node.options[0]) in {
|
|
355
|
+
"Inner",
|
|
356
|
+
"Left",
|
|
357
|
+
"Right",
|
|
358
|
+
"Full",
|
|
359
|
+
"Cross",
|
|
360
|
+
"Semi",
|
|
361
|
+
"Anti",
|
|
362
|
+
}:
|
|
363
|
+
return ir.Join(
|
|
364
|
+
schema,
|
|
365
|
+
left_on,
|
|
366
|
+
right_on,
|
|
367
|
+
node.options,
|
|
368
|
+
inp_left,
|
|
369
|
+
inp_right,
|
|
370
|
+
)
|
|
371
|
+
else:
|
|
372
|
+
how, op1, op2 = node.options[0]
|
|
373
|
+
if how != "IEJoin":
|
|
374
|
+
raise NotImplementedError(
|
|
375
|
+
f"Unsupported join type {how}"
|
|
376
|
+
) # pragma: no cover; asof joins not yet exposed
|
|
377
|
+
if op2 is None:
|
|
378
|
+
ops = [op1]
|
|
379
|
+
else:
|
|
380
|
+
ops = [op1, op2]
|
|
381
|
+
|
|
382
|
+
dtype = DataType(pl.datatypes.Boolean())
|
|
383
|
+
predicate = functools.reduce(
|
|
384
|
+
functools.partial(
|
|
385
|
+
expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND
|
|
386
|
+
),
|
|
387
|
+
(
|
|
388
|
+
expr.BinOp(
|
|
389
|
+
dtype,
|
|
390
|
+
expr.BinOp._MAPPING[op],
|
|
391
|
+
insert_colrefs(
|
|
392
|
+
left.value,
|
|
393
|
+
table_ref=plc.expressions.TableReference.LEFT,
|
|
394
|
+
name_to_index={
|
|
395
|
+
name: i for i, name in enumerate(inp_left.schema)
|
|
396
|
+
},
|
|
397
|
+
),
|
|
398
|
+
insert_colrefs(
|
|
399
|
+
right.value,
|
|
400
|
+
table_ref=plc.expressions.TableReference.RIGHT,
|
|
401
|
+
name_to_index={
|
|
402
|
+
name: i for i, name in enumerate(inp_right.schema)
|
|
403
|
+
},
|
|
404
|
+
),
|
|
405
|
+
)
|
|
406
|
+
for op, left, right in zip(ops, left_on, right_on, strict=True)
|
|
407
|
+
),
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
return ir.ConditionalJoin(schema, predicate, node.options, inp_left, inp_right)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
@_translate_ir.register
|
|
414
|
+
def _(node: pl_ir.HStack, translator: Translator, schema: Schema) -> ir.IR:
|
|
415
|
+
with set_node(translator.visitor, node.input):
|
|
416
|
+
inp = translator.translate_ir(n=None)
|
|
417
|
+
exprs = [
|
|
418
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.exprs
|
|
419
|
+
]
|
|
420
|
+
return ir.HStack(schema, exprs, node.should_broadcast, inp)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@_translate_ir.register
|
|
424
|
+
def _(
|
|
425
|
+
node: pl_ir.Reduce, translator: Translator, schema: Schema
|
|
426
|
+
) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet
|
|
427
|
+
with set_node(translator.visitor, node.input):
|
|
428
|
+
inp = translator.translate_ir(n=None)
|
|
429
|
+
exprs = [
|
|
430
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.expr
|
|
431
|
+
]
|
|
432
|
+
return ir.Reduce(schema, exprs, inp)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@_translate_ir.register
|
|
436
|
+
def _(node: pl_ir.Distinct, translator: Translator, schema: Schema) -> ir.IR:
|
|
437
|
+
(keep, subset, maintain_order, zlice) = node.options
|
|
438
|
+
keep = ir.Distinct._KEEP_MAP[keep]
|
|
439
|
+
subset = frozenset(subset) if subset is not None else None
|
|
440
|
+
return ir.Distinct(
|
|
441
|
+
schema,
|
|
442
|
+
keep,
|
|
443
|
+
subset,
|
|
444
|
+
zlice,
|
|
445
|
+
maintain_order,
|
|
446
|
+
translator.translate_ir(n=node.input),
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
@_translate_ir.register
|
|
451
|
+
def _(node: pl_ir.Sort, translator: Translator, schema: Schema) -> ir.IR:
|
|
452
|
+
with set_node(translator.visitor, node.input):
|
|
453
|
+
inp = translator.translate_ir(n=None)
|
|
454
|
+
by = [
|
|
455
|
+
translate_named_expr(translator, n=e, schema=inp.schema)
|
|
456
|
+
for e in node.by_column
|
|
457
|
+
]
|
|
458
|
+
stable, nulls_last, descending = node.sort_options
|
|
459
|
+
order, null_order = sorting.sort_order(
|
|
460
|
+
descending, nulls_last=nulls_last, num_keys=len(by)
|
|
461
|
+
)
|
|
462
|
+
return ir.Sort(schema, by, order, null_order, stable, node.slice, inp)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
@_translate_ir.register
|
|
466
|
+
def _(node: pl_ir.Slice, translator: Translator, schema: Schema) -> ir.IR:
|
|
467
|
+
return ir.Slice(
|
|
468
|
+
schema, node.offset, node.len, translator.translate_ir(n=node.input)
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
@_translate_ir.register
|
|
473
|
+
def _(node: pl_ir.Filter, translator: Translator, schema: Schema) -> ir.IR:
|
|
474
|
+
with set_node(translator.visitor, node.input):
|
|
475
|
+
inp = translator.translate_ir(n=None)
|
|
476
|
+
mask = translate_named_expr(translator, n=node.predicate, schema=inp.schema)
|
|
477
|
+
return ir.Filter(schema, mask, inp)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
@_translate_ir.register
|
|
481
|
+
def _(node: pl_ir.SimpleProjection, translator: Translator, schema: Schema) -> ir.IR:
|
|
482
|
+
return ir.Projection(schema, translator.translate_ir(n=node.input))
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@_translate_ir.register
|
|
486
|
+
def _(node: pl_ir.MergeSorted, translator: Translator, schema: Schema) -> ir.IR:
|
|
487
|
+
key = node.key
|
|
488
|
+
inp_left = translator.translate_ir(n=node.input_left)
|
|
489
|
+
inp_right = translator.translate_ir(n=node.input_right)
|
|
490
|
+
return ir.MergeSorted(
|
|
491
|
+
schema,
|
|
492
|
+
key,
|
|
493
|
+
inp_left,
|
|
494
|
+
inp_right,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
@_translate_ir.register
|
|
499
|
+
def _(node: pl_ir.MapFunction, translator: Translator, schema: Schema) -> ir.IR:
|
|
500
|
+
name, *options = node.function
|
|
501
|
+
return ir.MapFunction(
|
|
502
|
+
schema,
|
|
503
|
+
name,
|
|
504
|
+
options,
|
|
505
|
+
translator.translate_ir(n=node.input),
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
@_translate_ir.register
|
|
510
|
+
def _(node: pl_ir.Union, translator: Translator, schema: Schema) -> ir.IR:
|
|
511
|
+
return ir.Union(
|
|
512
|
+
schema, node.options, *(translator.translate_ir(n=n) for n in node.inputs)
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
@_translate_ir.register
|
|
517
|
+
def _(node: pl_ir.HConcat, translator: Translator, schema: Schema) -> ir.IR:
|
|
518
|
+
return ir.HConcat(
|
|
519
|
+
schema,
|
|
520
|
+
False, # noqa: FBT003
|
|
521
|
+
*(translator.translate_ir(n=n) for n in node.inputs),
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
@_translate_ir.register
|
|
526
|
+
def _(node: pl_ir.Sink, translator: Translator, schema: Schema) -> ir.IR:
|
|
527
|
+
payload = json.loads(node.payload)
|
|
528
|
+
try:
|
|
529
|
+
file = payload["File"]
|
|
530
|
+
sink_kind_options = file["file_type"]
|
|
531
|
+
except KeyError as err: # pragma: no cover
|
|
532
|
+
raise NotImplementedError("Unsupported payload structure") from err
|
|
533
|
+
if isinstance(sink_kind_options, dict):
|
|
534
|
+
if len(sink_kind_options) != 1: # pragma: no cover; not sure if this can happen
|
|
535
|
+
raise NotImplementedError("Sink options dict with more than one entry.")
|
|
536
|
+
sink_kind, options = next(iter(sink_kind_options.items()))
|
|
537
|
+
else:
|
|
538
|
+
raise NotImplementedError(
|
|
539
|
+
"Unsupported sink options structure"
|
|
540
|
+
) # pragma: no cover
|
|
541
|
+
|
|
542
|
+
sink_options = file.get("sink_options", {})
|
|
543
|
+
cloud_options = file.get("cloud_options")
|
|
544
|
+
|
|
545
|
+
options.update(sink_options)
|
|
546
|
+
|
|
547
|
+
return ir.Sink(
|
|
548
|
+
schema=schema,
|
|
549
|
+
kind=sink_kind,
|
|
550
|
+
path=file["target"] if POLARS_VERSION_LT_132 else file["target"]["Local"],
|
|
551
|
+
parquet_options=translator.config_options.parquet_options,
|
|
552
|
+
options=options,
|
|
553
|
+
cloud_options=cloud_options,
|
|
554
|
+
df=translator.translate_ir(n=node.input),
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def translate_named_expr(
|
|
559
|
+
translator: Translator, *, n: pl_expr.PyExprIR, schema: Schema
|
|
560
|
+
) -> expr.NamedExpr:
|
|
561
|
+
"""
|
|
562
|
+
Translate a polars-internal named expression IR object into our representation.
|
|
563
|
+
|
|
564
|
+
Parameters
|
|
565
|
+
----------
|
|
566
|
+
translator
|
|
567
|
+
Translator object
|
|
568
|
+
n
|
|
569
|
+
Node to translate, a named expression node.
|
|
570
|
+
schema
|
|
571
|
+
Schema of the IR node this expression uses as evaluation context.
|
|
572
|
+
|
|
573
|
+
Returns
|
|
574
|
+
-------
|
|
575
|
+
Translated IR object.
|
|
576
|
+
|
|
577
|
+
Notes
|
|
578
|
+
-----
|
|
579
|
+
The datatype of the internal expression will be obtained from the
|
|
580
|
+
visitor by calling ``get_dtype``, for this to work properly, the
|
|
581
|
+
caller should arrange that the expression is translated with the
|
|
582
|
+
node that it references "active" for the visitor (see :class:`set_node`).
|
|
583
|
+
|
|
584
|
+
Raises
|
|
585
|
+
------
|
|
586
|
+
NotImplementedError
|
|
587
|
+
If any translation fails due to unsupported functionality.
|
|
588
|
+
"""
|
|
589
|
+
return expr.NamedExpr(
|
|
590
|
+
n.output_name, translator.translate_expr(n=n.node, schema=schema)
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
@singledispatch
|
|
595
|
+
def _translate_expr(
|
|
596
|
+
node: Any, translator: Translator, dtype: DataType, schema: Schema
|
|
597
|
+
) -> expr.Expr:
|
|
598
|
+
raise NotImplementedError(
|
|
599
|
+
f"Translation for {type(node).__name__}"
|
|
600
|
+
) # pragma: no cover
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
@_translate_expr.register
|
|
604
|
+
def _(
|
|
605
|
+
node: pl_expr.Function, translator: Translator, dtype: DataType, schema: Schema
|
|
606
|
+
) -> expr.Expr:
|
|
607
|
+
name, *options = node.function_data
|
|
608
|
+
options = tuple(options)
|
|
609
|
+
if isinstance(name, pl_expr.StringFunction):
|
|
610
|
+
if name in {
|
|
611
|
+
pl_expr.StringFunction.StripChars,
|
|
612
|
+
pl_expr.StringFunction.StripCharsStart,
|
|
613
|
+
pl_expr.StringFunction.StripCharsEnd,
|
|
614
|
+
}:
|
|
615
|
+
column, chars = (
|
|
616
|
+
translator.translate_expr(n=n, schema=schema) for n in node.input
|
|
617
|
+
)
|
|
618
|
+
if isinstance(chars, expr.Literal):
|
|
619
|
+
# We check for null first because we want to use the
|
|
620
|
+
# chars type, but it is invalid to try and
|
|
621
|
+
# produce a string scalar with a null dtype.
|
|
622
|
+
if chars.value is None:
|
|
623
|
+
# Polars uses None to mean "strip all whitespace"
|
|
624
|
+
chars = expr.Literal(column.dtype, "")
|
|
625
|
+
elif chars.value == "":
|
|
626
|
+
# No-op in polars, but libcudf uses empty string
|
|
627
|
+
# as signifier to remove whitespace.
|
|
628
|
+
return column
|
|
629
|
+
return expr.StringFunction(
|
|
630
|
+
dtype,
|
|
631
|
+
expr.StringFunction.Name.from_polars(name),
|
|
632
|
+
options,
|
|
633
|
+
column,
|
|
634
|
+
chars,
|
|
635
|
+
)
|
|
636
|
+
return expr.StringFunction(
|
|
637
|
+
dtype,
|
|
638
|
+
expr.StringFunction.Name.from_polars(name),
|
|
639
|
+
options,
|
|
640
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
641
|
+
)
|
|
642
|
+
elif isinstance(name, pl_expr.BooleanFunction):
|
|
643
|
+
if name == pl_expr.BooleanFunction.IsBetween:
|
|
644
|
+
column, lo, hi = (
|
|
645
|
+
translator.translate_expr(n=n, schema=schema) for n in node.input
|
|
646
|
+
)
|
|
647
|
+
(closed,) = options
|
|
648
|
+
lop, rop = expr.BooleanFunction._BETWEEN_OPS[closed]
|
|
649
|
+
return expr.BinOp(
|
|
650
|
+
dtype,
|
|
651
|
+
plc.binaryop.BinaryOperator.LOGICAL_AND,
|
|
652
|
+
expr.BinOp(dtype, lop, column, lo),
|
|
653
|
+
expr.BinOp(dtype, rop, column, hi),
|
|
654
|
+
)
|
|
655
|
+
return expr.BooleanFunction(
|
|
656
|
+
dtype,
|
|
657
|
+
expr.BooleanFunction.Name.from_polars(name),
|
|
658
|
+
options,
|
|
659
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
660
|
+
)
|
|
661
|
+
elif isinstance(name, pl_expr.TemporalFunction):
|
|
662
|
+
# functions for which evaluation of the expression may not return
|
|
663
|
+
# the same dtype as polars, either due to libcudf returning a different
|
|
664
|
+
# dtype, or due to our internal processing affecting what libcudf returns
|
|
665
|
+
needs_cast = {
|
|
666
|
+
pl_expr.TemporalFunction.Year,
|
|
667
|
+
pl_expr.TemporalFunction.Month,
|
|
668
|
+
pl_expr.TemporalFunction.Day,
|
|
669
|
+
pl_expr.TemporalFunction.WeekDay,
|
|
670
|
+
pl_expr.TemporalFunction.Hour,
|
|
671
|
+
pl_expr.TemporalFunction.Minute,
|
|
672
|
+
pl_expr.TemporalFunction.Second,
|
|
673
|
+
pl_expr.TemporalFunction.Millisecond,
|
|
674
|
+
}
|
|
675
|
+
result_expr = expr.TemporalFunction(
|
|
676
|
+
dtype,
|
|
677
|
+
expr.TemporalFunction.Name.from_polars(name),
|
|
678
|
+
options,
|
|
679
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
680
|
+
)
|
|
681
|
+
if name in needs_cast:
|
|
682
|
+
return expr.Cast(dtype, result_expr)
|
|
683
|
+
return result_expr
|
|
684
|
+
elif not POLARS_VERSION_LT_131 and isinstance(name, pl_expr.StructFunction):
|
|
685
|
+
return expr.StructFunction(
|
|
686
|
+
dtype,
|
|
687
|
+
expr.StructFunction.Name.from_polars(name),
|
|
688
|
+
options,
|
|
689
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
690
|
+
)
|
|
691
|
+
elif isinstance(name, str):
|
|
692
|
+
children = (translator.translate_expr(n=n, schema=schema) for n in node.input)
|
|
693
|
+
if name == "log":
|
|
694
|
+
(base,) = options
|
|
695
|
+
(child,) = children
|
|
696
|
+
return expr.BinOp(
|
|
697
|
+
dtype,
|
|
698
|
+
plc.binaryop.BinaryOperator.LOG_BASE,
|
|
699
|
+
child,
|
|
700
|
+
expr.Literal(dtype, base),
|
|
701
|
+
)
|
|
702
|
+
elif name == "pow":
|
|
703
|
+
return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children)
|
|
704
|
+
return expr.UnaryFunction(dtype, name, options, *children)
|
|
705
|
+
raise NotImplementedError(
|
|
706
|
+
f"No handler for Expr function node with {name=}"
|
|
707
|
+
) # pragma: no cover; polars raises on the rust side for now
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
@_translate_expr.register
|
|
711
|
+
def _(
|
|
712
|
+
node: pl_expr.Window, translator: Translator, dtype: DataType, schema: Schema
|
|
713
|
+
) -> expr.Expr:
|
|
714
|
+
if isinstance(node.options, pl_expr.RollingGroupOptions):
|
|
715
|
+
# pl.col("a").rolling(...)
|
|
716
|
+
agg = translator.translate_expr(n=node.function, schema=schema)
|
|
717
|
+
name_generator = unique_names(schema)
|
|
718
|
+
aggs, named_post_agg = decompose_single_agg(
|
|
719
|
+
expr.NamedExpr(next(name_generator), agg),
|
|
720
|
+
name_generator,
|
|
721
|
+
is_top=True,
|
|
722
|
+
context=ExecutionContext.ROLLING,
|
|
723
|
+
)
|
|
724
|
+
named_aggs = [agg for agg, _ in aggs]
|
|
725
|
+
orderby = node.options.index_column
|
|
726
|
+
orderby_dtype = schema[orderby].plc
|
|
727
|
+
if plc.traits.is_integral(orderby_dtype):
|
|
728
|
+
# Integer orderby column is cast in implementation to int64 in polars
|
|
729
|
+
orderby_dtype = plc.DataType(plc.TypeId.INT64)
|
|
730
|
+
closed_window = node.options.closed_window
|
|
731
|
+
if isinstance(named_post_agg.value, expr.Col):
|
|
732
|
+
(named_agg,) = named_aggs
|
|
733
|
+
return expr.RollingWindow(
|
|
734
|
+
named_agg.value.dtype,
|
|
735
|
+
orderby_dtype,
|
|
736
|
+
node.options.offset,
|
|
737
|
+
node.options.period,
|
|
738
|
+
closed_window,
|
|
739
|
+
orderby,
|
|
740
|
+
named_agg.value,
|
|
741
|
+
)
|
|
742
|
+
replacements: dict[expr.Expr, expr.Expr] = {
|
|
743
|
+
expr.Col(agg.value.dtype, agg.name): expr.RollingWindow(
|
|
744
|
+
agg.value.dtype,
|
|
745
|
+
orderby_dtype,
|
|
746
|
+
node.options.offset,
|
|
747
|
+
node.options.period,
|
|
748
|
+
closed_window,
|
|
749
|
+
orderby,
|
|
750
|
+
agg.value,
|
|
751
|
+
)
|
|
752
|
+
for agg in named_aggs
|
|
753
|
+
}
|
|
754
|
+
return replace([named_post_agg.value], replacements)[0]
|
|
755
|
+
elif isinstance(node.options, pl_expr.WindowMapping):
|
|
756
|
+
# pl.col("a").over(...)
|
|
757
|
+
agg = translator.translate_expr(n=node.function, schema=schema)
|
|
758
|
+
name_gen = unique_names(schema)
|
|
759
|
+
aggs, post = decompose_single_agg(
|
|
760
|
+
expr.NamedExpr(next(name_gen), agg),
|
|
761
|
+
name_gen,
|
|
762
|
+
is_top=True,
|
|
763
|
+
context=ExecutionContext.WINDOW,
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
mapping = node.options.kind
|
|
767
|
+
has_order_by = node.order_by is not None
|
|
768
|
+
descending = bool(getattr(node, "order_by_descending", False))
|
|
769
|
+
nulls_last = bool(getattr(node, "order_by_nulls_last", False))
|
|
770
|
+
|
|
771
|
+
if mapping != "groups_to_rows":
|
|
772
|
+
raise NotImplementedError(
|
|
773
|
+
f"over(mapping_strategy) not supported yet: {mapping=}; "
|
|
774
|
+
f"expected 'groups_to_rows'"
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
order_by_expr = (
|
|
778
|
+
translator.translate_expr(n=node.order_by, schema=schema)
|
|
779
|
+
if has_order_by
|
|
780
|
+
else None
|
|
781
|
+
)
|
|
782
|
+
return expr.GroupedRollingWindow(
|
|
783
|
+
dtype,
|
|
784
|
+
(mapping, has_order_by, descending, nulls_last),
|
|
785
|
+
[agg for agg, _ in aggs],
|
|
786
|
+
post,
|
|
787
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.partition_by),
|
|
788
|
+
_order_by_expr=order_by_expr,
|
|
789
|
+
)
|
|
790
|
+
assert_never(node.options)
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
@_translate_expr.register
|
|
794
|
+
def _(
|
|
795
|
+
node: pl_expr.Literal, translator: Translator, dtype: DataType, schema: Schema
|
|
796
|
+
) -> expr.Expr:
|
|
797
|
+
if isinstance(node.value, plrs.PySeries):
|
|
798
|
+
return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value))
|
|
799
|
+
if dtype.id() == plc.TypeId.LIST: # pragma: no cover
|
|
800
|
+
# TODO: Remove once pylibcudf.Scalar supports lists
|
|
801
|
+
return expr.LiteralColumn(dtype, pl.Series(node.value))
|
|
802
|
+
return expr.Literal(dtype, node.value)
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
@_translate_expr.register
|
|
806
|
+
def _(
|
|
807
|
+
node: pl_expr.Sort, translator: Translator, dtype: DataType, schema: Schema
|
|
808
|
+
) -> expr.Expr:
|
|
809
|
+
# TODO: raise in groupby
|
|
810
|
+
return expr.Sort(
|
|
811
|
+
dtype, node.options, translator.translate_expr(n=node.expr, schema=schema)
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
@_translate_expr.register
|
|
816
|
+
def _(
|
|
817
|
+
node: pl_expr.SortBy, translator: Translator, dtype: DataType, schema: Schema
|
|
818
|
+
) -> expr.Expr:
|
|
819
|
+
options = node.sort_options
|
|
820
|
+
return expr.SortBy(
|
|
821
|
+
dtype,
|
|
822
|
+
(options[0], tuple(options[1]), tuple(options[2])),
|
|
823
|
+
translator.translate_expr(n=node.expr, schema=schema),
|
|
824
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.by),
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
@_translate_expr.register
|
|
829
|
+
def _(
|
|
830
|
+
node: pl_expr.Slice, translator: Translator, dtype: DataType, schema: Schema
|
|
831
|
+
) -> expr.Expr:
|
|
832
|
+
offset = translator.translate_expr(n=node.offset, schema=schema)
|
|
833
|
+
length = translator.translate_expr(n=node.length, schema=schema)
|
|
834
|
+
assert isinstance(offset, expr.Literal)
|
|
835
|
+
assert isinstance(length, expr.Literal)
|
|
836
|
+
return expr.Slice(
|
|
837
|
+
dtype,
|
|
838
|
+
offset.value,
|
|
839
|
+
length.value,
|
|
840
|
+
translator.translate_expr(n=node.input, schema=schema),
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
@_translate_expr.register
|
|
845
|
+
def _(
|
|
846
|
+
node: pl_expr.Gather, translator: Translator, dtype: DataType, schema: Schema
|
|
847
|
+
) -> expr.Expr:
|
|
848
|
+
return expr.Gather(
|
|
849
|
+
dtype,
|
|
850
|
+
translator.translate_expr(n=node.expr, schema=schema),
|
|
851
|
+
translator.translate_expr(n=node.idx, schema=schema),
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
@_translate_expr.register
|
|
856
|
+
def _(
|
|
857
|
+
node: pl_expr.Filter, translator: Translator, dtype: DataType, schema: Schema
|
|
858
|
+
) -> expr.Expr:
|
|
859
|
+
return expr.Filter(
|
|
860
|
+
dtype,
|
|
861
|
+
translator.translate_expr(n=node.input, schema=schema),
|
|
862
|
+
translator.translate_expr(n=node.by, schema=schema),
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
@_translate_expr.register
|
|
867
|
+
def _(
|
|
868
|
+
node: pl_expr.Cast, translator: Translator, dtype: DataType, schema: Schema
|
|
869
|
+
) -> expr.Expr:
|
|
870
|
+
inner = translator.translate_expr(n=node.expr, schema=schema)
|
|
871
|
+
# Push casts into literals so we can handle Cast(Literal(Null))
|
|
872
|
+
if isinstance(inner, expr.Literal):
|
|
873
|
+
return inner.astype(dtype)
|
|
874
|
+
elif isinstance(inner, expr.Cast):
|
|
875
|
+
# Translation of Len/Count-agg put in a cast, remove double
|
|
876
|
+
# casts if we have one.
|
|
877
|
+
(inner,) = inner.children
|
|
878
|
+
return expr.Cast(dtype, inner)
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
@_translate_expr.register
|
|
882
|
+
def _(
|
|
883
|
+
node: pl_expr.Column, translator: Translator, dtype: DataType, schema: Schema
|
|
884
|
+
) -> expr.Expr:
|
|
885
|
+
return expr.Col(dtype, node.name)
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
@_translate_expr.register
|
|
889
|
+
def _(
|
|
890
|
+
node: pl_expr.Agg, translator: Translator, dtype: DataType, schema: Schema
|
|
891
|
+
) -> expr.Expr:
|
|
892
|
+
value = expr.Agg(
|
|
893
|
+
dtype,
|
|
894
|
+
node.name,
|
|
895
|
+
node.options,
|
|
896
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.arguments),
|
|
897
|
+
)
|
|
898
|
+
if value.name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32:
|
|
899
|
+
return expr.Cast(value.dtype, value)
|
|
900
|
+
return value
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
@_translate_expr.register
|
|
904
|
+
def _(
|
|
905
|
+
node: pl_expr.Ternary, translator: Translator, dtype: DataType, schema: Schema
|
|
906
|
+
) -> expr.Expr:
|
|
907
|
+
return expr.Ternary(
|
|
908
|
+
dtype,
|
|
909
|
+
translator.translate_expr(n=node.predicate, schema=schema),
|
|
910
|
+
translator.translate_expr(n=node.truthy, schema=schema),
|
|
911
|
+
translator.translate_expr(n=node.falsy, schema=schema),
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
@_translate_expr.register
|
|
916
|
+
def _(
|
|
917
|
+
node: pl_expr.BinaryExpr,
|
|
918
|
+
translator: Translator,
|
|
919
|
+
dtype: DataType,
|
|
920
|
+
schema: Schema,
|
|
921
|
+
) -> expr.Expr:
|
|
922
|
+
if plc.traits.is_boolean(dtype.plc) and node.op == pl_expr.Operator.TrueDivide:
|
|
923
|
+
dtype = DataType(pl.Float64())
|
|
924
|
+
return expr.BinOp(
|
|
925
|
+
dtype,
|
|
926
|
+
expr.BinOp._MAPPING[node.op],
|
|
927
|
+
translator.translate_expr(n=node.left, schema=schema),
|
|
928
|
+
translator.translate_expr(n=node.right, schema=schema),
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
@_translate_expr.register
|
|
933
|
+
def _(
|
|
934
|
+
node: pl_expr.Len, translator: Translator, dtype: DataType, schema: Schema
|
|
935
|
+
) -> expr.Expr:
|
|
936
|
+
value = expr.Len(dtype)
|
|
937
|
+
if dtype.id() != plc.TypeId.INT32:
|
|
938
|
+
return expr.Cast(dtype, value)
|
|
939
|
+
return value # pragma: no cover; never reached since polars len has uint32 dtype
|