cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/callback.py +85 -53
  3. cudf_polars/containers/column.py +100 -7
  4. cudf_polars/containers/dataframe.py +16 -24
  5. cudf_polars/dsl/expr.py +3 -1
  6. cudf_polars/dsl/expressions/aggregation.py +3 -3
  7. cudf_polars/dsl/expressions/binaryop.py +2 -2
  8. cudf_polars/dsl/expressions/boolean.py +4 -4
  9. cudf_polars/dsl/expressions/datetime.py +39 -1
  10. cudf_polars/dsl/expressions/literal.py +3 -9
  11. cudf_polars/dsl/expressions/selection.py +2 -2
  12. cudf_polars/dsl/expressions/slicing.py +53 -0
  13. cudf_polars/dsl/expressions/sorting.py +1 -1
  14. cudf_polars/dsl/expressions/string.py +4 -4
  15. cudf_polars/dsl/expressions/unary.py +3 -2
  16. cudf_polars/dsl/ir.py +222 -93
  17. cudf_polars/dsl/nodebase.py +8 -1
  18. cudf_polars/dsl/translate.py +66 -38
  19. cudf_polars/experimental/base.py +18 -12
  20. cudf_polars/experimental/dask_serialize.py +22 -8
  21. cudf_polars/experimental/groupby.py +346 -0
  22. cudf_polars/experimental/io.py +13 -11
  23. cudf_polars/experimental/join.py +318 -0
  24. cudf_polars/experimental/parallel.py +57 -6
  25. cudf_polars/experimental/shuffle.py +194 -0
  26. cudf_polars/testing/plugin.py +23 -34
  27. cudf_polars/typing/__init__.py +33 -2
  28. cudf_polars/utils/config.py +138 -0
  29. cudf_polars/utils/conversion.py +40 -0
  30. cudf_polars/utils/dtypes.py +14 -4
  31. cudf_polars/utils/timer.py +39 -0
  32. cudf_polars/utils/versions.py +4 -3
  33. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info}/METADATA +8 -7
  34. cudf_polars_cu12-25.4.0.dist-info/RECORD +55 -0
  35. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info}/WHEEL +1 -1
  36. cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
  37. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info/licenses}/LICENSE +0 -0
  38. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/ir.py CHANGED
@@ -15,6 +15,8 @@ from __future__ import annotations
15
15
 
16
16
  import itertools
17
17
  import json
18
+ import random
19
+ import time
18
20
  from functools import cache
19
21
  from pathlib import Path
20
22
  from typing import TYPE_CHECKING, Any, ClassVar
@@ -38,7 +40,9 @@ if TYPE_CHECKING:
38
40
 
39
41
  from polars.polars import _expr_nodes as pl_expr
40
42
 
41
- from cudf_polars.typing import Schema
43
+ from cudf_polars.typing import Schema, Slice as Zlice
44
+ from cudf_polars.utils.config import ConfigOptions
45
+ from cudf_polars.utils.timer import Timer
42
46
 
43
47
 
