cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/callback.py +82 -65
  3. cudf_polars/containers/column.py +138 -7
  4. cudf_polars/containers/dataframe.py +26 -39
  5. cudf_polars/dsl/expr.py +3 -1
  6. cudf_polars/dsl/expressions/aggregation.py +27 -63
  7. cudf_polars/dsl/expressions/base.py +40 -72
  8. cudf_polars/dsl/expressions/binaryop.py +5 -41
  9. cudf_polars/dsl/expressions/boolean.py +25 -53
  10. cudf_polars/dsl/expressions/datetime.py +97 -17
  11. cudf_polars/dsl/expressions/literal.py +27 -33
  12. cudf_polars/dsl/expressions/rolling.py +110 -9
  13. cudf_polars/dsl/expressions/selection.py +8 -26
  14. cudf_polars/dsl/expressions/slicing.py +47 -0
  15. cudf_polars/dsl/expressions/sorting.py +5 -18
  16. cudf_polars/dsl/expressions/string.py +33 -36
  17. cudf_polars/dsl/expressions/ternary.py +3 -10
  18. cudf_polars/dsl/expressions/unary.py +35 -75
  19. cudf_polars/dsl/ir.py +749 -212
  20. cudf_polars/dsl/nodebase.py +8 -1
  21. cudf_polars/dsl/to_ast.py +5 -3
  22. cudf_polars/dsl/translate.py +319 -171
  23. cudf_polars/dsl/utils/__init__.py +8 -0
  24. cudf_polars/dsl/utils/aggregations.py +292 -0
  25. cudf_polars/dsl/utils/groupby.py +97 -0
  26. cudf_polars/dsl/utils/naming.py +34 -0
  27. cudf_polars/dsl/utils/replace.py +46 -0
  28. cudf_polars/dsl/utils/rolling.py +113 -0
  29. cudf_polars/dsl/utils/windows.py +186 -0
  30. cudf_polars/experimental/base.py +17 -19
  31. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  32. cudf_polars/experimental/benchmarks/pdsh.py +1279 -0
  33. cudf_polars/experimental/dask_registers.py +196 -0
  34. cudf_polars/experimental/distinct.py +174 -0
  35. cudf_polars/experimental/explain.py +127 -0
  36. cudf_polars/experimental/expressions.py +521 -0
  37. cudf_polars/experimental/groupby.py +288 -0
  38. cudf_polars/experimental/io.py +58 -29
  39. cudf_polars/experimental/join.py +353 -0
  40. cudf_polars/experimental/parallel.py +166 -93
  41. cudf_polars/experimental/repartition.py +69 -0
  42. cudf_polars/experimental/scheduler.py +155 -0
  43. cudf_polars/experimental/select.py +92 -7
  44. cudf_polars/experimental/shuffle.py +294 -0
  45. cudf_polars/experimental/sort.py +45 -0
  46. cudf_polars/experimental/spilling.py +151 -0
  47. cudf_polars/experimental/utils.py +100 -0
  48. cudf_polars/testing/asserts.py +146 -6
  49. cudf_polars/testing/io.py +72 -0
  50. cudf_polars/testing/plugin.py +78 -76
  51. cudf_polars/typing/__init__.py +59 -6
  52. cudf_polars/utils/config.py +353 -0
  53. cudf_polars/utils/conversion.py +40 -0
  54. cudf_polars/utils/dtypes.py +22 -5
  55. cudf_polars/utils/timer.py +39 -0
  56. cudf_polars/utils/versions.py +5 -4
  57. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/METADATA +10 -7
  58. cudf_polars_cu12-25.6.0.dist-info/RECORD +73 -0
  59. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/WHEEL +1 -1
  60. cudf_polars/experimental/dask_serialize.py +0 -59
  61. cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
  62. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info/licenses}/LICENSE +0 -0
  63. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/ir.py CHANGED
@@ -15,6 +15,8 @@ from __future__ import annotations
15
15
 
16
16
  import itertools
17
17
  import json
18
+ import random
19
+ import time
18
20
  from functools import cache
19
21
  from pathlib import Path
20
22
  from typing import TYPE_CHECKING, Any, ClassVar
@@ -28,17 +30,25 @@ import pylibcudf as plc
28
30
 
29
31
  import cudf_polars.dsl.expr as expr
30
32
  from cudf_polars.containers import Column, DataFrame
33
+ from cudf_polars.dsl.expressions import rolling
34
+ from cudf_polars.dsl.expressions.base import ExecutionContext
31
35
  from cudf_polars.dsl.nodebase import Node
32
36
  from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
37
+ from cudf_polars.dsl.utils.windows import range_window_bounds
33
38
  from cudf_polars.utils import dtypes
39
+ from cudf_polars.utils.versions import POLARS_VERSION_LT_128
34
40
 
35
41
  if TYPE_CHECKING:
36
- from collections.abc import Callable, Hashable, Iterable, MutableMapping, Sequence
42
+ from collections.abc import Callable, Hashable, Iterable, Sequence
37
43
  from typing import Literal
38
44
 
45
+ from typing_extensions import Self
46
+
39
47
  from polars.polars import _expr_nodes as pl_expr
40
48
 
41
- from cudf_polars.typing import Schema
49
+ from cudf_polars.typing import CSECache, ClosedInterval, Schema, Slice as Zlice
50
+ from cudf_polars.utils.config import ConfigOptions
51
+ from cudf_polars.utils.timer import Timer
42
52
 
43
53
 
