cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +60 -15
- cudf_polars/containers/column.py +137 -77
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +256 -114
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +33 -3
- cudf_polars/dsl/expressions/unary.py +126 -64
- cudf_polars/dsl/ir.py +1053 -350
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +307 -107
- cudf_polars/dsl/utils/aggregations.py +43 -30
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +55 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +792 -2
- cudf_polars/experimental/benchmarks/utils.py +596 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +30 -15
- cudf_polars/experimental/groupby.py +25 -4
- cudf_polars/experimental/io.py +156 -124
- cudf_polars/experimental/join.py +53 -23
- cudf_polars/experimental/parallel.py +68 -19
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
- cudf_polars/experimental/rapidsmpf/core.py +488 -0
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
- cudf_polars/experimental/rapidsmpf/io.py +696 -0
- cudf_polars/experimental/rapidsmpf/join.py +322 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
- cudf_polars/experimental/rapidsmpf/union.py +115 -0
- cudf_polars/experimental/rapidsmpf/utils.py +374 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +46 -12
- cudf_polars/experimental/sort.py +100 -26
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +93 -17
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +473 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +5 -4
- cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
- cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/translate.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024-
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
"""Translate polars IR representation to ours."""
|
|
@@ -14,8 +14,11 @@ from typing import TYPE_CHECKING, Any
|
|
|
14
14
|
from typing_extensions import assert_never
|
|
15
15
|
|
|
16
16
|
import polars as pl
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
|
|
18
|
+
# polars.polars is not a part of the public API,
|
|
19
|
+
# so we cannot rely on importing it directly
|
|
20
|
+
# See https://github.com/pola-rs/polars/issues/24826
|
|
21
|
+
from polars import polars as plrs # type: ignore[attr-defined]
|
|
19
22
|
|
|
20
23
|
import pylibcudf as plc
|
|
21
24
|
|
|
@@ -33,6 +36,8 @@ from cudf_polars.utils import config, sorting
|
|
|
33
36
|
from cudf_polars.utils.versions import (
|
|
34
37
|
POLARS_VERSION_LT_131,
|
|
35
38
|
POLARS_VERSION_LT_132,
|
|
39
|
+
POLARS_VERSION_LT_133,
|
|
40
|
+
POLARS_VERSION_LT_134,
|
|
36
41
|
POLARS_VERSION_LT_1323,
|
|
37
42
|
)
|
|
38
43
|
|
|
@@ -61,6 +66,7 @@ class Translator:
|
|
|
61
66
|
self.config_options = config.ConfigOptions.from_polars_engine(engine)
|
|
62
67
|
self.errors: list[Exception] = []
|
|
63
68
|
self._cache_nodes: dict[int, ir.Cache] = {}
|
|
69
|
+
self._expr_context: ExecutionContext = ExecutionContext.FRAME
|
|
64
70
|
|
|
65
71
|
def translate_ir(self, *, n: int | None = None) -> ir.IR:
|
|
66
72
|
"""
|
|
@@ -96,7 +102,7 @@ class Translator:
|
|
|
96
102
|
# IR is versioned with major.minor, minor is bumped for backwards
|
|
97
103
|
# compatible changes (e.g. adding new nodes), major is bumped for
|
|
98
104
|
# incompatible changes (e.g. renaming nodes).
|
|
99
|
-
if (version := self.visitor.version()) >= (
|
|
105
|
+
if (version := self.visitor.version()) >= (11, 1):
|
|
100
106
|
e = NotImplementedError(
|
|
101
107
|
f"No support for polars IR {version=}"
|
|
102
108
|
) # pragma: no cover; no such version for now.
|
|
@@ -201,6 +207,23 @@ class set_node(AbstractContextManager[None]):
|
|
|
201
207
|
noop_context: nullcontext[None] = nullcontext()
|
|
202
208
|
|
|
203
209
|
|
|
210
|
+
class set_expr_context(AbstractContextManager[None]):
|
|
211
|
+
__slots__ = ("_prev", "ctx", "translator")
|
|
212
|
+
|
|
213
|
+
def __init__(self, translator: Translator, ctx: ExecutionContext) -> None:
|
|
214
|
+
self.translator = translator
|
|
215
|
+
self.ctx = ctx
|
|
216
|
+
self._prev: ExecutionContext | None = None
|
|
217
|
+
|
|
218
|
+
def __enter__(self) -> None:
|
|
219
|
+
self._prev = self.translator._expr_context
|
|
220
|
+
self.translator._expr_context = self.ctx
|
|
221
|
+
|
|
222
|
+
def __exit__(self, *args: Any) -> None:
|
|
223
|
+
assert self._prev is not None
|
|
224
|
+
self.translator._expr_context = self._prev
|
|
225
|
+
|
|
226
|
+
|
|
204
227
|
@singledispatch
|
|
205
228
|
def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR:
|
|
206
229
|
raise NotImplementedError(
|
|
@@ -209,7 +232,7 @@ def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
209
232
|
|
|
210
233
|
|
|
211
234
|
@_translate_ir.register
|
|
212
|
-
def _(node:
|
|
235
|
+
def _(node: plrs._ir_nodes.PythonScan, translator: Translator, schema: Schema) -> ir.IR:
|
|
213
236
|
scan_fn, with_columns, source_type, predicate, nrows = node.options
|
|
214
237
|
options = (scan_fn, with_columns, source_type, nrows)
|
|
215
238
|
predicate = (
|
|
@@ -221,7 +244,7 @@ def _(node: pl_ir.PythonScan, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
221
244
|
|
|
222
245
|
|
|
223
246
|
@_translate_ir.register
|
|
224
|
-
def _(node:
|
|
247
|
+
def _(node: plrs._ir_nodes.Scan, translator: Translator, schema: Schema) -> ir.IR:
|
|
225
248
|
typ, *options = node.scan_type
|
|
226
249
|
paths = node.paths
|
|
227
250
|
# Polars can produce a Scan with an empty ``node.paths`` (eg. the native
|
|
@@ -254,6 +277,9 @@ def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
254
277
|
skip_rows = 0
|
|
255
278
|
else:
|
|
256
279
|
skip_rows, n_rows = pre_slice
|
|
280
|
+
if (n_rows == 2**32 - 1) or (n_rows == 2**64 - 1):
|
|
281
|
+
# Polars translates slice(10, None) -> (10, u32/64max)
|
|
282
|
+
n_rows = -1
|
|
257
283
|
|
|
258
284
|
return ir.Scan(
|
|
259
285
|
schema,
|
|
@@ -274,7 +300,7 @@ def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
274
300
|
|
|
275
301
|
|
|
276
302
|
@_translate_ir.register
|
|
277
|
-
def _(node:
|
|
303
|
+
def _(node: plrs._ir_nodes.Cache, translator: Translator, schema: Schema) -> ir.IR:
|
|
278
304
|
if POLARS_VERSION_LT_1323: # pragma: no cover
|
|
279
305
|
refcount = node.cache_hits
|
|
280
306
|
else:
|
|
@@ -293,7 +319,9 @@ def _(node: pl_ir.Cache, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
293
319
|
|
|
294
320
|
|
|
295
321
|
@_translate_ir.register
|
|
296
|
-
def _(
|
|
322
|
+
def _(
|
|
323
|
+
node: plrs._ir_nodes.DataFrameScan, translator: Translator, schema: Schema
|
|
324
|
+
) -> ir.IR:
|
|
297
325
|
return ir.DataFrameScan(
|
|
298
326
|
schema,
|
|
299
327
|
node.df,
|
|
@@ -302,7 +330,7 @@ def _(node: pl_ir.DataFrameScan, translator: Translator, schema: Schema) -> ir.I
|
|
|
302
330
|
|
|
303
331
|
|
|
304
332
|
@_translate_ir.register
|
|
305
|
-
def _(node:
|
|
333
|
+
def _(node: plrs._ir_nodes.Select, translator: Translator, schema: Schema) -> ir.IR:
|
|
306
334
|
with set_node(translator.visitor, node.input):
|
|
307
335
|
inp = translator.translate_ir(n=None)
|
|
308
336
|
exprs = [
|
|
@@ -312,15 +340,17 @@ def _(node: pl_ir.Select, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
312
340
|
|
|
313
341
|
|
|
314
342
|
@_translate_ir.register
|
|
315
|
-
def _(node:
|
|
343
|
+
def _(node: plrs._ir_nodes.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
|
|
316
344
|
with set_node(translator.visitor, node.input):
|
|
317
345
|
inp = translator.translate_ir(n=None)
|
|
318
346
|
keys = [
|
|
319
347
|
translate_named_expr(translator, n=e, schema=inp.schema) for e in node.keys
|
|
320
348
|
]
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
349
|
+
with set_expr_context(translator, ExecutionContext.GROUPBY):
|
|
350
|
+
original_aggs = [
|
|
351
|
+
translate_named_expr(translator, n=e, schema=inp.schema)
|
|
352
|
+
for e in node.aggs
|
|
353
|
+
]
|
|
324
354
|
is_rolling = node.options.rolling is not None
|
|
325
355
|
is_dynamic = node.options.dynamic is not None
|
|
326
356
|
if is_dynamic:
|
|
@@ -333,8 +363,34 @@ def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
333
363
|
return rewrite_groupby(node, schema, keys, original_aggs, inp)
|
|
334
364
|
|
|
335
365
|
|
|
366
|
+
_DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def _align_decimal_scales(
|
|
370
|
+
left: expr.Expr, right: expr.Expr
|
|
371
|
+
) -> tuple[expr.Expr, expr.Expr]:
|
|
372
|
+
left_type, right_type = left.dtype, right.dtype
|
|
373
|
+
|
|
374
|
+
if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point(
|
|
375
|
+
right_type.plc_type
|
|
376
|
+
):
|
|
377
|
+
target = DataType.common_decimal_dtype(left_type, right_type)
|
|
378
|
+
|
|
379
|
+
if (
|
|
380
|
+
left_type.id() != target.id() or left_type.scale() != target.scale()
|
|
381
|
+
): # pragma: no cover; no test yet
|
|
382
|
+
left = expr.Cast(target, True, left) # noqa: FBT003
|
|
383
|
+
|
|
384
|
+
if (
|
|
385
|
+
right_type.id() != target.id() or right_type.scale() != target.scale()
|
|
386
|
+
): # pragma: no cover; no test yet
|
|
387
|
+
right = expr.Cast(target, True, right) # noqa: FBT003
|
|
388
|
+
|
|
389
|
+
return left, right
|
|
390
|
+
|
|
391
|
+
|
|
336
392
|
@_translate_ir.register
|
|
337
|
-
def _(node:
|
|
393
|
+
def _(node: plrs._ir_nodes.Join, translator: Translator, schema: Schema) -> ir.IR:
|
|
338
394
|
# Join key dtypes are dependent on the schema of the left and
|
|
339
395
|
# right inputs, so these must be translated with the relevant
|
|
340
396
|
# input active.
|
|
@@ -388,22 +444,24 @@ def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
388
444
|
expr.BinOp(
|
|
389
445
|
dtype,
|
|
390
446
|
expr.BinOp._MAPPING[op],
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
447
|
+
*_align_decimal_scales(
|
|
448
|
+
insert_colrefs(
|
|
449
|
+
left_ne.value,
|
|
450
|
+
table_ref=plc.expressions.TableReference.LEFT,
|
|
451
|
+
name_to_index={
|
|
452
|
+
name: i for i, name in enumerate(inp_left.schema)
|
|
453
|
+
},
|
|
454
|
+
),
|
|
455
|
+
insert_colrefs(
|
|
456
|
+
right_ne.value,
|
|
457
|
+
table_ref=plc.expressions.TableReference.RIGHT,
|
|
458
|
+
name_to_index={
|
|
459
|
+
name: i for i, name in enumerate(inp_right.schema)
|
|
460
|
+
},
|
|
461
|
+
),
|
|
404
462
|
),
|
|
405
463
|
)
|
|
406
|
-
for op,
|
|
464
|
+
for op, left_ne, right_ne in zip(ops, left_on, right_on, strict=True)
|
|
407
465
|
),
|
|
408
466
|
)
|
|
409
467
|
|
|
@@ -411,7 +469,7 @@ def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
411
469
|
|
|
412
470
|
|
|
413
471
|
@_translate_ir.register
|
|
414
|
-
def _(node:
|
|
472
|
+
def _(node: plrs._ir_nodes.HStack, translator: Translator, schema: Schema) -> ir.IR:
|
|
415
473
|
with set_node(translator.visitor, node.input):
|
|
416
474
|
inp = translator.translate_ir(n=None)
|
|
417
475
|
exprs = [
|
|
@@ -422,7 +480,7 @@ def _(node: pl_ir.HStack, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
422
480
|
|
|
423
481
|
@_translate_ir.register
|
|
424
482
|
def _(
|
|
425
|
-
node:
|
|
483
|
+
node: plrs._ir_nodes.Reduce, translator: Translator, schema: Schema
|
|
426
484
|
) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet
|
|
427
485
|
with set_node(translator.visitor, node.input):
|
|
428
486
|
inp = translator.translate_ir(n=None)
|
|
@@ -433,7 +491,7 @@ def _(
|
|
|
433
491
|
|
|
434
492
|
|
|
435
493
|
@_translate_ir.register
|
|
436
|
-
def _(node:
|
|
494
|
+
def _(node: plrs._ir_nodes.Distinct, translator: Translator, schema: Schema) -> ir.IR:
|
|
437
495
|
(keep, subset, maintain_order, zlice) = node.options
|
|
438
496
|
keep = ir.Distinct._KEEP_MAP[keep]
|
|
439
497
|
subset = frozenset(subset) if subset is not None else None
|
|
@@ -448,7 +506,7 @@ def _(node: pl_ir.Distinct, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
448
506
|
|
|
449
507
|
|
|
450
508
|
@_translate_ir.register
|
|
451
|
-
def _(node:
|
|
509
|
+
def _(node: plrs._ir_nodes.Sort, translator: Translator, schema: Schema) -> ir.IR:
|
|
452
510
|
with set_node(translator.visitor, node.input):
|
|
453
511
|
inp = translator.translate_ir(n=None)
|
|
454
512
|
by = [
|
|
@@ -463,14 +521,14 @@ def _(node: pl_ir.Sort, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
463
521
|
|
|
464
522
|
|
|
465
523
|
@_translate_ir.register
|
|
466
|
-
def _(node:
|
|
524
|
+
def _(node: plrs._ir_nodes.Slice, translator: Translator, schema: Schema) -> ir.IR:
|
|
467
525
|
return ir.Slice(
|
|
468
526
|
schema, node.offset, node.len, translator.translate_ir(n=node.input)
|
|
469
527
|
)
|
|
470
528
|
|
|
471
529
|
|
|
472
530
|
@_translate_ir.register
|
|
473
|
-
def _(node:
|
|
531
|
+
def _(node: plrs._ir_nodes.Filter, translator: Translator, schema: Schema) -> ir.IR:
|
|
474
532
|
with set_node(translator.visitor, node.input):
|
|
475
533
|
inp = translator.translate_ir(n=None)
|
|
476
534
|
mask = translate_named_expr(translator, n=node.predicate, schema=inp.schema)
|
|
@@ -478,12 +536,16 @@ def _(node: pl_ir.Filter, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
478
536
|
|
|
479
537
|
|
|
480
538
|
@_translate_ir.register
|
|
481
|
-
def _(
|
|
539
|
+
def _(
|
|
540
|
+
node: plrs._ir_nodes.SimpleProjection, translator: Translator, schema: Schema
|
|
541
|
+
) -> ir.IR:
|
|
482
542
|
return ir.Projection(schema, translator.translate_ir(n=node.input))
|
|
483
543
|
|
|
484
544
|
|
|
485
545
|
@_translate_ir.register
|
|
486
|
-
def _(
|
|
546
|
+
def _(
|
|
547
|
+
node: plrs._ir_nodes.MergeSorted, translator: Translator, schema: Schema
|
|
548
|
+
) -> ir.IR:
|
|
487
549
|
key = node.key
|
|
488
550
|
inp_left = translator.translate_ir(n=node.input_left)
|
|
489
551
|
inp_right = translator.translate_ir(n=node.input_right)
|
|
@@ -496,7 +558,9 @@ def _(node: pl_ir.MergeSorted, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
496
558
|
|
|
497
559
|
|
|
498
560
|
@_translate_ir.register
|
|
499
|
-
def _(
|
|
561
|
+
def _(
|
|
562
|
+
node: plrs._ir_nodes.MapFunction, translator: Translator, schema: Schema
|
|
563
|
+
) -> ir.IR:
|
|
500
564
|
name, *options = node.function
|
|
501
565
|
return ir.MapFunction(
|
|
502
566
|
schema,
|
|
@@ -507,14 +571,14 @@ def _(node: pl_ir.MapFunction, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
507
571
|
|
|
508
572
|
|
|
509
573
|
@_translate_ir.register
|
|
510
|
-
def _(node:
|
|
574
|
+
def _(node: plrs._ir_nodes.Union, translator: Translator, schema: Schema) -> ir.IR:
|
|
511
575
|
return ir.Union(
|
|
512
576
|
schema, node.options, *(translator.translate_ir(n=n) for n in node.inputs)
|
|
513
577
|
)
|
|
514
578
|
|
|
515
579
|
|
|
516
580
|
@_translate_ir.register
|
|
517
|
-
def _(node:
|
|
581
|
+
def _(node: plrs._ir_nodes.HConcat, translator: Translator, schema: Schema) -> ir.IR:
|
|
518
582
|
return ir.HConcat(
|
|
519
583
|
schema,
|
|
520
584
|
False, # noqa: FBT003
|
|
@@ -523,7 +587,7 @@ def _(node: pl_ir.HConcat, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
523
587
|
|
|
524
588
|
|
|
525
589
|
@_translate_ir.register
|
|
526
|
-
def _(node:
|
|
590
|
+
def _(node: plrs._ir_nodes.Sink, translator: Translator, schema: Schema) -> ir.IR:
|
|
527
591
|
payload = json.loads(node.payload)
|
|
528
592
|
try:
|
|
529
593
|
file = payload["File"]
|
|
@@ -556,7 +620,7 @@ def _(node: pl_ir.Sink, translator: Translator, schema: Schema) -> ir.IR:
|
|
|
556
620
|
|
|
557
621
|
|
|
558
622
|
def translate_named_expr(
|
|
559
|
-
translator: Translator, *, n:
|
|
623
|
+
translator: Translator, *, n: plrs._expr_nodes.PyExprIR, schema: Schema
|
|
560
624
|
) -> expr.NamedExpr:
|
|
561
625
|
"""
|
|
562
626
|
Translate a polars-internal named expression IR object into our representation.
|
|
@@ -602,15 +666,18 @@ def _translate_expr(
|
|
|
602
666
|
|
|
603
667
|
@_translate_expr.register
|
|
604
668
|
def _(
|
|
605
|
-
node:
|
|
669
|
+
node: plrs._expr_nodes.Function,
|
|
670
|
+
translator: Translator,
|
|
671
|
+
dtype: DataType,
|
|
672
|
+
schema: Schema,
|
|
606
673
|
) -> expr.Expr:
|
|
607
674
|
name, *options = node.function_data
|
|
608
675
|
options = tuple(options)
|
|
609
|
-
if isinstance(name,
|
|
676
|
+
if isinstance(name, plrs._expr_nodes.StringFunction):
|
|
610
677
|
if name in {
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
678
|
+
plrs._expr_nodes.StringFunction.StripChars,
|
|
679
|
+
plrs._expr_nodes.StringFunction.StripCharsStart,
|
|
680
|
+
plrs._expr_nodes.StringFunction.StripCharsEnd,
|
|
614
681
|
}:
|
|
615
682
|
column, chars = (
|
|
616
683
|
translator.translate_expr(n=n, schema=schema) for n in node.input
|
|
@@ -639,8 +706,8 @@ def _(
|
|
|
639
706
|
options,
|
|
640
707
|
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
641
708
|
)
|
|
642
|
-
elif isinstance(name,
|
|
643
|
-
if name ==
|
|
709
|
+
elif isinstance(name, plrs._expr_nodes.BooleanFunction):
|
|
710
|
+
if name == plrs._expr_nodes.BooleanFunction.IsBetween:
|
|
644
711
|
column, lo, hi = (
|
|
645
712
|
translator.translate_expr(n=n, schema=schema) for n in node.input
|
|
646
713
|
)
|
|
@@ -658,19 +725,19 @@ def _(
|
|
|
658
725
|
options,
|
|
659
726
|
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
660
727
|
)
|
|
661
|
-
elif isinstance(name,
|
|
728
|
+
elif isinstance(name, plrs._expr_nodes.TemporalFunction):
|
|
662
729
|
# functions for which evaluation of the expression may not return
|
|
663
730
|
# the same dtype as polars, either due to libcudf returning a different
|
|
664
731
|
# dtype, or due to our internal processing affecting what libcudf returns
|
|
665
732
|
needs_cast = {
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
733
|
+
plrs._expr_nodes.TemporalFunction.Year,
|
|
734
|
+
plrs._expr_nodes.TemporalFunction.Month,
|
|
735
|
+
plrs._expr_nodes.TemporalFunction.Day,
|
|
736
|
+
plrs._expr_nodes.TemporalFunction.WeekDay,
|
|
737
|
+
plrs._expr_nodes.TemporalFunction.Hour,
|
|
738
|
+
plrs._expr_nodes.TemporalFunction.Minute,
|
|
739
|
+
plrs._expr_nodes.TemporalFunction.Second,
|
|
740
|
+
plrs._expr_nodes.TemporalFunction.Millisecond,
|
|
674
741
|
}
|
|
675
742
|
result_expr = expr.TemporalFunction(
|
|
676
743
|
dtype,
|
|
@@ -679,9 +746,11 @@ def _(
|
|
|
679
746
|
*(translator.translate_expr(n=n, schema=schema) for n in node.input),
|
|
680
747
|
)
|
|
681
748
|
if name in needs_cast:
|
|
682
|
-
return expr.Cast(dtype, result_expr)
|
|
749
|
+
return expr.Cast(dtype, True, result_expr) # noqa: FBT003
|
|
683
750
|
return result_expr
|
|
684
|
-
elif not POLARS_VERSION_LT_131 and isinstance(
|
|
751
|
+
elif not POLARS_VERSION_LT_131 and isinstance(
|
|
752
|
+
name, plrs._expr_nodes.StructFunction
|
|
753
|
+
):
|
|
685
754
|
return expr.StructFunction(
|
|
686
755
|
dtype,
|
|
687
756
|
expr.StructFunction.Name.from_polars(name),
|
|
@@ -690,15 +759,38 @@ def _(
|
|
|
690
759
|
)
|
|
691
760
|
elif isinstance(name, str):
|
|
692
761
|
children = (translator.translate_expr(n=n, schema=schema) for n in node.input)
|
|
693
|
-
if name == "log"
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
762
|
+
if name == "log" or (
|
|
763
|
+
not POLARS_VERSION_LT_133
|
|
764
|
+
and name == "l"
|
|
765
|
+
and isinstance(options[0], str)
|
|
766
|
+
and "".join((name, *options)) == "log"
|
|
767
|
+
):
|
|
768
|
+
if POLARS_VERSION_LT_133: # pragma: no cover
|
|
769
|
+
(base,) = options
|
|
770
|
+
(child,) = children
|
|
771
|
+
return expr.BinOp(
|
|
772
|
+
dtype,
|
|
773
|
+
plc.binaryop.BinaryOperator.LOG_BASE,
|
|
774
|
+
child,
|
|
775
|
+
expr.Literal(dtype, base),
|
|
776
|
+
)
|
|
777
|
+
else:
|
|
778
|
+
(child, base) = children
|
|
779
|
+
res = expr.BinOp(
|
|
780
|
+
dtype,
|
|
781
|
+
plc.binaryop.BinaryOperator.LOG_BASE,
|
|
782
|
+
child,
|
|
783
|
+
expr.Literal(dtype, base.value),
|
|
784
|
+
)
|
|
785
|
+
return (
|
|
786
|
+
res
|
|
787
|
+
if not POLARS_VERSION_LT_134
|
|
788
|
+
else expr.Cast(
|
|
789
|
+
DataType(pl.Float64()),
|
|
790
|
+
True, # noqa: FBT003
|
|
791
|
+
res,
|
|
792
|
+
)
|
|
793
|
+
)
|
|
702
794
|
elif name == "pow":
|
|
703
795
|
return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children)
|
|
704
796
|
return expr.UnaryFunction(dtype, name, options, *children)
|
|
@@ -709,11 +801,15 @@ def _(
|
|
|
709
801
|
|
|
710
802
|
@_translate_expr.register
|
|
711
803
|
def _(
|
|
712
|
-
node:
|
|
804
|
+
node: plrs._expr_nodes.Window,
|
|
805
|
+
translator: Translator,
|
|
806
|
+
dtype: DataType,
|
|
807
|
+
schema: Schema,
|
|
713
808
|
) -> expr.Expr:
|
|
714
|
-
if isinstance(node.options,
|
|
809
|
+
if isinstance(node.options, plrs._expr_nodes.RollingGroupOptions):
|
|
715
810
|
# pl.col("a").rolling(...)
|
|
716
|
-
|
|
811
|
+
with set_expr_context(translator, ExecutionContext.ROLLING):
|
|
812
|
+
agg = translator.translate_expr(n=node.function, schema=schema)
|
|
717
813
|
name_generator = unique_names(schema)
|
|
718
814
|
aggs, named_post_agg = decompose_single_agg(
|
|
719
815
|
expr.NamedExpr(next(name_generator), agg),
|
|
@@ -723,7 +819,7 @@ def _(
|
|
|
723
819
|
)
|
|
724
820
|
named_aggs = [agg for agg, _ in aggs]
|
|
725
821
|
orderby = node.options.index_column
|
|
726
|
-
orderby_dtype = schema[orderby].
|
|
822
|
+
orderby_dtype = schema[orderby].plc_type
|
|
727
823
|
if plc.traits.is_integral(orderby_dtype):
|
|
728
824
|
# Integer orderby column is cast in implementation to int64 in polars
|
|
729
825
|
orderby_dtype = plc.DataType(plc.TypeId.INT64)
|
|
@@ -752,9 +848,10 @@ def _(
|
|
|
752
848
|
for agg in named_aggs
|
|
753
849
|
}
|
|
754
850
|
return replace([named_post_agg.value], replacements)[0]
|
|
755
|
-
elif isinstance(node.options,
|
|
851
|
+
elif isinstance(node.options, plrs._expr_nodes.WindowMapping):
|
|
756
852
|
# pl.col("a").over(...)
|
|
757
|
-
|
|
853
|
+
with set_expr_context(translator, ExecutionContext.WINDOW):
|
|
854
|
+
agg = translator.translate_expr(n=node.function, schema=schema)
|
|
758
855
|
name_gen = unique_names(schema)
|
|
759
856
|
aggs, post = decompose_single_agg(
|
|
760
857
|
expr.NamedExpr(next(name_gen), agg),
|
|
@@ -779,20 +876,41 @@ def _(
|
|
|
779
876
|
if has_order_by
|
|
780
877
|
else None
|
|
781
878
|
)
|
|
879
|
+
|
|
880
|
+
named_aggs = [agg for agg, _ in aggs]
|
|
881
|
+
|
|
882
|
+
by_exprs = [
|
|
883
|
+
translator.translate_expr(n=n, schema=schema) for n in node.partition_by
|
|
884
|
+
]
|
|
885
|
+
|
|
886
|
+
child_deps = [
|
|
887
|
+
v.children[0]
|
|
888
|
+
for ne in named_aggs
|
|
889
|
+
for v in (ne.value,)
|
|
890
|
+
if isinstance(v, expr.Agg)
|
|
891
|
+
or (
|
|
892
|
+
isinstance(v, expr.UnaryFunction)
|
|
893
|
+
and v.name in {"rank", "fill_null_with_strategy", "cum_sum"}
|
|
894
|
+
)
|
|
895
|
+
]
|
|
896
|
+
children = (*by_exprs, *((order_by_expr,) if has_order_by else ()), *child_deps)
|
|
782
897
|
return expr.GroupedRollingWindow(
|
|
783
898
|
dtype,
|
|
784
899
|
(mapping, has_order_by, descending, nulls_last),
|
|
785
|
-
|
|
900
|
+
named_aggs,
|
|
786
901
|
post,
|
|
787
|
-
|
|
788
|
-
|
|
902
|
+
len(by_exprs),
|
|
903
|
+
*children,
|
|
789
904
|
)
|
|
790
905
|
assert_never(node.options)
|
|
791
906
|
|
|
792
907
|
|
|
793
908
|
@_translate_expr.register
|
|
794
909
|
def _(
|
|
795
|
-
node:
|
|
910
|
+
node: plrs._expr_nodes.Literal,
|
|
911
|
+
translator: Translator,
|
|
912
|
+
dtype: DataType,
|
|
913
|
+
schema: Schema,
|
|
796
914
|
) -> expr.Expr:
|
|
797
915
|
if isinstance(node.value, plrs.PySeries):
|
|
798
916
|
return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value))
|
|
@@ -804,7 +922,7 @@ def _(
|
|
|
804
922
|
|
|
805
923
|
@_translate_expr.register
|
|
806
924
|
def _(
|
|
807
|
-
node:
|
|
925
|
+
node: plrs._expr_nodes.Sort, translator: Translator, dtype: DataType, schema: Schema
|
|
808
926
|
) -> expr.Expr:
|
|
809
927
|
# TODO: raise in groupby
|
|
810
928
|
return expr.Sort(
|
|
@@ -814,7 +932,10 @@ def _(
|
|
|
814
932
|
|
|
815
933
|
@_translate_expr.register
|
|
816
934
|
def _(
|
|
817
|
-
node:
|
|
935
|
+
node: plrs._expr_nodes.SortBy,
|
|
936
|
+
translator: Translator,
|
|
937
|
+
dtype: DataType,
|
|
938
|
+
schema: Schema,
|
|
818
939
|
) -> expr.Expr:
|
|
819
940
|
options = node.sort_options
|
|
820
941
|
return expr.SortBy(
|
|
@@ -827,7 +948,10 @@ def _(
|
|
|
827
948
|
|
|
828
949
|
@_translate_expr.register
|
|
829
950
|
def _(
|
|
830
|
-
node:
|
|
951
|
+
node: plrs._expr_nodes.Slice,
|
|
952
|
+
translator: Translator,
|
|
953
|
+
dtype: DataType,
|
|
954
|
+
schema: Schema,
|
|
831
955
|
) -> expr.Expr:
|
|
832
956
|
offset = translator.translate_expr(n=node.offset, schema=schema)
|
|
833
957
|
length = translator.translate_expr(n=node.length, schema=schema)
|
|
@@ -843,7 +967,10 @@ def _(
|
|
|
843
967
|
|
|
844
968
|
@_translate_expr.register
|
|
845
969
|
def _(
|
|
846
|
-
node:
|
|
970
|
+
node: plrs._expr_nodes.Gather,
|
|
971
|
+
translator: Translator,
|
|
972
|
+
dtype: DataType,
|
|
973
|
+
schema: Schema,
|
|
847
974
|
) -> expr.Expr:
|
|
848
975
|
return expr.Gather(
|
|
849
976
|
dtype,
|
|
@@ -854,7 +981,10 @@ def _(
|
|
|
854
981
|
|
|
855
982
|
@_translate_expr.register
|
|
856
983
|
def _(
|
|
857
|
-
node:
|
|
984
|
+
node: plrs._expr_nodes.Filter,
|
|
985
|
+
translator: Translator,
|
|
986
|
+
dtype: DataType,
|
|
987
|
+
schema: Schema,
|
|
858
988
|
) -> expr.Expr:
|
|
859
989
|
return expr.Filter(
|
|
860
990
|
dtype,
|
|
@@ -865,44 +995,70 @@ def _(
|
|
|
865
995
|
|
|
866
996
|
@_translate_expr.register
|
|
867
997
|
def _(
|
|
868
|
-
node:
|
|
998
|
+
node: plrs._expr_nodes.Cast, translator: Translator, dtype: DataType, schema: Schema
|
|
869
999
|
) -> expr.Expr:
|
|
1000
|
+
# TODO: node.options can be 2 meaning wrap_numerical=True
|
|
1001
|
+
# don't necessarily raise because wrapping isn't always needed, but it's unhandled
|
|
1002
|
+
strict = node.options != 1
|
|
870
1003
|
inner = translator.translate_expr(n=node.expr, schema=schema)
|
|
1004
|
+
|
|
1005
|
+
if plc.traits.is_floating_point(inner.dtype.plc_type) and plc.traits.is_fixed_point(
|
|
1006
|
+
dtype.plc_type
|
|
1007
|
+
):
|
|
1008
|
+
return expr.Cast(
|
|
1009
|
+
dtype,
|
|
1010
|
+
strict,
|
|
1011
|
+
expr.UnaryFunction(
|
|
1012
|
+
inner.dtype, "round", (-dtype.plc_type.scale(), "half_to_even"), inner
|
|
1013
|
+
),
|
|
1014
|
+
)
|
|
1015
|
+
|
|
871
1016
|
# Push casts into literals so we can handle Cast(Literal(Null))
|
|
872
1017
|
if isinstance(inner, expr.Literal):
|
|
873
1018
|
return inner.astype(dtype)
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
# casts if we have one.
|
|
877
|
-
(inner,) = inner.children
|
|
878
|
-
return expr.Cast(dtype, inner)
|
|
1019
|
+
else:
|
|
1020
|
+
return expr.Cast(dtype, strict, inner)
|
|
879
1021
|
|
|
880
1022
|
|
|
881
1023
|
@_translate_expr.register
|
|
882
1024
|
def _(
|
|
883
|
-
node:
|
|
1025
|
+
node: plrs._expr_nodes.Column,
|
|
1026
|
+
translator: Translator,
|
|
1027
|
+
dtype: DataType,
|
|
1028
|
+
schema: Schema,
|
|
884
1029
|
) -> expr.Expr:
|
|
885
1030
|
return expr.Col(dtype, node.name)
|
|
886
1031
|
|
|
887
1032
|
|
|
888
1033
|
@_translate_expr.register
|
|
889
1034
|
def _(
|
|
890
|
-
node:
|
|
1035
|
+
node: plrs._expr_nodes.Agg, translator: Translator, dtype: DataType, schema: Schema
|
|
891
1036
|
) -> expr.Expr:
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
1037
|
+
agg_name = node.name
|
|
1038
|
+
args = [translator.translate_expr(n=arg, schema=schema) for arg in node.arguments]
|
|
1039
|
+
|
|
1040
|
+
if agg_name not in ("count", "n_unique", "mean", "median", "quantile"):
|
|
1041
|
+
args = [
|
|
1042
|
+
expr.Cast(dtype, True, arg) # noqa: FBT003
|
|
1043
|
+
if plc.traits.is_fixed_point(arg.dtype.plc_type)
|
|
1044
|
+
and arg.dtype.plc_type != dtype.plc_type
|
|
1045
|
+
else arg
|
|
1046
|
+
for arg in args
|
|
1047
|
+
]
|
|
1048
|
+
|
|
1049
|
+
value = expr.Agg(dtype, agg_name, node.options, translator._expr_context, *args)
|
|
1050
|
+
|
|
1051
|
+
if agg_name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32:
|
|
1052
|
+
return expr.Cast(value.dtype, True, value) # noqa: FBT003
|
|
900
1053
|
return value
|
|
901
1054
|
|
|
902
1055
|
|
|
903
1056
|
@_translate_expr.register
|
|
904
1057
|
def _(
|
|
905
|
-
node:
|
|
1058
|
+
node: plrs._expr_nodes.Ternary,
|
|
1059
|
+
translator: Translator,
|
|
1060
|
+
dtype: DataType,
|
|
1061
|
+
schema: Schema,
|
|
906
1062
|
) -> expr.Expr:
|
|
907
1063
|
return expr.Ternary(
|
|
908
1064
|
dtype,
|
|
@@ -914,26 +1070,70 @@ def _(
|
|
|
914
1070
|
|
|
915
1071
|
@_translate_expr.register
|
|
916
1072
|
def _(
|
|
917
|
-
node:
|
|
1073
|
+
node: plrs._expr_nodes.BinaryExpr,
|
|
918
1074
|
translator: Translator,
|
|
919
1075
|
dtype: DataType,
|
|
920
1076
|
schema: Schema,
|
|
921
1077
|
) -> expr.Expr:
|
|
922
|
-
|
|
923
|
-
|
|
1078
|
+
left = translator.translate_expr(n=node.left, schema=schema)
|
|
1079
|
+
right = translator.translate_expr(n=node.right, schema=schema)
|
|
1080
|
+
if (
|
|
1081
|
+
POLARS_VERSION_LT_133
|
|
1082
|
+
and plc.traits.is_boolean(dtype.plc_type)
|
|
1083
|
+
and node.op == plrs._expr_nodes.Operator.TrueDivide
|
|
1084
|
+
):
|
|
1085
|
+
dtype = DataType(pl.Float64()) # pragma: no cover
|
|
1086
|
+
if node.op == plrs._expr_nodes.Operator.TrueDivide and (
|
|
1087
|
+
plc.traits.is_fixed_point(left.dtype.plc_type)
|
|
1088
|
+
or plc.traits.is_fixed_point(right.dtype.plc_type)
|
|
1089
|
+
):
|
|
1090
|
+
f64 = DataType(pl.Float64())
|
|
1091
|
+
return expr.Cast(
|
|
1092
|
+
dtype,
|
|
1093
|
+
True, # noqa: FBT003
|
|
1094
|
+
expr.BinOp(
|
|
1095
|
+
f64,
|
|
1096
|
+
expr.BinOp._MAPPING[node.op],
|
|
1097
|
+
expr.Cast(f64, True, left), # noqa: FBT003
|
|
1098
|
+
expr.Cast(f64, True, right), # noqa: FBT003
|
|
1099
|
+
),
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1102
|
+
if (
|
|
1103
|
+
not POLARS_VERSION_LT_134
|
|
1104
|
+
and node.op == plrs._expr_nodes.Operator.Multiply
|
|
1105
|
+
and plc.traits.is_fixed_point(left.dtype.plc_type)
|
|
1106
|
+
and plc.traits.is_fixed_point(right.dtype.plc_type)
|
|
1107
|
+
):
|
|
1108
|
+
left_scale = -left.dtype.plc_type.scale()
|
|
1109
|
+
right_scale = -right.dtype.plc_type.scale()
|
|
1110
|
+
out_scale = max(left_scale, right_scale)
|
|
1111
|
+
|
|
1112
|
+
return expr.UnaryFunction(
|
|
1113
|
+
DataType(pl.Decimal(38, out_scale)),
|
|
1114
|
+
"round",
|
|
1115
|
+
(out_scale, "half_to_even"),
|
|
1116
|
+
expr.BinOp(
|
|
1117
|
+
DataType(pl.Decimal(38, left_scale + right_scale)),
|
|
1118
|
+
expr.BinOp._MAPPING[node.op],
|
|
1119
|
+
left,
|
|
1120
|
+
right,
|
|
1121
|
+
),
|
|
1122
|
+
)
|
|
1123
|
+
|
|
924
1124
|
return expr.BinOp(
|
|
925
1125
|
dtype,
|
|
926
1126
|
expr.BinOp._MAPPING[node.op],
|
|
927
|
-
|
|
928
|
-
|
|
1127
|
+
left,
|
|
1128
|
+
right,
|
|
929
1129
|
)
|
|
930
1130
|
|
|
931
1131
|
|
|
932
1132
|
@_translate_expr.register
|
|
933
1133
|
def _(
|
|
934
|
-
node:
|
|
1134
|
+
node: plrs._expr_nodes.Len, translator: Translator, dtype: DataType, schema: Schema
|
|
935
1135
|
) -> expr.Expr:
|
|
936
1136
|
value = expr.Len(dtype)
|
|
937
1137
|
if dtype.id() != plc.TypeId.INT32:
|
|
938
|
-
return expr.Cast(dtype, value)
|
|
1138
|
+
return expr.Cast(dtype, True, value) # noqa: FBT003
|
|
939
1139
|
return value # pragma: no cover; never reached since polars len has uint32 dtype
|