cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,939 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Translate polars IR representation to ours."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import functools
9
+ import json
10
+ from contextlib import AbstractContextManager, nullcontext
11
+ from functools import singledispatch
12
+ from typing import TYPE_CHECKING, Any
13
+
14
+ from typing_extensions import assert_never
15
+
16
+ import polars as pl
17
+ import polars.polars as plrs
18
+ from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir
19
+
20
+ import pylibcudf as plc
21
+
22
+ from cudf_polars.containers import DataType
23
+ from cudf_polars.dsl import expr, ir
24
+ from cudf_polars.dsl.expressions.base import ExecutionContext
25
+ from cudf_polars.dsl.to_ast import insert_colrefs
26
+ from cudf_polars.dsl.utils.aggregations import decompose_single_agg
27
+ from cudf_polars.dsl.utils.groupby import rewrite_groupby
28
+ from cudf_polars.dsl.utils.naming import unique_names
29
+ from cudf_polars.dsl.utils.replace import replace
30
+ from cudf_polars.dsl.utils.rolling import rewrite_rolling
31
+ from cudf_polars.typing import Schema
32
+ from cudf_polars.utils import config, sorting
33
+ from cudf_polars.utils.versions import (
34
+ POLARS_VERSION_LT_131,
35
+ POLARS_VERSION_LT_132,
36
+ POLARS_VERSION_LT_1323,
37
+ )
38
+
39
+ if TYPE_CHECKING:
40
+ from polars import GPUEngine
41
+
42
+ from cudf_polars.typing import NodeTraverser
43
+
44
+ __all__ = ["Translator", "translate_named_expr"]
45
+
46
+
47
+ class Translator:
48
+ """
49
+ Translates polars-internal IR nodes and expressions to our representation.
50
+
51
+ Parameters
52
+ ----------
53
+ visitor
54
+ Polars NodeTraverser object
55
+ engine
56
+ GPU engine configuration.
57
+ """
58
+
59
+ def __init__(self, visitor: NodeTraverser, engine: GPUEngine):
60
+ self.visitor = visitor
61
+ self.config_options = config.ConfigOptions.from_polars_engine(engine)
62
+ self.errors: list[Exception] = []
63
+ self._cache_nodes: dict[int, ir.Cache] = {}
64
+
65
+ def translate_ir(self, *, n: int | None = None) -> ir.IR:
66
+ """
67
+ Translate a polars-internal IR node to our representation.
68
+
69
+ Parameters
70
+ ----------
71
+ visitor
72
+ Polars NodeTraverser object
73
+ n
74
+ Optional node to start traversing from, if not provided uses
75
+ current polars-internal node.
76
+
77
+ Returns
78
+ -------
79
+ Translated IR object
80
+
81
+ Raises
82
+ ------
83
+ NotImplementedError
84
+ If the version of Polars IR is unsupported.
85
+
86
+ Notes
87
+ -----
88
+ Any expression nodes that cannot be translated are replaced by
89
+ :class:`expr.ErrorNode` nodes and collected in the the `errors` attribute.
90
+ After translation is complete, this list of errors should be inspected
91
+ to determine if the query is supported.
92
+ """
93
+ ctx: AbstractContextManager[None] = (
94
+ set_node(self.visitor, n) if n is not None else noop_context
95
+ )
96
+ # IR is versioned with major.minor, minor is bumped for backwards
97
+ # compatible changes (e.g. adding new nodes), major is bumped for
98
+ # incompatible changes (e.g. renaming nodes).
99
+ if (version := self.visitor.version()) >= (10, 1):
100
+ e = NotImplementedError(
101
+ f"No support for polars IR {version=}"
102
+ ) # pragma: no cover; no such version for now.
103
+ self.errors.append(e) # pragma: no cover
104
+ raise e # pragma: no cover
105
+
106
+ with ctx:
107
+ polars_schema = self.visitor.get_schema()
108
+ try:
109
+ schema = {k: DataType(v) for k, v in polars_schema.items()}
110
+ except Exception as e:
111
+ self.errors.append(NotImplementedError(str(e)))
112
+ return ir.ErrorNode({}, str(e))
113
+ try:
114
+ node = self.visitor.view_current_node()
115
+ except Exception as e:
116
+ self.errors.append(e)
117
+ return ir.ErrorNode(schema, str(e))
118
+ try:
119
+ result = _translate_ir(node, self, schema)
120
+ except Exception as e:
121
+ self.errors.append(e)
122
+ return ir.ErrorNode(schema, str(e))
123
+ if any(
124
+ isinstance(dtype, pl.Null)
125
+ for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values())
126
+ ):
127
+ error = NotImplementedError(
128
+ f"No GPU support for {result} with Null column dtype."
129
+ )
130
+ self.errors.append(error)
131
+ return ir.ErrorNode(schema, str(error))
132
+
133
+ return result
134
+
135
+ def translate_expr(self, *, n: int, schema: Schema) -> expr.Expr:
136
+ """
137
+ Translate a polars-internal expression IR into our representation.
138
+
139
+ Parameters
140
+ ----------
141
+ n
142
+ Node to translate, an integer referencing a polars internal node.
143
+ schema
144
+ Schema of the IR node this expression uses as evaluation context.
145
+
146
+ Returns
147
+ -------
148
+ Translated IR object.
149
+
150
+ Notes
151
+ -----
152
+ Any expression nodes that cannot be translated are replaced by
153
+ :class:`expr.ErrorExpr` nodes and collected in the the `errors` attribute.
154
+ After translation is complete, this list of errors should be inspected
155
+ to determine if the query is supported.
156
+ """
157
+ node = self.visitor.view_expression(n)
158
+ dtype = DataType(self.visitor.get_dtype(n))
159
+ try:
160
+ return _translate_expr(node, self, dtype, schema)
161
+ except Exception as e:
162
+ self.errors.append(e)
163
+ return expr.ErrorExpr(dtype, str(e))
164
+
165
+
166
+ class set_node(AbstractContextManager[None]):
167
+ """
168
+ Run a block with current node set in the visitor.
169
+
170
+ Parameters
171
+ ----------
172
+ visitor
173
+ The internal Rust visitor object
174
+ n
175
+ The node to set as the current root.
176
+
177
+ Notes
178
+ -----
179
+ This is useful for translating expressions with a given node
180
+ active, restoring the node when the block exits.
181
+ """
182
+
183
+ __slots__ = ("n", "visitor")
184
+ visitor: NodeTraverser
185
+
186
+ n: int
187
+
188
+ def __init__(self, visitor: NodeTraverser, n: int) -> None:
189
+ self.visitor = visitor
190
+ self.n = n
191
+
192
+ def __enter__(self) -> None:
193
+ n = self.visitor.get_node()
194
+ self.visitor.set_node(self.n)
195
+ self.n = n
196
+
197
+ def __exit__(self, *args: Any) -> None:
198
+ self.visitor.set_node(self.n)
199
+
200
+
201
+ noop_context: nullcontext[None] = nullcontext()
202
+
203
+
204
+ @singledispatch
205
+ def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR:
206
+ raise NotImplementedError(
207
+ f"Translation for {type(node).__name__}"
208
+ ) # pragma: no cover
209
+
210
+
211
+ @_translate_ir.register
212
+ def _(node: pl_ir.PythonScan, translator: Translator, schema: Schema) -> ir.IR:
213
+ scan_fn, with_columns, source_type, predicate, nrows = node.options
214
+ options = (scan_fn, with_columns, source_type, nrows)
215
+ predicate = (
216
+ translate_named_expr(translator, n=predicate, schema=schema)
217
+ if predicate is not None
218
+ else None
219
+ )
220
+ return ir.PythonScan(schema, options, predicate)
221
+
222
+
223
+ @_translate_ir.register
224
+ def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR:
225
+ typ, *options = node.scan_type
226
+ paths = node.paths
227
+ # Polars can produce a Scan with an empty ``node.paths`` (eg. the native
228
+ # Iceberg reader on a table with no data files yet). In this case, polars returns an
229
+ # empty DataFrame with the declared schema. Mirror that here by
230
+ # replacing the Scan with an Empty IR node.
231
+ if not paths: # pragma: no cover
232
+ return ir.Empty(schema)
233
+ if typ == "ndjson":
234
+ (reader_options,) = map(json.loads, options)
235
+ cloud_options = None
236
+ else:
237
+ reader_options, cloud_options = map(json.loads, options)
238
+ file_options = node.file_options
239
+ with_columns = file_options.with_columns
240
+ row_index = file_options.row_index
241
+ include_file_paths = file_options.include_file_paths
242
+ if not POLARS_VERSION_LT_131:
243
+ deletion_files = file_options.deletion_files # pragma: no cover
244
+ if deletion_files: # pragma: no cover
245
+ raise NotImplementedError(
246
+ "Iceberg format is not supported in cudf-polars. Furthermore, row-level deletions are not supported."
247
+ ) # pragma: no cover
248
+ config_options = translator.config_options
249
+ parquet_options = config_options.parquet_options
250
+
251
+ pre_slice = file_options.n_rows
252
+ if pre_slice is None:
253
+ n_rows = -1
254
+ skip_rows = 0
255
+ else:
256
+ skip_rows, n_rows = pre_slice
257
+
258
+ return ir.Scan(
259
+ schema,
260
+ typ,
261
+ reader_options,
262
+ cloud_options,
263
+ paths,
264
+ with_columns,
265
+ skip_rows,
266
+ n_rows,
267
+ row_index,
268
+ include_file_paths,
269
+ translate_named_expr(translator, n=node.predicate, schema=schema)
270
+ if node.predicate is not None
271
+ else None,
272
+ parquet_options,
273
+ )
274
+
275
+
276
+ @_translate_ir.register
277
+ def _(node: pl_ir.Cache, translator: Translator, schema: Schema) -> ir.IR:
278
+ if POLARS_VERSION_LT_1323: # pragma: no cover
279
+ refcount = node.cache_hits
280
+ else:
281
+ refcount = None
282
+
283
+ # Make sure Cache nodes with the same id_
284
+ # are actually the same object.
285
+ if node.id_ not in translator._cache_nodes:
286
+ translator._cache_nodes[node.id_] = ir.Cache(
287
+ schema,
288
+ node.id_,
289
+ refcount,
290
+ translator.translate_ir(n=node.input),
291
+ )
292
+ return translator._cache_nodes[node.id_]
293
+
294
+
295
+ @_translate_ir.register
296
+ def _(node: pl_ir.DataFrameScan, translator: Translator, schema: Schema) -> ir.IR:
297
+ return ir.DataFrameScan(
298
+ schema,
299
+ node.df,
300
+ node.projection,
301
+ )
302
+
303
+
304
+ @_translate_ir.register
305
+ def _(node: pl_ir.Select, translator: Translator, schema: Schema) -> ir.IR:
306
+ with set_node(translator.visitor, node.input):
307
+ inp = translator.translate_ir(n=None)
308
+ exprs = [
309
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.expr
310
+ ]
311
+ return ir.Select(schema, exprs, node.should_broadcast, inp)
312
+
313
+
314
+ @_translate_ir.register
315
+ def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
316
+ with set_node(translator.visitor, node.input):
317
+ inp = translator.translate_ir(n=None)
318
+ keys = [
319
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.keys
320
+ ]
321
+ original_aggs = [
322
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.aggs
323
+ ]
324
+ is_rolling = node.options.rolling is not None
325
+ is_dynamic = node.options.dynamic is not None
326
+ if is_dynamic:
327
+ raise NotImplementedError("group_by_dynamic")
328
+ elif is_rolling:
329
+ return rewrite_rolling(
330
+ node.options, schema, keys, original_aggs, translator.config_options, inp
331
+ )
332
+ else:
333
+ return rewrite_groupby(node, schema, keys, original_aggs, inp)
334
+
335
+
336
+ @_translate_ir.register
337
+ def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
338
+ # Join key dtypes are dependent on the schema of the left and
339
+ # right inputs, so these must be translated with the relevant
340
+ # input active.
341
+ with set_node(translator.visitor, node.input_left):
342
+ inp_left = translator.translate_ir(n=None)
343
+ left_on = [
344
+ translate_named_expr(translator, n=e, schema=inp_left.schema)
345
+ for e in node.left_on
346
+ ]
347
+ with set_node(translator.visitor, node.input_right):
348
+ inp_right = translator.translate_ir(n=None)
349
+ right_on = [
350
+ translate_named_expr(translator, n=e, schema=inp_right.schema)
351
+ for e in node.right_on
352
+ ]
353
+
354
+ if (how := node.options[0]) in {
355
+ "Inner",
356
+ "Left",
357
+ "Right",
358
+ "Full",
359
+ "Cross",
360
+ "Semi",
361
+ "Anti",
362
+ }:
363
+ return ir.Join(
364
+ schema,
365
+ left_on,
366
+ right_on,
367
+ node.options,
368
+ inp_left,
369
+ inp_right,
370
+ )
371
+ else:
372
+ how, op1, op2 = node.options[0]
373
+ if how != "IEJoin":
374
+ raise NotImplementedError(
375
+ f"Unsupported join type {how}"
376
+ ) # pragma: no cover; asof joins not yet exposed
377
+ if op2 is None:
378
+ ops = [op1]
379
+ else:
380
+ ops = [op1, op2]
381
+
382
+ dtype = DataType(pl.datatypes.Boolean())
383
+ predicate = functools.reduce(
384
+ functools.partial(
385
+ expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND
386
+ ),
387
+ (
388
+ expr.BinOp(
389
+ dtype,
390
+ expr.BinOp._MAPPING[op],
391
+ insert_colrefs(
392
+ left.value,
393
+ table_ref=plc.expressions.TableReference.LEFT,
394
+ name_to_index={
395
+ name: i for i, name in enumerate(inp_left.schema)
396
+ },
397
+ ),
398
+ insert_colrefs(
399
+ right.value,
400
+ table_ref=plc.expressions.TableReference.RIGHT,
401
+ name_to_index={
402
+ name: i for i, name in enumerate(inp_right.schema)
403
+ },
404
+ ),
405
+ )
406
+ for op, left, right in zip(ops, left_on, right_on, strict=True)
407
+ ),
408
+ )
409
+
410
+ return ir.ConditionalJoin(schema, predicate, node.options, inp_left, inp_right)
411
+
412
+
413
+ @_translate_ir.register
414
+ def _(node: pl_ir.HStack, translator: Translator, schema: Schema) -> ir.IR:
415
+ with set_node(translator.visitor, node.input):
416
+ inp = translator.translate_ir(n=None)
417
+ exprs = [
418
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.exprs
419
+ ]
420
+ return ir.HStack(schema, exprs, node.should_broadcast, inp)
421
+
422
+
423
+ @_translate_ir.register
424
+ def _(
425
+ node: pl_ir.Reduce, translator: Translator, schema: Schema
426
+ ) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet
427
+ with set_node(translator.visitor, node.input):
428
+ inp = translator.translate_ir(n=None)
429
+ exprs = [
430
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.expr
431
+ ]
432
+ return ir.Reduce(schema, exprs, inp)
433
+
434
+
435
+ @_translate_ir.register
436
+ def _(node: pl_ir.Distinct, translator: Translator, schema: Schema) -> ir.IR:
437
+ (keep, subset, maintain_order, zlice) = node.options
438
+ keep = ir.Distinct._KEEP_MAP[keep]
439
+ subset = frozenset(subset) if subset is not None else None
440
+ return ir.Distinct(
441
+ schema,
442
+ keep,
443
+ subset,
444
+ zlice,
445
+ maintain_order,
446
+ translator.translate_ir(n=node.input),
447
+ )
448
+
449
+
450
+ @_translate_ir.register
451
+ def _(node: pl_ir.Sort, translator: Translator, schema: Schema) -> ir.IR:
452
+ with set_node(translator.visitor, node.input):
453
+ inp = translator.translate_ir(n=None)
454
+ by = [
455
+ translate_named_expr(translator, n=e, schema=inp.schema)
456
+ for e in node.by_column
457
+ ]
458
+ stable, nulls_last, descending = node.sort_options
459
+ order, null_order = sorting.sort_order(
460
+ descending, nulls_last=nulls_last, num_keys=len(by)
461
+ )
462
+ return ir.Sort(schema, by, order, null_order, stable, node.slice, inp)
463
+
464
+
465
+ @_translate_ir.register
466
+ def _(node: pl_ir.Slice, translator: Translator, schema: Schema) -> ir.IR:
467
+ return ir.Slice(
468
+ schema, node.offset, node.len, translator.translate_ir(n=node.input)
469
+ )
470
+
471
+
472
+ @_translate_ir.register
473
+ def _(node: pl_ir.Filter, translator: Translator, schema: Schema) -> ir.IR:
474
+ with set_node(translator.visitor, node.input):
475
+ inp = translator.translate_ir(n=None)
476
+ mask = translate_named_expr(translator, n=node.predicate, schema=inp.schema)
477
+ return ir.Filter(schema, mask, inp)
478
+
479
+
480
+ @_translate_ir.register
481
+ def _(node: pl_ir.SimpleProjection, translator: Translator, schema: Schema) -> ir.IR:
482
+ return ir.Projection(schema, translator.translate_ir(n=node.input))
483
+
484
+
485
+ @_translate_ir.register
486
+ def _(node: pl_ir.MergeSorted, translator: Translator, schema: Schema) -> ir.IR:
487
+ key = node.key
488
+ inp_left = translator.translate_ir(n=node.input_left)
489
+ inp_right = translator.translate_ir(n=node.input_right)
490
+ return ir.MergeSorted(
491
+ schema,
492
+ key,
493
+ inp_left,
494
+ inp_right,
495
+ )
496
+
497
+
498
+ @_translate_ir.register
499
+ def _(node: pl_ir.MapFunction, translator: Translator, schema: Schema) -> ir.IR:
500
+ name, *options = node.function
501
+ return ir.MapFunction(
502
+ schema,
503
+ name,
504
+ options,
505
+ translator.translate_ir(n=node.input),
506
+ )
507
+
508
+
509
+ @_translate_ir.register
510
+ def _(node: pl_ir.Union, translator: Translator, schema: Schema) -> ir.IR:
511
+ return ir.Union(
512
+ schema, node.options, *(translator.translate_ir(n=n) for n in node.inputs)
513
+ )
514
+
515
+
516
+ @_translate_ir.register
517
+ def _(node: pl_ir.HConcat, translator: Translator, schema: Schema) -> ir.IR:
518
+ return ir.HConcat(
519
+ schema,
520
+ False, # noqa: FBT003
521
+ *(translator.translate_ir(n=n) for n in node.inputs),
522
+ )
523
+
524
+
525
+ @_translate_ir.register
526
+ def _(node: pl_ir.Sink, translator: Translator, schema: Schema) -> ir.IR:
527
+ payload = json.loads(node.payload)
528
+ try:
529
+ file = payload["File"]
530
+ sink_kind_options = file["file_type"]
531
+ except KeyError as err: # pragma: no cover
532
+ raise NotImplementedError("Unsupported payload structure") from err
533
+ if isinstance(sink_kind_options, dict):
534
+ if len(sink_kind_options) != 1: # pragma: no cover; not sure if this can happen
535
+ raise NotImplementedError("Sink options dict with more than one entry.")
536
+ sink_kind, options = next(iter(sink_kind_options.items()))
537
+ else:
538
+ raise NotImplementedError(
539
+ "Unsupported sink options structure"
540
+ ) # pragma: no cover
541
+
542
+ sink_options = file.get("sink_options", {})
543
+ cloud_options = file.get("cloud_options")
544
+
545
+ options.update(sink_options)
546
+
547
+ return ir.Sink(
548
+ schema=schema,
549
+ kind=sink_kind,
550
+ path=file["target"] if POLARS_VERSION_LT_132 else file["target"]["Local"],
551
+ parquet_options=translator.config_options.parquet_options,
552
+ options=options,
553
+ cloud_options=cloud_options,
554
+ df=translator.translate_ir(n=node.input),
555
+ )
556
+
557
+
558
+ def translate_named_expr(
559
+ translator: Translator, *, n: pl_expr.PyExprIR, schema: Schema
560
+ ) -> expr.NamedExpr:
561
+ """
562
+ Translate a polars-internal named expression IR object into our representation.
563
+
564
+ Parameters
565
+ ----------
566
+ translator
567
+ Translator object
568
+ n
569
+ Node to translate, a named expression node.
570
+ schema
571
+ Schema of the IR node this expression uses as evaluation context.
572
+
573
+ Returns
574
+ -------
575
+ Translated IR object.
576
+
577
+ Notes
578
+ -----
579
+ The datatype of the internal expression will be obtained from the
580
+ visitor by calling ``get_dtype``, for this to work properly, the
581
+ caller should arrange that the expression is translated with the
582
+ node that it references "active" for the visitor (see :class:`set_node`).
583
+
584
+ Raises
585
+ ------
586
+ NotImplementedError
587
+ If any translation fails due to unsupported functionality.
588
+ """
589
+ return expr.NamedExpr(
590
+ n.output_name, translator.translate_expr(n=n.node, schema=schema)
591
+ )
592
+
593
+
594
+ @singledispatch
595
+ def _translate_expr(
596
+ node: Any, translator: Translator, dtype: DataType, schema: Schema
597
+ ) -> expr.Expr:
598
+ raise NotImplementedError(
599
+ f"Translation for {type(node).__name__}"
600
+ ) # pragma: no cover
601
+
602
+
603
+ @_translate_expr.register
604
+ def _(
605
+ node: pl_expr.Function, translator: Translator, dtype: DataType, schema: Schema
606
+ ) -> expr.Expr:
607
+ name, *options = node.function_data
608
+ options = tuple(options)
609
+ if isinstance(name, pl_expr.StringFunction):
610
+ if name in {
611
+ pl_expr.StringFunction.StripChars,
612
+ pl_expr.StringFunction.StripCharsStart,
613
+ pl_expr.StringFunction.StripCharsEnd,
614
+ }:
615
+ column, chars = (
616
+ translator.translate_expr(n=n, schema=schema) for n in node.input
617
+ )
618
+ if isinstance(chars, expr.Literal):
619
+ # We check for null first because we want to use the
620
+ # chars type, but it is invalid to try and
621
+ # produce a string scalar with a null dtype.
622
+ if chars.value is None:
623
+ # Polars uses None to mean "strip all whitespace"
624
+ chars = expr.Literal(column.dtype, "")
625
+ elif chars.value == "":
626
+ # No-op in polars, but libcudf uses empty string
627
+ # as signifier to remove whitespace.
628
+ return column
629
+ return expr.StringFunction(
630
+ dtype,
631
+ expr.StringFunction.Name.from_polars(name),
632
+ options,
633
+ column,
634
+ chars,
635
+ )
636
+ return expr.StringFunction(
637
+ dtype,
638
+ expr.StringFunction.Name.from_polars(name),
639
+ options,
640
+ *(translator.translate_expr(n=n, schema=schema) for n in node.input),
641
+ )
642
+ elif isinstance(name, pl_expr.BooleanFunction):
643
+ if name == pl_expr.BooleanFunction.IsBetween:
644
+ column, lo, hi = (
645
+ translator.translate_expr(n=n, schema=schema) for n in node.input
646
+ )
647
+ (closed,) = options
648
+ lop, rop = expr.BooleanFunction._BETWEEN_OPS[closed]
649
+ return expr.BinOp(
650
+ dtype,
651
+ plc.binaryop.BinaryOperator.LOGICAL_AND,
652
+ expr.BinOp(dtype, lop, column, lo),
653
+ expr.BinOp(dtype, rop, column, hi),
654
+ )
655
+ return expr.BooleanFunction(
656
+ dtype,
657
+ expr.BooleanFunction.Name.from_polars(name),
658
+ options,
659
+ *(translator.translate_expr(n=n, schema=schema) for n in node.input),
660
+ )
661
+ elif isinstance(name, pl_expr.TemporalFunction):
662
+ # functions for which evaluation of the expression may not return
663
+ # the same dtype as polars, either due to libcudf returning a different
664
+ # dtype, or due to our internal processing affecting what libcudf returns
665
+ needs_cast = {
666
+ pl_expr.TemporalFunction.Year,
667
+ pl_expr.TemporalFunction.Month,
668
+ pl_expr.TemporalFunction.Day,
669
+ pl_expr.TemporalFunction.WeekDay,
670
+ pl_expr.TemporalFunction.Hour,
671
+ pl_expr.TemporalFunction.Minute,
672
+ pl_expr.TemporalFunction.Second,
673
+ pl_expr.TemporalFunction.Millisecond,
674
+ }
675
+ result_expr = expr.TemporalFunction(
676
+ dtype,
677
+ expr.TemporalFunction.Name.from_polars(name),
678
+ options,
679
+ *(translator.translate_expr(n=n, schema=schema) for n in node.input),
680
+ )
681
+ if name in needs_cast:
682
+ return expr.Cast(dtype, result_expr)
683
+ return result_expr
684
+ elif not POLARS_VERSION_LT_131 and isinstance(name, pl_expr.StructFunction):
685
+ return expr.StructFunction(
686
+ dtype,
687
+ expr.StructFunction.Name.from_polars(name),
688
+ options,
689
+ *(translator.translate_expr(n=n, schema=schema) for n in node.input),
690
+ )
691
+ elif isinstance(name, str):
692
+ children = (translator.translate_expr(n=n, schema=schema) for n in node.input)
693
+ if name == "log":
694
+ (base,) = options
695
+ (child,) = children
696
+ return expr.BinOp(
697
+ dtype,
698
+ plc.binaryop.BinaryOperator.LOG_BASE,
699
+ child,
700
+ expr.Literal(dtype, base),
701
+ )
702
+ elif name == "pow":
703
+ return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children)
704
+ return expr.UnaryFunction(dtype, name, options, *children)
705
+ raise NotImplementedError(
706
+ f"No handler for Expr function node with {name=}"
707
+ ) # pragma: no cover; polars raises on the rust side for now
708
+
709
+
710
+ @_translate_expr.register
711
+ def _(
712
+ node: pl_expr.Window, translator: Translator, dtype: DataType, schema: Schema
713
+ ) -> expr.Expr:
714
+ if isinstance(node.options, pl_expr.RollingGroupOptions):
715
+ # pl.col("a").rolling(...)
716
+ agg = translator.translate_expr(n=node.function, schema=schema)
717
+ name_generator = unique_names(schema)
718
+ aggs, named_post_agg = decompose_single_agg(
719
+ expr.NamedExpr(next(name_generator), agg),
720
+ name_generator,
721
+ is_top=True,
722
+ context=ExecutionContext.ROLLING,
723
+ )
724
+ named_aggs = [agg for agg, _ in aggs]
725
+ orderby = node.options.index_column
726
+ orderby_dtype = schema[orderby].plc
727
+ if plc.traits.is_integral(orderby_dtype):
728
+ # Integer orderby column is cast in implementation to int64 in polars
729
+ orderby_dtype = plc.DataType(plc.TypeId.INT64)
730
+ closed_window = node.options.closed_window
731
+ if isinstance(named_post_agg.value, expr.Col):
732
+ (named_agg,) = named_aggs
733
+ return expr.RollingWindow(
734
+ named_agg.value.dtype,
735
+ orderby_dtype,
736
+ node.options.offset,
737
+ node.options.period,
738
+ closed_window,
739
+ orderby,
740
+ named_agg.value,
741
+ )
742
+ replacements: dict[expr.Expr, expr.Expr] = {
743
+ expr.Col(agg.value.dtype, agg.name): expr.RollingWindow(
744
+ agg.value.dtype,
745
+ orderby_dtype,
746
+ node.options.offset,
747
+ node.options.period,
748
+ closed_window,
749
+ orderby,
750
+ agg.value,
751
+ )
752
+ for agg in named_aggs
753
+ }
754
+ return replace([named_post_agg.value], replacements)[0]
755
+ elif isinstance(node.options, pl_expr.WindowMapping):
756
+ # pl.col("a").over(...)
757
+ agg = translator.translate_expr(n=node.function, schema=schema)
758
+ name_gen = unique_names(schema)
759
+ aggs, post = decompose_single_agg(
760
+ expr.NamedExpr(next(name_gen), agg),
761
+ name_gen,
762
+ is_top=True,
763
+ context=ExecutionContext.WINDOW,
764
+ )
765
+
766
+ mapping = node.options.kind
767
+ has_order_by = node.order_by is not None
768
+ descending = bool(getattr(node, "order_by_descending", False))
769
+ nulls_last = bool(getattr(node, "order_by_nulls_last", False))
770
+
771
+ if mapping != "groups_to_rows":
772
+ raise NotImplementedError(
773
+ f"over(mapping_strategy) not supported yet: {mapping=}; "
774
+ f"expected 'groups_to_rows'"
775
+ )
776
+
777
+ order_by_expr = (
778
+ translator.translate_expr(n=node.order_by, schema=schema)
779
+ if has_order_by
780
+ else None
781
+ )
782
+ return expr.GroupedRollingWindow(
783
+ dtype,
784
+ (mapping, has_order_by, descending, nulls_last),
785
+ [agg for agg, _ in aggs],
786
+ post,
787
+ *(translator.translate_expr(n=n, schema=schema) for n in node.partition_by),
788
+ _order_by_expr=order_by_expr,
789
+ )
790
+ assert_never(node.options)
791
+
792
+
793
+ @_translate_expr.register
794
+ def _(
795
+ node: pl_expr.Literal, translator: Translator, dtype: DataType, schema: Schema
796
+ ) -> expr.Expr:
797
+ if isinstance(node.value, plrs.PySeries):
798
+ return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value))
799
+ if dtype.id() == plc.TypeId.LIST: # pragma: no cover
800
+ # TODO: Remove once pylibcudf.Scalar supports lists
801
+ return expr.LiteralColumn(dtype, pl.Series(node.value))
802
+ return expr.Literal(dtype, node.value)
803
+
804
+
805
+ @_translate_expr.register
806
+ def _(
807
+ node: pl_expr.Sort, translator: Translator, dtype: DataType, schema: Schema
808
+ ) -> expr.Expr:
809
+ # TODO: raise in groupby
810
+ return expr.Sort(
811
+ dtype, node.options, translator.translate_expr(n=node.expr, schema=schema)
812
+ )
813
+
814
+
815
+ @_translate_expr.register
816
+ def _(
817
+ node: pl_expr.SortBy, translator: Translator, dtype: DataType, schema: Schema
818
+ ) -> expr.Expr:
819
+ options = node.sort_options
820
+ return expr.SortBy(
821
+ dtype,
822
+ (options[0], tuple(options[1]), tuple(options[2])),
823
+ translator.translate_expr(n=node.expr, schema=schema),
824
+ *(translator.translate_expr(n=n, schema=schema) for n in node.by),
825
+ )
826
+
827
+
828
+ @_translate_expr.register
829
+ def _(
830
+ node: pl_expr.Slice, translator: Translator, dtype: DataType, schema: Schema
831
+ ) -> expr.Expr:
832
+ offset = translator.translate_expr(n=node.offset, schema=schema)
833
+ length = translator.translate_expr(n=node.length, schema=schema)
834
+ assert isinstance(offset, expr.Literal)
835
+ assert isinstance(length, expr.Literal)
836
+ return expr.Slice(
837
+ dtype,
838
+ offset.value,
839
+ length.value,
840
+ translator.translate_expr(n=node.input, schema=schema),
841
+ )
842
+
843
+
844
+ @_translate_expr.register
845
+ def _(
846
+ node: pl_expr.Gather, translator: Translator, dtype: DataType, schema: Schema
847
+ ) -> expr.Expr:
848
+ return expr.Gather(
849
+ dtype,
850
+ translator.translate_expr(n=node.expr, schema=schema),
851
+ translator.translate_expr(n=node.idx, schema=schema),
852
+ )
853
+
854
+
855
+ @_translate_expr.register
856
+ def _(
857
+ node: pl_expr.Filter, translator: Translator, dtype: DataType, schema: Schema
858
+ ) -> expr.Expr:
859
+ return expr.Filter(
860
+ dtype,
861
+ translator.translate_expr(n=node.input, schema=schema),
862
+ translator.translate_expr(n=node.by, schema=schema),
863
+ )
864
+
865
+
866
+ @_translate_expr.register
867
+ def _(
868
+ node: pl_expr.Cast, translator: Translator, dtype: DataType, schema: Schema
869
+ ) -> expr.Expr:
870
+ inner = translator.translate_expr(n=node.expr, schema=schema)
871
+ # Push casts into literals so we can handle Cast(Literal(Null))
872
+ if isinstance(inner, expr.Literal):
873
+ return inner.astype(dtype)
874
+ elif isinstance(inner, expr.Cast):
875
+ # Translation of Len/Count-agg put in a cast, remove double
876
+ # casts if we have one.
877
+ (inner,) = inner.children
878
+ return expr.Cast(dtype, inner)
879
+
880
+
881
+ @_translate_expr.register
882
+ def _(
883
+ node: pl_expr.Column, translator: Translator, dtype: DataType, schema: Schema
884
+ ) -> expr.Expr:
885
+ return expr.Col(dtype, node.name)
886
+
887
+
888
+ @_translate_expr.register
889
+ def _(
890
+ node: pl_expr.Agg, translator: Translator, dtype: DataType, schema: Schema
891
+ ) -> expr.Expr:
892
+ value = expr.Agg(
893
+ dtype,
894
+ node.name,
895
+ node.options,
896
+ *(translator.translate_expr(n=n, schema=schema) for n in node.arguments),
897
+ )
898
+ if value.name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32:
899
+ return expr.Cast(value.dtype, value)
900
+ return value
901
+
902
+
903
+ @_translate_expr.register
904
+ def _(
905
+ node: pl_expr.Ternary, translator: Translator, dtype: DataType, schema: Schema
906
+ ) -> expr.Expr:
907
+ return expr.Ternary(
908
+ dtype,
909
+ translator.translate_expr(n=node.predicate, schema=schema),
910
+ translator.translate_expr(n=node.truthy, schema=schema),
911
+ translator.translate_expr(n=node.falsy, schema=schema),
912
+ )
913
+
914
+
915
+ @_translate_expr.register
916
+ def _(
917
+ node: pl_expr.BinaryExpr,
918
+ translator: Translator,
919
+ dtype: DataType,
920
+ schema: Schema,
921
+ ) -> expr.Expr:
922
+ if plc.traits.is_boolean(dtype.plc) and node.op == pl_expr.Operator.TrueDivide:
923
+ dtype = DataType(pl.Float64())
924
+ return expr.BinOp(
925
+ dtype,
926
+ expr.BinOp._MAPPING[node.op],
927
+ translator.translate_expr(n=node.left, schema=schema),
928
+ translator.translate_expr(n=node.right, schema=schema),
929
+ )
930
+
931
+
932
+ @_translate_expr.register
933
+ def _(
934
+ node: pl_expr.Len, translator: Translator, dtype: DataType, schema: Schema
935
+ ) -> expr.Expr:
936
+ value = expr.Len(dtype)
937
+ if dtype.id() != plc.TypeId.INT32:
938
+ return expr.Cast(dtype, value)
939
+ return value # pragma: no cover; never reached since polars len has uint32 dtype