cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  """Translate polars IR representation to ours."""
@@ -14,8 +14,11 @@ from typing import TYPE_CHECKING, Any
14
14
  from typing_extensions import assert_never
15
15
 
16
16
  import polars as pl
17
- import polars.polars as plrs
18
- from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir
17
+
18
+ # polars.polars is not a part of the public API,
19
+ # so we cannot rely on importing it directly
20
+ # See https://github.com/pola-rs/polars/issues/24826
21
+ from polars import polars as plrs # type: ignore[attr-defined]
19
22
 
20
23
  import pylibcudf as plc
21
24
 
@@ -33,6 +36,8 @@ from cudf_polars.utils import config, sorting
33
36
  from cudf_polars.utils.versions import (
34
37
  POLARS_VERSION_LT_131,
35
38
  POLARS_VERSION_LT_132,
39
+ POLARS_VERSION_LT_133,
40
+ POLARS_VERSION_LT_134,
36
41
  POLARS_VERSION_LT_1323,
37
42
  )
38
43
 
@@ -61,6 +66,7 @@ class Translator:
61
66
  self.config_options = config.ConfigOptions.from_polars_engine(engine)
62
67
  self.errors: list[Exception] = []
63
68
  self._cache_nodes: dict[int, ir.Cache] = {}
69
+ self._expr_context: ExecutionContext = ExecutionContext.FRAME
64
70
 
65
71
  def translate_ir(self, *, n: int | None = None) -> ir.IR:
66
72
  """
@@ -96,7 +102,7 @@ class Translator:
96
102
  # IR is versioned with major.minor, minor is bumped for backwards
97
103
  # compatible changes (e.g. adding new nodes), major is bumped for
98
104
  # incompatible changes (e.g. renaming nodes).
99
- if (version := self.visitor.version()) >= (10, 1):
105
+ if (version := self.visitor.version()) >= (11, 1):
100
106
  e = NotImplementedError(
101
107
  f"No support for polars IR {version=}"
102
108
  ) # pragma: no cover; no such version for now.
@@ -201,6 +207,23 @@ class set_node(AbstractContextManager[None]):
201
207
  noop_context: nullcontext[None] = nullcontext()
202
208
 
203
209
 
210
+ class set_expr_context(AbstractContextManager[None]):
211
+ __slots__ = ("_prev", "ctx", "translator")
212
+
213
+ def __init__(self, translator: Translator, ctx: ExecutionContext) -> None:
214
+ self.translator = translator
215
+ self.ctx = ctx
216
+ self._prev: ExecutionContext | None = None
217
+
218
+ def __enter__(self) -> None:
219
+ self._prev = self.translator._expr_context
220
+ self.translator._expr_context = self.ctx
221
+
222
+ def __exit__(self, *args: Any) -> None:
223
+ assert self._prev is not None
224
+ self.translator._expr_context = self._prev
225
+
226
+
204
227
  @singledispatch
205
228
  def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR:
206
229
  raise NotImplementedError(
@@ -209,7 +232,7 @@ def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR:
209
232
 
210
233
 
211
234
  @_translate_ir.register
212
- def _(node: pl_ir.PythonScan, translator: Translator, schema: Schema) -> ir.IR:
235
+ def _(node: plrs._ir_nodes.PythonScan, translator: Translator, schema: Schema) -> ir.IR:
213
236
  scan_fn, with_columns, source_type, predicate, nrows = node.options
214
237
  options = (scan_fn, with_columns, source_type, nrows)
215
238
  predicate = (
@@ -221,7 +244,7 @@ def _(node: pl_ir.PythonScan, translator: Translator, schema: Schema) -> ir.IR:
221
244
 
222
245
 
223
246
  @_translate_ir.register
224
- def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR:
247
+ def _(node: plrs._ir_nodes.Scan, translator: Translator, schema: Schema) -> ir.IR:
225
248
  typ, *options = node.scan_type
226
249
  paths = node.paths
227
250
  # Polars can produce a Scan with an empty ``node.paths`` (eg. the native
@@ -254,6 +277,9 @@ def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR:
254
277
  skip_rows = 0
255
278
  else:
256
279
  skip_rows, n_rows = pre_slice
280
+ if (n_rows == 2**32 - 1) or (n_rows == 2**64 - 1):
281
+ # Polars translates slice(10, None) -> (10, u32/64max)
282
+ n_rows = -1
257
283
 
258
284
  return ir.Scan(
259
285
  schema,
@@ -274,7 +300,7 @@ def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR:
274
300
 
275
301
 
276
302
  @_translate_ir.register
277
- def _(node: pl_ir.Cache, translator: Translator, schema: Schema) -> ir.IR:
303
+ def _(node: plrs._ir_nodes.Cache, translator: Translator, schema: Schema) -> ir.IR:
278
304
  if POLARS_VERSION_LT_1323: # pragma: no cover
279
305
  refcount = node.cache_hits
280
306
  else:
@@ -293,7 +319,9 @@ def _(node: pl_ir.Cache, translator: Translator, schema: Schema) -> ir.IR:
293
319
 
294
320
 
295
321
  @_translate_ir.register
296
- def _(node: pl_ir.DataFrameScan, translator: Translator, schema: Schema) -> ir.IR:
322
+ def _(
323
+ node: plrs._ir_nodes.DataFrameScan, translator: Translator, schema: Schema
324
+ ) -> ir.IR:
297
325
  return ir.DataFrameScan(
298
326
  schema,
299
327
  node.df,
@@ -302,7 +330,7 @@ def _(node: pl_ir.DataFrameScan, translator: Translator, schema: Schema) -> ir.I
302
330
 
303
331
 
304
332
  @_translate_ir.register
305
- def _(node: pl_ir.Select, translator: Translator, schema: Schema) -> ir.IR:
333
+ def _(node: plrs._ir_nodes.Select, translator: Translator, schema: Schema) -> ir.IR:
306
334
  with set_node(translator.visitor, node.input):
307
335
  inp = translator.translate_ir(n=None)
308
336
  exprs = [
@@ -312,15 +340,17 @@ def _(node: pl_ir.Select, translator: Translator, schema: Schema) -> ir.IR:
312
340
 
313
341
 
314
342
  @_translate_ir.register
315
- def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
343
+ def _(node: plrs._ir_nodes.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
316
344
  with set_node(translator.visitor, node.input):
317
345
  inp = translator.translate_ir(n=None)
318
346
  keys = [
319
347
  translate_named_expr(translator, n=e, schema=inp.schema) for e in node.keys
320
348
  ]
321
- original_aggs = [
322
- translate_named_expr(translator, n=e, schema=inp.schema) for e in node.aggs
323
- ]
349
+ with set_expr_context(translator, ExecutionContext.GROUPBY):
350
+ original_aggs = [
351
+ translate_named_expr(translator, n=e, schema=inp.schema)
352
+ for e in node.aggs
353
+ ]
324
354
  is_rolling = node.options.rolling is not None
325
355
  is_dynamic = node.options.dynamic is not None
326
356
  if is_dynamic:
@@ -333,8 +363,34 @@ def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
333
363
  return rewrite_groupby(node, schema, keys, original_aggs, inp)
334
364
 
335
365
 
366
+ _DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
367
+
368
+
369
+ def _align_decimal_scales(
370
+ left: expr.Expr, right: expr.Expr
371
+ ) -> tuple[expr.Expr, expr.Expr]:
372
+ left_type, right_type = left.dtype, right.dtype
373
+
374
+ if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point(
375
+ right_type.plc_type
376
+ ):
377
+ target = DataType.common_decimal_dtype(left_type, right_type)
378
+
379
+ if (
380
+ left_type.id() != target.id() or left_type.scale() != target.scale()
381
+ ): # pragma: no cover; no test yet
382
+ left = expr.Cast(target, True, left) # noqa: FBT003
383
+
384
+ if (
385
+ right_type.id() != target.id() or right_type.scale() != target.scale()
386
+ ): # pragma: no cover; no test yet
387
+ right = expr.Cast(target, True, right) # noqa: FBT003
388
+
389
+ return left, right
390
+
391
+
336
392
  @_translate_ir.register
337
- def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
393
+ def _(node: plrs._ir_nodes.Join, translator: Translator, schema: Schema) -> ir.IR:
338
394
  # Join key dtypes are dependent on the schema of the left and
339
395
  # right inputs, so these must be translated with the relevant
340
396
  # input active.
@@ -388,22 +444,24 @@ def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
388
444
  expr.BinOp(
389
445
  dtype,
390
446
  expr.BinOp._MAPPING[op],
391
- insert_colrefs(
392
- left.value,
393
- table_ref=plc.expressions.TableReference.LEFT,
394
- name_to_index={
395
- name: i for i, name in enumerate(inp_left.schema)
396
- },
397
- ),
398
- insert_colrefs(
399
- right.value,
400
- table_ref=plc.expressions.TableReference.RIGHT,
401
- name_to_index={
402
- name: i for i, name in enumerate(inp_right.schema)
403
- },
447
+ *_align_decimal_scales(
448
+ insert_colrefs(
449
+ left_ne.value,
450
+ table_ref=plc.expressions.TableReference.LEFT,
451
+ name_to_index={
452
+ name: i for i, name in enumerate(inp_left.schema)
453
+ },
454
+ ),
455
+ insert_colrefs(
456
+ right_ne.value,
457
+ table_ref=plc.expressions.TableReference.RIGHT,
458
+ name_to_index={
459
+ name: i for i, name in enumerate(inp_right.schema)
460
+ },
461
+ ),
404
462
  ),
405
463
  )
406
- for op, left, right in zip(ops, left_on, right_on, strict=True)
464
+ for op, left_ne, right_ne in zip(ops, left_on, right_on, strict=True)
407
465
  ),
408
466
  )
409
467
 
@@ -411,7 +469,7 @@ def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
411
469
 
412
470
 
413
471
  @_translate_ir.register
414
- def _(node: pl_ir.HStack, translator: Translator, schema: Schema) -> ir.IR:
472
+ def _(node: plrs._ir_nodes.HStack, translator: Translator, schema: Schema) -> ir.IR:
415
473
  with set_node(translator.visitor, node.input):
416
474
  inp = translator.translate_ir(n=None)
417
475
  exprs = [
@@ -422,7 +480,7 @@ def _(node: pl_ir.HStack, translator: Translator, schema: Schema) -> ir.IR:
422
480
 
423
481
  @_translate_ir.register
424
482
  def _(
425
- node: pl_ir.Reduce, translator: Translator, schema: Schema
483
+ node: plrs._ir_nodes.Reduce, translator: Translator, schema: Schema
426
484
  ) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet
427
485
  with set_node(translator.visitor, node.input):
428
486
  inp = translator.translate_ir(n=None)
@@ -433,7 +491,7 @@ def _(
433
491
 
434
492
 
435
493
  @_translate_ir.register
436
- def _(node: pl_ir.Distinct, translator: Translator, schema: Schema) -> ir.IR:
494
+ def _(node: plrs._ir_nodes.Distinct, translator: Translator, schema: Schema) -> ir.IR:
437
495
  (keep, subset, maintain_order, zlice) = node.options
438
496
  keep = ir.Distinct._KEEP_MAP[keep]
439
497
  subset = frozenset(subset) if subset is not None else None
@@ -448,7 +506,7 @@ def _(node: pl_ir.Distinct, translator: Translator, schema: Schema) -> ir.IR:
448
506
 
449
507
 
450
508
  @_translate_ir.register
451
- def _(node: pl_ir.Sort, translator: Translator, schema: Schema) -> ir.IR:
509
+ def _(node: plrs._ir_nodes.Sort, translator: Translator, schema: Schema) -> ir.IR:
452
510
  with set_node(translator.visitor, node.input):
453
511
  inp = translator.translate_ir(n=None)
454
512
  by = [
@@ -463,14 +521,14 @@ def _(node: pl_ir.Sort, translator: Translator, schema: Schema) -> ir.IR:
463
521
 
464
522
 
465
523
  @_translate_ir.register
466
- def _(node: pl_ir.Slice, translator: Translator, schema: Schema) -> ir.IR:
524
+ def _(node: plrs._ir_nodes.Slice, translator: Translator, schema: Schema) -> ir.IR:
467
525
  return ir.Slice(
468
526
  schema, node.offset, node.len, translator.translate_ir(n=node.input)
469
527
  )
470
528
 
471
529
 
472
530
  @_translate_ir.register
473
- def _(node: pl_ir.Filter, translator: Translator, schema: Schema) -> ir.IR:
531
+ def _(node: plrs._ir_nodes.Filter, translator: Translator, schema: Schema) -> ir.IR:
474
532
  with set_node(translator.visitor, node.input):
475
533
  inp = translator.translate_ir(n=None)
476
534
  mask = translate_named_expr(translator, n=node.predicate, schema=inp.schema)
@@ -478,12 +536,16 @@ def _(node: pl_ir.Filter, translator: Translator, schema: Schema) -> ir.IR:
478
536
 
479
537
 
480
538
  @_translate_ir.register
481
- def _(node: pl_ir.SimpleProjection, translator: Translator, schema: Schema) -> ir.IR:
539
+ def _(
540
+ node: plrs._ir_nodes.SimpleProjection, translator: Translator, schema: Schema
541
+ ) -> ir.IR:
482
542
  return ir.Projection(schema, translator.translate_ir(n=node.input))
483
543
 
484
544
 
485
545
  @_translate_ir.register
486
- def _(node: pl_ir.MergeSorted, translator: Translator, schema: Schema) -> ir.IR:
546
+ def _(
547
+ node: plrs._ir_nodes.MergeSorted, translator: Translator, schema: Schema
548
+ ) -> ir.IR:
487
549
  key = node.key
488
550
  inp_left = translator.translate_ir(n=node.input_left)
489
551
  inp_right = translator.translate_ir(n=node.input_right)
@@ -496,7 +558,9 @@ def _(node: pl_ir.MergeSorted, translator: Translator, schema: Schema) -> ir.IR:
496
558
 
497
559
 
498
560
  @_translate_ir.register
499
- def _(node: pl_ir.MapFunction, translator: Translator, schema: Schema) -> ir.IR:
561
+ def _(
562
+ node: plrs._ir_nodes.MapFunction, translator: Translator, schema: Schema
563
+ ) -> ir.IR:
500
564
  name, *options = node.function
501
565
  return ir.MapFunction(
502
566
  schema,
@@ -507,14 +571,14 @@ def _(node: pl_ir.MapFunction, translator: Translator, schema: Schema) -> ir.IR:
507
571
 
508
572
 
509
573
  @_translate_ir.register
510
- def _(node: pl_ir.Union, translator: Translator, schema: Schema) -> ir.IR:
574
+ def _(node: plrs._ir_nodes.Union, translator: Translator, schema: Schema) -> ir.IR:
511
575
  return ir.Union(
512
576
  schema, node.options, *(translator.translate_ir(n=n) for n in node.inputs)
513
577
  )
514
578
 
515
579
 
516
580
  @_translate_ir.register
517
- def _(node: pl_ir.HConcat, translator: Translator, schema: Schema) -> ir.IR:
581
+ def _(node: plrs._ir_nodes.HConcat, translator: Translator, schema: Schema) -> ir.IR:
518
582
  return ir.HConcat(
519
583
  schema,
520
584
  False, # noqa: FBT003
@@ -523,7 +587,7 @@ def _(node: pl_ir.HConcat, translator: Translator, schema: Schema) -> ir.IR:
523
587
 
524
588
 
525
589
  @_translate_ir.register
526
- def _(node: pl_ir.Sink, translator: Translator, schema: Schema) -> ir.IR:
590
+ def _(node: plrs._ir_nodes.Sink, translator: Translator, schema: Schema) -> ir.IR:
527
591
  payload = json.loads(node.payload)
528
592
  try:
529
593
  file = payload["File"]
@@ -556,7 +620,7 @@ def _(node: pl_ir.Sink, translator: Translator, schema: Schema) -> ir.IR:
556
620
 
557
621
 
558
622
  def translate_named_expr(
559
- translator: Translator, *, n: pl_expr.PyExprIR, schema: Schema
623
+ translator: Translator, *, n: plrs._expr_nodes.PyExprIR, schema: Schema
560
624
  ) -> expr.NamedExpr:
561
625
  """
