cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/callback.py +82 -65
  3. cudf_polars/containers/column.py +138 -7
  4. cudf_polars/containers/dataframe.py +26 -39
  5. cudf_polars/dsl/expr.py +3 -1
  6. cudf_polars/dsl/expressions/aggregation.py +27 -63
  7. cudf_polars/dsl/expressions/base.py +40 -72
  8. cudf_polars/dsl/expressions/binaryop.py +5 -41
  9. cudf_polars/dsl/expressions/boolean.py +25 -53
  10. cudf_polars/dsl/expressions/datetime.py +97 -17
  11. cudf_polars/dsl/expressions/literal.py +27 -33
  12. cudf_polars/dsl/expressions/rolling.py +110 -9
  13. cudf_polars/dsl/expressions/selection.py +8 -26
  14. cudf_polars/dsl/expressions/slicing.py +47 -0
  15. cudf_polars/dsl/expressions/sorting.py +5 -18
  16. cudf_polars/dsl/expressions/string.py +33 -36
  17. cudf_polars/dsl/expressions/ternary.py +3 -10
  18. cudf_polars/dsl/expressions/unary.py +35 -75
  19. cudf_polars/dsl/ir.py +749 -212
  20. cudf_polars/dsl/nodebase.py +8 -1
  21. cudf_polars/dsl/to_ast.py +5 -3
  22. cudf_polars/dsl/translate.py +319 -171
  23. cudf_polars/dsl/utils/__init__.py +8 -0
  24. cudf_polars/dsl/utils/aggregations.py +292 -0
  25. cudf_polars/dsl/utils/groupby.py +97 -0
  26. cudf_polars/dsl/utils/naming.py +34 -0
  27. cudf_polars/dsl/utils/replace.py +46 -0
  28. cudf_polars/dsl/utils/rolling.py +113 -0
  29. cudf_polars/dsl/utils/windows.py +186 -0
  30. cudf_polars/experimental/base.py +17 -19
  31. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  32. cudf_polars/experimental/benchmarks/pdsh.py +1279 -0
  33. cudf_polars/experimental/dask_registers.py +196 -0
  34. cudf_polars/experimental/distinct.py +174 -0
  35. cudf_polars/experimental/explain.py +127 -0
  36. cudf_polars/experimental/expressions.py +521 -0
  37. cudf_polars/experimental/groupby.py +288 -0
  38. cudf_polars/experimental/io.py +58 -29
  39. cudf_polars/experimental/join.py +353 -0
  40. cudf_polars/experimental/parallel.py +166 -93
  41. cudf_polars/experimental/repartition.py +69 -0
  42. cudf_polars/experimental/scheduler.py +155 -0
  43. cudf_polars/experimental/select.py +92 -7
  44. cudf_polars/experimental/shuffle.py +294 -0
  45. cudf_polars/experimental/sort.py +45 -0
  46. cudf_polars/experimental/spilling.py +151 -0
  47. cudf_polars/experimental/utils.py +100 -0
  48. cudf_polars/testing/asserts.py +146 -6
  49. cudf_polars/testing/io.py +72 -0
  50. cudf_polars/testing/plugin.py +78 -76
  51. cudf_polars/typing/__init__.py +59 -6
  52. cudf_polars/utils/config.py +353 -0
  53. cudf_polars/utils/conversion.py +40 -0
  54. cudf_polars/utils/dtypes.py +22 -5
  55. cudf_polars/utils/timer.py +39 -0
  56. cudf_polars/utils/versions.py +5 -4
  57. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/METADATA +10 -7
  58. cudf_polars_cu12-25.6.0.dist-info/RECORD +73 -0
  59. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/WHEEL +1 -1
  60. cudf_polars/experimental/dask_serialize.py +0 -59
  61. cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
  62. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info/licenses}/LICENSE +0 -0
  63. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/top_level.txt +0 -0
@@ -22,8 +22,14 @@ import pylibcudf as plc
22
22
 
23
23
  from cudf_polars.dsl import expr, ir
24
24
  from cudf_polars.dsl.to_ast import insert_colrefs
25
- from cudf_polars.typing import NodeTraverser
26
- from cudf_polars.utils import dtypes, sorting
25
+ from cudf_polars.dsl.utils.aggregations import decompose_single_agg
26
+ from cudf_polars.dsl.utils.groupby import rewrite_groupby
27
+ from cudf_polars.dsl.utils.naming import unique_names
28
+ from cudf_polars.dsl.utils.replace import replace
29
+ from cudf_polars.dsl.utils.rolling import rewrite_rolling
30
+ from cudf_polars.dsl.utils.windows import offsets_to_windows
31
+ from cudf_polars.typing import Schema
32
+ from cudf_polars.utils import config, dtypes, sorting
27
33
 
28
34
  if TYPE_CHECKING:
29
35
  from polars import GPUEngine
@@ -41,13 +47,13 @@ class Translator:
41
47
  ----------
42
48
  visitor
43
49
  Polars NodeTraverser object
44
- config
50
+ engine
45
51
  GPU engine configuration.