44
54
  __all__ = [
@@ -47,6 +57,7 @@ __all__ = [
47
57
  "ConditionalJoin",
48
58
  "DataFrameScan",
49
59
  "Distinct",
60
+ "Empty",
50
61
  "ErrorNode",
51
62
  "Filter",
52
63
  "GroupBy",
@@ -54,10 +65,14 @@ __all__ = [
54
65
  "HStack",
55
66
  "Join",
56
67
  "MapFunction",
68
+ "MergeSorted",
57
69
  "Projection",
58
70
  "PythonScan",
71
+ "Reduce",
72
+ "Rolling",
59
73
  "Scan",
60
74
  "Select",
75
+ "Sink",
61
76
  "Slice",
62
77
  "Sort",
63
78
  "Union",
@@ -100,7 +115,7 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
100
115
  """
101
116
  if len(columns) == 0:
102
117
  return []
103
- lengths: set[int] = {column.obj.size() for column in columns}
118
+ lengths: set[int] = {column.size for column in columns}
104
119
  if lengths == {1}:
105
120
  if target_length is None:
106
121
  return list(columns)
@@ -116,7 +131,7 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
116
131
  )
117
132
  return [
118
133
  column
119
- if column.obj.size() != 1
134
+ if column.size != 1
120
135
  else Column(
121
136
  plc.Column.from_scalar(column.obj_scalar, nrows),
122
137
  is_sorted=plc.types.Sorted.YES,
@@ -181,7 +196,7 @@ class IR(Node["IR"]):
181
196
  translation phase should fail earlier.
182
197
  """
183
198
 
184
- def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
199
+ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
185
200
  """
186
201
  Evaluate the node (recursively) and return a dataframe.
187
202
 
@@ -190,6 +205,9 @@ class IR(Node["IR"]):
190
205
  cache
191
206
  Mapping from cached node ids to constructed DataFrames.
192
207
  Used to implement evaluation of the `Cache` node.
208
+ timer
209
+ If not None, a Timer object to record timings for the
210
+ evaluation of the node.
193
211
 
194
212
  Notes
195
213
  -----
@@ -208,10 +226,16 @@ class IR(Node["IR"]):
208
226
  If evaluation fails. Ideally this should not occur, since the
209
227
  translation phase should fail earlier.
210
228
  """
211
- return self.do_evaluate(
212
- *self._non_child_args,
213
- *(child.evaluate(cache=cache) for child in self.children),
214
- )
229
+ children = [child.evaluate(cache=cache, timer=timer) for child in self.children]
230
+ if timer is not None:
231
+ start = time.monotonic_ns()
232
+ result = self.do_evaluate(*self._non_child_args, *children)
233
+ end = time.monotonic_ns()
234
+ # TODO: Set better names on each class object.
235
+ timer.store(start, end, type(self).__name__)
236
+ return result
237
+ else:
238
+ return self.do_evaluate(*self._non_child_args, *children)
215
239
 
216
240
 
217
241
  class ErrorNode(IR):
@@ -256,6 +280,7 @@ class Scan(IR):
256
280
  __slots__ = (
257
281
  "cloud_options",
258
282
  "config_options",
283
+ "include_file_paths",
259
284
  "n_rows",
260
285
  "paths",
261
286
  "predicate",
@@ -276,6 +301,7 @@ class Scan(IR):
276
301
  "skip_rows",
277
302
  "n_rows",
278
303
  "row_index",
304
+ "include_file_paths",
279
305
  "predicate",
280
306
  )
281
307
  typ: str
@@ -284,7 +310,7 @@ class Scan(IR):
284
310
  """Reader-specific options, as dictionary."""
285
311
  cloud_options: dict[str, Any] | None
286
312
  """Cloud-related authentication options, currently ignored."""
287
- config_options: dict[str, Any]
313
+ config_options: ConfigOptions
288
314
  """GPU-specific configuration options"""
289
315
  paths: list[str]
290
316
  """List of paths to read from."""
@@ -296,6 +322,8 @@ class Scan(IR):
296
322
  """Number of rows to read after skipping."""
297
323
  row_index: tuple[str, int] | None
298
324
  """If not None add an integer index column of the given name."""
325
+ include_file_paths: str | None
326
+ """Include the path of the source file(s) as a column with this name."""
299
327
  predicate: expr.NamedExpr | None
300
328
  """Mask to apply to the read dataframe."""
301
329
 
@@ -308,12 +336,13 @@ class Scan(IR):
308
336
  typ: str,
309
337
  reader_options: dict[str, Any],
310
338
  cloud_options: dict[str, Any] | None,
311
- config_options: dict[str, Any],
339
+ config_options: ConfigOptions,
312
340
  paths: list[str],
313
341
  with_columns: list[str] | None,
314
342
  skip_rows: int,
315
343
  n_rows: int,
316
344
  row_index: tuple[str, int] | None,
345
+ include_file_paths: str | None,
317
346
  predicate: expr.NamedExpr | None,
318
347
  ):
319
348
  self.schema = schema
@@ -326,6 +355,7 @@ class Scan(IR):
326
355
  self.skip_rows = skip_rows
327
356
  self.n_rows = n_rows
328
357
  self.row_index = row_index
358
+ self.include_file_paths = include_file_paths
329
359
  self.predicate = predicate
330
360
  self._non_child_args = (
331
361
  schema,
@@ -337,6 +367,7 @@ class Scan(IR):
337
367
  skip_rows,
338
368
  n_rows,
339
369
  row_index,
370
+ include_file_paths,
340
371
  predicate,
341
372
  )
342
373
  self.children = ()
@@ -350,7 +381,9 @@ class Scan(IR):
350
381
  # TODO: polars has this implemented for parquet,
351
382
  # maybe we can do this too?
352
383
  raise NotImplementedError("slice pushdown for negative slices")
353
- if self.typ in {"csv"} and self.skip_rows != 0: # pragma: no cover
384
+ if (
385
+ POLARS_VERSION_LT_128 and self.typ in {"csv"} and self.skip_rows != 0
386
+ ): # pragma: no cover
354
387
  # This comes from slice pushdown, but that
355
388
  # optimization doesn't happen right now
356
389
  raise NotImplementedError("skipping rows in CSV reader")
@@ -360,7 +393,7 @@ class Scan(IR):
360
393
  raise NotImplementedError(
361
394
  "Read from cloud storage"
362
395
  ) # pragma: no cover; no test yet
363
- if any(p.startswith("https://") for p in self.paths):
396
+ if any(str(p).startswith("https:/") for p in self.paths):
364
397
  raise NotImplementedError("Read from https")
365
398
  if self.typ == "csv":
366
399
  if self.reader_options["skip_rows_after_header"] != 0:
@@ -379,9 +412,18 @@ class Scan(IR):
379
412
  "Multi-character comment prefix not supported for CSV reader"
380
413
  )
381
414
  if not self.reader_options["has_header"]:
382
- # Need to do some file introspection to get the number
383
- # of columns so that column projection works right.
384
- raise NotImplementedError("Reading CSV without header")
415
+ # TODO: To support reading headerless CSV files without requiring new
416
+ # column names, we would need to do file introspection to infer the number
417
+ # of columns so column projection works right.
418
+ reader_schema = self.reader_options.get("schema")
419
+ if not (
420
+ reader_schema
421
+ and isinstance(schema, dict)
422
+ and "fields" in reader_schema
423
+ ):
424
+ raise NotImplementedError(
425
+ "Reading CSV without header requires user-provided column names via new_columns"
426
+ )
385
427
  elif self.typ == "ndjson":
386
428
  # TODO: consider handling the low memory option here
387
429
  # (maybe use chunked JSON reader)
@@ -389,6 +431,9 @@ class Scan(IR):
389
431
  raise NotImplementedError(
390
432
  "ignore_errors is not supported in the JSON reader"
391
433
  )
434
+ if include_file_paths is not None:
435
+ # TODO: Need to populate num_rows_per_source in read_json in libcudf
436
+ raise NotImplementedError("Including file paths in a json scan.")
392
437
  elif (
393
438
  self.typ == "parquet"
394
439
  and self.row_index is not None
@@ -413,31 +458,60 @@ class Scan(IR):
413
458
  self.typ,
414
459
  json.dumps(self.reader_options),
415
460
  json.dumps(self.cloud_options),
416
- json.dumps(self.config_options),
461
+ self.config_options,
417
462
  tuple(self.paths),
418
463
  tuple(self.with_columns) if self.with_columns is not None else None,
419
464
  self.skip_rows,
420
465
  self.n_rows,
421
466
  self.row_index,
467
+ self.include_file_paths,
422
468
  self.predicate,
423
469
  )
424
470
 
471
+ @staticmethod
472
+ def add_file_paths(
473
+ name: str, paths: list[str], rows_per_path: list[int], df: DataFrame
474
+ ) -> DataFrame:
475
+ """
476
+ Add a Column of file paths to the DataFrame.
477
+
478
+ Each path is repeated according to the number of rows read from it.
479
+ """
480
+ (filepaths,) = plc.filling.repeat(
481
+ # TODO: Remove call from_arrow when we support python list to Column
482
+ plc.Table([plc.interop.from_arrow(pa.array(map(str, paths)))]),
483
+ plc.interop.from_arrow(pa.array(rows_per_path, type=pa.int32())),
484
+ ).columns()
485
+ return df.with_columns([Column(filepaths, name=name)])
486
+
425
487
  @classmethod
426
488
  def do_evaluate(
427
489
  cls,
428
490
  schema: Schema,
429
491
  typ: str,
430
492
  reader_options: dict[str, Any],
431
- config_options: dict[str, Any],
493
+ config_options: ConfigOptions,
432
494
  paths: list[str],
433
495
  with_columns: list[str] | None,
434
496
  skip_rows: int,
435
497
  n_rows: int,
436
498
  row_index: tuple[str, int] | None,
499
+ include_file_paths: str | None,
437
500
  predicate: expr.NamedExpr | None,
438
- ):
501
+ ) -> DataFrame:
439
502
  """Evaluate and return a dataframe."""
440
503
  if typ == "csv":
504
+
505
+ def read_csv_header(
506
+ path: Path | str, sep: str
507
+ ) -> list[str]: # pragma: no cover
508
+ with Path(path).open() as f:
509
+ for line in f:
510
+ stripped = line.strip()
511
+ if stripped:
512
+ return stripped.split(sep)
513
+ return []
514
+
441
515
  parse_options = reader_options["parse_options"]
442
516
  sep = chr(parse_options["separator"])
443
517
  quote = chr(parse_options["quote_char"])
@@ -449,8 +523,8 @@ class Scan(IR):
449
523
  # file provides column names
450
524
  column_names = None
451
525
  usecols = with_columns
452
- # TODO: support has_header=False
453
- header = 0
526
+ has_header = reader_options["has_header"]
527
+ header = 0 if has_header else -1
454
528
 
455
529
  # polars defaults to no null recognition
456
530
  null_values = [""]
@@ -470,6 +544,7 @@ class Scan(IR):
470
544
 
471
545
  # polars skips blank lines at the beginning of the file
472
546
  pieces = []
547
+ seen_paths = []
473
548
  read_partial = n_rows != -1
474
549
  for p in paths:
475
550
  skiprows = reader_options["skip_rows"]
@@ -480,7 +555,9 @@ class Scan(IR):
480
555
  options = (
481
556
  plc.io.csv.CsvReaderOptions.builder(plc.io.SourceInfo([path]))
482
557
  .nrows(n_rows)
483
- .skiprows(skiprows)
558
+ .skiprows(
559
+ skiprows if POLARS_VERSION_LT_128 else skiprows + skip_rows
560
+ ) # pragma: no cover
484
561
  .lineterminator(str(eol))
485
562
  .quotechar(str(quote))
486
563
  .decimal(decimal)
@@ -491,6 +568,13 @@ class Scan(IR):
491
568
  options.set_delimiter(str(sep))
492
569
  if column_names is not None:
493
570
  options.set_names([str(name) for name in column_names])
571
+ else:
572
+ if (
573
+ not POLARS_VERSION_LT_128 and header > -1 and skip_rows > header
574
+ ): # pragma: no cover
575
+ # We need to read the header otherwise we would skip it
576
+ column_names = read_csv_header(path, str(sep))
577
+ options.set_names(column_names)
494
578
  options.set_header(header)
495
579
  options.set_dtypes(schema)
496
580
  if usecols is not None:
@@ -500,6 +584,8 @@ class Scan(IR):
500
584
  options.set_comment(comment)
501
585
  tbl_w_meta = plc.io.csv.read_csv(options)
502
586
  pieces.append(tbl_w_meta)
587
+ if include_file_paths is not None:
588
+ seen_paths.append(p)
503
589
  if read_partial:
504
590
  n_rows -= tbl_w_meta.tbl.num_rows()
505
591
  if n_rows <= 0:
@@ -515,12 +601,26 @@ class Scan(IR):
515
601
  plc.concatenate.concatenate(list(tables)),
516
602
  colnames[0],
517
603
  )