562
626
  Translate a polars-internal named expression IR object into our representation.
@@ -602,15 +666,18 @@ def _translate_expr(
602
666
 
603
667
  @_translate_expr.register
604
668
  def _(
605
- node: pl_expr.Function, translator: Translator, dtype: DataType, schema: Schema
669
+ node: plrs._expr_nodes.Function,
670
+ translator: Translator,
671
+ dtype: DataType,
672
+ schema: Schema,
606
673
  ) -> expr.Expr:
607
674
  name, *options = node.function_data
608
675
  options = tuple(options)
609
- if isinstance(name, pl_expr.StringFunction):
676
+ if isinstance(name, plrs._expr_nodes.StringFunction):
610
677
  if name in {
611
- pl_expr.StringFunction.StripChars,
612
- pl_expr.StringFunction.StripCharsStart,
613
- pl_expr.StringFunction.StripCharsEnd,
678
+ plrs._expr_nodes.StringFunction.StripChars,
679
+ plrs._expr_nodes.StringFunction.StripCharsStart,
680
+ plrs._expr_nodes.StringFunction.StripCharsEnd,
614
681
  }:
615
682
  column, chars = (
616
683
  translator.translate_expr(n=n, schema=schema) for n in node.input
@@ -639,8 +706,8 @@ def _(
639
706
  options,
640
707
  *(translator.translate_expr(n=n, schema=schema) for n in node.input),
641
708
  )
642
- elif isinstance(name, pl_expr.BooleanFunction):
643
- if name == pl_expr.BooleanFunction.IsBetween:
709
+ elif isinstance(name, plrs._expr_nodes.BooleanFunction):
710
+ if name == plrs._expr_nodes.BooleanFunction.IsBetween:
644
711
  column, lo, hi = (
645
712
  translator.translate_expr(n=n, schema=schema) for n in node.input
646
713
  )
@@ -658,19 +725,19 @@ def _(
658
725
  options,
659
726
  *(translator.translate_expr(n=n, schema=schema) for n in node.input),
660
727
  )
661
- elif isinstance(name, pl_expr.TemporalFunction):
728
+ elif isinstance(name, plrs._expr_nodes.TemporalFunction):
662
729
  # functions for which evaluation of the expression may not return
663
730
  # the same dtype as polars, either due to libcudf returning a different
664
731
  # dtype, or due to our internal processing affecting what libcudf returns
665
732
  needs_cast = {
666
- pl_expr.TemporalFunction.Year,
667
- pl_expr.TemporalFunction.Month,
668
- pl_expr.TemporalFunction.Day,
669
- pl_expr.TemporalFunction.WeekDay,
670
- pl_expr.TemporalFunction.Hour,
671
- pl_expr.TemporalFunction.Minute,
672
- pl_expr.TemporalFunction.Second,
673
- pl_expr.TemporalFunction.Millisecond,
733
+ plrs._expr_nodes.TemporalFunction.Year,
734
+ plrs._expr_nodes.TemporalFunction.Month,
735
+ plrs._expr_nodes.TemporalFunction.Day,
736
+ plrs._expr_nodes.TemporalFunction.WeekDay,
737
+ plrs._expr_nodes.TemporalFunction.Hour,
738
+ plrs._expr_nodes.TemporalFunction.Minute,
739
+ plrs._expr_nodes.TemporalFunction.Second,
740
+ plrs._expr_nodes.TemporalFunction.Millisecond,
674
741
  }
675
742
  result_expr = expr.TemporalFunction(
676
743
  dtype,
@@ -679,9 +746,11 @@ def _(
679
746
  *(translator.translate_expr(n=n, schema=schema) for n in node.input),
680
747
  )
681
748
  if name in needs_cast:
682
- return expr.Cast(dtype, result_expr)
749
+ return expr.Cast(dtype, True, result_expr) # noqa: FBT003
683
750
  return result_expr
684
- elif not POLARS_VERSION_LT_131 and isinstance(name, pl_expr.StructFunction):
751
+ elif not POLARS_VERSION_LT_131 and isinstance(
752
+ name, plrs._expr_nodes.StructFunction
753
+ ):
685
754
  return expr.StructFunction(
686
755
  dtype,
687
756
  expr.StructFunction.Name.from_polars(name),
@@ -690,15 +759,38 @@ def _(
690
759
  )
691
760
  elif isinstance(name, str):
692
761
  children = (translator.translate_expr(n=n, schema=schema) for n in node.input)
693
- if name == "log":
694
- (base,) = options
695
- (child,) = children
696
- return expr.BinOp(
697
- dtype,
698
- plc.binaryop.BinaryOperator.LOG_BASE,
699
- child,
700
- expr.Literal(dtype, base),
701
- )
762
+ if name == "log" or (
763
+ not POLARS_VERSION_LT_133
764
+ and name == "l"
765
+ and isinstance(options[0], str)
766
+ and "".join((name, *options)) == "log"
767
+ ):
768
+ if POLARS_VERSION_LT_133: # pragma: no cover
769
+ (base,) = options
770
+ (child,) = children
771
+ return expr.BinOp(
772
+ dtype,
773
+ plc.binaryop.BinaryOperator.LOG_BASE,
774
+ child,
775
+ expr.Literal(dtype, base),
776
+ )
777
+ else:
778
+ (child, base) = children
779
+ res = expr.BinOp(
780
+ dtype,
781
+ plc.binaryop.BinaryOperator.LOG_BASE,
782
+ child,
783
+ expr.Literal(dtype, base.value),
784
+ )
785
+ return (
786
+ res
787
+ if not POLARS_VERSION_LT_134
788
+ else expr.Cast(
789
+ DataType(pl.Float64()),
790
+ True, # noqa: FBT003
791
+ res,
792
+ )
793
+ )
702
794
  elif name == "pow":
703
795
  return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children)
704
796
  return expr.UnaryFunction(dtype, name, options, *children)
@@ -709,11 +801,15 @@ def _(
709
801
 
710
802
  @_translate_expr.register
711
803
  def _(
712
- node: pl_expr.Window, translator: Translator, dtype: DataType, schema: Schema
804
+ node: plrs._expr_nodes.Window,
805
+ translator: Translator,
806
+ dtype: DataType,
807
+ schema: Schema,
713
808
  ) -> expr.Expr:
714
- if isinstance(node.options, pl_expr.RollingGroupOptions):
809
+ if isinstance(node.options, plrs._expr_nodes.RollingGroupOptions):
715
810
  # pl.col("a").rolling(...)
716
- agg = translator.translate_expr(n=node.function, schema=schema)
811
+ with set_expr_context(translator, ExecutionContext.ROLLING):
812
+ agg = translator.translate_expr(n=node.function, schema=schema)
717
813
  name_generator = unique_names(schema)
718
814
  aggs, named_post_agg = decompose_single_agg(
719
815
  expr.NamedExpr(next(name_generator), agg),
@@ -723,7 +819,7 @@ def _(
723
819
  )
724
820
  named_aggs = [agg for agg, _ in aggs]
725
821
  orderby = node.options.index_column
726
- orderby_dtype = schema[orderby].plc
822
+ orderby_dtype = schema[orderby].plc_type
727
823
  if plc.traits.is_integral(orderby_dtype):
728
824
  # Integer orderby column is cast in implementation to int64 in polars
729
825
  orderby_dtype = plc.DataType(plc.TypeId.INT64)
@@ -752,9 +848,10 @@ def _(
752
848
  for agg in named_aggs
753
849
  }
754
850
  return replace([named_post_agg.value], replacements)[0]
755
- elif isinstance(node.options, pl_expr.WindowMapping):
851
+ elif isinstance(node.options, plrs._expr_nodes.WindowMapping):
756
852
  # pl.col("a").over(...)
757
- agg = translator.translate_expr(n=node.function, schema=schema)
853
+ with set_expr_context(translator, ExecutionContext.WINDOW):
854
+ agg = translator.translate_expr(n=node.function, schema=schema)
758
855
  name_gen = unique_names(schema)
759
856
  aggs, post = decompose_single_agg(
760
857
  expr.NamedExpr(next(name_gen), agg),
@@ -779,20 +876,41 @@ def _(
779
876
  if has_order_by
780
877
  else None
781
878
  )
879
+
880
+ named_aggs = [agg for agg, _ in aggs]
881
+
882
+ by_exprs = [
883
+ translator.translate_expr(n=n, schema=schema) for n in node.partition_by
884
+ ]
885
+
886
+ child_deps = [
887
+ v.children[0]
888
+ for ne in named_aggs
889
+ for v in (ne.value,)
890
+ if isinstance(v, expr.Agg)
891
+ or (
892
+ isinstance(v, expr.UnaryFunction)
893
+ and v.name in {"rank", "fill_null_with_strategy", "cum_sum"}
894
+ )
895
+ ]
896
+ children = (*by_exprs, *((order_by_expr,) if has_order_by else ()), *child_deps)
782
897
  return expr.GroupedRollingWindow(
783
898
  dtype,
784
899
  (mapping, has_order_by, descending, nulls_last),
785
- [agg for agg, _ in aggs],
900
+ named_aggs,
786
901
  post,
787
- *(translator.translate_expr(n=n, schema=schema) for n in node.partition_by),
788
- _order_by_expr=order_by_expr,
902
+ len(by_exprs),
903
+ *children,
789
904
  )
790
905
  assert_never(node.options)
791
906
 
792
907
 
793
908
  @_translate_expr.register
794
909
  def _(
795
- node: pl_expr.Literal, translator: Translator, dtype: DataType, schema: Schema
910
+ node: plrs._expr_nodes.Literal,
911
+ translator: Translator,
912
+ dtype: DataType,
913
+ schema: Schema,
796
914
  ) -> expr.Expr:
797
915
  if isinstance(node.value, plrs.PySeries):
798
916
  return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value))
@@ -804,7 +922,7 @@ def _(
804
922
 
805
923
  @_translate_expr.register
806
924
  def _(
807
- node: pl_expr.Sort, translator: Translator, dtype: DataType, schema: Schema
925
+ node: plrs._expr_nodes.Sort, translator: Translator, dtype: DataType, schema: Schema
808
926
  ) -> expr.Expr:
809
927
  # TODO: raise in groupby
810
928
  return expr.Sort(
@@ -814,7 +932,10 @@ def _(
814
932
 
815
933
  @_translate_expr.register
816
934
  def _(
817
- node: pl_expr.SortBy, translator: Translator, dtype: DataType, schema: Schema
935
+ node: plrs._expr_nodes.SortBy,
936
+ translator: Translator,
937
+ dtype: DataType,
938
+ schema: Schema,
818
939
  ) -> expr.Expr:
819
940
  options = node.sort_options
820
941
  return expr.SortBy(
@@ -827,7 +948,10 @@ def _(
827
948
 
828
949
  @_translate_expr.register
829
950
  def _(
830
- node: pl_expr.Slice, translator: Translator, dtype: DataType, schema: Schema
951
+ node: plrs._expr_nodes.Slice,
952
+ translator: Translator,
953
+ dtype: DataType,
954
+ schema: Schema,
831
955
  ) -> expr.Expr:
832
956
  offset = translator.translate_expr(n=node.offset, schema=schema)
833
957
  length = translator.translate_expr(n=node.length, schema=schema)
@@ -843,7 +967,10 @@ def _(
843
967
 
844
968
  @_translate_expr.register
845
969
  def _(
846
- node: pl_expr.Gather, translator: Translator, dtype: DataType, schema: Schema
970
+ node: plrs._expr_nodes.Gather,
971
+ translator: Translator,
972
+ dtype: DataType,
973
+ schema: Schema,
847
974
  ) -> expr.Expr:
848
975
  return expr.Gather(
849
976
  dtype,
@@ -854,7 +981,10 @@ def _(
854
981
 
855
982
  @_translate_expr.register
856
983
  def _(
857
- node: pl_expr.Filter, translator: Translator, dtype: DataType, schema: Schema
984
+ node: plrs._expr_nodes.Filter,
985
+ translator: Translator,
986
+ dtype: DataType,
987
+ schema: Schema,
858
988
  ) -> expr.Expr:
859
989
  return expr.Filter(
860
990
  dtype,
@@ -865,44 +995,70 @@ def _(
865
995
 
866
996
  @_translate_expr.register
867
997
  def _(
868
- node: pl_expr.Cast, translator: Translator, dtype: DataType, schema: Schema
998
+ node: plrs._expr_nodes.Cast, translator: Translator, dtype: DataType, schema: Schema
869
999
  ) -> expr.Expr:
1000
+ # TODO: node.options can be 2 meaning wrap_numerical=True
1001
+ # don't necessarily raise because wrapping isn't always needed, but it's unhandled
1002
+ strict = node.options != 1
870
1003
  inner = translator.translate_expr(n=node.expr, schema=schema)
1004
+
1005
+ if plc.traits.is_floating_point(inner.dtype.plc_type) and plc.traits.is_fixed_point(
1006
+ dtype.plc_type
1007
+ ):
1008
+ return expr.Cast(
1009
+ dtype,
1010
+ strict,
1011
+ expr.UnaryFunction(
1012
+ inner.dtype, "round", (-dtype.plc_type.scale(), "half_to_even"), inner
1013
+ ),
1014
+ )
1015
+
871
1016
  # Push casts into literals so we can handle Cast(Literal(Null))
872
1017
  if isinstance(inner, expr.Literal):
873
1018
  return inner.astype(dtype)
874
- elif isinstance(inner, expr.Cast):
875
- # Translation of Len/Count-agg put in a cast, remove double
876
- # casts if we have one.
877
- (inner,) = inner.children
878
- return expr.Cast(dtype, inner)
1019
+ else:
1020
+ return expr.Cast(dtype, strict, inner)
879
1021
 
880
1022
 
881
1023
  @_translate_expr.register
882
1024
  def _(
883
- node: pl_expr.Column, translator: Translator, dtype: DataType, schema: Schema
1025
+ node: plrs._expr_nodes.Column,
1026
+ translator: Translator,
1027
+ dtype: DataType,
1028
+ schema: Schema,
884
1029
  ) -> expr.Expr:
885
1030
  return expr.Col(dtype, node.name)
886
1031
 
887
1032
 
888
1033
  @_translate_expr.register
889
1034
  def _(
890
- node: pl_expr.Agg, translator: Translator, dtype: DataType, schema: Schema
1035
+ node: plrs._expr_nodes.Agg, translator: Translator, dtype: DataType, schema: Schema
891
1036
  ) -> expr.Expr:
892
- value = expr.Agg(
893
- dtype,
894
- node.name,
895
- node.options,
896
- *(translator.translate_expr(n=n, schema=schema) for n in node.arguments),
897
- )
898
- if value.name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32:
899
- return expr.Cast(value.dtype, value)
1037
+ agg_name = node.name
1038
+ args = [translator.translate_expr(n=arg, schema=schema) for arg in node.arguments]
1039
+
1040
+ if agg_name not in ("count", "n_unique", "mean", "median", "quantile"):
1041
+ args = [
1042
+ expr.Cast(dtype, True, arg) # noqa: FBT003
1043
+ if plc.traits.is_fixed_point(arg.dtype.plc_type)
1044
+ and arg.dtype.plc_type != dtype.plc_type
1045
+ else arg
1046
+ for arg in args
1047
+ ]
1048
+
1049
+ value = expr.Agg(dtype, agg_name, node.options, translator._expr_context, *args)
1050
+
1051
+ if agg_name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32:
1052
+ return expr.Cast(value.dtype, True, value) # noqa: FBT003
900
1053
  return value
901
1054
 
902
1055
 
903
1056
  @_translate_expr.register
904
1057
  def _(
905
- node: pl_expr.Ternary, translator: Translator, dtype: DataType, schema: Schema
1058
+ node: plrs._expr_nodes.Ternary,
1059
+ translator: Translator,
1060
+ dtype: DataType,
1061
+ schema: Schema,
906
1062
  ) -> expr.Expr:
907
1063
  return expr.Ternary(
908
1064
  dtype,
@@ -914,26 +1070,70 @@ def _(
914
1070
 
915
1071
  @_translate_expr.register
916
1072
  def _(
917
- node: pl_expr.BinaryExpr,
1073
+ node: plrs._expr_nodes.BinaryExpr,
918
1074
  translator: Translator,
919
1075
  dtype: DataType,
920
1076
  schema: Schema,
921
1077
  ) -> expr.Expr:
922
- if plc.traits.is_boolean(dtype.plc) and node.op == pl_expr.Operator.TrueDivide:
923
- dtype = DataType(pl.Float64())
1078
+ left = translator.translate_expr(n=node.left, schema=schema)
1079
+ right = translator.translate_expr(n=node.right, schema=schema)
1080
+ if (
1081
+ POLARS_VERSION_LT_133
1082
+ and plc.traits.is_boolean(dtype.plc_type)
1083
+ and node.op == plrs._expr_nodes.Operator.TrueDivide
1084
+ ):
1085
+ dtype = DataType(pl.Float64()) # pragma: no cover
1086
+ if node.op == plrs._expr_nodes.Operator.TrueDivide and (
1087
+ plc.traits.is_fixed_point(left.dtype.plc_type)
1088
+ or plc.traits.is_fixed_point(right.dtype.plc_type)
1089
+ ):
1090
+ f64 = DataType(pl.Float64())
1091
+ return expr.Cast(
1092
+ dtype,
1093
+ True, # noqa: FBT003
1094
+ expr.BinOp(
1095
+ f64,
1096
+ expr.BinOp._MAPPING[node.op],
1097
+ expr.Cast(f64, True, left), # noqa: FBT003
1098
+ expr.Cast(f64, True, right), # noqa: FBT003
1099
+ ),
1100
+ )
1101
+
1102
+ if (
1103
+ not POLARS_VERSION_LT_134
1104
+ and node.op == plrs._expr_nodes.Operator.Multiply
1105
+ and plc.traits.is_fixed_point(left.dtype.plc_type)
1106
+ and plc.traits.is_fixed_point(right.dtype.plc_type)
1107
+ ):
1108
+ left_scale = -left.dtype.plc_type.scale()
1109
+ right_scale = -right.dtype.plc_type.scale()
1110
+ out_scale = max(left_scale, right_scale)
1111
+
1112
+ return expr.UnaryFunction(
1113
+ DataType(pl.Decimal(38, out_scale)),
1114
+ "round",
1115
+ (out_scale, "half_to_even"),
1116
+ expr.BinOp(
1117
+ DataType(pl.Decimal(38, left_scale + right_scale)),
1118
+ expr.BinOp._MAPPING[node.op],
1119
+ left,
1120
+ right,
1121
+ ),
1122
+ )
1123
+
924
1124
  return expr.BinOp(
925
1125
  dtype,
926
1126
  expr.BinOp._MAPPING[node.op],
927
- translator.translate_expr(n=node.left, schema=schema),
928
- translator.translate_expr(n=node.right, schema=schema),
1127
+ left,
1128
+ right,
929
1129
  )
930
1130
 
931
1131
 
932
1132
  @_translate_expr.register
933
1133
  def _(
934
- node: pl_expr.Len, translator: Translator, dtype: DataType, schema: Schema
1134
+ node: plrs._expr_nodes.Len, translator: Translator, dtype: DataType, schema: Schema
935
1135
  ) -> expr.Expr:
936
1136
  value = expr.Len(dtype)
937
1137
  if dtype.id() != plc.TypeId.INT32:
938
- return expr.Cast(dtype, value)
1138
+ return expr.Cast(dtype, True, value) # noqa: FBT003
939
1139
  return value # pragma: no cover; never reached since polars len has uint32 dtype