46
52
  """
47
53
 
48
- def __init__(self, visitor: NodeTraverser, config: GPUEngine):
54
+ def __init__(self, visitor: NodeTraverser, engine: GPUEngine):
49
55
  self.visitor = visitor
50
- self.config = config
56
+ self.config_options = config.ConfigOptions.from_polars_engine(engine)
51
57
  self.errors: list[Exception] = []
52
58
 
53
59
  def translate_ir(self, *, n: int | None = None) -> ir.IR:
@@ -84,7 +90,7 @@ class Translator:
84
90
  # IR is versioned with major.minor, minor is bumped for backwards
85
91
  # compatible changes (e.g. adding new nodes), major is bumped for
86
92
  # incompatible changes (e.g. renaming nodes).
87
- if (version := self.visitor.version()) >= (5, 1):
93
+ if (version := self.visitor.version()) >= (7, 1):
88
94
  e = NotImplementedError(
89
95
  f"No support for polars IR {version=}"
90
96
  ) # pragma: no cover; no such version for now.
@@ -120,7 +126,7 @@ class Translator:
120
126
 
121
127
  return result
122
128
 
123
- def translate_expr(self, *, n: int) -> expr.Expr:
129
+ def translate_expr(self, *, n: int, schema: Schema) -> expr.Expr:
124
130
  """
125
131
  Translate a polars-internal expression IR into our representation.
126
132
 
@@ -128,6 +134,8 @@ class Translator:
128
134
  ----------
129
135
  n
130
136
  Node to translate, an integer referencing a polars internal node.
137
+ schema
138
+ Schema of the IR node this expression uses as evaluation context.
131
139
 
132
140
  Returns
133
141
  -------
@@ -143,7 +151,7 @@ class Translator:
143
151
  node = self.visitor.view_expression(n)
144
152
  dtype = dtypes.from_polars(self.visitor.get_dtype(n))
145
153
  try:
146
- return _translate_expr(node, self, dtype)
154
+ return _translate_expr(node, self, dtype, schema)
147
155
  except Exception as e:
148
156
  self.errors.append(e)
149
157
  return expr.ErrorExpr(dtype, str(e))
@@ -168,6 +176,7 @@ class set_node(AbstractContextManager[None]):
168
176
 
169
177
  __slots__ = ("n", "visitor")
170
178
  visitor: NodeTraverser
179
+
171
180
  n: int
172
181
 
173
182
  def __init__(self, visitor: NodeTraverser, n: int) -> None:
@@ -187,30 +196,26 @@ noop_context: nullcontext[None] = nullcontext()
187
196
 
188
197
 
189
198
  @singledispatch
190
- def _translate_ir(
191
- node: Any, translator: Translator, schema: dict[str, plc.DataType]
192
- ) -> ir.IR:
199
+ def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR:
193
200
  raise NotImplementedError(
194
201
  f"Translation for {type(node).__name__}"
195
202
  ) # pragma: no cover
196
203
 
197
204
 
198
205
  @_translate_ir.register
199
- def _(
200
- node: pl_ir.PythonScan, translator: Translator, schema: dict[str, plc.DataType]
201
- ) -> ir.IR:
206
+ def _(node: pl_ir.PythonScan, translator: Translator, schema: Schema) -> ir.IR:
202
207
  scan_fn, with_columns, source_type, predicate, nrows = node.options
203
208
  options = (scan_fn, with_columns, source_type, nrows)
204
209
  predicate = (
205
- translate_named_expr(translator, n=predicate) if predicate is not None else None
210
+ translate_named_expr(translator, n=predicate, schema=schema)
211
+ if predicate is not None
212
+ else None
206
213
  )
207
214
  return ir.PythonScan(schema, options, predicate)
208
215
 
209
216
 
210
217
  @_translate_ir.register
211
- def _(
212
- node: pl_ir.Scan, translator: Translator, schema: dict[str, plc.DataType]
213
- ) -> ir.IR:
218
+ def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR:
214
219
  typ, *options = node.scan_type
215
220
  if typ == "ndjson":
216
221
  (reader_options,) = map(json.loads, options)
