cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +82 -65
- cudf_polars/containers/column.py +138 -7
- cudf_polars/containers/dataframe.py +26 -39
- cudf_polars/dsl/expr.py +3 -1
- cudf_polars/dsl/expressions/aggregation.py +27 -63
- cudf_polars/dsl/expressions/base.py +40 -72
- cudf_polars/dsl/expressions/binaryop.py +5 -41
- cudf_polars/dsl/expressions/boolean.py +25 -53
- cudf_polars/dsl/expressions/datetime.py +97 -17
- cudf_polars/dsl/expressions/literal.py +27 -33
- cudf_polars/dsl/expressions/rolling.py +110 -9
- cudf_polars/dsl/expressions/selection.py +8 -26
- cudf_polars/dsl/expressions/slicing.py +47 -0
- cudf_polars/dsl/expressions/sorting.py +5 -18
- cudf_polars/dsl/expressions/string.py +33 -36
- cudf_polars/dsl/expressions/ternary.py +3 -10
- cudf_polars/dsl/expressions/unary.py +35 -75
- cudf_polars/dsl/ir.py +749 -212
- cudf_polars/dsl/nodebase.py +8 -1
- cudf_polars/dsl/to_ast.py +5 -3
- cudf_polars/dsl/translate.py +319 -171
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +292 -0
- cudf_polars/dsl/utils/groupby.py +97 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +46 -0
- cudf_polars/dsl/utils/rolling.py +113 -0
- cudf_polars/dsl/utils/windows.py +186 -0
- cudf_polars/experimental/base.py +17 -19
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsh.py +1279 -0
- cudf_polars/experimental/dask_registers.py +196 -0
- cudf_polars/experimental/distinct.py +174 -0
- cudf_polars/experimental/explain.py +127 -0
- cudf_polars/experimental/expressions.py +521 -0
- cudf_polars/experimental/groupby.py +288 -0
- cudf_polars/experimental/io.py +58 -29
- cudf_polars/experimental/join.py +353 -0
- cudf_polars/experimental/parallel.py +166 -93
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +92 -7
- cudf_polars/experimental/shuffle.py +294 -0
- cudf_polars/experimental/sort.py +45 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/utils.py +100 -0
- cudf_polars/testing/asserts.py +146 -6
- cudf_polars/testing/io.py +72 -0
- cudf_polars/testing/plugin.py +78 -76
- cudf_polars/typing/__init__.py +59 -6
- cudf_polars/utils/config.py +353 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +22 -5
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +5 -4
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/METADATA +10 -7
- cudf_polars_cu12-25.6.0.dist-info/RECORD +73 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/WHEEL +1 -1
- cudf_polars/experimental/dask_serialize.py +0 -59
- cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info/licenses}/LICENSE +0 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/translate.py
CHANGED
|
@@ -22,8 +22,14 @@ import pylibcudf as plc
|
|
|
22
22
|
|
|
23
23
|
from cudf_polars.dsl import expr, ir
|
|
24
24
|
from cudf_polars.dsl.to_ast import insert_colrefs
|
|
25
|
-
from cudf_polars.
|
|
26
|
-
from cudf_polars.utils import
|
|
25
|
+
from cudf_polars.dsl.utils.aggregations import decompose_single_agg
|
|
26
|
+
from cudf_polars.dsl.utils.groupby import rewrite_groupby
|
|
27
|
+
from cudf_polars.dsl.utils.naming import unique_names
|
|
28
|
+
from cudf_polars.dsl.utils.replace import replace
|
|
29
|
+
from cudf_polars.dsl.utils.rolling import rewrite_rolling
|
|
30
|
+
from cudf_polars.dsl.utils.windows import offsets_to_windows
|
|
31
|
+
from cudf_polars.typing import Schema
|
|
32
|
+
from cudf_polars.utils import config, dtypes, sorting
|
|
27
33
|
|
|
28
34
|
if TYPE_CHECKING:
|
|
29
35
|
from polars import GPUEngine
|
|
@@ -41,13 +47,13 @@ class Translator:
|
|
|
41
47
|
----------
|
|
42
48
|
visitor
|
|
43
49
|
Polars NodeTraverser object
|
|
44
|
-
|
|
50
|
+
engine
|
|
45
51
|
GPU engine configuration.
|
|
46
52
|
"""
|
|
47
53
|
|
|
48
|
-
def __init__(self, visitor: NodeTraverser,
|
|
54
|
+
def __init__(self, visitor: NodeTraverser, engine: GPUEngine):
|
|
49
55
|
self.visitor = visitor
|
|
50
|
-
self.
|
|
56
|
+
self.config_options = config.ConfigOptions.from_polars_engine(engine)
|
|
51
57
|
self.errors: list[Exception] = []
|
|
52
58
|
|
|
53
59
|
def translate_ir(self, *, n: int | None = None) -> ir.IR:
|
|
@@ -84,7 +90,7 @@ class Translator:
|
|
|
84
90
|
# IR is versioned with major.minor, minor is bumped for backwards
|
|
85
91
|
# compatible changes (e.g. adding new nodes), major is bumped for
|
|
86
92
|
# incompatible changes (e.g. renaming nodes).
|
|
87
|
-
if (version := self.visitor.version()) >= (
|
|
93
|
+
if (version := self.visitor.version()) >= (7, 1):
|
|
88
94
|
e = NotImplementedError(
|
|
89
95
|
f"No support for polars IR {version=}"
|
|
90
96
|
) # pragma: no cover; no such version for now.
|
|
@@ -120,7 +126,7 @@ class Translator:
|
|
|
120
126
|
|
|
121
127
|
return result
|
|
122
128
|
|
|
123
|
-
def translate_expr(self, *, n: int) -> expr.Expr:
|
|
129
|
+
def translate_expr(self, *, n: int, schema: Schema) -> expr.Expr:
|
|
124
130
|
"""
|
|
125
131
|
Translate a polars-internal expression IR into our representation.
|
|
126
132
|
|
|
@@ -128,6 +134,8 @@ class Translator:
|
|
|
128
134
|
----------
|
|
129
135
|
n
|
|
130
136
|
Node to translate, an integer referencing a polars internal node.
|
|
137
|
+
schema
|
|
138
|
+
Schema of the IR node this expression uses as evaluation context.
|
|
131
139
|
|
|
132
140
|
Returns
|
|
133
141
|
-------
|
|
@@ -143,7 +151,7 @@ class Translator:
|
|
|
143
151
|
node = self.visitor.view_expression(n)
|
|
144
152
|
dtype = dtypes.from_polars(self.visitor.get_dtype(n))
|
|
145
153
|
try:
|
|
146
|
-
return _translate_expr(node, self, dtype)
|
|
154
|
+
return _translate_expr(node, self, dtype, schema)
|
|
147
155
|
except Exception as e:
|
|
148
156
|
self.errors.append(e)
|
|
149
157
|
return expr.ErrorExpr(dtype, str(e))
|
|
@@ -168,6 +176,7 @@ class set_node(AbstractContextManager[None]):
|
|
|
168
176
|
|
|
169
177
|
__slots__ = ("n", "visitor")
|
|
170
178
|
visitor: NodeTraverser
|
|
179
|
+
|
|
171
180
|
n: int
|
|
172
181
|
|
|
173
182
|
def __init__(self, visitor: NodeTraverser, n: int) -> None:
|
|
@@ -187,30 +196,26 @@ noop_context: nullcontext[None] = nullcontext()
|
|
|
187
196
|
|
|
188
197
|
|
|
189
198
|
@singledispatch
|
|
190
|
-
def _translate_ir(
|
|
191
|
-
node: Any, translator: Translator, schema: dict[str, plc.DataType]
|
|
192
|
-
) -> ir.IR:
|
|
199
|
+
def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR:
|
|
193
200
|
raise NotImplementedError(
|
|
194
201
|
f"Translation for {type(node).__name__}"
|
|
195
202
|
) # pragma: no cover
|
|
196
203
|
|
|
197
204
|
|
|
198
205
|
@_translate_ir.register
|
|
199
|
-
def _(
|
|
200
|
-
node: pl_ir.PythonScan, translator: Translator, schema: dict[str, plc.DataType]
|
|
201
|
-
) -> ir.IR:
|
|
206
|
+
def _(node: pl_ir.PythonScan, translator: Translator, schema: Schema) -> ir.IR:
|
|
202
207
|
scan_fn, with_columns, source_type, predicate, nrows = node.options
|
|
203
208
|
options = (scan_fn, with_columns, source_type, nrows)
|
|
204
209
|
predicate = (
|
|
205
|
-
translate_named_expr(translator, n=predicate)
|
|
210
|
+
translate_named_expr(translator, n=predicate, schema=schema)
|
|
211
|
+
if predicate is not None
|
|
212
|
+
else None
|
|
206
213
|
)
|
|
207
214
|
return ir.PythonScan(schema, options, predicate)
|
|
208
215
|
|
|
209
216
|
|
|
210
217
|
@_translate_ir.register
|
|
211
|
-
def _(
|
|
212
|
-
node: pl_ir.Scan, translator: Translator, schema: dict[str, plc.DataType]
|
|
213
|
-
) -> ir.IR:
|
|
218
|
+
def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR:
|
|
214
219
|
typ, *options = node.scan_type
|
|
215
220
|
if typ == "ndjson":
|
|
216
221
|
(reader_options,) = map(json.loads, options)
|
|
@@ -219,118 +224,102 @@ def _(
|
|
|
219
224
|
reader_options, cloud_options = map(json.loads, options)
|
|
220
225
|
file_options = node.file_options
|
|
221
226
|
with_columns = file_options.with_columns
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
227
|
+
row_index = file_options.row_index
|
|
228
|
+
include_file_paths = file_options.include_file_paths
|
|
229
|
+
|
|
230
|
+
pre_slice = file_options.n_rows
|
|
231
|
+
if pre_slice is None:
|
|
232
|
+
n_rows = -1
|
|
233
|
+
skip_rows = 0
|
|
226
234
|
else:
|
|
227
|
-
|
|
228
|
-
skip_rows, n_rows = n_rows
|
|
235
|
+
skip_rows, n_rows = pre_slice
|
|
229
236
|
|
|
230
|
-
row_index = file_options.row_index
|
|
231
237
|
return ir.Scan(
|
|
232
238
|
schema,
|
|
233
239
|
typ,
|
|
234
240
|
reader_options,
|
|
235
241
|
cloud_options,
|
|
236
|
-
translator.
|
|
242
|
+
translator.config_options,
|
|
237
243
|
node.paths,
|
|
238
244
|
with_columns,
|
|
239
245
|
skip_rows,
|
|
240
246
|
n_rows,
|
|
241
247
|
row_index,
|
|
242
|
-
|
|
248
|
+
include_file_paths,
|
|
249
|
+
translate_named_expr(translator, n=node.predicate, schema=schema)
|
|
243
250
|
if node.predicate is not None
|
|
244
251
|
else None,
|
|
245
252
|
)
|
|
246
253
|
|
|
247
254
|
|
|
248
255
|
@_translate_ir.register
|
|
249
|
-
def _(
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
256
|
+
def _(node: pl_ir.Cache, translator: Translator, schema: Schema) -> ir.IR:
|
|
257
|
+
return ir.Cache(
|
|
258
|
+
schema, node.id_, node.cache_hits, translator.translate_ir(n=node.input)
|
|
259
|
+
)
|
|
253
260
|
|
|
254
261
|
|
|
255
262
|
@_translate_ir.register
|
|
256
|
-
def _(
|
|
257
|
-
node: pl_ir.DataFrameScan, translator: Translator, schema: dict[str, plc.DataType]
|
|
258
|
-
) -> ir.IR:
|
|
263
|
+
def _(node: pl_ir.DataFrameScan, translator: Translator, schema: Schema) -> ir.IR:
|
|
259
264
|
return ir.DataFrameScan(
|
|
260
265
|
schema,
|
|
261
266
|
node.df,
|
|
262
267
|
node.projection,
|
|
263
|
-
translator.
|
|
268
|
+
translator.config_options,
|
|
264
269
|
)
|
|
265
270
|
|
|
266
271
|
|
|
267
272
|
@_translate_ir.register
|
|
268
|
-
def _(
|
|
269
|
-
node: pl_ir.Select, translator: Translator, schema: dict[str, plc.DataType]
|
|
270
|
-
) -> ir.IR:
|
|
273
|
+
def _(node: pl_ir.Select, translator: Translator, schema: Schema) -> ir.IR:
|
|
271
274
|
with set_node(translator.visitor, node.input):
|
|
272
275
|
inp = translator.translate_ir(n=None)
|
|
273
|
-
exprs = [
|
|
276
|
+
exprs = [
|
|
277
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.expr
|
|
278
|
+
]
|
|
274
279
|
return ir.Select(schema, exprs, node.should_broadcast, inp)
|
|
275
280
|
|
|
276
281
|
|
|
277
282
|
@_translate_ir.register
|
|
278
|
-
def _(
|
|
279
|
-
node: pl_ir.GroupBy, translator: Translator, schema: dict[str, plc.DataType]
|
|
280
|
-
) -> ir.IR:
|
|
283
|
+
def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
|
|
281
284
|
with set_node(translator.visitor, node.input):
|
|
282
285
|
inp = translator.translate_ir(n=None)
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
286
|
+
keys = [
|
|
287
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.keys
|
|
288
|
+
]
|
|
289
|
+
original_aggs = [
|
|
290
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.aggs
|
|
291
|
+
]
|
|
292
|
+
is_rolling = node.options.rolling is not None
|
|
293
|
+
is_dynamic = node.options.dynamic is not None
|
|
294
|
+
if is_dynamic:
|
|
295
|
+
raise NotImplementedError("group_by_dynamic")
|
|
296
|
+
elif is_rolling:
|
|
297
|
+
return rewrite_rolling(
|
|
298
|
+
node.options, schema, keys, original_aggs, translator.config_options, inp
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
return rewrite_groupby(
|
|
302
|
+
node, schema, keys, original_aggs, translator.config_options, inp
|
|
303
|
+
)
|
|
293
304
|
|
|
294
305
|
|
|
295
306
|
@_translate_ir.register
|
|
296
|
-
def _(
|
|
297
|
-
node: pl_ir.Join, translator: Translator, schema: dict[str, plc.DataType]
|
|
298
|
-
) -> ir.IR:
|
|
307
|
+
def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
|
|
299
308
|
# Join key dtypes are dependent on the schema of the left and
|
|
300
309
|
# right inputs, so these must be translated with the relevant
|
|
301
310
|
# input active.
|
|
302
|
-
def adjust_literal_dtype(literal: expr.Literal) -> expr.Literal:
|
|
303
|
-
if literal.dtype.id() == plc.types.TypeId.INT32:
|
|
304
|
-
plc_int64 = plc.types.DataType(plc.types.TypeId.INT64)
|
|
305
|
-
return expr.Literal(
|
|
306
|
-
plc_int64,
|
|
307
|
-
pa.scalar(literal.value.as_py(), type=plc.interop.to_arrow(plc_int64)),
|
|
308
|
-
)
|
|
309
|
-
return literal
|
|
310
|
-
|
|
311
|
-
def maybe_adjust_binop(e) -> expr.Expr:
|
|
312
|
-
if isinstance(e.value, expr.BinOp):
|
|
313
|
-
left, right = e.value.children
|
|
314
|
-
if isinstance(left, expr.Col) and isinstance(right, expr.Literal):
|
|
315
|
-
e.value.children = (left, adjust_literal_dtype(right))
|
|
316
|
-
elif isinstance(left, expr.Literal) and isinstance(right, expr.Col):
|
|
317
|
-
e.value.children = (adjust_literal_dtype(left), right)
|
|
318
|
-
return e
|
|
319
|
-
|
|
320
|
-
def translate_expr_and_maybe_fix_binop_args(translator, exprs):
|
|
321
|
-
return [
|
|
322
|
-
maybe_adjust_binop(translate_named_expr(translator, n=e)) for e in exprs
|
|
323
|
-
]
|
|
324
|
-
|
|
325
311
|
with set_node(translator.visitor, node.input_left):
|
|
326
312
|
inp_left = translator.translate_ir(n=None)
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
313
|
+
left_on = [
|
|
314
|
+
translate_named_expr(translator, n=e, schema=inp_left.schema)
|
|
315
|
+
for e in node.left_on
|
|
316
|
+
]
|
|
331
317
|
with set_node(translator.visitor, node.input_right):
|
|
332
318
|
inp_right = translator.translate_ir(n=None)
|
|
333
|
-
right_on =
|
|
319
|
+
right_on = [
|
|
320
|
+
translate_named_expr(translator, n=e, schema=inp_right.schema)
|
|
321
|
+
for e in node.right_on
|
|
322
|
+
]
|
|
334
323
|
|
|
335
324
|
if (how := node.options[0]) in {
|
|
336
325
|
"Inner",
|
|
@@ -341,7 +330,15 @@ def _(
|
|
|
341
330
|
"Semi",
|
|
342
331
|
"Anti",
|
|
343
332
|
}:
|
|
344
|
-
return ir.Join(
|
|
333
|
+
return ir.Join(
|
|
334
|
+
schema,
|
|
335
|
+
left_on,
|
|
336
|
+
right_on,
|
|
337
|
+
node.options,
|
|
338
|
+
translator.config_options,
|
|
339
|
+
inp_left,
|
|
340
|
+
inp_right,
|
|
341
|
+
)
|
|
345
342
|
else:
|
|
346
343
|
how, op1, op2 = node.options[0]
|
|
347
344
|
if how != "IEJoin":
|
|
@@ -385,29 +382,29 @@ def _(
|
|
|
385
382
|
|
|
386
383
|
|
|
387
384
|
@_translate_ir.register
|
|
388
|
-
def _(
|
|
389
|
-
node: pl_ir.HStack, translator: Translator, schema: dict[str, plc.DataType]
|
|
390
|
-
) -> ir.IR:
|
|
385
|
+
def _(node: pl_ir.HStack, translator: Translator, schema: Schema) -> ir.IR:
|
|
391
386
|
with set_node(translator.visitor, node.input):
|
|
392
387
|
inp = translator.translate_ir(n=None)
|
|
393
|
-
exprs = [
|
|
388
|
+
exprs = [
|
|
389
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.exprs
|
|
390
|
+
]
|
|
394
391
|
return ir.HStack(schema, exprs, node.should_broadcast, inp)
|
|
395
392
|
|
|
396
393
|
|
|
397
394
|
@_translate_ir.register
|
|
398
395
|
def _(
|
|
399
|
-
node: pl_ir.Reduce, translator: Translator, schema:
|
|
396
|
+
node: pl_ir.Reduce, translator: Translator, schema: Schema
|
|
400
397
|
) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet
|
|
401
398
|
with set_node(translator.visitor, node.input):
|
|
402
399
|
inp = translator.translate_ir(n=None)
|
|
403
|
-
exprs = [
|
|
400
|
+
exprs = [
|
|
401
|
+
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.expr
|
|
402
|
+
]
|
|
404
403
|
return ir.Reduce(schema, exprs, inp)
|
|
405
404
|
|
|
406
405
|
|
|
407
406
|
@_translate_ir.register
|
|
408
|
-
def _(
|
|
409
|
-
node: pl_ir.Distinct, translator: Translator, schema: dict[str, plc.DataType]
|
|
410
|
-
) -> ir.IR:
|
|
407
|
+
def _(node: pl_ir.Distinct, translator: Translator, schema: Schema) -> ir.IR:
|
|
411
408
|
(keep, subset, maintain_order, zlice) = node.options
|
|
412
409
|
keep = ir.Distinct._KEEP_MAP[keep]
|
|
413
410
|
subset = frozenset(subset) if subset is not None else None
|
|
@@ -422,12 +419,13 @@ def _(
|
|
|
422
419
|
|
|
423
420
|
|
|
424
421
|
@_translate_ir.register
|
|
425
|
-
def _(
|
|
426
|
-
node: pl_ir.Sort, translator: Translator, schema: dict[str, plc.DataType]
|
|
427
|
-
) -> ir.IR:
|
|
422
|
+
def _(node: pl_ir.Sort, translator: Translator, schema: Schema) -> ir.IR:
|
|
428
423
|
with set_node(translator.visitor, node.input):
|
|
429
424
|
inp = translator.translate_ir(n=None)
|
|
430
|
-
by = [
|
|
425
|
+
by = [
|
|
426
|
+
translate_named_expr(translator, n=e, schema=inp.schema)
|
|
427
|
+
for e in node.by_column
|
|
428
|
+
]
|
|
431
429
|
stable, nulls_last, descending = node.sort_options
|
|
432
430
|
order, null_order = sorting.sort_order(
|
|
433
431
|
descending, nulls_last=nulls_last, num_keys=len(by)
|
|
@@ -436,65 +434,101 @@ def _(
|
|
|
436
434
|
|
|
437
435
|
|
|
438
436
|
@_translate_ir.register
|
|
439
|
-
def _(
|
|
440
|
-
node: pl_ir.Slice, translator: Translator, schema: dict[str, plc.DataType]
|
|
441
|
-
) -> ir.IR:
|
|
437
|
+
def _(node: pl_ir.Slice, translator: Translator, schema: Schema) -> ir.IR:
|
|
442
438
|
return ir.Slice(
|
|
443
439
|
schema, node.offset, node.len, translator.translate_ir(n=node.input)
|
|
444
440
|
)
|
|
445
441
|
|
|
446
442
|
|
|
447
443
|
@_translate_ir.register
|
|
448
|
-
def _(
|
|
449
|
-
node: pl_ir.Filter, translator: Translator, schema: dict[str, plc.DataType]
|
|
450
|
-
) -> ir.IR:
|
|
444
|
+
def _(node: pl_ir.Filter, translator: Translator, schema: Schema) -> ir.IR:
|
|
451
445
|
with set_node(translator.visitor, node.input):
|
|
452
446
|
inp = translator.translate_ir(n=None)
|
|
453
|
-
mask = translate_named_expr(translator, n=node.predicate)
|
|
447
|
+
mask = translate_named_expr(translator, n=node.predicate, schema=inp.schema)
|
|
454
448
|
return ir.Filter(schema, mask, inp)
|
|
455
449
|
|
|
456
450
|
|
|
457
451
|
@_translate_ir.register
|
|
458
|
-
def _(
|
|
459
|
-
node: pl_ir.SimpleProjection,
|
|
460
|
-
translator: Translator,
|
|
461
|
-
schema: dict[str, plc.DataType],
|
|
462
|
-
) -> ir.IR:
|
|
452
|
+
def _(node: pl_ir.SimpleProjection, translator: Translator, schema: Schema) -> ir.IR:
|
|
463
453
|
return ir.Projection(schema, translator.translate_ir(n=node.input))
|
|
464
454
|
|
|
465
455
|
|
|
466
456
|
@_translate_ir.register
|
|
467
|
-
def _(
|
|
468
|
-
|
|
469
|
-
|
|
457
|
+
def _(node: pl_ir.MergeSorted, translator: Translator, schema: Schema) -> ir.IR:
|
|
458
|
+
key = node.key
|
|
459
|
+
inp_left = translator.translate_ir(n=node.input_left)
|
|
460
|
+
inp_right = translator.translate_ir(n=node.input_right)
|
|
461
|
+
return ir.MergeSorted(
|
|
462
|
+
schema,
|
|
463
|
+
key,
|
|
464
|
+
inp_left,
|
|
465
|
+
inp_right,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
@_translate_ir.register
|
|
470
|
+
def _(node: pl_ir.MapFunction, translator: Translator, schema: Schema) -> ir.IR:
|
|
470
471
|
name, *options = node.function
|
|
471
472
|
return ir.MapFunction(
|
|
472
473
|
schema,
|
|
473
474
|
name,
|
|
474
475
|
options,
|
|
475
|
-
# TODO: merge_sorted breaks this pattern
|
|
476
476
|
translator.translate_ir(n=node.input),
|
|
477
477
|
)
|
|
478
478
|
|
|
479
479
|
|
|
480
480
|
@_translate_ir.register
|
|
481
|
-
def _(
|
|
482
|
-
node: pl_ir.Union, translator: Translator, schema: dict[str, plc.DataType]
|
|
483
|
-
) -> ir.IR:
|
|
481
|
+
def _(node: pl_ir.Union, translator: Translator, schema: Schema) -> ir.IR:
|
|
484
482
|
return ir.Union(
|
|
485
483
|
schema, node.options, *(translator.translate_ir(n=n) for n in node.inputs)
|
|
486
484
|
)
|
|
487
485
|
|
|
488
486
|
|
|
487
|
+
@_translate_ir.register
|
|
488
|
+
def _(node: pl_ir.HConcat, translator: Translator, schema: Schema) -> ir.IR:
|
|
489
|
+
return ir.HConcat(
|
|
490
|
+
schema,
|
|
491
|
+
False, # noqa: FBT003
|
|
492
|
+
*(translator.translate_ir(n=n) for n in node.inputs),
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
|
|
489
496
|
@_translate_ir.register
|
|
490
497
|
def _(
|
|
491
|
-
node: pl_ir.
|
|
498
|
+
node: pl_ir.Sink, translator: Translator, schema: dict[str, plc.DataType]
|
|
492
499
|
) -> ir.IR:
|
|
493
|
-
|
|
500
|
+
payload = json.loads(node.payload)
|
|
501
|
+
try:
|
|
502
|
+
file = payload["File"]
|
|
503
|
+
sink_kind_options = file["file_type"]
|
|
504
|
+
except KeyError as err: # pragma: no cover
|
|
505
|
+
raise NotImplementedError("Unsupported payload structure") from err
|
|
506
|
+
if isinstance(sink_kind_options, dict):
|
|
507
|
+
if len(sink_kind_options) != 1: # pragma: no cover; not sure if this can happen
|
|
508
|
+
raise NotImplementedError("Sink options dict with more than one entry.")
|
|
509
|
+
sink_kind, options = next(iter(sink_kind_options.items()))
|
|
510
|
+
else:
|
|
511
|
+
raise NotImplementedError(
|
|
512
|
+
"Unsupported sink options structure"
|
|
513
|
+
) # pragma: no cover
|
|
514
|
+
|
|
515
|
+
sink_options = file.get("sink_options", {})
|
|
516
|
+
cloud_options = file.get("cloud_options")
|
|
517
|
+
|
|
518
|
+
options.update(sink_options)
|
|
519
|
+
|
|
520
|
+
return ir.Sink(
|
|
521
|
+
schema=schema,
|
|
522
|
+
kind=sink_kind,
|
|
523
|
+
path=file["target"],
|
|
524
|
+
options=options,
|
|
525
|
+
cloud_options=cloud_options,
|
|
526
|
+
df=translator.translate_ir(n=node.input),
|
|
527
|
+
)
|
|
494
528
|
|
|
495
529
|
|
|
496
530
|
def translate_named_expr(
|
|
497
|
-
translator: Translator, *, n: pl_expr.PyExprIR
|
|
531
|
+
translator: Translator, *, n: pl_expr.PyExprIR, schema: Schema
|
|
498
532
|
) -> expr.NamedExpr:
|
|
499
533
|
"""
|
|
500
534
|
Translate a polars-internal named expression IR object into our representation.
|
|
@@ -505,6 +539,8 @@ def translate_named_expr(
|
|
|
505
539
|
Translator object
|
|
506
540
|
n
|
|
507
541
|
Node to translate, a named expression node.
|
|
542
|
+
schema
|
|
543
|
+
Schema of the IR node this expression uses as evaluation context.
|
|
508
544
|
|
|
509
545
|
Returns
|
|
510
546
|
-------
|
|
@@ -522,12 +558,14 @@ def translate_named_expr(
|
|
|
522
558
|
NotImplementedError
|
|
523
559
|
If any translation fails due to unsupported functionality.
|
|
524
560
|
"""
|
|
525
|
-
return expr.NamedExpr(
|
|
561
|
+
return expr.NamedExpr(
|
|
562
|
+
n.output_name, translator.translate_expr(n=n.node, schema=schema)
|
|
563
|
+
)
|
|
526
564
|
|
|
527
565
|
|
|
528
566
|
@singledispatch
|
|
529
567
|
def _translate_expr(
|
|
530
|
-
node: Any, translator: Translator, dtype: plc.DataType
|
|
568
|
+
node: Any, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
531
569
|
) -> expr.Expr:
|
|
532
570
|
raise NotImplementedError(
|
|
533
571
|
f"Translation for {type(node).__name__}"
|
|
@@ -535,7 +573,9 @@ def _translate_expr(
|
|
|
535
573
|
|
|
536
574
|
|
|
537
575
|
@_translate_expr.register
|
|
538
|
-
def _(
|
|
576
|
+
def _(
|
|
577
|
+
node: pl_expr.Function, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
578
|
+
) -> expr.Expr:
|
|
539
579
|
name, *options = node.function_data
|
|
540
580
|
options = tuple(options)
|
|
541
581
|
if isinstance(name, pl_expr.StringFunction):
|
|
@@ -544,18 +584,20 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
|
|
|
544
584
|
pl_expr.StringFunction.StripCharsStart,
|
|
545
585
|
pl_expr.StringFunction.StripCharsEnd,
|
|
546
586
|
}:
|
|
547
|
-
column, chars = (
|
|
587
|
+
column, chars = (
|
|
588
|
+
translator.translate_expr(n=n, schema=schema) for n in node.input
|
|
589
|
+
)
|
|
548
590
|
if isinstance(chars, expr.Literal):
|
|
549
|
-
|
|
591
|
+
# We check for null first because we want to use the
|
|
592
|
+
# chars pyarrow type, but it is invalid to try and
|
|
593
|
+
# produce a string scalar with a null dtype.
|
|
594
|
+
if chars.value is None:
|
|
595
|
+
# Polars uses None to mean "strip all whitespace"
|
|
596
|
+
chars = expr.Literal(column.dtype, "")
|
|
597
|
+
elif chars.value == "":
|
|
550
598
|
# No-op in polars, but libcudf uses empty string
|
|
551
599
|
# as signifier to remove whitespace.
|
|
552
600
|
return column
|
|
553
|
-
elif chars.value == pa.scalar(None):
|
|
554
|
-
# Polars uses None to mean "strip all whitespace"
|
|
555
|
-
chars = expr.Literal(
|
|
556
|
-
column.dtype,
|
|
557
|
-
pa.scalar("", type=plc.interop.to_arrow(column.dtype)),
|
|
558
|
-
)
|
|
559
601
|
return expr.StringFunction(
|
|
560
602
|
dtype,
|
|
561
603
|
expr.StringFunction.Name.from_polars(name),
|
|
@@ -567,11 +609,13 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
|
|
|
567
609
|
dtype,
|
|
568
610
|
expr.StringFunction.Name.from_polars(name),
|
|
569
611
|
options,
|
|
570
|
-
*(translator.translate_expr(n=n) for n in node.input),
|
|
612
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
571
613
|
)
|
|
572
614
|
elif isinstance(name, pl_expr.BooleanFunction):
|
|
573
615
|
if name == pl_expr.BooleanFunction.IsBetween:
|
|
574
|
-
column, lo, hi = (
|
|
616
|
+
column, lo, hi = (
|
|
617
|
+
translator.translate_expr(n=n, schema=schema) for n in node.input
|
|
618
|
+
)
|
|
575
619
|
(closed,) = options
|
|
576
620
|
lop, rop = expr.BooleanFunction._BETWEEN_OPS[closed]
|
|
577
621
|
return expr.BinOp(
|
|
@@ -584,7 +628,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
|
|
|
584
628
|
dtype,
|
|
585
629
|
expr.BooleanFunction.Name.from_polars(name),
|
|
586
630
|
options,
|
|
587
|
-
*(translator.translate_expr(n=n) for n in node.input),
|
|
631
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
588
632
|
)
|
|
589
633
|
elif isinstance(name, pl_expr.TemporalFunction):
|
|
590
634
|
# functions for which evaluation of the expression may not return
|
|
@@ -604,14 +648,14 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
|
|
|
604
648
|
dtype,
|
|
605
649
|
expr.TemporalFunction.Name.from_polars(name),
|
|
606
650
|
options,
|
|
607
|
-
*(translator.translate_expr(n=n) for n in node.input),
|
|
651
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
608
652
|
)
|
|
609
653
|
if name in needs_cast:
|
|
610
654
|
return expr.Cast(dtype, result_expr)
|
|
611
655
|
return result_expr
|
|
612
656
|
|
|
613
657
|
elif isinstance(name, str):
|
|
614
|
-
children = (translator.translate_expr(n=n) for n in node.input)
|
|
658
|
+
children = (translator.translate_expr(n=n, schema=schema) for n in node.input)
|
|
615
659
|
if name == "log":
|
|
616
660
|
(base,) = options
|
|
617
661
|
(child,) = children
|
|
@@ -619,10 +663,21 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
|
|
|
619
663
|
dtype,
|
|
620
664
|
plc.binaryop.BinaryOperator.LOG_BASE,
|
|
621
665
|
child,
|
|
622
|
-
expr.Literal(dtype,
|
|
666
|
+
expr.Literal(dtype, base),
|
|
623
667
|
)
|
|
624
668
|
elif name == "pow":
|
|
625
669
|
return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children)
|
|
670
|
+
elif name in "top_k":
|
|
671
|
+
(col, k) = children
|
|
672
|
+
assert isinstance(k, expr.Literal)
|
|
673
|
+
(descending,) = options
|
|
674
|
+
return expr.Slice(
|
|
675
|
+
dtype,
|
|
676
|
+
0,
|
|
677
|
+
k.value,
|
|
678
|
+
expr.Sort(dtype, (False, True, not descending), col),
|
|
679
|
+
)
|
|
680
|
+
|
|
626
681
|
return expr.UnaryFunction(dtype, name, options, *children)
|
|
627
682
|
raise NotImplementedError(
|
|
628
683
|
f"No handler for Expr function node with {name=}"
|
|
@@ -630,73 +685,155 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
|
|
|
630
685
|
|
|
631
686
|
|
|
632
687
|
@_translate_expr.register
|
|
633
|
-
def _(
|
|
634
|
-
|
|
688
|
+
def _(
|
|
689
|
+
node: pl_expr.Window, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
690
|
+
) -> expr.Expr:
|
|
635
691
|
if isinstance(node.options, pl_expr.RollingGroupOptions):
|
|
636
692
|
# pl.col("a").rolling(...)
|
|
637
|
-
|
|
638
|
-
|
|
693
|
+
agg = translator.translate_expr(n=node.function, schema=schema)
|
|
694
|
+
name_generator = unique_names(schema)
|
|
695
|
+
aggs, named_post_agg = decompose_single_agg(
|
|
696
|
+
expr.NamedExpr(next(name_generator), agg), name_generator, is_top=True
|
|
639
697
|
)
|
|
698
|
+
named_aggs = [agg for agg, _ in aggs]
|
|
699
|
+
orderby = node.options.index_column
|
|
700
|
+
orderby_dtype = schema[orderby]
|
|
701
|
+
if plc.traits.is_integral(orderby_dtype):
|
|
702
|
+
# Integer orderby column is cast in implementation to int64 in polars
|
|
703
|
+
orderby_dtype = plc.DataType(plc.TypeId.INT64)
|
|
704
|
+
preceding, following = offsets_to_windows(
|
|
705
|
+
orderby_dtype,
|
|
706
|
+
node.options.offset,
|
|
707
|
+
node.options.period,
|
|
708
|
+
)
|
|
709
|
+
closed_window = node.options.closed_window
|
|
710
|
+
if isinstance(named_post_agg.value, expr.Col):
|
|
711
|
+
(named_agg,) = named_aggs
|
|
712
|
+
return expr.RollingWindow(
|
|
713
|
+
named_agg.value.dtype,
|
|
714
|
+
preceding,
|
|
715
|
+
following,
|
|
716
|
+
closed_window,
|
|
717
|
+
orderby,
|
|
718
|
+
named_agg.value,
|
|
719
|
+
)
|
|
720
|
+
replacements: dict[expr.Expr, expr.Expr] = {
|
|
721
|
+
expr.Col(agg.value.dtype, agg.name): expr.RollingWindow(
|
|
722
|
+
agg.value.dtype,
|
|
723
|
+
preceding,
|
|
724
|
+
following,
|
|
725
|
+
closed_window,
|
|
726
|
+
orderby,
|
|
727
|
+
agg.value,
|
|
728
|
+
)
|
|
729
|
+
for agg in named_aggs
|
|
730
|
+
}
|
|
731
|
+
return replace([named_post_agg.value], replacements)[0]
|
|
640
732
|
elif isinstance(node.options, pl_expr.WindowMapping):
|
|
641
733
|
# pl.col("a").over(...)
|
|
642
734
|
return expr.GroupedRollingWindow(
|
|
643
735
|
dtype,
|
|
644
736
|
node.options,
|
|
645
|
-
translator.translate_expr(n=node.function),
|
|
646
|
-
*(translator.translate_expr(n=n) for n in node.partition_by),
|
|
737
|
+
translator.translate_expr(n=node.function, schema=schema),
|
|
738
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.partition_by),
|
|
647
739
|
)
|
|
648
740
|
assert_never(node.options)
|
|
649
741
|
|
|
650
742
|
|
|
651
743
|
@_translate_expr.register
|
|
652
|
-
def _(
|
|
744
|
+
def _(
|
|
745
|
+
node: pl_expr.Literal, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
746
|
+
) -> expr.Expr:
|
|
653
747
|
if isinstance(node.value, plrs.PySeries):
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
748
|
+
data = pl.Series._from_pyseries(node.value).to_arrow(
|
|
749
|
+
compat_level=dtypes.TO_ARROW_COMPAT_LEVEL
|
|
750
|
+
)
|
|
751
|
+
return expr.LiteralColumn(
|
|
752
|
+
dtype, data.cast(dtypes.downcast_arrow_lists(data.type))
|
|
753
|
+
)
|
|
754
|
+
if dtype.id() == plc.TypeId.LIST: # pragma: no cover
|
|
755
|
+
# TODO: Find an alternative to pa.infer_type
|
|
756
|
+
data = pa.array(node.value, type=pa.infer_type(node.value))
|
|
757
|
+
return expr.LiteralColumn(dtype, data)
|
|
758
|
+
return expr.Literal(dtype, node.value)
|
|
657
759
|
|
|
658
760
|
|
|
659
761
|
@_translate_expr.register
|
|
660
|
-
def _(
|
|
762
|
+
def _(
|
|
763
|
+
node: pl_expr.Sort, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
764
|
+
) -> expr.Expr:
|
|
661
765
|
# TODO: raise in groupby
|
|
662
|
-
return expr.Sort(
|
|
766
|
+
return expr.Sort(
|
|
767
|
+
dtype, node.options, translator.translate_expr(n=node.expr, schema=schema)
|
|
768
|
+
)
|
|
663
769
|
|
|
664
770
|
|
|
665
771
|
@_translate_expr.register
|
|
666
|
-
def _(
|
|
772
|
+
def _(
|
|
773
|
+
node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
774
|
+
) -> expr.Expr:
|
|
667
775
|
options = node.sort_options
|
|
668
776
|
return expr.SortBy(
|
|
669
777
|
dtype,
|
|
670
778
|
(options[0], tuple(options[1]), tuple(options[2])),
|
|
671
|
-
translator.translate_expr(n=node.expr),
|
|
672
|
-
*(translator.translate_expr(n=n) for n in node.by),
|
|
779
|
+
translator.translate_expr(n=node.expr, schema=schema),
|
|
780
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.by),
|
|
673
781
|
)
|
|
674
782
|
|
|
675
783
|
|
|
676
784
|
@_translate_expr.register
|
|
677
|
-
def _(
|
|
785
|
+
def _(
|
|
786
|
+
node: pl_expr.Slice, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
787
|
+
) -> expr.Expr:
|
|
788
|
+
offset = translator.translate_expr(n=node.offset, schema=schema)
|
|
789
|
+
length = translator.translate_expr(n=node.length, schema=schema)
|
|
790
|
+
assert isinstance(offset, expr.Literal)
|
|
791
|
+
assert isinstance(length, expr.Literal)
|
|
792
|
+
return expr.Slice(
|
|
793
|
+
dtype,
|
|
794
|
+
offset.value,
|
|
795
|
+
length.value,
|
|
796
|
+
translator.translate_expr(n=node.input, schema=schema),
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
@_translate_expr.register
|
|
801
|
+
def _(
|
|
802
|
+
node: pl_expr.Gather, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
803
|
+
) -> expr.Expr:
|
|
678
804
|
return expr.Gather(
|
|
679
805
|
dtype,
|
|
680
|
-
translator.translate_expr(n=node.expr),
|
|
681
|
-
translator.translate_expr(n=node.idx),
|
|
806
|
+
translator.translate_expr(n=node.expr, schema=schema),
|
|
807
|
+
translator.translate_expr(n=node.idx, schema=schema),
|
|
682
808
|
)
|
|
683
809
|
|
|
684
810
|
|
|
685
811
|
@_translate_expr.register
|
|
686
|
-
def _(
|
|
812
|
+
def _(
|
|
813
|
+
node: pl_expr.Filter, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
814
|
+
) -> expr.Expr:
|
|
687
815
|
return expr.Filter(
|
|
688
816
|
dtype,
|
|
689
|
-
translator.translate_expr(n=node.input),
|
|
690
|
-
translator.translate_expr(n=node.by),
|
|
817
|
+
translator.translate_expr(n=node.input, schema=schema),
|
|
818
|
+
translator.translate_expr(n=node.by, schema=schema),
|
|
691
819
|
)
|
|
692
820
|
|
|
693
821
|
|
|
694
822
|
@_translate_expr.register
|
|
695
|
-
def _(
|
|
696
|
-
|
|
823
|
+
def _(
|
|
824
|
+
node: pl_expr.Cast, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
825
|
+
) -> expr.Expr:
|
|
826
|
+
inner = translator.translate_expr(n=node.expr, schema=schema)
|
|
697
827
|
# Push casts into literals so we can handle Cast(Literal(Null))
|
|
698
828
|
if isinstance(inner, expr.Literal):
|
|
699
|
-
|
|
829
|
+
plc_column = plc.Column.from_scalar(
|
|
830
|
+
plc.Scalar.from_py(inner.value, inner.dtype), 1
|
|
831
|
+
)
|
|
832
|
+
casted_column = plc.unary.cast(plc_column, dtype)
|
|
833
|
+
casted_py_scalar = plc.interop.to_arrow(
|
|
834
|
+
plc.copying.get_element(casted_column, 0)
|
|
835
|
+
).as_py()
|
|
836
|
+
return expr.Literal(dtype, casted_py_scalar)
|
|
700
837
|
elif isinstance(inner, expr.Cast):
|
|
701
838
|
# Translation of Len/Count-agg put in a cast, remove double
|
|
702
839
|
# casts if we have one.
|
|
@@ -705,17 +842,21 @@ def _(node: pl_expr.Cast, translator: Translator, dtype: plc.DataType) -> expr.E
|
|
|
705
842
|
|
|
706
843
|
|
|
707
844
|
@_translate_expr.register
|
|
708
|
-
def _(
|
|
845
|
+
def _(
|
|
846
|
+
node: pl_expr.Column, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
847
|
+
) -> expr.Expr:
|
|
709
848
|
return expr.Col(dtype, node.name)
|
|
710
849
|
|
|
711
850
|
|
|
712
851
|
@_translate_expr.register
|
|
713
|
-
def _(
|
|
852
|
+
def _(
|
|
853
|
+
node: pl_expr.Agg, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
854
|
+
) -> expr.Expr:
|
|
714
855
|
value = expr.Agg(
|
|
715
856
|
dtype,
|
|
716
857
|
node.name,
|
|
717
858
|
node.options,
|
|
718
|
-
*(translator.translate_expr(n=n) for n in node.arguments),
|
|
859
|
+
*(translator.translate_expr(n=n, schema=schema) for n in node.arguments),
|
|
719
860
|
)
|
|
720
861
|
if value.name == "count" and value.dtype.id() != plc.TypeId.INT32:
|
|
721
862
|
return expr.Cast(value.dtype, value)
|
|
@@ -723,29 +864,36 @@ def _(node: pl_expr.Agg, translator: Translator, dtype: plc.DataType) -> expr.Ex
|
|
|
723
864
|
|
|
724
865
|
|
|
725
866
|
@_translate_expr.register
|
|
726
|
-
def _(
|
|
867
|
+
def _(
|
|
868
|
+
node: pl_expr.Ternary, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
869
|
+
) -> expr.Expr:
|
|
727
870
|
return expr.Ternary(
|
|
728
871
|
dtype,
|
|
729
|
-
translator.translate_expr(n=node.predicate),
|
|
730
|
-
translator.translate_expr(n=node.truthy),
|
|
731
|
-
translator.translate_expr(n=node.falsy),
|
|
872
|
+
translator.translate_expr(n=node.predicate, schema=schema),
|
|
873
|
+
translator.translate_expr(n=node.truthy, schema=schema),
|
|
874
|
+
translator.translate_expr(n=node.falsy, schema=schema),
|
|
732
875
|
)
|
|
733
876
|
|
|
734
877
|
|
|
735
878
|
@_translate_expr.register
|
|
736
879
|
def _(
|
|
737
|
-
node: pl_expr.BinaryExpr,
|
|
880
|
+
node: pl_expr.BinaryExpr,
|
|
881
|
+
translator: Translator,
|
|
882
|
+
dtype: plc.DataType,
|
|
883
|
+
schema: Schema,
|
|
738
884
|
) -> expr.Expr:
|
|
739
885
|
return expr.BinOp(
|
|
740
886
|
dtype,
|
|
741
887
|
expr.BinOp._MAPPING[node.op],
|
|
742
|
-
translator.translate_expr(n=node.left),
|
|
743
|
-
translator.translate_expr(n=node.right),
|
|
888
|
+
translator.translate_expr(n=node.left, schema=schema),
|
|
889
|
+
translator.translate_expr(n=node.right, schema=schema),
|
|
744
890
|
)
|
|
745
891
|
|
|
746
892
|
|
|
747
893
|
@_translate_expr.register
|
|
748
|
-
def _(
|
|
894
|
+
def _(
|
|
895
|
+
node: pl_expr.Len, translator: Translator, dtype: plc.DataType, schema: Schema
|
|
896
|
+
) -> expr.Expr:
|
|
749
897
|
value = expr.Len(dtype)
|
|
750
898
|
if dtype.id() != plc.TypeId.INT32:
|
|
751
899
|
return expr.Cast(dtype, value)
|