604
+ if include_file_paths is not None:
605
+ df = Scan.add_file_paths(
606
+ include_file_paths,
607
+ seen_paths,
608
+ [t.num_rows() for t in tables],
609
+ df,
610
+ )
518
611
  elif typ == "parquet":
519
- parquet_options = config_options.get("parquet_options", {})
520
- if parquet_options.get("chunked", True):
521
- options = plc.io.parquet.ParquetReaderOptions.builder(
522
- plc.io.SourceInfo(paths)
523
- ).build()
612
+ filters = None
613
+ if predicate is not None and row_index is None:
614
+ # Can't apply filters during read if we have a row index.
615
+ filters = to_parquet_filter(predicate.value)
616
+ options = plc.io.parquet.ParquetReaderOptions.builder(
617
+ plc.io.SourceInfo(paths)
618
+ ).build()
619
+ if with_columns is not None:
620
+ options.set_columns(with_columns)
621
+ if filters is not None:
622
+ options.set_filter(filters)
623
+ if config_options.parquet_options.chunked:
524
624
  # We handle skip_rows != 0 by reading from the
525
625
  # up to n_rows + skip_rows and slicing off the
526
626
  # first skip_rows entries.
@@ -530,21 +630,15 @@ class Scan(IR):
530
630
  nrows = n_rows + skip_rows
531
631
  if nrows > -1:
532
632
  options.set_num_rows(nrows)
533
- if with_columns is not None:
534
- options.set_columns(with_columns)
535
633
  reader = plc.io.parquet.ChunkedParquetReader(
536
634
  options,
537
- chunk_read_limit=parquet_options.get(
538
- "chunk_read_limit", cls.PARQUET_DEFAULT_CHUNK_SIZE
539
- ),
540
- pass_read_limit=parquet_options.get(
541
- "pass_read_limit", cls.PARQUET_DEFAULT_PASS_LIMIT
542
- ),
635
+ chunk_read_limit=config_options.parquet_options.chunk_read_limit,
636
+ pass_read_limit=config_options.parquet_options.pass_read_limit,
543
637
  )
544
- chk = reader.read_chunk()
638
+ chunk = reader.read_chunk()
545
639
  rows_left_to_skip = skip_rows
546
640
 
547
- def slice_skip(tbl: plc.Table):
641
+ def slice_skip(tbl: plc.Table) -> plc.Table:
548
642
  nonlocal rows_left_to_skip
549
643
  if rows_left_to_skip > 0:
550
644
  table_rows = tbl.num_rows()
@@ -556,12 +650,13 @@ class Scan(IR):
556
650
  rows_left_to_skip -= chunk_skip
557
651
  return tbl
558
652
 
559
- tbl = slice_skip(chk.tbl)
653
+ tbl = slice_skip(chunk.tbl)
560
654
  # TODO: Nested column names
561
- names = chk.column_names(include_children=False)
655
+ names = chunk.column_names(include_children=False)
562
656
  concatenated_columns = tbl.columns()
563
657
  while reader.has_next():
564
- tbl = slice_skip(reader.read_chunk().tbl)
658
+ chunk = reader.read_chunk()
659
+ tbl = slice_skip(chunk.tbl)
565
660
 
566
661
  for i in range(tbl.num_columns()):
567
662
  concatenated_columns[i] = plc.concatenate.concatenate(
@@ -574,31 +669,28 @@ class Scan(IR):
574
669
  plc.Table(concatenated_columns),
575
670
  names=names,
576
671
  )
672
+ if include_file_paths is not None:
673
+ df = Scan.add_file_paths(
674
+ include_file_paths, paths, chunk.num_rows_per_source, df
675
+ )
577
676
  else:
578
- filters = None
579
- if predicate is not None and row_index is None:
580
- # Can't apply filters during read if we have a row index.
581
- filters = to_parquet_filter(predicate.value)
582
- options = plc.io.parquet.ParquetReaderOptions.builder(
583
- plc.io.SourceInfo(paths)
584
- ).build()
585
677
  if n_rows != -1:
586
678
  options.set_num_rows(n_rows)
587
679
  if skip_rows != 0:
588
680
  options.set_skip_rows(skip_rows)
589
- if with_columns is not None:
590
- options.set_columns(with_columns)
591
- if filters is not None:
592
- options.set_filter(filters)
593
681
  tbl_w_meta = plc.io.parquet.read_parquet(options)
594
682
  df = DataFrame.from_table(
595
683
  tbl_w_meta.tbl,
596
684
  # TODO: consider nested column names?
597
685
  tbl_w_meta.column_names(include_children=False),
598
686
  )
599
- if filters is not None:
600
- # Mask must have been applied.
601
- return df
687
+ if include_file_paths is not None:
688
+ df = Scan.add_file_paths(
689
+ include_file_paths, paths, tbl_w_meta.num_rows_per_source, df
690
+ )
691
+ if filters is not None:
692
+ # Mask must have been applied.
693
+ return df
602
694
 
603
695
  elif typ == "ndjson":