@@ -219,118 +224,102 @@ def _(
219
224
  reader_options, cloud_options = map(json.loads, options)
220
225
  file_options = node.file_options
221
226
  with_columns = file_options.with_columns
222
- n_rows = file_options.n_rows
223
- if n_rows is None:
224
- n_rows = -1 # All rows
225
- skip_rows = 0 # Don't skip
227
+ row_index = file_options.row_index
228
+ include_file_paths = file_options.include_file_paths
229
+
230
+ pre_slice = file_options.n_rows
231
+ if pre_slice is None:
232
+ n_rows = -1
233
+ skip_rows = 0
226
234
  else:
227
- # TODO: with versioning, rename on the rust side
228
- skip_rows, n_rows = n_rows
235
+ skip_rows, n_rows = pre_slice
229
236
 
230
- row_index = file_options.row_index
231
237
  return ir.Scan(
232
238
  schema,
233
239
  typ,
234
240
  reader_options,
235
241
  cloud_options,
236
- translator.config.config.copy(),
242
+ translator.config_options,
237
243
  node.paths,
238
244
  with_columns,
239
245
  skip_rows,
240
246
  n_rows,
241
247
  row_index,
242
- translate_named_expr(translator, n=node.predicate)
248
+ include_file_paths,
249
+ translate_named_expr(translator, n=node.predicate, schema=schema)
243
250
  if node.predicate is not None
244
251
  else None,
245
252
  )
246
253
 
247
254
 
248
255
  @_translate_ir.register
249
- def _(
250
- node: pl_ir.Cache, translator: Translator, schema: dict[str, plc.DataType]
251
- ) -> ir.IR:
252
- return ir.Cache(schema, node.id_, translator.translate_ir(n=node.input))
256
+ def _(node: pl_ir.Cache, translator: Translator, schema: Schema) -> ir.IR:
257
+ return ir.Cache(
258
+ schema, node.id_, node.cache_hits, translator.translate_ir(n=node.input)
259
+ )
253
260
 
254
261
 
255
262
  @_translate_ir.register
256
- def _(
257
- node: pl_ir.DataFrameScan, translator: Translator, schema: dict[str, plc.DataType]
258
- ) -> ir.IR:
263
+ def _(node: pl_ir.DataFrameScan, translator: Translator, schema: Schema) -> ir.IR:
259
264
  return ir.DataFrameScan(
260
265
  schema,
261
266
  node.df,
262
267
  node.projection,
263
- translator.config.config.copy(),
268
+ translator.config_options,
264
269
  )
265
270
 
266
271
 
267
272
  @_translate_ir.register
268
- def _(
269
- node: pl_ir.Select, translator: Translator, schema: dict[str, plc.DataType]
270
- ) -> ir.IR:
273
+ def _(node: pl_ir.Select, translator: Translator, schema: Schema) -> ir.IR:
271
274
  with set_node(translator.visitor, node.input):
272
275
  inp = translator.translate_ir(n=None)
273
- exprs = [translate_named_expr(translator, n=e) for e in node.expr]
276
+ exprs = [
277
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.expr
278
+ ]
274
279
  return ir.Select(schema, exprs, node.should_broadcast, inp)
275
280
 
276
281
 
277
282
  @_translate_ir.register
278
- def _(
279
- node: pl_ir.GroupBy, translator: Translator, schema: dict[str, plc.DataType]
280
- ) -> ir.IR:
283
+ def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
281
284
  with set_node(translator.visitor, node.input):
282
285
  inp = translator.translate_ir(n=None)
283
- aggs = [translate_named_expr(translator, n=e) for e in node.aggs]
284
- keys = [translate_named_expr(translator, n=e) for e in node.keys]
285
- return ir.GroupBy(
286
- schema,
287
- keys,
288
- aggs,
289
- node.maintain_order,
290
- node.options,
291
- inp,
292
- )
286
+ keys = [
287
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.keys
288
+ ]
289
+ original_aggs = [
290
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.aggs
291
+ ]
292
+ is_rolling = node.options.rolling is not None
293
+ is_dynamic = node.options.dynamic is not None
294
+ if is_dynamic:
295
+ raise NotImplementedError("group_by_dynamic")
296
+ elif is_rolling:
297
+ return rewrite_rolling(
298
+ node.options, schema, keys, original_aggs, translator.config_options, inp
299
+ )
300
+ else:
301
+ return rewrite_groupby(
302
+ node, schema, keys, original_aggs, translator.config_options, inp
303
+ )
293
304
 
294
305
 
295
306
  @_translate_ir.register
296
- def _(
297
- node: pl_ir.Join, translator: Translator, schema: dict[str, plc.DataType]
298
- ) -> ir.IR:
307
+ def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
299
308
  # Join key dtypes are dependent on the schema of the left and
300
309
  # right inputs, so these must be translated with the relevant
301
310
  # input active.
302
- def adjust_literal_dtype(literal: expr.Literal) -> expr.Literal:
303
- if literal.dtype.id() == plc.types.TypeId.INT32:
304
- plc_int64 = plc.types.DataType(plc.types.TypeId.INT64)
305
- return expr.Literal(
306
- plc_int64,
307
- pa.scalar(literal.value.as_py(), type=plc.interop.to_arrow(plc_int64)),
308
- )
309
- return literal
310
-
311
- def maybe_adjust_binop(e) -> expr.Expr:
312
- if isinstance(e.value, expr.BinOp):
313
- left, right = e.value.children
314
- if isinstance(left, expr.Col) and isinstance(right, expr.Literal):
315
- e.value.children = (left, adjust_literal_dtype(right))
316
- elif isinstance(left, expr.Literal) and isinstance(right, expr.Col):
317
- e.value.children = (adjust_literal_dtype(left), right)
318
- return e
319
-
320
- def translate_expr_and_maybe_fix_binop_args(translator, exprs):
321
- return [
322
- maybe_adjust_binop(translate_named_expr(translator, n=e)) for e in exprs
323
- ]
324
-
325
311
  with set_node(translator.visitor, node.input_left):
326
312
  inp_left = translator.translate_ir(n=None)
327
- # TODO: There's bug in the polars type coercion phase. Use
328
- # translate_named_expr directly once it is resolved.
329
- # Tracking issue: https://github.com/pola-rs/polars/issues/20935
330
- left_on = translate_expr_and_maybe_fix_binop_args(translator, node.left_on)
313
+ left_on = [
314
+ translate_named_expr(translator, n=e, schema=inp_left.schema)
315
+ for e in node.left_on
316
+ ]
331
317
  with set_node(translator.visitor, node.input_right):
332
318
  inp_right = translator.translate_ir(n=None)
333
- right_on = translate_expr_and_maybe_fix_binop_args(translator, node.right_on)
319
+ right_on = [
320
+ translate_named_expr(translator, n=e, schema=inp_right.schema)
321
+ for e in node.right_on
322
+ ]
334
323
 
335
324
  if (how := node.options[0]) in {
336
325
  "Inner",
@@ -341,7 +330,15 @@ def _(
341
330
  "Semi",
342
331
  "Anti",
343
332
  }:
344
- return ir.Join(schema, left_on, right_on, node.options, inp_left, inp_right)
333
+ return ir.Join(
334
+ schema,
335
+ left_on,
336
+ right_on,
337
+ node.options,
338
+ translator.config_options,
339
+ inp_left,
340
+ inp_right,
341
+ )
345
342
  else:
346
343
  how, op1, op2 = node.options[0]
347
344
  if how != "IEJoin":
@@ -385,29 +382,29 @@ def _(
385
382
 
386
383
 
387
384
  @_translate_ir.register
388
- def _(
389
- node: pl_ir.HStack, translator: Translator, schema: dict[str, plc.DataType]
390
- ) -> ir.IR:
385
+ def _(node: pl_ir.HStack, translator: Translator, schema: Schema) -> ir.IR:
391
386
  with set_node(translator.visitor, node.input):
392
387
  inp = translator.translate_ir(n=None)
393
- exprs = [translate_named_expr(translator, n=e) for e in node.exprs]
388
+ exprs = [
389
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.exprs
390
+ ]
394
391
  return ir.HStack(schema, exprs, node.should_broadcast, inp)
395
392
 
396
393
 
397
394
  @_translate_ir.register
398
395
  def _(
399
- node: pl_ir.Reduce, translator: Translator, schema: dict[str, plc.DataType]
396
+ node: pl_ir.Reduce, translator: Translator, schema: Schema
400
397
  ) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet
401
398
  with set_node(translator.visitor, node.input):
402
399
  inp = translator.translate_ir(n=None)
403
- exprs = [translate_named_expr(translator, n=e) for e in node.expr]
400
+ exprs = [
401
+ translate_named_expr(translator, n=e, schema=inp.schema) for e in node.expr
402
+ ]
404
403
  return ir.Reduce(schema, exprs, inp)
405
404
 
406
405
 
407
406
  @_translate_ir.register
408
- def _(
409
- node: pl_ir.Distinct, translator: Translator, schema: dict[str, plc.DataType]
410
- ) -> ir.IR:
407
+ def _(node: pl_ir.Distinct, translator: Translator, schema: Schema) -> ir.IR:
411
408
  (keep, subset, maintain_order, zlice) = node.options
412
409
  keep = ir.Distinct._KEEP_MAP[keep]
413
410
  subset = frozenset(subset) if subset is not None else None
@@ -422,12 +419,13 @@ def _(
422
419
 
423
420
 
424
421
  @_translate_ir.register
425
- def _(
426
- node: pl_ir.Sort, translator: Translator, schema: dict[str, plc.DataType]
427
- ) -> ir.IR:
422
+ def _(node: pl_ir.Sort, translator: Translator, schema: Schema) -> ir.IR:
428
423
  with set_node(translator.visitor, node.input):
429
424
  inp = translator.translate_ir(n=None)
430
- by = [translate_named_expr(translator, n=e) for e in node.by_column]
425
+ by = [
426
+ translate_named_expr(translator, n=e, schema=inp.schema)
427
+ for e in node.by_column
428
+ ]
431
429
  stable, nulls_last, descending = node.sort_options
432
430
  order, null_order = sorting.sort_order(
433
431
  descending, nulls_last=nulls_last, num_keys=len(by)
@@ -436,65 +434,101 @@ def _(
436
434
 
437
435
 
438
436
  @_translate_ir.register
439
- def _(
440
- node: pl_ir.Slice, translator: Translator, schema: dict[str, plc.DataType]
441
- ) -> ir.IR:
437
+ def _(node: pl_ir.Slice, translator: Translator, schema: Schema) -> ir.IR:
442
438
  return ir.Slice(
443
439
  schema, node.offset, node.len, translator.translate_ir(n=node.input)
444
440
  )
445
441
 
446
442
 
447
443
  @_translate_ir.register
448
- def _(
449
- node: pl_ir.Filter, translator: Translator, schema: dict[str, plc.DataType]
450
- ) -> ir.IR:
444
+ def _(node: pl_ir.Filter, translator: Translator, schema: Schema) -> ir.IR:
451
445
  with set_node(translator.visitor, node.input):
452
446
  inp = translator.translate_ir(n=None)
453
- mask = translate_named_expr(translator, n=node.predicate)
447
+ mask = translate_named_expr(translator, n=node.predicate, schema=inp.schema)
454
448
  return ir.Filter(schema, mask, inp)
455
449
 
456
450
 
457
451
  @_translate_ir.register
458
- def _(
459
- node: pl_ir.SimpleProjection,
460
- translator: Translator,
461
- schema: dict[str, plc.DataType],
462
- ) -> ir.IR:
452
+ def _(node: pl_ir.SimpleProjection, translator: Translator, schema: Schema) -> ir.IR:
463
453
  return ir.Projection(schema, translator.translate_ir(n=node.input))
464
454
 
465
455
 
466
456
  @_translate_ir.register
467
- def _(
468
- node: pl_ir.MapFunction, translator: Translator, schema: dict[str, plc.DataType]
469
- ) -> ir.IR:
457
+ def _(node: pl_ir.MergeSorted, translator: Translator, schema: Schema) -> ir.IR:
458
+ key = node.key
459
+ inp_left = translator.translate_ir(n=node.input_left)
460
+ inp_right = translator.translate_ir(n=node.input_right)
461
+ return ir.MergeSorted(
462
+ schema,
463
+ key,
464
+ inp_left,
465
+ inp_right,
466
+ )
467
+
468
+
469
+ @_translate_ir.register
470
+ def _(node: pl_ir.MapFunction, translator: Translator, schema: Schema) -> ir.IR:
470
471
  name, *options = node.function
471
472
  return ir.MapFunction(
472
473
  schema,
473
474
  name,
474
475
  options,
475
- # TODO: merge_sorted breaks this pattern
476
476
  translator.translate_ir(n=node.input),
477
477
  )
478
478
 
479
479
 
480
480
  @_translate_ir.register
481
- def _(
482
- node: pl_ir.Union, translator: Translator, schema: dict[str, plc.DataType]
483
- ) -> ir.IR:
481
+ def _(node: pl_ir.Union, translator: Translator, schema: Schema) -> ir.IR:
484
482
  return ir.Union(
485
483
  schema, node.options, *(translator.translate_ir(n=n) for n in node.inputs)
486
484
  )
487
485
 
488
486
 
487
+ @_translate_ir.register
488
+ def _(node: pl_ir.HConcat, translator: Translator, schema: Schema) -> ir.IR:
489
+ return ir.HConcat(
490
+ schema,
491
+ False, # noqa: FBT003
492
+ *(translator.translate_ir(n=n) for n in node.inputs),
493
+ )
494
+
495
+
489
496
  @_translate_ir.register
490
497
  def _(
491
- node: pl_ir.HConcat, translator: Translator, schema: dict[str, plc.DataType]
498
+ node: pl_ir.Sink, translator: Translator, schema: dict[str, plc.DataType]
492
499
  ) -> ir.IR:
493
- return ir.HConcat(schema, *(translator.translate_ir(n=n) for n in node.inputs))
500
+ payload = json.loads(node.payload)
501
+ try:
502
+ file = payload["File"]
503
+ sink_kind_options = file["file_type"]
504
+ except KeyError as err: # pragma: no cover
505
+ raise NotImplementedError("Unsupported payload structure") from err
506
+ if isinstance(sink_kind_options, dict):
507
+ if len(sink_kind_options) != 1: # pragma: no cover; not sure if this can happen
508
+ raise NotImplementedError("Sink options dict with more than one entry.")
509
+ sink_kind, options = next(iter(sink_kind_options.items()))
510
+ else:
511
+ raise NotImplementedError(
512
+ "Unsupported sink options structure"
513
+ ) # pragma: no cover
514
+
515
+ sink_options = file.get("sink_options", {})
516
+ cloud_options = file.get("cloud_options")
517
+
518
+ options.update(sink_options)
519
+
520
+ return ir.Sink(
521
+ schema=schema,
522
+ kind=sink_kind,
523
+ path=file["target"],
524
+ options=options,
525
+ cloud_options=cloud_options,
526
+ df=translator.translate_ir(n=node.input),
527
+ )
494
528
 
495
529
 
496
530
  def translate_named_expr(
497
- translator: Translator, *, n: pl_expr.PyExprIR
531
+ translator: Translator, *, n: pl_expr.PyExprIR, schema: Schema
498
532
  ) -> expr.NamedExpr:
499
533
  """
500
534
  Translate a polars-internal named expression IR object into our representation.
@@ -505,6 +539,8 @@ def translate_named_expr(
505
539
  Translator object
506
540
  n
507
541
  Node to translate, a named expression node.
542
+ schema
543
+ Schema of the IR node this expression uses as evaluation context.
508
544
 
509
545
  Returns
510
546
  -------
@@ -522,12 +558,14 @@ def translate_named_expr(
522
558
  NotImplementedError
523
559
  If any translation fails due to unsupported functionality.
524
560
  """
525
- return expr.NamedExpr(n.output_name, translator.translate_expr(n=n.node))
561
+ return expr.NamedExpr(
562
+ n.output_name, translator.translate_expr(n=n.node, schema=schema)
563
+ )
526
564
 
527
565
 
528
566
  @singledispatch
529
567
  def _translate_expr(
530
- node: Any, translator: Translator, dtype: plc.DataType
568
+ node: Any, translator: Translator, dtype: plc.DataType, schema: Schema
531
569
  ) -> expr.Expr:
532
570
  raise NotImplementedError(
533
571
  f"Translation for {type(node).__name__}"
@@ -535,7 +573,9 @@ def _translate_expr(
535
573
 
536
574
 
537
575
  @_translate_expr.register
538
- def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> expr.Expr:
576
+ def _(
577
+ node: pl_expr.Function, translator: Translator, dtype: plc.DataType, schema: Schema
578
+ ) -> expr.Expr:
539
579
  name, *options = node.function_data
540
580
  options = tuple(options)
541
581
  if isinstance(name, pl_expr.StringFunction):
@@ -544,18 +584,20 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
544
584
  pl_expr.StringFunction.StripCharsStart,
545
585
  pl_expr.StringFunction.StripCharsEnd,
546
586
  }:
547
- column, chars = (translator.translate_expr(n=n) for n in node.input)
587
+ column, chars = (
588
+ translator.translate_expr(n=n, schema=schema) for n in node.input
589
+ )
548
590
  if isinstance(chars, expr.Literal):
549
- if chars.value == pa.scalar(""):
591
+ # We check for null first because we want to use the
592
+ # chars pyarrow type, but it is invalid to try and
593
+ # produce a string scalar with a null dtype.
594
+ if chars.value is None:
595
+ # Polars uses None to mean "strip all whitespace"
596
+ chars = expr.Literal(column.dtype, "")
597
+ elif chars.value == "":
550
598
  # No-op in polars, but libcudf uses empty string
551
599
  # as signifier to remove whitespace.
552
600
  return column
553
- elif chars.value == pa.scalar(None):
554
- # Polars uses None to mean "strip all whitespace"
555
- chars = expr.Literal(
556
- column.dtype,
557
- pa.scalar("", type=plc.interop.to_arrow(column.dtype)),
558
- )
559
601
  return expr.StringFunction(
560
602
  dtype,
561
603
  expr.StringFunction.Name.from_polars(name),
@@ -567,11 +609,13 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
567
609
  dtype,
568
610
  expr.StringFunction.Name.from_polars(name),
569
611
  options,
570
- *(translator.translate_expr(n=n) for n in node.input),
612
+ *(translator.translate_expr(n=n, schema=schema) for n in node.input),
571
613
  )
572
614
  elif isinstance(name, pl_expr.BooleanFunction):
573
615
  if name == pl_expr.BooleanFunction.IsBetween:
574
- column, lo, hi = (translator.translate_expr(n=n) for n in node.input)
616
+ column, lo, hi = (
617
+ translator.translate_expr(n=n, schema=schema) for n in node.input
618
+ )
575
619
  (closed,) = options
576
620
  lop, rop = expr.BooleanFunction._BETWEEN_OPS[closed]
577
621
  return expr.BinOp(
@@ -584,7 +628,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
584
628
  dtype,
585
629
  expr.BooleanFunction.Name.from_polars(name),
586
630
  options,
587
- *(translator.translate_expr(n=n) for n in node.input),
631
+ *(translator.translate_expr(n=n, schema=schema) for n in node.input),
588
632
  )
589
633
  elif isinstance(name, pl_expr.TemporalFunction):
590
634
  # functions for which evaluation of the expression may not return
@@ -604,14 +648,14 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
604
648
  dtype,
605
649
  expr.TemporalFunction.Name.from_polars(name),
606
650
  options,
607
- *(translator.translate_expr(n=n) for n in node.input),
651
+ *(translator.translate_expr(n=n, schema=schema) for n in node.input),
608
652
  )
609
653
  if name in needs_cast:
610
654
  return expr.Cast(dtype, result_expr)
611
655
  return result_expr
612
656
 
613
657
  elif isinstance(name, str):
614
- children = (translator.translate_expr(n=n) for n in node.input)
658
+ children = (translator.translate_expr(n=n, schema=schema) for n in node.input)
615
659
  if name == "log":
616
660
  (base,) = options
617
661
  (child,) = children
@@ -619,10 +663,21 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
619
663
  dtype,
620
664
  plc.binaryop.BinaryOperator.LOG_BASE,
621
665
  child,
622
- expr.Literal(dtype, pa.scalar(base, type=plc.interop.to_arrow(dtype))),
666
+ expr.Literal(dtype, base),
623
667
  )
624
668
  elif name == "pow":
625
669
  return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children)
670
+ elif name in "top_k":
671
+ (col, k) = children
672
+ assert isinstance(k, expr.Literal)
673
+ (descending,) = options
674
+ return expr.Slice(
675
+ dtype,
676
+ 0,
677
+ k.value,
678
+ expr.Sort(dtype, (False, True, not descending), col),
679
+ )
680
+
626
681
  return expr.UnaryFunction(dtype, name, options, *children)
627
682
  raise NotImplementedError(
628
683
  f"No handler for Expr function node with {name=}"
@@ -630,73 +685,155 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
630
685
 
631
686
 
632
687
  @_translate_expr.register
633
- def _(node: pl_expr.Window, translator: Translator, dtype: plc.DataType) -> expr.Expr:
634
- # TODO: raise in groupby?
688
+ def _(
689
+ node: pl_expr.Window, translator: Translator, dtype: plc.DataType, schema: Schema
690
+ ) -> expr.Expr:
635
691
  if isinstance(node.options, pl_expr.RollingGroupOptions):
636
692
  # pl.col("a").rolling(...)
637
- return expr.RollingWindow(
638
- dtype, node.options, translator.translate_expr(n=node.function)
693
+ agg = translator.translate_expr(n=node.function, schema=schema)
694
+ name_generator = unique_names(schema)
695
+ aggs, named_post_agg = decompose_single_agg(
696
+ expr.NamedExpr(next(name_generator), agg), name_generator, is_top=True
639
697
  )
698
+ named_aggs = [agg for agg, _ in aggs]
699
+ orderby = node.options.index_column
700
+ orderby_dtype = schema[orderby]
701
+ if plc.traits.is_integral(orderby_dtype):
702
+ # Integer orderby column is cast in implementation to int64 in polars
703
+ orderby_dtype = plc.DataType(plc.TypeId.INT64)
704
+ preceding, following = offsets_to_windows(
705
+ orderby_dtype,
706
+ node.options.offset,
707
+ node.options.period,
708
+ )
709
+ closed_window = node.options.closed_window
710
+ if isinstance(named_post_agg.value, expr.Col):
711
+ (named_agg,) = named_aggs
712
+ return expr.RollingWindow(
713
+ named_agg.value.dtype,
714
+ preceding,
715
+ following,
716
+ closed_window,
717
+ orderby,
718
+ named_agg.value,
719
+ )
720
+ replacements: dict[expr.Expr, expr.Expr] = {
721
+ expr.Col(agg.value.dtype, agg.name): expr.RollingWindow(
722
+ agg.value.dtype,
723
+ preceding,
724
+ following,
725
+ closed_window,
726
+ orderby,
727
+ agg.value,
728
+ )
729
+ for agg in named_aggs
730
+ }
731
+ return replace([named_post_agg.value], replacements)[0]
640
732
  elif isinstance(node.options, pl_expr.WindowMapping):
641
733
  # pl.col("a").over(...)
642
734
  return expr.GroupedRollingWindow(
643
735
  dtype,
644
736
  node.options,
645
- translator.translate_expr(n=node.function),
646
- *(translator.translate_expr(n=n) for n in node.partition_by),
737
+ translator.translate_expr(n=node.function, schema=schema),
738
+ *(translator.translate_expr(n=n, schema=schema) for n in node.partition_by),
647
739
  )
648
740
  assert_never(node.options)
649
741
 
650
742
 
651
743
  @_translate_expr.register
652
- def _(node: pl_expr.Literal, translator: Translator, dtype: plc.DataType) -> expr.Expr:
744
+ def _(
745
+ node: pl_expr.Literal, translator: Translator, dtype: plc.DataType, schema: Schema
746
+ ) -> expr.Expr:
653
747
  if isinstance(node.value, plrs.PySeries):
654
- return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value))
655
- value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype))
656
- return expr.Literal(dtype, value)
748
+ data = pl.Series._from_pyseries(node.value).to_arrow(
749
+ compat_level=dtypes.TO_ARROW_COMPAT_LEVEL
750
+ )
751
+ return expr.LiteralColumn(
752
+ dtype, data.cast(dtypes.downcast_arrow_lists(data.type))
753
+ )
754
+ if dtype.id() == plc.TypeId.LIST: # pragma: no cover
755
+ # TODO: Find an alternative to pa.infer_type
756
+ data = pa.array(node.value, type=pa.infer_type(node.value))
757
+ return expr.LiteralColumn(dtype, data)
758
+ return expr.Literal(dtype, node.value)
657
759
 
658
760
 
659
761
  @_translate_expr.register
660
- def _(node: pl_expr.Sort, translator: Translator, dtype: plc.DataType) -> expr.Expr:
762
+ def _(
763
+ node: pl_expr.Sort, translator: Translator, dtype: plc.DataType, schema: Schema
764
+ ) -> expr.Expr:
661
765
  # TODO: raise in groupby
662
- return expr.Sort(dtype, node.options, translator.translate_expr(n=node.expr))
766
+ return expr.Sort(
767
+ dtype, node.options, translator.translate_expr(n=node.expr, schema=schema)
768
+ )
663
769
 
664
770
 
665
771
  @_translate_expr.register
666
- def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr.Expr:
772
+ def _(
773
+ node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType, schema: Schema
774
+ ) -> expr.Expr:
667
775
  options = node.sort_options
668
776
  return expr.SortBy(
669
777
  dtype,
670
778
  (options[0], tuple(options[1]), tuple(options[2])),
671
- translator.translate_expr(n=node.expr),
672
- *(translator.translate_expr(n=n) for n in node.by),
779
+ translator.translate_expr(n=node.expr, schema=schema),
780
+ *(translator.translate_expr(n=n, schema=schema) for n in node.by),
673
781
  )
674
782
 
675
783
 
676
784
  @_translate_expr.register
677
- def _(node: pl_expr.Gather, translator: Translator, dtype: plc.DataType) -> expr.Expr:
785
+ def _(
786
+ node: pl_expr.Slice, translator: Translator, dtype: plc.DataType, schema: Schema
787
+ ) -> expr.Expr:
788
+ offset = translator.translate_expr(n=node.offset, schema=schema)
789
+ length = translator.translate_expr(n=node.length, schema=schema)
790
+ assert isinstance(offset, expr.Literal)
791
+ assert isinstance(length, expr.Literal)
792
+ return expr.Slice(
793
+ dtype,
794
+ offset.value,
795
+ length.value,
796
+ translator.translate_expr(n=node.input, schema=schema),
797
+ )
798
+
799
+
800
+ @_translate_expr.register
801
+ def _(
802
+ node: pl_expr.Gather, translator: Translator, dtype: plc.DataType, schema: Schema
803
+ ) -> expr.Expr:
678
804
  return expr.Gather(
679
805
  dtype,
680
- translator.translate_expr(n=node.expr),
681
- translator.translate_expr(n=node.idx),
806
+ translator.translate_expr(n=node.expr, schema=schema),
807
+ translator.translate_expr(n=node.idx, schema=schema),
682
808
  )
683
809
 
684
810
 
685
811
  @_translate_expr.register
686
- def _(node: pl_expr.Filter, translator: Translator, dtype: plc.DataType) -> expr.Expr:
812
+ def _(
813
+ node: pl_expr.Filter, translator: Translator, dtype: plc.DataType, schema: Schema
814
+ ) -> expr.Expr:
687
815
  return expr.Filter(
688
816
  dtype,
689
- translator.translate_expr(n=node.input),
690
- translator.translate_expr(n=node.by),
817
+ translator.translate_expr(n=node.input, schema=schema),
818
+ translator.translate_expr(n=node.by, schema=schema),
691
819
  )
692
820
 
693
821
 
694
822
  @_translate_expr.register
695
- def _(node: pl_expr.Cast, translator: Translator, dtype: plc.DataType) -> expr.Expr:
696
- inner = translator.translate_expr(n=node.expr)
823
+ def _(
824
+ node: pl_expr.Cast, translator: Translator, dtype: plc.DataType, schema: Schema
825
+ ) -> expr.Expr:
826
+ inner = translator.translate_expr(n=node.expr, schema=schema)
697
827
  # Push casts into literals so we can handle Cast(Literal(Null))
698
828
  if isinstance(inner, expr.Literal):
699
- return expr.Literal(dtype, inner.value.cast(plc.interop.to_arrow(dtype)))
829
+ plc_column = plc.Column.from_scalar(
830
+ plc.Scalar.from_py(inner.value, inner.dtype), 1
831
+ )
832
+ casted_column = plc.unary.cast(plc_column, dtype)
833
+ casted_py_scalar = plc.interop.to_arrow(
834
+ plc.copying.get_element(casted_column, 0)
835
+ ).as_py()
836
+ return expr.Literal(dtype, casted_py_scalar)
700
837
  elif isinstance(inner, expr.Cast):
701
838
  # Translation of Len/Count-agg put in a cast, remove double
702
839
  # casts if we have one.
@@ -705,17 +842,21 @@ def _(node: pl_expr.Cast, translator: Translator, dtype: plc.DataType) -> expr.E
705
842
 
706
843
 
707
844
  @_translate_expr.register
708
- def _(node: pl_expr.Column, translator: Translator, dtype: plc.DataType) -> expr.Expr:
845
+ def _(
846
+ node: pl_expr.Column, translator: Translator, dtype: plc.DataType, schema: Schema
847
+ ) -> expr.Expr:
709
848
  return expr.Col(dtype, node.name)
710
849
 
711
850
 
712
851
  @_translate_expr.register
713
- def _(node: pl_expr.Agg, translator: Translator, dtype: plc.DataType) -> expr.Expr:
852
+ def _(
853
+ node: pl_expr.Agg, translator: Translator, dtype: plc.DataType, schema: Schema
854
+ ) -> expr.Expr:
714
855
  value = expr.Agg(
715
856
  dtype,
716
857
  node.name,
717
858
  node.options,
718
- *(translator.translate_expr(n=n) for n in node.arguments),
859
+ *(translator.translate_expr(n=n, schema=schema) for n in node.arguments),
719
860
  )
720
861
  if value.name == "count" and value.dtype.id() != plc.TypeId.INT32:
721
862
  return expr.Cast(value.dtype, value)
@@ -723,29 +864,36 @@ def _(node: pl_expr.Agg, translator: Translator, dtype: plc.DataType) -> expr.Ex
723
864
 
724
865
 
725
866
  @_translate_expr.register
726
- def _(node: pl_expr.Ternary, translator: Translator, dtype: plc.DataType) -> expr.Expr:
867
+ def _(
868
+ node: pl_expr.Ternary, translator: Translator, dtype: plc.DataType, schema: Schema
869
+ ) -> expr.Expr:
727
870
  return expr.Ternary(
728
871
  dtype,
729
- translator.translate_expr(n=node.predicate),
730
- translator.translate_expr(n=node.truthy),
731
- translator.translate_expr(n=node.falsy),
872
+ translator.translate_expr(n=node.predicate, schema=schema),
873
+ translator.translate_expr(n=node.truthy, schema=schema),
874
+ translator.translate_expr(n=node.falsy, schema=schema),
732
875
  )
733
876
 
734
877
 
735
878
  @_translate_expr.register
736
879
  def _(
737
- node: pl_expr.BinaryExpr, translator: Translator, dtype: plc.DataType
880
+ node: pl_expr.BinaryExpr,
881
+ translator: Translator,
882
+ dtype: plc.DataType,
883
+ schema: Schema,
738
884
  ) -> expr.Expr:
739
885
  return expr.BinOp(
740
886
  dtype,
741
887
  expr.BinOp._MAPPING[node.op],
742
- translator.translate_expr(n=node.left),
743
- translator.translate_expr(n=node.right),
888
+ translator.translate_expr(n=node.left, schema=schema),
889
+ translator.translate_expr(n=node.right, schema=schema),
744
890
  )
745
891
 
746
892
 
747
893
  @_translate_expr.register
748
- def _(node: pl_expr.Len, translator: Translator, dtype: plc.DataType) -> expr.Expr:
894
+ def _(
895
+ node: pl_expr.Len, translator: Translator, dtype: plc.DataType, schema: Schema
896
+ ) -> expr.Expr:
749
897
  value = expr.Len(dtype)
750
898
  if dtype.id() != plc.TypeId.INT32:
751
899
  return expr.Cast(dtype, value)