44
48
  __all__ = [
@@ -100,7 +104,7 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
100
104
  """
101
105
  if len(columns) == 0:
102
106
  return []
103
- lengths: set[int] = {column.obj.size() for column in columns}
107
+ lengths: set[int] = {column.size for column in columns}
104
108
  if lengths == {1}:
105
109
  if target_length is None:
106
110
  return list(columns)
@@ -116,7 +120,7 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
116
120
  )
117
121
  return [
118
122
  column
119
- if column.obj.size() != 1
123
+ if column.size != 1
120
124
  else Column(
121
125
  plc.Column.from_scalar(column.obj_scalar, nrows),
122
126
  is_sorted=plc.types.Sorted.YES,
@@ -181,7 +185,9 @@ class IR(Node["IR"]):
181
185
  translation phase should fail earlier.
182
186
  """
183
187
 
184
- def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
188
+ def evaluate(
189
+ self, *, cache: MutableMapping[int, DataFrame], timer: Timer | None
190
+ ) -> DataFrame:
185
191
  """
186
192
  Evaluate the node (recursively) and return a dataframe.
187
193
 
@@ -190,6 +196,9 @@ class IR(Node["IR"]):
190
196
  cache
191
197
  Mapping from cached node ids to constructed DataFrames.
192
198
  Used to implement evaluation of the `Cache` node.
199
+ timer
200
+ If not None, a Timer object to record timings for the
201
+ evaluation of the node.
193
202
 
194
203
  Notes
195
204
  -----
@@ -208,10 +217,16 @@ class IR(Node["IR"]):
208
217
  If evaluation fails. Ideally this should not occur, since the
209
218
  translation phase should fail earlier.
210
219
  """
211
- return self.do_evaluate(
212
- *self._non_child_args,
213
- *(child.evaluate(cache=cache) for child in self.children),
214
- )
220
+ children = [child.evaluate(cache=cache, timer=timer) for child in self.children]
221
+ if timer is not None:
222
+ start = time.monotonic_ns()
223
+ result = self.do_evaluate(*self._non_child_args, *children)
224
+ end = time.monotonic_ns()
225
+ # TODO: Set better names on each class object.
226
+ timer.store(start, end, type(self).__name__)
227
+ return result
228
+ else:
229
+ return self.do_evaluate(*self._non_child_args, *children)
215
230
 
216
231
 
217
232
  class ErrorNode(IR):
@@ -284,7 +299,7 @@ class Scan(IR):
284
299
  """Reader-specific options, as dictionary."""
285
300
  cloud_options: dict[str, Any] | None
286
301
  """Cloud-related authentication options, currently ignored."""
287
- config_options: dict[str, Any]
302
+ config_options: ConfigOptions
288
303
  """GPU-specific configuration options"""
289
304
  paths: list[str]
290
305
  """List of paths to read from."""
@@ -308,7 +323,7 @@ class Scan(IR):
308
323
  typ: str,
309
324
  reader_options: dict[str, Any],
310
325
  cloud_options: dict[str, Any] | None,
311
- config_options: dict[str, Any],
326
+ config_options: ConfigOptions,
312
327
  paths: list[str],
313
328
  with_columns: list[str] | None,
314
329
  skip_rows: int,
@@ -413,7 +428,7 @@ class Scan(IR):
413
428
  self.typ,
414
429
  json.dumps(self.reader_options),
415
430
  json.dumps(self.cloud_options),
416
- json.dumps(self.config_options),
431
+ self.config_options,
417
432
  tuple(self.paths),
418
433
  tuple(self.with_columns) if self.with_columns is not None else None,
419
434
  self.skip_rows,
@@ -428,7 +443,7 @@ class Scan(IR):
428
443
  schema: Schema,
429
444
  typ: str,
430
445
  reader_options: dict[str, Any],
431
- config_options: dict[str, Any],
446
+ config_options: ConfigOptions,
432
447
  paths: list[str],
433
448
  with_columns: list[str] | None,
434
449
  skip_rows: int,
@@ -516,11 +531,18 @@ class Scan(IR):
516
531
  colnames[0],
517
532
  )
518
533
  elif typ == "parquet":
519
- parquet_options = config_options.get("parquet_options", {})
520
- if parquet_options.get("chunked", True):
521
- options = plc.io.parquet.ParquetReaderOptions.builder(
522
- plc.io.SourceInfo(paths)
523
- ).build()
534
+ filters = None
535
+ if predicate is not None and row_index is None:
536
+ # Can't apply filters during read if we have a row index.
537
+ filters = to_parquet_filter(predicate.value)
538
+ options = plc.io.parquet.ParquetReaderOptions.builder(
539
+ plc.io.SourceInfo(paths)
540
+ ).build()
541
+ if with_columns is not None:
542
+ options.set_columns(with_columns)
543
+ if filters is not None:
544
+ options.set_filter(filters)
545
+ if config_options.get("parquet_options.chunked", default=True):
524
546
  # We handle skip_rows != 0 by reading from the
525
547
  # up to n_rows + skip_rows and slicing off the
526
548
  # first skip_rows entries.
@@ -530,15 +552,15 @@ class Scan(IR):
530
552
  nrows = n_rows + skip_rows
531
553
  if nrows > -1:
532
554
  options.set_num_rows(nrows)