604
696
  json_schema: list[plc.io.json.NameAndType] = [
@@ -629,20 +721,18 @@ class Scan(IR):
629
721
  name, offset = row_index
630
722
  offset += skip_rows
631
723
  dtype = schema[name]
632
- step = plc.interop.from_arrow(
633
- pa.scalar(1, type=plc.interop.to_arrow(dtype))
634
- )
635
- init = plc.interop.from_arrow(
636
- pa.scalar(offset, type=plc.interop.to_arrow(dtype))
637
- )
638
- index = Column(
724
+ step = plc.Scalar.from_py(1, dtype)
725
+ init = plc.Scalar.from_py(offset, dtype)
726
+ index_col = Column(
639
727
  plc.filling.sequence(df.num_rows, init, step),
640
728
  is_sorted=plc.types.Sorted.YES,
641
729
  order=plc.types.Order.ASCENDING,
642
730
  null_order=plc.types.NullOrder.AFTER,
643
731
  name=name,
644
732
  )
645
- df = DataFrame([index, *df.columns])
733
+ df = DataFrame([index_col, *df.columns])
734
+ if next(iter(schema)) != name:
735
+ df = df.select(schema)
646
736
  assert all(c.obj.type() == schema[name] for name, c in df.column_map.items())
647
737
  if predicate is None:
648
738
  return df
@@ -651,6 +741,193 @@ class Scan(IR):
651
741
  return df.filter(mask)
652
742
 
653
743
 
744
+ class Sink(IR):
745
+ """Sink a dataframe to a file."""
746
+
747
+ __slots__ = ("cloud_options", "kind", "options", "path")
748
+ _non_child = ("schema", "kind", "path", "options", "cloud_options")
749
+
750
+ kind: str
751
+ path: str
752
+ options: dict[str, Any]
753
+
754
+ def __init__(
755
+ self,
756
+ schema: Schema,
757
+ kind: str,
758
+ path: str,
759
+ options: dict[str, Any],
760
+ cloud_options: dict[str, Any],
761
+ df: IR,
762
+ ):
763
+ self.schema = schema
764
+ self.kind = kind
765
+ self.path = path
766
+ self.options = options
767
+ self.cloud_options = cloud_options
768
+ self.children = (df,)
769
+ self._non_child_args = (schema, kind, path, options)
770
+ if self.cloud_options is not None and any(
771
+ self.cloud_options.get(k) is not None
772
+ for k in ("config", "credential_provider")
773
+ ):
774
+ raise NotImplementedError(
775
+ "Write to cloud storage"
776
+ ) # pragma: no cover; no test yet
777
+ sync_on_close = options.get("sync_on_close")
778
+ if sync_on_close not in {"None", None}:
779
+ raise NotImplementedError(
780
+ f"sync_on_close='{sync_on_close}' is not supported."
781
+ ) # pragma: no cover; no test yet
782
+ child_schema = df.schema.values()
783
+ if kind == "Csv":
784
+ if not all(
785
+ plc.io.csv.is_supported_write_csv(dtype) for dtype in child_schema
786
+ ):
787
+ # Nested types are unsupported in polars and libcudf
788
+ raise NotImplementedError(
789
+ "Contains unsupported types for CSV writing"
790
+ ) # pragma: no cover
791
+ serialize = options["serialize_options"]
792
+ if options["include_bom"]:
793
+ raise NotImplementedError("include_bom is not supported.")
794
+ for key in (
795
+ "date_format",
796
+ "time_format",
797
+ "datetime_format",
798
+ "float_scientific",
799
+ "float_precision",
800
+ ):
801
+ if serialize[key] is not None:
802
+ raise NotImplementedError(f"{key} is not supported.")
803
+ if serialize["quote_style"] != "Necessary":
804
+ raise NotImplementedError("Only quote_style='Necessary' is supported.")
805
+ if chr(serialize["quote_char"]) != '"':
806
+ raise NotImplementedError("Only quote_char='\"' is supported.")
807
+ elif kind == "Parquet":
808
+ compression = options["compression"]
809
+ if isinstance(compression, dict):
810
+ if len(compression) != 1:
811
+ raise NotImplementedError(
812
+ "Compression dict with more than one entry."
813
+ ) # pragma: no cover
814
+ compression, compression_level = next(iter(compression.items()))
815
+ options["compression"] = compression
816
+ if compression_level is not None:
817
+ raise NotImplementedError(
818
+ "Setting compression_level is not supported."
819
+ )
820
+ if compression == "Lz4Raw":
821
+ compression = "Lz4"
822
+ options["compression"] = compression
823
+ if (
824
+ compression != "Uncompressed"
825
+ and not plc.io.parquet.is_supported_write_parquet(
826
+ getattr(plc.io.types.CompressionType, compression.upper())
827
+ )
828
+ ):
829
+ raise NotImplementedError(
830
+ f"Compression type '{compression}' is not supported."
831
+ )
832
+ elif (
833
+ kind == "Json"
834
+ ): # pragma: no cover; options are validated on the polars side
835
+ if not all(
836
+ plc.io.json.is_supported_write_json(dtype) for dtype in child_schema
837
+ ):
838
+ # Nested types are unsupported in polars and libcudf
839
+ raise NotImplementedError(
840
+ "Contains unsupported types for JSON writing"
841
+ ) # pragma: no cover
842
+ shared_writer_options = {"sync_on_close", "maintain_order", "mkdir"}
843
+ if set(options) - shared_writer_options:
844
+ raise NotImplementedError("Unsupported options passed JSON writer.")
845
+ else:
846
+ raise NotImplementedError(
847
+ f"Unhandled sink kind: {kind}"
848
+ ) # pragma: no cover
849
+
850
+ def get_hashable(self) -> Hashable:
851
+ """
852
+ Hashable representation of the node.
853
+
854
+ The option dictionary is serialised for hashing purposes.
855
+ """
856
+ schema_hash = tuple(self.schema.items()) # pragma: no cover
857
+ return (
858
+ type(self),
859
+ schema_hash,
860
+ self.kind,
861
+ self.path,
862
+ json.dumps(self.options),
863
+ json.dumps(self.cloud_options),
864
+ ) # pragma: no cover
865
+
866
+ @classmethod
867
+ def do_evaluate(
868
+ cls,
869
+ schema: Schema,
870
+ kind: str,
871
+ path: str,
872
+ options: dict[str, Any],
873
+ df: DataFrame,
874
+ ) -> DataFrame:
875
+ """Write the dataframe to a file."""
876
+ target = plc.io.SinkInfo([path])
877
+
878
+ if options.get("mkdir", False):
879
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
880
+ if kind == "Csv":
881
+ serialize = options["serialize_options"]
882
+ options = (
883
+ plc.io.csv.CsvWriterOptions.builder(target, df.table)
884
+ .include_header(options["include_header"])
885
+ .names(df.column_names if options["include_header"] else [])
886
+ .na_rep(serialize["null"])
887
+ .line_terminator(serialize["line_terminator"])
888
+ .inter_column_delimiter(chr(serialize["separator"]))
889
+ .build()
890
+ )
891
+ plc.io.csv.write_csv(options)
892
+
893
+ elif kind == "Parquet":
894
+ metadata = plc.io.types.TableInputMetadata(df.table)
895
+ for i, name in enumerate(df.column_names):
896
+ metadata.column_metadata[i].set_name(name)
897
+
898
+ builder = plc.io.parquet.ParquetWriterOptions.builder(target, df.table)
899
+ compression = options["compression"]
900
+ if compression != "Uncompressed":
901
+ builder.compression(
902
+ getattr(plc.io.types.CompressionType, compression.upper())
903
+ )
904
+
905
+ writer_options = builder.metadata(metadata).build()
906
+ if options["data_page_size"] is not None:
907
+ writer_options.set_max_page_size_bytes(options["data_page_size"])
908
+ if options["row_group_size"] is not None:
909
+ writer_options.set_row_group_size_rows(options["row_group_size"])
910
+
911
+ plc.io.parquet.write_parquet(writer_options)
912
+
913
+ elif kind == "Json":
914
+ metadata = plc.io.TableWithMetadata(
915
+ df.table, [(col, []) for col in df.column_names]
916
+ )
917
+ options = (
918
+ plc.io.json.JsonWriterOptions.builder(target, df.table)
919
+ .lines(val=True)
920
+ .na_rep("null")
921
+ .include_nulls(val=True)
922
+ .metadata(metadata)
923
+ .utf8_escaped(val=False)
924
+ .build()
925
+ )
926
+ plc.io.json.write_json(options)
927
+
928
+ return DataFrame([])
929
+
930
+
654
931
  class Cache(IR):
655
932
  """
656
933
  Return a cached plan node.
@@ -658,35 +935,59 @@ class Cache(IR):
658
935
  Used for CSE at the plan level.
659
936
  """
660
937
 
661
- __slots__ = ("key",)
662
- _non_child = ("schema", "key")
938
+ __slots__ = ("key", "refcount")
939
+ _non_child = ("schema", "key", "refcount")
663
940
  key: int
664
941
  """The cache key."""
942
+ refcount: int
943
+ """The number of cache hits."""
665
944
 
666
- def __init__(self, schema: Schema, key: int, value: IR):
945
+ def __init__(self, schema: Schema, key: int, refcount: int, value: IR):
667
946
  self.schema = schema
668
947
  self.key = key
948
+ self.refcount = refcount
669
949
  self.children = (value,)
670
- self._non_child_args = (key,)
950
+ self._non_child_args = (key, refcount)
951
+
952
+ def get_hashable(self) -> Hashable: # noqa: D102
953
+ # Polars arranges that the keys are unique across all cache
954
+ # nodes that reference the same child, so we don't need to
955
+ # hash the child.
956
+ return (type(self), self.key, self.refcount)
957
+
958
+ def is_equal(self, other: Self) -> bool: # noqa: D102
959
+ if self.key == other.key and self.refcount == other.refcount:
960
+ self.children = other.children
961
+ return True
962
+ return False
671
963
 
672
964
  @classmethod
673
965
  def do_evaluate(
674
- cls, key: int, df: DataFrame
966
+ cls, key: int, refcount: int, df: DataFrame
675
967
  ) -> DataFrame: # pragma: no cover; basic evaluation never calls this
676
968
  """Evaluate and return a dataframe."""
677
969
  # Our value has already been computed for us, so let's just
678
970
  # return it.
679
971
  return df
680
972
 
681
- def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
973
+ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
682
974
  """Evaluate and return a dataframe."""
683
975
  # We must override the recursion scheme because we don't want
684
976
  # to recurse if we're in the cache.
685
977
  try:
686
- return cache[self.key]
978
+ (result, hits) = cache[self.key]
687
979
  except KeyError:
688
980
  (value,) = self.children
689
- return cache.setdefault(self.key, value.evaluate(cache=cache))
981
+ result = value.evaluate(cache=cache, timer=timer)
982
+ cache[self.key] = (result, 0)
983
+ return result
984
+ else:
985
+ hits += 1
986
+ if hits == self.refcount:
987
+ del cache[self.key]
988
+ else:
989
+ cache[self.key] = (result, hits)
990
+ return result
690
991
 
691
992
 
692
993
  class DataFrameScan(IR):
@@ -696,13 +997,13 @@ class DataFrameScan(IR):
696
997
  This typically arises from ``q.collect().lazy()``
697
998
  """
698
999
 
699
- __slots__ = ("config_options", "df", "projection")
1000
+ __slots__ = ("_id_for_hash", "config_options", "df", "projection")
700
1001
  _non_child = ("schema", "df", "projection", "config_options")
701
1002
  df: Any
702
- """Polars LazyFrame object."""
1003
+ """Polars internal PyDataFrame object."""
703
1004
  projection: tuple[str, ...] | None
704
1005
  """List of columns to project out."""
705
- config_options: dict[str, Any]
1006
+ config_options: ConfigOptions
706
1007
  """GPU-specific configuration options"""
707
1008
 
708
1009
  def __init__(
@@ -710,29 +1011,35 @@ class DataFrameScan(IR):
710
1011
  schema: Schema,
711
1012
  df: Any,
712
1013
  projection: Sequence[str] | None,
713
- config_options: dict[str, Any],
1014
+ config_options: ConfigOptions,
714
1015
  ):
715
1016
  self.schema = schema
716
1017
  self.df = df
717
1018
  self.projection = tuple(projection) if projection is not None else None
718
1019
  self.config_options = config_options
719
- self._non_child_args = (schema, df, self.projection)
1020
+ self._non_child_args = (
1021
+ schema,
1022
+ pl.DataFrame._from_pydf(df),
1023
+ self.projection,
1024
+ )
720
1025
  self.children = ()
1026
+ self._id_for_hash = random.randint(0, 2**64 - 1)
721
1027
 
722
1028
  def get_hashable(self) -> Hashable:
723
1029
  """
724
1030
  Hashable representation of the node.
725
1031
 
726
- The (heavy) dataframe object is hashed as its id, so this is
727
- not stable across runs, or repeat instances of the same equal dataframes.
1032
+ The (heavy) dataframe object is not hashed. No two instances of
1033
+ ``DataFrameScan`` will have the same hash, even if they have the
1034
+ same schema, projection, and config options, and data.
728
1035
  """
729
1036
  schema_hash = tuple(self.schema.items())
730
1037
  return (
731
1038
  type(self),
732
1039
  schema_hash,
733
- id(self.df),
1040
+ self._id_for_hash,
734
1041
  self.projection,
735
- json.dumps(self.config_options),
1042
+ self.config_options,
736
1043
  )
737
1044
 
738
1045
  @classmethod
@@ -743,10 +1050,9 @@ class DataFrameScan(IR):
743
1050
  projection: tuple[str, ...] | None,
744
1051
  ) -> DataFrame:
745
1052
  """Evaluate and return a dataframe."""
746
- pdf = pl.DataFrame._from_pydf(df)
747
1053
  if projection is not None:
748
- pdf = pdf.select(projection)
749
- df = DataFrame.from_polars(pdf)
1054
+ df = df.select(projection)
1055
+ df = DataFrame.from_polars(df)
750
1056
  assert all(
751
1057
  c.obj.type() == dtype
752
1058
  for c, dtype in zip(df.columns, schema.values(), strict=True)
@@ -820,29 +1126,191 @@ class Reduce(IR):
820
1126
  ) -> DataFrame: # pragma: no cover; not exposed by polars yet
821
1127
  """Evaluate and return a dataframe."""
822
1128
  columns = broadcast(*(e.evaluate(df) for e in exprs))
823
- assert all(column.obj.size() == 1 for column in columns)
1129
+ assert all(column.size == 1 for column in columns)
824
1130
  return DataFrame(columns)
825
1131
 
826
1132
 
1133
+ class Rolling(IR):
1134
+ """Perform a (possibly grouped) rolling aggregation."""
1135
+
1136
+ __slots__ = (
1137
+ "agg_requests",
1138
+ "closed_window",
1139
+ "following",
1140
+ "index",
1141
+ "keys",
1142
+ "preceding",
1143
+ "zlice",
1144
+ )
1145
+ _non_child = (
1146
+ "schema",
1147
+ "index",
1148
+ "preceding",
1149
+ "following",
1150
+ "closed_window",
1151
+ "keys",
1152
+ "agg_requests",
1153
+ "zlice",
1154
+ )
1155
+ index: expr.NamedExpr
1156
+ """Column being rolled over."""
1157
+ preceding: plc.Scalar
1158
+ """Preceding window extent defining start of window."""
1159
+ following: plc.Scalar
1160
+ """Following window extent defining end of window."""
1161
+ closed_window: ClosedInterval
1162
+ """Treatment of window endpoints."""
1163
+ keys: tuple[expr.NamedExpr, ...]
1164
+ """Grouping keys."""
1165
+ agg_requests: tuple[expr.NamedExpr, ...]
1166
+ """Aggregation expressions."""
1167
+ zlice: Zlice | None
1168
+ """Optional slice"""
1169
+
1170
+ def __init__(
1171
+ self,
1172
+ schema: Schema,
1173
+ index: expr.NamedExpr,
1174
+ preceding: plc.Scalar,
1175
+ following: plc.Scalar,
1176
+ closed_window: ClosedInterval,
1177
+ keys: Sequence[expr.NamedExpr],
1178
+ agg_requests: Sequence[expr.NamedExpr],
1179
+ zlice: Zlice | None,
1180
+ df: IR,
1181
+ ):
1182
+ self.schema = schema
1183
+ self.index = index
1184
+ self.preceding = preceding
1185
+ self.following = following
1186
+ self.closed_window = closed_window
1187
+ self.keys = tuple(keys)
1188
+ self.agg_requests = tuple(agg_requests)
1189
+ if not all(
1190
+ plc.rolling.is_valid_rolling_aggregation(
1191
+ agg.value.dtype, agg.value.agg_request
1192
+ )
1193
+ for agg in self.agg_requests
1194
+ ):
1195
+ raise NotImplementedError("Unsupported rolling aggregation")
1196
+ if any(
1197
+ agg.value.agg_request.kind() == plc.aggregation.Kind.COLLECT_LIST
1198
+ for agg in self.agg_requests
1199
+ ):
1200
+ raise NotImplementedError(
1201
+ "Incorrect handling of empty groups for list collection"
1202
+ )
1203
+
1204
+ self.zlice = zlice
1205
+ self.children = (df,)
1206
+ self._non_child_args = (
1207
+ index,
1208
+ preceding,
1209
+ following,
1210
+ closed_window,
1211
+ keys,
1212
+ agg_requests,
1213
+ zlice,
1214
+ )
1215
+
1216
+ @classmethod
1217
+ def do_evaluate(
1218
+ cls,
1219
+ index: expr.NamedExpr,
1220
+ preceding: plc.Scalar,
1221
+ following: plc.Scalar,
1222
+ closed_window: ClosedInterval,
1223
+ keys_in: Sequence[expr.NamedExpr],
1224
+ aggs: Sequence[expr.NamedExpr],
1225
+ zlice: Zlice | None,
1226
+ df: DataFrame,
1227
+ ) -> DataFrame:
1228
+ """Evaluate and return a dataframe."""
1229
+ keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
1230
+ orderby = index.evaluate(df)
1231
+ # Polars casts integral orderby to int64, but only for calculating window bounds
1232
+ if (
1233
+ plc.traits.is_integral(orderby.obj.type())
1234
+ and orderby.obj.type().id() != plc.TypeId.INT64
1235
+ ):
1236
+ orderby_obj = plc.unary.cast(orderby.obj, plc.DataType(plc.TypeId.INT64))
1237
+ else:
1238
+ orderby_obj = orderby.obj
1239
+ preceding_window, following_window = range_window_bounds(
1240
+ preceding, following, closed_window
1241
+ )
1242
+ if orderby.obj.null_count() != 0:
1243
+ raise RuntimeError(
1244
+ f"Index column '{index.name}' in rolling may not contain nulls"
1245
+ )
1246
+ if len(keys_in) > 0:
1247
+ # Must always check sortedness
1248
+ table = plc.Table([*(k.obj for k in keys), orderby_obj])
1249
+ n = table.num_columns()
1250
+ if not plc.sorting.is_sorted(
1251
+ table, [plc.types.Order.ASCENDING] * n, [plc.types.NullOrder.BEFORE] * n
1252
+ ):
1253
+ raise RuntimeError("Input for grouped rolling is not sorted")
1254
+ else:
1255
+ if not orderby.check_sorted(
1256
+ order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE
1257
+ ):
1258
+ raise RuntimeError(
1259
+ f"Index column '{index.name}' in rolling is not sorted, please sort first"
1260
+ )
1261
+ values = plc.rolling.grouped_range_rolling_window(
1262
+ plc.Table([k.obj for k in keys]),
1263
+ orderby_obj,
1264
+ plc.types.Order.ASCENDING, # Polars requires ascending orderby.
1265
+ plc.types.NullOrder.BEFORE, # Doesn't matter, polars doesn't allow nulls in orderby
1266
+ preceding_window,
1267
+ following_window,
1268
+ [rolling.to_request(request.value, orderby, df) for request in aggs],
1269
+ )
1270
+ return DataFrame(
1271
+ itertools.chain(
1272
+ keys,
1273
+ [orderby],
1274
+ (
1275
+ Column(col, name=name)
1276
+ for col, name in zip(
1277
+ values.columns(),
1278
+ (request.name for request in aggs),
1279
+ strict=True,
1280
+ )
1281
+ ),
1282
+ )
1283
+ ).slice(zlice)
1284
+
1285
+
827
1286
  class GroupBy(IR):
828
1287
  """Perform a groupby."""
829
1288
 
830
1289
  __slots__ = (
831
- "agg_infos",
832
1290
  "agg_requests",
1291
+ "config_options",
833
1292
  "keys",
834
1293
  "maintain_order",
835
- "options",
1294
+ "zlice",
1295
+ )
1296
+ _non_child = (
1297
+ "schema",
1298
+ "keys",
1299
+ "agg_requests",
1300
+ "maintain_order",
1301
+ "zlice",
1302
+ "config_options",
836
1303
  )
837
- _non_child = ("schema", "keys", "agg_requests", "maintain_order", "options")
838
1304
  keys: tuple[expr.NamedExpr, ...]
839
1305
  """Grouping keys."""
840
1306
  agg_requests: tuple[expr.NamedExpr, ...]
841
1307
  """Aggregation expressions."""
842
1308
  maintain_order: bool
843
1309
  """Preserve order in groupby."""
844
- options: Any
845
- """Arbitrary options."""
1310
+ zlice: Zlice | None
1311
+ """Optional slice to apply after grouping."""
1312
+ config_options: ConfigOptions
1313
+ """GPU-specific configuration options"""
846
1314
 
847
1315
  def __init__(
848
1316
  self,
@@ -850,70 +1318,33 @@ class GroupBy(IR):
850
1318
  keys: Sequence[expr.NamedExpr],
851
1319
  agg_requests: Sequence[expr.NamedExpr],
852
1320
  maintain_order: bool, # noqa: FBT001
853
- options: Any,
1321
+ zlice: Zlice | None,
1322
+ config_options: ConfigOptions,
854
1323
  df: IR,
855
1324
  ):
856
1325
  self.schema = schema
857
1326
  self.keys = tuple(keys)
858
1327
  self.agg_requests = tuple(agg_requests)
859
1328
  self.maintain_order = maintain_order
860
- self.options = options
1329
+ self.zlice = zlice
1330
+ self.config_options = config_options
861
1331
  self.children = (df,)
862
- if self.options.rolling:
863
- raise NotImplementedError(
864
- "rolling window/groupby"
865
- ) # pragma: no cover; rollingwindow constructor has already raised
866
- if self.options.dynamic:
867
- raise NotImplementedError("dynamic group by")
868
- if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests):
869
- raise NotImplementedError("Nested aggregations in groupby")
870
- self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests]
871
1332
  self._non_child_args = (
872
1333
  self.keys,
873
1334
  self.agg_requests,
874
1335
  maintain_order,
875
- options,
876
- self.agg_infos,
1336
+ self.zlice,
877
1337
  )
878
1338
 
879
- @staticmethod
880
- def check_agg(agg: expr.Expr) -> int:
881
- """
882
- Determine if we can handle an aggregation expression.
883
-
884
- Parameters
885
- ----------
886
- agg
887
- Expression to check
888
-
889
- Returns
890
- -------
891
- depth of nesting
892
-
893
- Raises
894
- ------
895
- NotImplementedError
896
- For unsupported expression nodes.
897
- """
898
- if isinstance(agg, (expr.BinOp, expr.Cast, expr.UnaryFunction)):
899
- return max(GroupBy.check_agg(child) for child in agg.children)
900
- elif isinstance(agg, expr.Agg):
901
- return 1 + max(GroupBy.check_agg(child) for child in agg.children)
902
- elif isinstance(agg, (expr.Len, expr.Col, expr.Literal, expr.LiteralColumn)):
903
- return 0
904
- else:
905
- raise NotImplementedError(f"No handler for {agg=}")
906
-
907
1339
  @classmethod
908
1340
  def do_evaluate(
909
1341
  cls,
910
1342
  keys_in: Sequence[expr.NamedExpr],
911
1343
  agg_requests: Sequence[expr.NamedExpr],
912
1344
  maintain_order: bool, # noqa: FBT001
913
- options: Any,
914
- agg_infos: Sequence[expr.AggInfo],
1345
+ zlice: Zlice | None,
915
1346
  df: DataFrame,
916
- ):
1347
+ ) -> DataFrame:
917
1348
  """Evaluate and return a dataframe."""
918
1349
  keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
919
1350
  sorted = (
@@ -928,32 +1359,38 @@ class GroupBy(IR):
928
1359
  column_order=[k.order for k in keys],
929
1360
  null_precedence=[k.null_order for k in keys],
930
1361
  )
931
- # TODO: uniquify
932
1362
  requests = []
933
- replacements: list[expr.Expr] = []
934
- for info in agg_infos:
935
- for pre_eval, req, rep in info.requests:
936
- if pre_eval is None:
937
- # A count aggregation, doesn't touch the column,
938
- # but we need to have one. Rather than evaluating
939
- # one, just use one of the key columns.
940
- col = keys[0].obj
1363
+ names = []
1364
+ for request in agg_requests:
1365
+ name = request.name
1366
+ value = request.value
1367
+ if isinstance(value, expr.Len):
1368
+ # A count aggregation, we need a column so use a key column
1369
+ col = keys[0].obj
1370
+ elif isinstance(value, expr.Agg):
1371
+ if value.name == "quantile":
1372
+ child = value.children[0]
941
1373
  else:
942
- col = pre_eval.evaluate(df).obj
943
- requests.append(plc.groupby.GroupByRequest(col, [req]))
944
- replacements.append(rep)
1374
+ (child,) = value.children
1375
+ col = child.evaluate(df, context=ExecutionContext.GROUPBY).obj
1376
+ else:
1377
+ # Anything else, we pre-evaluate
1378
+ col = value.evaluate(df, context=ExecutionContext.GROUPBY).obj
1379
+ requests.append(plc.groupby.GroupByRequest(col, [value.agg_request]))
1380
+ names.append(name)
945
1381
  group_keys, raw_tables = grouper.aggregate(requests)
946
- raw_columns: list[Column] = []
947
- for i, table in enumerate(raw_tables):
948
- (column,) = table.columns()
949
- raw_columns.append(Column(column, name=f"tmp{i}"))
950
- mapping = dict(zip(replacements, raw_columns, strict=True))
1382
+ results = [
1383
+ Column(column, name=name)
1384
+ for name, column in zip(
1385
+ names,
1386
+ itertools.chain.from_iterable(t.columns() for t in raw_tables),
1387
+ strict=True,
1388
+ )
1389
+ ]
951
1390
  result_keys = [
952
1391
  Column(grouped_key, name=key.name)
953
1392
  for key, grouped_key in zip(keys, group_keys.columns(), strict=True)
954
1393
  ]
955
- result_subs = DataFrame(raw_columns)
956
- results = [req.evaluate(result_subs, mapping=mapping) for req in agg_requests]
957
1394
  broadcasted = broadcast(*result_keys, *results)
958
1395
  # Handle order preservation of groups
959
1396
  if maintain_order and not sorted:
@@ -996,12 +1433,26 @@ class GroupBy(IR):
996
1433
  ordered_table.columns(), broadcasted, strict=True
997
1434
  )
998
1435
  ]
999
- return DataFrame(broadcasted).slice(options.slice)
1436
+ return DataFrame(broadcasted).slice(zlice)
1000
1437
 
1001
1438
 
1002
1439
  class ConditionalJoin(IR):
1003
1440
  """A conditional inner join of two dataframes on a predicate."""
1004
1441
 
1442
+ class Predicate:
1443
+ """Serializable wrapper for a predicate expression."""
1444
+
1445
+ predicate: expr.Expr
1446
+ ast: plc.expressions.Expression
1447
+
1448
+ def __init__(self, predicate: expr.Expr):
1449
+ self.predicate = predicate
1450
+ self.ast = to_ast(predicate)
1451
+
1452
+ def __reduce__(self) -> tuple[Any, ...]:
1453
+ """Pickle a Predicate object."""
1454
+ return (type(self), (self.predicate,))
1455
+
1005
1456
  __slots__ = ("ast_predicate", "options", "predicate")
1006
1457
  _non_child = ("schema", "predicate", "options")
1007
1458
  predicate: expr.Expr
@@ -1012,7 +1463,7 @@ class ConditionalJoin(IR):
1012
1463
  pl_expr.Operator | Iterable[pl_expr.Operator],
1013
1464
  ],
1014
1465
  bool,
1015
- tuple[int, int] | None,
1466
+ Zlice | None,
1016
1467
  str,
1017
1468
  bool,
1018
1469
  Literal["none", "left", "right", "left_right", "right_left"],
@@ -1020,7 +1471,7 @@ class ConditionalJoin(IR):
1020
1471
  """
1021
1472
  tuple of options:
1022
1473
  - predicates: tuple of ir join type (eg. ie_join) and (In)Equality conditions
1023
- - join_nulls: do nulls compare equal?
1474
+ - nulls_equal: do nulls compare equal?
1024
1475
  - slice: optional slice to perform after joining.
1025
1476
  - suffix: string suffix for right columns if names match
1026
1477
  - coalesce: should key columns be coalesced (only makes sense for outer joins)
@@ -1034,30 +1485,34 @@ class ConditionalJoin(IR):
1034
1485
  self.predicate = predicate
1035
1486
  self.options = options
1036
1487
  self.children = (left, right)
1037
- self.ast_predicate = to_ast(predicate)
1038
- _, join_nulls, zlice, suffix, coalesce, maintain_order = self.options
1488
+ predicate_wrapper = self.Predicate(predicate)
1489
+ _, nulls_equal, zlice, suffix, coalesce, maintain_order = self.options
1039
1490
  # Preconditions from polars
1040
- assert not join_nulls
1491
+ assert not nulls_equal
1041
1492
  assert not coalesce
1042
1493
  assert maintain_order == "none"
1043
- if self.ast_predicate is None:
1494
+ if predicate_wrapper.ast is None:
1044
1495
  raise NotImplementedError(
1045
1496
  f"Conditional join with predicate {predicate}"
1046
1497
  ) # pragma: no cover; polars never delivers expressions we can't handle
1047
- self._non_child_args = (self.ast_predicate, zlice, suffix, maintain_order)
1498
+ self._non_child_args = (predicate_wrapper, zlice, suffix, maintain_order)
1048
1499
 
1049
1500
  @classmethod
1050
1501
  def do_evaluate(
1051
1502
  cls,
1052
- predicate: plc.expressions.Expression,
1053
- zlice: tuple[int, int] | None,
1503
+ predicate_wrapper: Predicate,
1504
+ zlice: Zlice | None,
1054
1505
  suffix: str,
1055
1506
  maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
1056
1507
  left: DataFrame,
1057
1508
  right: DataFrame,
1058
1509
  ) -> DataFrame:
1059
1510
  """Evaluate and return a dataframe."""
1060
- lg, rg = plc.join.conditional_inner_join(left.table, right.table, predicate)
1511
+ lg, rg = plc.join.conditional_inner_join(
1512
+ left.table,
1513
+ right.table,
1514
+ predicate_wrapper.ast,
1515
+ )
1061
1516
  left = DataFrame.from_table(
1062
1517
  plc.copying.gather(
1063
1518
  left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
@@ -1084,8 +1539,8 @@ class ConditionalJoin(IR):
1084
1539
  class Join(IR):
1085
1540
  """A join of two dataframes."""
1086
1541
 
1087
- __slots__ = ("left_on", "options", "right_on")
1088
- _non_child = ("schema", "left_on", "right_on", "options")
1542
+ __slots__ = ("config_options", "left_on", "options", "right_on")
1543
+ _non_child = ("schema", "left_on", "right_on", "options", "config_options")
1089
1544
  left_on: tuple[expr.NamedExpr, ...]
1090
1545
  """List of expressions used as keys in the left frame."""
1091
1546
  right_on: tuple[expr.NamedExpr, ...]
@@ -1093,7 +1548,7 @@ class Join(IR):
1093
1548
  options: tuple[
1094
1549
  Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
1095
1550
  bool,
1096
- tuple[int, int] | None,
1551
+ Zlice | None,
1097
1552
  str,
1098
1553
  bool,
1099
1554
  Literal["none", "left", "right", "left_right", "right_left"],
@@ -1101,12 +1556,14 @@ class Join(IR):
1101
1556
  """
1102
1557
  tuple of options:
1103
1558
  - how: join type
1104
- - join_nulls: do nulls compare equal?
1559
+ - nulls_equal: do nulls compare equal?
1105
1560
  - slice: optional slice to perform after joining.
1106
1561
  - suffix: string suffix for right columns if names match
1107
1562
  - coalesce: should key columns be coalesced (only makes sense for outer joins)
1108
1563
  - maintain_order: which DataFrame row order to preserve, if any
1109
1564
  """
1565
+ config_options: ConfigOptions
1566
+ """GPU-specific configuration options"""
1110
1567
 
1111
1568
  def __init__(
1112
1569
  self,
@@ -1114,6 +1571,7 @@ class Join(IR):
1114
1571
  left_on: Sequence[expr.NamedExpr],
1115
1572
  right_on: Sequence[expr.NamedExpr],
1116
1573
  options: Any,
1574
+ config_options: ConfigOptions,
1117
1575
  left: IR,
1118
1576
  right: IR,
1119
1577
  ):
@@ -1121,6 +1579,7 @@ class Join(IR):
1121
1579
  self.left_on = tuple(left_on)
1122
1580
  self.right_on = tuple(right_on)
1123
1581
  self.options = options
1582
+ self.config_options = config_options
1124
1583
  self.children = (left, right)
1125
1584
  self._non_child_args = (self.left_on, self.right_on, self.options)
1126
1585
  # TODO: Implement maintain_order
@@ -1203,9 +1662,8 @@ class Join(IR):
1203
1662
  left keys, and is stable wrt the right keys. For all other
1204
1663
  joins, there is no order obligation.
1205
1664
  """
1206
- dt = plc.interop.to_arrow(plc.types.SIZE_TYPE)
1207
- init = plc.interop.from_arrow(pa.scalar(0, type=dt))
1208
- step = plc.interop.from_arrow(pa.scalar(1, type=dt))
1665
+ init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE)
1666
+ step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE)
1209
1667
  left_order = plc.copying.gather(
1210
1668
  plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
1211
1669
  )
@@ -1227,7 +1685,7 @@ class Join(IR):
1227
1685
  options: tuple[
1228
1686
  Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
1229
1687
  bool,
1230
- tuple[int, int] | None,
1688
+ Zlice | None,
1231
1689
  str,
1232
1690
  bool,
1233
1691
  Literal["none", "left", "right", "left_right", "right_left"],
@@ -1236,7 +1694,7 @@ class Join(IR):
1236
1694
  right: DataFrame,
1237
1695
  ) -> DataFrame:
1238
1696
  """Evaluate and return a dataframe."""
1239
- how, join_nulls, zlice, suffix, coalesce, _ = options
1697
+ how, nulls_equal, zlice, suffix, coalesce, _ = options
1240
1698
  if how == "Cross":
1241
1699
  # Separate implementation, since cross_join returns the
1242
1700
  # result, not the gather maps
@@ -1264,7 +1722,7 @@ class Join(IR):
1264
1722
  right_on = DataFrame(broadcast(*(e.evaluate(right) for e in right_on_exprs)))
1265
1723
  null_equality = (
1266
1724
  plc.types.NullEquality.EQUAL
1267
- if join_nulls
1725
+ if nulls_equal
1268
1726
  else plc.types.NullEquality.UNEQUAL
1269
1727
  )
1270
1728
  join_fn, left_policy, right_policy = cls._joiners(how)
@@ -1385,7 +1843,7 @@ class Distinct(IR):
1385
1843
  subset: frozenset[str] | None
1386
1844
  """Which columns should be used to define distinctness. If None,
1387
1845
  then all columns are used."""
1388
- zlice: tuple[int, int] | None
1846
+ zlice: Zlice | None
1389
1847
  """Optional slice to apply to the result."""
1390
1848
  stable: bool
1391
1849
  """Should the result maintain ordering."""
@@ -1395,7 +1853,7 @@ class Distinct(IR):
1395
1853
  schema: Schema,
1396
1854
  keep: plc.stream_compaction.DuplicateKeepOption,
1397
1855
  subset: frozenset[str] | None,
1398
- zlice: tuple[int, int] | None,
1856
+ zlice: Zlice | None,
1399
1857
  stable: bool, # noqa: FBT001
1400
1858
  df: IR,
1401
1859
  ):
@@ -1419,10 +1877,10 @@ class Distinct(IR):
1419
1877
  cls,
1420
1878
  keep: plc.stream_compaction.DuplicateKeepOption,
1421
1879
  subset: frozenset[str] | None,
1422
- zlice: tuple[int, int] | None,
1880
+ zlice: Zlice | None,
1423
1881
  stable: bool, # noqa: FBT001
1424
1882
  df: DataFrame,
1425
- ):
1883
+ ) -> DataFrame:
1426
1884
  """Evaluate and return a dataframe."""
1427
1885
  if subset is None:
1428
1886
  indices = list(range(df.num_columns))
@@ -1475,7 +1933,7 @@ class Sort(IR):
1475
1933
  """Null sorting location for each sort key."""
1476
1934
  stable: bool
1477
1935
  """Should the sort be stable?"""
1478
- zlice: tuple[int, int] | None
1936
+ zlice: Zlice | None
1479
1937
  """Optional slice to apply to the result."""
1480
1938
 
1481
1939
  def __init__(
@@ -1485,7 +1943,7 @@ class Sort(IR):
1485
1943
  order: Sequence[plc.types.Order],
1486
1944
  null_order: Sequence[plc.types.NullOrder],
1487
1945
  stable: bool, # noqa: FBT001
1488
- zlice: tuple[int, int] | None,
1946
+ zlice: Zlice | None,
1489
1947
  df: IR,
1490
1948
  ):
1491
1949
  self.schema = schema
@@ -1510,17 +1968,11 @@ class Sort(IR):
1510
1968
  order: Sequence[plc.types.Order],
1511
1969
  null_order: Sequence[plc.types.NullOrder],
1512
1970
  stable: bool, # noqa: FBT001
1513
- zlice: tuple[int, int] | None,
1971
+ zlice: Zlice | None,
1514
1972
  df: DataFrame,
1515
1973
  ) -> DataFrame:
1516
1974
  """Evaluate and return a dataframe."""
1517
1975
  sort_keys = broadcast(*(k.evaluate(df) for k in by), target_length=df.num_rows)
1518
- # TODO: More robust identification here.
1519
- keys_in_result = {
1520
- k.name: i
1521
- for i, k in enumerate(sort_keys)
1522
- if k.name in df.column_map and k.obj is df.column_map[k.name].obj
1523
- }
1524
1976
  do_sort = plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
1525
1977
  table = do_sort(
1526
1978
  df.table,
@@ -1528,19 +1980,17 @@ class Sort(IR):
1528
1980
  list(order),
1529
1981
  list(null_order),
1530
1982
  )
1531
- columns: list[Column] = []
1532
- for name, c in zip(df.column_map, table.columns(), strict=True):
1533
- column = Column(c, name=name)
1534
- # If a sort key is in the result table, set the sortedness property
1535
- if name in keys_in_result:
1536
- i = keys_in_result[name]
1537
- column = column.set_sorted(
1538
- is_sorted=plc.types.Sorted.YES,
1539
- order=order[i],
1540
- null_order=null_order[i],
1541
- )
1542
- columns.append(column)
1543
- return DataFrame(columns).slice(zlice)
1983
+ result = DataFrame.from_table(table, df.column_names)
1984
+ first_key = sort_keys[0]
1985
+ name = by[0].name
1986
+ first_key_in_result = (
1987
+ name in df.column_map and first_key.obj is df.column_map[name].obj
1988
+ )
1989
+ if first_key_in_result:
1990
+ result.column_map[name].set_sorted(
1991
+ is_sorted=plc.types.Sorted.YES, order=order[0], null_order=null_order[0]
1992
+ )
1993
+ return result.slice(zlice)
1544
1994
 
1545
1995
 
1546
1996
  class Slice(IR):
@@ -1608,6 +2058,42 @@ class Projection(IR):
1608
2058
  return DataFrame(columns)
1609
2059
 
1610
2060
 
2061
+ class MergeSorted(IR):
2062
+ """Merge sorted operation."""
2063
+
2064
+ __slots__ = ("key",)
2065
+ _non_child = ("schema", "key")
2066
+ key: str
2067
+ """Key that is sorted."""
2068
+
2069
+ def __init__(self, schema: Schema, key: str, left: IR, right: IR):
2070
+ assert isinstance(left, Sort)
2071
+ assert isinstance(right, Sort)
2072
+ assert left.order == right.order
2073
+ assert len(left.schema.keys()) <= len(right.schema.keys())
2074
+ self.schema = schema
2075
+ self.key = key
2076
+ self.children = (left, right)
2077
+ self._non_child_args = (key,)
2078
+
2079
+ @classmethod
2080
+ def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
2081
+ """Evaluate and return a dataframe."""
2082
+ left, right = dfs
2083
+ right = right.discard_columns(right.column_names_set - left.column_names_set)
2084
+ on_col_left = left.select_columns({key})[0]
2085
+ on_col_right = right.select_columns({key})[0]
2086
+ return DataFrame.from_table(
2087
+ plc.merge.merge(
2088
+ [right.table, left.table],
2089
+ [left.column_names.index(key), right.column_names.index(key)],
2090
+ [on_col_left.order, on_col_right.order],
2091
+ [on_col_left.null_order, on_col_right.null_order],
2092
+ ),
2093
+ left.column_names,
2094
+ )
2095
+
2096
+
1611
2097
  class MapFunction(IR):
1612
2098
  """Apply some function to a dataframe."""
1613
2099
 
@@ -1621,13 +2107,10 @@ class MapFunction(IR):
1621
2107
  _NAMES: ClassVar[frozenset[str]] = frozenset(
1622
2108
  [
1623
2109
  "rechunk",
1624
- # libcudf merge is not stable wrt order of inputs, since
1625
- # it uses a priority queue to manage the tables it produces.
1626
- # See: https://github.com/rapidsai/cudf/issues/16010
1627
- # "merge_sorted",
1628
2110
  "rename",
1629
2111
  "explode",
1630
2112
  "unpivot",
2113
+ "row_index",
1631
2114
  ]
1632
2115
  )
1633
2116
 
@@ -1636,8 +2119,12 @@ class MapFunction(IR):
1636
2119
  self.name = name
1637
2120
  self.options = options
1638
2121
  self.children = (df,)
1639
- if self.name not in MapFunction._NAMES:
1640
- raise NotImplementedError(f"Unhandled map function {self.name}")
2122
+ if (
2123
+ self.name not in MapFunction._NAMES
2124
+ ): # pragma: no cover; need more polars rust functions
2125
+ raise NotImplementedError(
2126
+ f"Unhandled map function {self.name}"
2127
+ ) # pragma: no cover
1641
2128
  if self.name == "explode":
1642
2129
  (to_explode,) = self.options
1643
2130
  if len(to_explode) > 1:
@@ -1674,6 +2161,9 @@ class MapFunction(IR):
1674
2161
  variable_name,
1675
2162
  value_name,
1676
2163
  )
2164
+ elif self.name == "row_index":
2165
+ col_name, offset = options
2166
+ self.options = (col_name, offset)
1677
2167
  self._non_child_args = (schema, name, self.options)
1678
2168
 
1679
2169
  @classmethod
@@ -1739,6 +2229,19 @@ class MapFunction(IR):
1739
2229
  Column(value_column, name=value_name),
1740
2230
  ]
1741
2231
  )
2232
+ elif name == "row_index":
2233
+ col_name, offset = options
2234
+ dtype = schema[col_name]
2235
+ step = plc.Scalar.from_py(1, dtype)
2236
+ init = plc.Scalar.from_py(offset, dtype)
2237
+ index_col = Column(
2238
+ plc.filling.sequence(df.num_rows, init, step),
2239
+ is_sorted=plc.types.Sorted.YES,
2240
+ order=plc.types.Order.ASCENDING,
2241
+ null_order=plc.types.NullOrder.AFTER,
2242
+ name=col_name,
2243
+ )
2244
+ return DataFrame([index_col, *df.columns])
1742
2245
  else:
1743
2246
  raise AssertionError("Should never be reached") # pragma: no cover
1744
2247
 
@@ -1748,10 +2251,10 @@ class Union(IR):
1748
2251
 
1749
2252
  __slots__ = ("zlice",)
1750
2253
  _non_child = ("schema", "zlice")
1751
- zlice: tuple[int, int] | None
2254
+ zlice: Zlice | None
1752
2255
  """Optional slice to apply to the result."""
1753
2256
 
1754
- def __init__(self, schema: Schema, zlice: tuple[int, int] | None, *children: IR):
2257
+ def __init__(self, schema: Schema, zlice: Zlice | None, *children: IR):
1755
2258
  self.schema = schema
1756
2259
  self.zlice = zlice
1757
2260
  self._non_child_args = (zlice,)
@@ -1759,7 +2262,7 @@ class Union(IR):
1759
2262
  schema = self.children[0].schema
1760
2263
 
1761
2264
  @classmethod
1762
- def do_evaluate(cls, zlice: tuple[int, int] | None, *dfs: DataFrame) -> DataFrame:
2265
+ def do_evaluate(cls, zlice: Zlice | None, *dfs: DataFrame) -> DataFrame:
1763
2266
  """Evaluate and return a dataframe."""
1764
2267
  # TODO: only evaluate what we need if we have a slice?
1765
2268
  return DataFrame.from_table(
@@ -1771,12 +2274,18 @@ class Union(IR):
1771
2274
  class HConcat(IR):
1772
2275
  """Concatenate dataframes horizontally."""
1773
2276
 
1774
- __slots__ = ()
1775
- _non_child = ("schema",)
2277
+ __slots__ = ("should_broadcast",)
2278
+ _non_child = ("schema", "should_broadcast")
1776
2279
 
1777
- def __init__(self, schema: Schema, *children: IR):
2280
+ def __init__(
2281
+ self,
2282
+ schema: Schema,
2283
+ should_broadcast: bool, # noqa: FBT001
2284
+ *children: IR,
2285
+ ):
1778
2286
  self.schema = schema
1779
- self._non_child_args = ()
2287
+ self.should_broadcast = should_broadcast
2288
+ self._non_child_args = (should_broadcast,)
1780
2289
  self.children = children
1781
2290
 
1782
2291
  @staticmethod
@@ -1808,8 +2317,19 @@ class HConcat(IR):
1808
2317
  )
1809
2318
 
1810
2319
  @classmethod
1811
- def do_evaluate(cls, *dfs: DataFrame) -> DataFrame:
2320
+ def do_evaluate(
2321
+ cls,
2322
+ should_broadcast: bool, # noqa: FBT001
2323
+ *dfs: DataFrame,
2324
+ ) -> DataFrame:
1812
2325
  """Evaluate and return a dataframe."""
2326
+ # Special should_broadcast case.
2327
+ # Used to recombine decomposed expressions
2328
+ if should_broadcast:
2329
+ return DataFrame(
2330
+ broadcast(*itertools.chain.from_iterable(df.columns for df in dfs))
2331
+ )
2332
+
1813
2333
  max_rows = max(df.num_rows for df in dfs)
1814
2334
  # Horizontal concatenation extends shorter tables with nulls
1815
2335
  return DataFrame(
@@ -1826,3 +2346,20 @@ class HConcat(IR):
1826
2346
  )
1827
2347
  )
1828
2348
  )
2349
+
2350
+
2351
+ class Empty(IR):
2352
+ """Represents an empty DataFrame."""
2353
+
2354
+ __slots__ = ()
2355
+ _non_child = ()
2356
+
2357
+ def __init__(self) -> None:
2358
+ self.schema = {}
2359
+ self._non_child_args = ()
2360
+ self.children = ()
2361
+
2362
+ @classmethod
2363
+ def do_evaluate(cls) -> DataFrame: # pragma: no cover
2364
+ """Evaluate and return a dataframe."""
2365
+ return DataFrame([])