533
- if with_columns is not None:
534
- options.set_columns(with_columns)
535
555
  reader = plc.io.parquet.ChunkedParquetReader(
536
556
  options,
537
- chunk_read_limit=parquet_options.get(
538
- "chunk_read_limit", cls.PARQUET_DEFAULT_CHUNK_SIZE
557
+ chunk_read_limit=config_options.get(
558
+ "parquet_options.chunk_read_limit",
559
+ default=cls.PARQUET_DEFAULT_CHUNK_SIZE,
539
560
  ),
540
- pass_read_limit=parquet_options.get(
541
- "pass_read_limit", cls.PARQUET_DEFAULT_PASS_LIMIT
561
+ pass_read_limit=config_options.get(
562
+ "parquet_options.pass_read_limit",
563
+ default=cls.PARQUET_DEFAULT_PASS_LIMIT,
542
564
  ),
543
565
  )
544
566
  chk = reader.read_chunk()
@@ -575,30 +597,19 @@ class Scan(IR):
575
597
  names=names,
576
598
  )
577
599
  else:
578
- filters = None
579
- if predicate is not None and row_index is None:
580
- # Can't apply filters during read if we have a row index.
581
- filters = to_parquet_filter(predicate.value)
582
- options = plc.io.parquet.ParquetReaderOptions.builder(
583
- plc.io.SourceInfo(paths)
584
- ).build()
585
600
  if n_rows != -1:
586
601
  options.set_num_rows(n_rows)
587
602
  if skip_rows != 0:
588
603
  options.set_skip_rows(skip_rows)
589
- if with_columns is not None:
590
- options.set_columns(with_columns)
591
- if filters is not None:
592
- options.set_filter(filters)
593
604
  tbl_w_meta = plc.io.parquet.read_parquet(options)
594
605
  df = DataFrame.from_table(
595
606
  tbl_w_meta.tbl,
596
607
  # TODO: consider nested column names?
597
608
  tbl_w_meta.column_names(include_children=False),
598
609
  )
599
- if filters is not None:
600
- # Mask must have been applied.
601
- return df
610
+ if filters is not None:
611
+ # Mask must have been applied.
612
+ return df
602
613
 
603
614
  elif typ == "ndjson":
604
615
  json_schema: list[plc.io.json.NameAndType] = [
@@ -678,7 +689,9 @@ class Cache(IR):
678
689
  # return it.
679
690
  return df
680
691
 
681
- def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
692
+ def evaluate(
693
+ self, *, cache: MutableMapping[int, DataFrame], timer: Timer | None
694
+ ) -> DataFrame:
682
695
  """Evaluate and return a dataframe."""
683
696
  # We must override the recursion scheme because we don't want
684
697
  # to recurse if we're in the cache.
@@ -686,7 +699,7 @@ class Cache(IR):
686
699
  return cache[self.key]
687
700
  except KeyError:
688
701
  (value,) = self.children
689
- return cache.setdefault(self.key, value.evaluate(cache=cache))
702
+ return cache.setdefault(self.key, value.evaluate(cache=cache, timer=timer))
690
703
 
691
704
 
692
705
  class DataFrameScan(IR):
@@ -696,13 +709,13 @@ class DataFrameScan(IR):
696
709
  This typically arises from ``q.collect().lazy()``
697
710
  """
698
711
 
699
- __slots__ = ("config_options", "df", "projection")
712
+ __slots__ = ("_id_for_hash", "config_options", "df", "projection")
700
713
  _non_child = ("schema", "df", "projection", "config_options")
701
714
  df: Any
702
- """Polars LazyFrame object."""
715
+ """Polars internal PyDataFrame object."""
703
716
  projection: tuple[str, ...] | None
704
717
  """List of columns to project out."""
705
- config_options: dict[str, Any]
718
+ config_options: ConfigOptions
706
719
  """GPU-specific configuration options"""
707
720
 
708
721
  def __init__(
@@ -710,29 +723,35 @@ class DataFrameScan(IR):
710
723
  schema: Schema,
711
724
  df: Any,
712
725
  projection: Sequence[str] | None,
713
- config_options: dict[str, Any],
726
+ config_options: ConfigOptions,
714
727
  ):
715
728
  self.schema = schema
716
729
  self.df = df
717
730
  self.projection = tuple(projection) if projection is not None else None
718
731
  self.config_options = config_options
719
- self._non_child_args = (schema, df, self.projection)
732
+ self._non_child_args = (
733
+ schema,
734
+ pl.DataFrame._from_pydf(df),
735
+ self.projection,
736
+ )
720
737
  self.children = ()
738
+ self._id_for_hash = random.randint(0, 2**64 - 1)
721
739
 
722
740
  def get_hashable(self) -> Hashable:
723
741
  """
724
742
  Hashable representation of the node.
725
743
 
726
- The (heavy) dataframe object is hashed as its id, so this is
727
- not stable across runs, or repeat instances of the same equal dataframes.
744
+ The (heavy) dataframe object is not hashed. No two instances of
745
+ ``DataFrameScan`` will have the same hash, even if they have the
746
+ same schema, projection, and config options, and data.
728
747
  """
729
748
  schema_hash = tuple(self.schema.items())
730
749
  return (
731
750
  type(self),
732
751
  schema_hash,
733
- id(self.df),
752
+ self._id_for_hash,
734
753
  self.projection,
735
- json.dumps(self.config_options),
754
+ self.config_options,
736
755
  )
737
756
 
738
757
  @classmethod
@@ -743,10 +762,9 @@ class DataFrameScan(IR):
743
762
  projection: tuple[str, ...] | None,
744
763
  ) -> DataFrame:
745
764
  """Evaluate and return a dataframe."""
746
- pdf = pl.DataFrame._from_pydf(df)
747
765
  if projection is not None:
748
- pdf = pdf.select(projection)
749
- df = DataFrame.from_polars(pdf)
766
+ df = df.select(projection)
767
+ df = DataFrame.from_polars(df)
750
768
  assert all(
751
769
  c.obj.type() == dtype
752
770
  for c, dtype in zip(df.columns, schema.values(), strict=True)
@@ -820,29 +838,61 @@ class Reduce(IR):
820
838
  ) -> DataFrame: # pragma: no cover; not exposed by polars yet
821
839
  """Evaluate and return a dataframe."""
822
840
  columns = broadcast(*(e.evaluate(df) for e in exprs))
823
- assert all(column.obj.size() == 1 for column in columns)
841
+ assert all(column.size == 1 for column in columns)
824
842
  return DataFrame(columns)
825
843
 
826
844
 
827
845
  class GroupBy(IR):
828
846
  """Perform a groupby."""
829
847
 
848
+ class AggInfos:
849
+ """Serializable wrapper for GroupBy aggregation info."""
850
+
851
+ agg_requests: Sequence[expr.NamedExpr]
852
+ agg_infos: Sequence[expr.AggInfo]
853
+
854
+ def __init__(self, agg_requests: Sequence[expr.NamedExpr]):
855
+ self.agg_requests = tuple(agg_requests)
856
+ self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests]
857
+
858
+ def __reduce__(self):
859
+ """Pickle an AggInfos object."""
860
+ return (type(self), (self.agg_requests,))
861
+
862
+ class GroupbyOptions:
863
+ """Serializable wrapper for polars GroupbyOptions."""
864
+
865
+ def __init__(self, polars_groupby_options: Any):
866
+ self.dynamic = polars_groupby_options.dynamic
867
+ self.rolling = polars_groupby_options.rolling
868
+ self.slice = polars_groupby_options.slice
869
+
830
870
  __slots__ = (
831
871
  "agg_infos",
832
872
  "agg_requests",
873
+ "config_options",
833
874
  "keys",
834
875
  "maintain_order",
835
876
  "options",
836
877
  )
837
- _non_child = ("schema", "keys", "agg_requests", "maintain_order", "options")
878
+ _non_child = (
879
+ "schema",
880
+ "keys",
881
+ "agg_requests",
882
+ "maintain_order",
883
+ "options",
884
+ "config_options",
885
+ )
838
886
  keys: tuple[expr.NamedExpr, ...]
839
887
  """Grouping keys."""
840
888
  agg_requests: tuple[expr.NamedExpr, ...]
841
889
  """Aggregation expressions."""
842
890
  maintain_order: bool
843
891
  """Preserve order in groupby."""
844
- options: Any
892
+ options: GroupbyOptions
845
893
  """Arbitrary options."""
894
+ config_options: ConfigOptions
895
+ """GPU-specific configuration options"""
846
896
 
847
897
  def __init__(
848
898
  self,
@@ -851,13 +901,15 @@ class GroupBy(IR):
851
901
  agg_requests: Sequence[expr.NamedExpr],
852
902
  maintain_order: bool, # noqa: FBT001
853
903
  options: Any,
904
+ config_options: ConfigOptions,
854
905
  df: IR,
855
906
  ):
856
907
  self.schema = schema
857
908
  self.keys = tuple(keys)
858
909
  self.agg_requests = tuple(agg_requests)
859
910
  self.maintain_order = maintain_order
860
- self.options = options
911
+ self.options = self.GroupbyOptions(options)
912
+ self.config_options = config_options
861
913
  self.children = (df,)
862
914
  if self.options.rolling:
863
915
  raise NotImplementedError(
@@ -867,13 +919,12 @@ class GroupBy(IR):
867
919
  raise NotImplementedError("dynamic group by")
868
920
  if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests):
869
921
  raise NotImplementedError("Nested aggregations in groupby")
870
- self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests]
871
922
  self._non_child_args = (
872
923
  self.keys,
873
924
  self.agg_requests,
874
925
  maintain_order,
875
- options,
876
- self.agg_infos,
926
+ self.options,
927
+ self.AggInfos(self.agg_requests),
877
928
  )
878
929
 
879
930
  @staticmethod
@@ -910,8 +961,8 @@ class GroupBy(IR):
910
961
  keys_in: Sequence[expr.NamedExpr],
911
962
  agg_requests: Sequence[expr.NamedExpr],
912
963
  maintain_order: bool, # noqa: FBT001
913
- options: Any,
914
- agg_infos: Sequence[expr.AggInfo],
964
+ options: GroupbyOptions,
965
+ agg_info_wrapper: AggInfos,
915
966
  df: DataFrame,
916
967
  ):
917
968
  """Evaluate and return a dataframe."""
@@ -931,7 +982,7 @@ class GroupBy(IR):
931
982
  # TODO: uniquify
932
983
  requests = []
933
984
  replacements: list[expr.Expr] = []
934
- for info in agg_infos:
985
+ for info in agg_info_wrapper.agg_infos:
935
986
  for pre_eval, req, rep in info.requests:
936
987
  if pre_eval is None:
937
988
  # A count aggregation, doesn't touch the column,
@@ -1002,6 +1053,20 @@ class GroupBy(IR):
1002
1053
  class ConditionalJoin(IR):
1003
1054
  """A conditional inner join of two dataframes on a predicate."""
1004
1055
 
1056
+ class Predicate:
1057
+ """Serializable wrapper for a predicate expression."""
1058
+
1059
+ predicate: expr.Expr
1060
+ ast: plc.expressions.Expression
1061
+
1062
+ def __init__(self, predicate: expr.Expr):
1063
+ self.predicate = predicate
1064
+ self.ast = to_ast(predicate)
1065
+
1066
+ def __reduce__(self):
1067
+ """Pickle a Predicate object."""
1068
+ return (type(self), (self.predicate,))
1069
+
1005
1070
  __slots__ = ("ast_predicate", "options", "predicate")
1006
1071
  _non_child = ("schema", "predicate", "options")
1007
1072
  predicate: expr.Expr
@@ -1012,7 +1077,7 @@ class ConditionalJoin(IR):
1012
1077
  pl_expr.Operator | Iterable[pl_expr.Operator],
1013
1078
  ],
1014
1079
  bool,
1015
- tuple[int, int] | None,
1080
+ Zlice | None,
1016
1081
  str,
1017
1082
  bool,
1018
1083
  Literal["none", "left", "right", "left_right", "right_left"],
@@ -1020,7 +1085,7 @@ class ConditionalJoin(IR):
1020
1085
  """
1021
1086
  tuple of options:
1022
1087
  - predicates: tuple of ir join type (eg. ie_join) and (In)Equality conditions
1023
- - join_nulls: do nulls compare equal?
1088
+ - nulls_equal: do nulls compare equal?
1024
1089
  - slice: optional slice to perform after joining.
1025
1090
  - suffix: string suffix for right columns if names match
1026
1091
  - coalesce: should key columns be coalesced (only makes sense for outer joins)
@@ -1034,30 +1099,34 @@ class ConditionalJoin(IR):
1034
1099
  self.predicate = predicate
1035
1100
  self.options = options
1036
1101
  self.children = (left, right)
1037
- self.ast_predicate = to_ast(predicate)
1038
- _, join_nulls, zlice, suffix, coalesce, maintain_order = self.options
1102
+ predicate_wrapper = self.Predicate(predicate)
1103
+ _, nulls_equal, zlice, suffix, coalesce, maintain_order = self.options
1039
1104
  # Preconditions from polars
1040
- assert not join_nulls
1105
+ assert not nulls_equal
1041
1106
  assert not coalesce
1042
1107
  assert maintain_order == "none"
1043
- if self.ast_predicate is None:
1108
+ if predicate_wrapper.ast is None:
1044
1109
  raise NotImplementedError(
1045
1110
  f"Conditional join with predicate {predicate}"
1046
1111
  ) # pragma: no cover; polars never delivers expressions we can't handle
1047
- self._non_child_args = (self.ast_predicate, zlice, suffix, maintain_order)
1112
+ self._non_child_args = (predicate_wrapper, zlice, suffix, maintain_order)
1048
1113
 
1049
1114
  @classmethod
1050
1115
  def do_evaluate(
1051
1116
  cls,
1052
- predicate: plc.expressions.Expression,
1053
- zlice: tuple[int, int] | None,
1117
+ predicate_wrapper: Predicate,
1118
+ zlice: Zlice | None,
1054
1119
  suffix: str,
1055
1120
  maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
1056
1121
  left: DataFrame,
1057
1122
  right: DataFrame,
1058
1123
  ) -> DataFrame:
1059
1124
  """Evaluate and return a dataframe."""
1060
- lg, rg = plc.join.conditional_inner_join(left.table, right.table, predicate)
1125
+ lg, rg = plc.join.conditional_inner_join(
1126
+ left.table,
1127
+ right.table,
1128
+ predicate_wrapper.ast,
1129
+ )
1061
1130
  left = DataFrame.from_table(
1062
1131
  plc.copying.gather(
1063
1132
  left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
@@ -1084,8 +1153,8 @@ class ConditionalJoin(IR):
1084
1153
  class Join(IR):
1085
1154
  """A join of two dataframes."""
1086
1155
 
1087
- __slots__ = ("left_on", "options", "right_on")
1088
- _non_child = ("schema", "left_on", "right_on", "options")
1156
+ __slots__ = ("config_options", "left_on", "options", "right_on")
1157
+ _non_child = ("schema", "left_on", "right_on", "options", "config_options")
1089
1158
  left_on: tuple[expr.NamedExpr, ...]
1090
1159
  """List of expressions used as keys in the left frame."""
1091
1160
  right_on: tuple[expr.NamedExpr, ...]
@@ -1093,7 +1162,7 @@ class Join(IR):
1093
1162
  options: tuple[
1094
1163
  Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
1095
1164
  bool,
1096
- tuple[int, int] | None,
1165
+ Zlice | None,
1097
1166
  str,
1098
1167
  bool,
1099
1168
  Literal["none", "left", "right", "left_right", "right_left"],
@@ -1101,12 +1170,14 @@ class Join(IR):
1101
1170
  """
1102
1171
  tuple of options:
1103
1172
  - how: join type
1104
- - join_nulls: do nulls compare equal?
1173
+ - nulls_equal: do nulls compare equal?
1105
1174
  - slice: optional slice to perform after joining.
1106
1175
  - suffix: string suffix for right columns if names match
1107
1176
  - coalesce: should key columns be coalesced (only makes sense for outer joins)
1108
1177
  - maintain_order: which DataFrame row order to preserve, if any
1109
1178
  """
1179
+ config_options: ConfigOptions
1180
+ """GPU-specific configuration options"""
1110
1181
 
1111
1182
  def __init__(
1112
1183
  self,
@@ -1114,6 +1185,7 @@ class Join(IR):
1114
1185
  left_on: Sequence[expr.NamedExpr],
1115
1186
  right_on: Sequence[expr.NamedExpr],
1116
1187
  options: Any,
1188
+ config_options: ConfigOptions,
1117
1189
  left: IR,
1118
1190
  right: IR,
1119
1191
  ):
@@ -1121,6 +1193,7 @@ class Join(IR):
1121
1193
  self.left_on = tuple(left_on)
1122
1194
  self.right_on = tuple(right_on)
1123
1195
  self.options = options
1196
+ self.config_options = config_options
1124
1197
  self.children = (left, right)
1125
1198
  self._non_child_args = (self.left_on, self.right_on, self.options)
1126
1199
  # TODO: Implement maintain_order
@@ -1227,7 +1300,7 @@ class Join(IR):
1227
1300
  options: tuple[
1228
1301
  Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
1229
1302
  bool,
1230
- tuple[int, int] | None,
1303
+ Zlice | None,
1231
1304
  str,
1232
1305
  bool,
1233
1306
  Literal["none", "left", "right", "left_right", "right_left"],
@@ -1236,7 +1309,7 @@ class Join(IR):
1236
1309
  right: DataFrame,
1237
1310
  ) -> DataFrame:
1238
1311
  """Evaluate and return a dataframe."""
1239
- how, join_nulls, zlice, suffix, coalesce, _ = options
1312
+ how, nulls_equal, zlice, suffix, coalesce, _ = options
1240
1313
  if how == "Cross":
1241
1314
  # Separate implementation, since cross_join returns the
1242
1315
  # result, not the gather maps
@@ -1264,7 +1337,7 @@ class Join(IR):
1264
1337
  right_on = DataFrame(broadcast(*(e.evaluate(right) for e in right_on_exprs)))
1265
1338
  null_equality = (
1266
1339
  plc.types.NullEquality.EQUAL
1267
- if join_nulls
1340
+ if nulls_equal
1268
1341
  else plc.types.NullEquality.UNEQUAL
1269
1342
  )
1270
1343
  join_fn, left_policy, right_policy = cls._joiners(how)
@@ -1385,7 +1458,7 @@ class Distinct(IR):
1385
1458
  subset: frozenset[str] | None
1386
1459
  """Which columns should be used to define distinctness. If None,
1387
1460
  then all columns are used."""
1388
- zlice: tuple[int, int] | None
1461
+ zlice: Zlice | None
1389
1462
  """Optional slice to apply to the result."""
1390
1463
  stable: bool
1391
1464
  """Should the result maintain ordering."""
@@ -1395,7 +1468,7 @@ class Distinct(IR):
1395
1468
  schema: Schema,
1396
1469
  keep: plc.stream_compaction.DuplicateKeepOption,
1397
1470
  subset: frozenset[str] | None,
1398
- zlice: tuple[int, int] | None,
1471
+ zlice: Zlice | None,
1399
1472
  stable: bool, # noqa: FBT001
1400
1473
  df: IR,
1401
1474
  ):
@@ -1419,7 +1492,7 @@ class Distinct(IR):
1419
1492
  cls,
1420
1493
  keep: plc.stream_compaction.DuplicateKeepOption,
1421
1494
  subset: frozenset[str] | None,
1422
- zlice: tuple[int, int] | None,
1495
+ zlice: Zlice | None,
1423
1496
  stable: bool, # noqa: FBT001
1424
1497
  df: DataFrame,
1425
1498
  ):
@@ -1475,7 +1548,7 @@ class Sort(IR):
1475
1548
  """Null sorting location for each sort key."""
1476
1549
  stable: bool
1477
1550
  """Should the sort be stable?"""
1478
- zlice: tuple[int, int] | None
1551
+ zlice: Zlice | None
1479
1552
  """Optional slice to apply to the result."""
1480
1553
 
1481
1554
  def __init__(
@@ -1485,7 +1558,7 @@ class Sort(IR):
1485
1558
  order: Sequence[plc.types.Order],
1486
1559
  null_order: Sequence[plc.types.NullOrder],
1487
1560
  stable: bool, # noqa: FBT001
1488
- zlice: tuple[int, int] | None,
1561
+ zlice: Zlice | None,
1489
1562
  df: IR,
1490
1563
  ):
1491
1564
  self.schema = schema
@@ -1510,7 +1583,7 @@ class Sort(IR):
1510
1583
  order: Sequence[plc.types.Order],
1511
1584
  null_order: Sequence[plc.types.NullOrder],
1512
1585
  stable: bool, # noqa: FBT001
1513
- zlice: tuple[int, int] | None,
1586
+ zlice: Zlice | None,
1514
1587
  df: DataFrame,
1515
1588
  ) -> DataFrame:
1516
1589
  """Evaluate and return a dataframe."""
@@ -1608,6 +1681,41 @@ class Projection(IR):
1608
1681
  return DataFrame(columns)
1609
1682
 
1610
1683
 
1684
+ class MergeSorted(IR):
1685
+ """Merge sorted operation."""
1686
+
1687
+ __slots__ = ("key",)
1688
+ _non_child = ("schema", "key")
1689
+ key: str
1690
+ """Key that is sorted."""
1691
+
1692
+ def __init__(self, schema: Schema, key: str, left: IR, right: IR):
1693
+ assert isinstance(left, Sort)
1694
+ assert isinstance(right, Sort)
1695
+ assert left.order == right.order
1696
+ assert len(left.schema.keys()) <= len(right.schema.keys())
1697
+ self.schema = schema
1698
+ self.key = key
1699
+ self.children = (left, right)
1700
+ self._non_child_args = (key,)
1701
+
1702
+ @classmethod
1703
+ def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
1704
+ left, right = dfs
1705
+ right = right.discard_columns(right.column_names_set - left.column_names_set)
1706
+ on_col_left = left.select_columns({key})[0]
1707
+ on_col_right = right.select_columns({key})[0]
1708
+ return DataFrame.from_table(
1709
+ plc.merge.merge(
1710
+ [right.table, left.table],
1711
+ [left.column_names.index(key), right.column_names.index(key)],
1712
+ [on_col_left.order, on_col_right.order],
1713
+ [on_col_left.null_order, on_col_right.null_order],
1714
+ ),
1715
+ left.column_names,
1716
+ )
1717
+
1718
+
1611
1719
  class MapFunction(IR):
1612
1720
  """Apply some function to a dataframe."""
1613
1721
 
@@ -1621,13 +1729,10 @@ class MapFunction(IR):
1621
1729
  _NAMES: ClassVar[frozenset[str]] = frozenset(
1622
1730
  [
1623
1731
  "rechunk",
1624
- # libcudf merge is not stable wrt order of inputs, since
1625
- # it uses a priority queue to manage the tables it produces.
1626
- # See: https://github.com/rapidsai/cudf/issues/16010
1627
- # "merge_sorted",
1628
1732
  "rename",
1629
1733
  "explode",
1630
1734
  "unpivot",
1735
+ "row_index",
1631
1736
  ]
1632
1737
  )
1633
1738
 
@@ -1636,8 +1741,12 @@ class MapFunction(IR):
1636
1741
  self.name = name
1637
1742
  self.options = options
1638
1743
  self.children = (df,)
1639
- if self.name not in MapFunction._NAMES:
1640
- raise NotImplementedError(f"Unhandled map function {self.name}")
1744
+ if (
1745
+ self.name not in MapFunction._NAMES
1746
+ ): # pragma: no cover; need more polars rust functions
1747
+ raise NotImplementedError(
1748
+ f"Unhandled map function {self.name}"
1749
+ ) # pragma: no cover
1641
1750
  if self.name == "explode":
1642
1751
  (to_explode,) = self.options
1643
1752
  if len(to_explode) > 1:
@@ -1674,6 +1783,9 @@ class MapFunction(IR):
1674
1783
  variable_name,
1675
1784
  value_name,
1676
1785
  )
1786
+ elif self.name == "row_index":
1787
+ col_name, offset = options
1788
+ self.options = (col_name, offset)
1677
1789
  self._non_child_args = (schema, name, self.options)
1678
1790
 
1679
1791
  @classmethod
@@ -1739,6 +1851,23 @@ class MapFunction(IR):
1739
1851
  Column(value_column, name=value_name),
1740
1852
  ]
1741
1853
  )
1854
+ elif name == "row_index":
1855
+ col_name, offset = options
1856
+ dtype = schema[col_name]
1857
+ step = plc.interop.from_arrow(
1858
+ pa.scalar(1, type=plc.interop.to_arrow(dtype))
1859
+ )
1860
+ init = plc.interop.from_arrow(
1861
+ pa.scalar(offset, type=plc.interop.to_arrow(dtype))
1862
+ )
1863
+ index_col = Column(
1864
+ plc.filling.sequence(df.num_rows, init, step),
1865
+ is_sorted=plc.types.Sorted.YES,
1866
+ order=plc.types.Order.ASCENDING,
1867
+ null_order=plc.types.NullOrder.AFTER,
1868
+ name=col_name,
1869
+ )
1870
+ return DataFrame([index_col, *df.columns])
1742
1871
  else:
1743
1872
  raise AssertionError("Should never be reached") # pragma: no cover
1744
1873
 
@@ -1748,10 +1877,10 @@ class Union(IR):
1748
1877
 
1749
1878
  __slots__ = ("zlice",)
1750
1879
  _non_child = ("schema", "zlice")
1751
- zlice: tuple[int, int] | None
1880
+ zlice: Zlice | None
1752
1881
  """Optional slice to apply to the result."""
1753
1882
 
1754
- def __init__(self, schema: Schema, zlice: tuple[int, int] | None, *children: IR):
1883
+ def __init__(self, schema: Schema, zlice: Zlice | None, *children: IR):
1755
1884
  self.schema = schema
1756
1885
  self.zlice = zlice
1757
1886
  self._non_child_args = (zlice,)
@@ -1759,7 +1888,7 @@ class Union(IR):
1759
1888
  schema = self.children[0].schema
1760
1889
 
1761
1890
  @classmethod
1762
- def do_evaluate(cls, zlice: tuple[int, int] | None, *dfs: DataFrame) -> DataFrame:
1891
+ def do_evaluate(cls, zlice: Zlice | None, *dfs: DataFrame) -> DataFrame:
1763
1892
  """Evaluate and return a dataframe."""
1764
1893
  # TODO: only evaluate what we need if we have a slice?
1765
1894
  return DataFrame.from_table(