cudf-polars-cu12 24.12.0__py3-none-any.whl → 25.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/__init__.py +1 -1
  3. cudf_polars/callback.py +28 -3
  4. cudf_polars/containers/__init__.py +1 -1
  5. cudf_polars/dsl/expr.py +16 -16
  6. cudf_polars/dsl/expressions/aggregation.py +21 -4
  7. cudf_polars/dsl/expressions/base.py +7 -2
  8. cudf_polars/dsl/expressions/binaryop.py +1 -0
  9. cudf_polars/dsl/expressions/boolean.py +65 -22
  10. cudf_polars/dsl/expressions/datetime.py +82 -20
  11. cudf_polars/dsl/expressions/literal.py +2 -0
  12. cudf_polars/dsl/expressions/rolling.py +3 -1
  13. cudf_polars/dsl/expressions/selection.py +3 -1
  14. cudf_polars/dsl/expressions/sorting.py +2 -0
  15. cudf_polars/dsl/expressions/string.py +118 -39
  16. cudf_polars/dsl/expressions/ternary.py +1 -0
  17. cudf_polars/dsl/expressions/unary.py +11 -1
  18. cudf_polars/dsl/ir.py +173 -122
  19. cudf_polars/dsl/to_ast.py +4 -6
  20. cudf_polars/dsl/translate.py +53 -21
  21. cudf_polars/dsl/traversal.py +10 -10
  22. cudf_polars/experimental/base.py +43 -0
  23. cudf_polars/experimental/dispatch.py +84 -0
  24. cudf_polars/experimental/io.py +325 -0
  25. cudf_polars/experimental/parallel.py +253 -0
  26. cudf_polars/experimental/select.py +36 -0
  27. cudf_polars/testing/asserts.py +14 -5
  28. cudf_polars/testing/plugin.py +60 -4
  29. cudf_polars/typing/__init__.py +5 -5
  30. cudf_polars/utils/dtypes.py +9 -7
  31. cudf_polars/utils/versions.py +4 -7
  32. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.0.dist-info}/METADATA +6 -6
  33. cudf_polars_cu12-25.2.0.dist-info/RECORD +48 -0
  34. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.0.dist-info}/WHEEL +1 -1
  35. cudf_polars_cu12-24.12.0.dist-info/RECORD +0 -43
  36. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.0.dist-info}/LICENSE +0 -0
  37. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/ir.py CHANGED
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """
4
4
  DSL nodes for the LogicalPlan of polars.
@@ -31,35 +31,36 @@ from cudf_polars.containers import Column, DataFrame
31
31
  from cudf_polars.dsl.nodebase import Node
32
32
  from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
33
33
  from cudf_polars.utils import dtypes
34
- from cudf_polars.utils.versions import POLARS_VERSION_GT_112
35
34
 
36
35
  if TYPE_CHECKING:
37
- from collections.abc import Callable, Hashable, MutableMapping, Sequence
36
+ from collections.abc import Callable, Hashable, Iterable, MutableMapping, Sequence
38
37
  from typing import Literal
39
38
 
39
+ from polars.polars import _expr_nodes as pl_expr
40
+
40
41
  from cudf_polars.typing import Schema
41
42
 
42
43
 
43
44
  __all__ = [
44
45
  "IR",
45
- "ErrorNode",
46
- "PythonScan",
47
- "Scan",
48
46
  "Cache",
49
- "DataFrameScan",
50
- "Select",
51
- "GroupBy",
52
- "Join",
53
47
  "ConditionalJoin",
54
- "HStack",
48
+ "DataFrameScan",
55
49
  "Distinct",
56
- "Sort",
57
- "Slice",
50
+ "ErrorNode",
58
51
  "Filter",
59
- "Projection",
52
+ "GroupBy",
53
+ "HConcat",
54
+ "HStack",
55
+ "Join",
60
56
  "MapFunction",
57
+ "Projection",
58
+ "PythonScan",
59
+ "Scan",
60
+ "Select",
61
+ "Slice",
62
+ "Sort",
61
63
  "Union",
62
- "HConcat",
63
64
  ]
64
65
 
65
66
 
@@ -130,7 +131,7 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
130
131
  class IR(Node["IR"]):
131
132
  """Abstract plan node, representing an unevaluated dataframe."""
132
133
 
133
- __slots__ = ("schema", "_non_child_args")
134
+ __slots__ = ("_non_child_args", "schema")
134
135
  # This annotation is needed because of https://github.com/python/mypy/issues/17981
135
136
  _non_child: ClassVar[tuple[str, ...]] = ("schema",)
136
137
  # Concrete classes should set this up with the arguments that will
@@ -253,16 +254,16 @@ class Scan(IR):
253
254
  """Input from files."""
254
255
 
255
256
  __slots__ = (
256
- "typ",
257
- "reader_options",
258
257
  "cloud_options",
259
258
  "config_options",
260
- "paths",
261
- "with_columns",
262
- "skip_rows",
263
259
  "n_rows",
264
- "row_index",
260
+ "paths",
265
261
  "predicate",
262
+ "reader_options",
263
+ "row_index",
264
+ "skip_rows",
265
+ "typ",
266
+ "with_columns",
266
267
  )
267
268
  _non_child = (
268
269
  "schema",
@@ -476,23 +477,28 @@ class Scan(IR):
476
477
  with path.open() as f:
477
478
  while f.readline() == "\n":
478
479
  skiprows += 1
479
- tbl_w_meta = plc.io.csv.read_csv(
480
- plc.io.SourceInfo([path]),
481
- delimiter=sep,
482
- quotechar=quote,
483
- lineterminator=eol,
484
- col_names=column_names,
485
- header=header,
486
- usecols=usecols,
487
- na_filter=True,
488
- na_values=null_values,
489
- keep_default_na=False,
490
- skiprows=skiprows,
491
- comment=comment,
492
- decimal=decimal,
493
- dtypes=schema,
494
- nrows=n_rows,
480
+ options = (
481
+ plc.io.csv.CsvReaderOptions.builder(plc.io.SourceInfo([path]))
482
+ .nrows(n_rows)
483
+ .skiprows(skiprows)
484
+ .lineterminator(str(eol))
485
+ .quotechar(str(quote))
486
+ .decimal(decimal)
487
+ .keep_default_na(keep_default_na=False)
488
+ .na_filter(na_filter=True)
489
+ .build()
495
490
  )
491
+ options.set_delimiter(str(sep))
492
+ if column_names is not None:
493
+ options.set_names([str(name) for name in column_names])
494
+ options.set_header(header)
495
+ options.set_dtypes(schema)
496
+ if usecols is not None:
497
+ options.set_use_cols_names([str(name) for name in usecols])
498
+ options.set_na_values(null_values)
499
+ if comment is not None:
500
+ options.set_comment(comment)
501
+ tbl_w_meta = plc.io.csv.read_csv(options)
496
502
  pieces.append(tbl_w_meta)
497
503
  if read_partial:
498
504
  n_rows -= tbl_w_meta.tbl.num_rows()
@@ -512,17 +518,22 @@ class Scan(IR):
512
518
  elif typ == "parquet":
513
519
  parquet_options = config_options.get("parquet_options", {})
514
520
  if parquet_options.get("chunked", True):
521
+ options = plc.io.parquet.ParquetReaderOptions.builder(
522
+ plc.io.SourceInfo(paths)
523
+ ).build()
524
+ # We handle skip_rows != 0 by reading from the
525
+ # up to n_rows + skip_rows and slicing off the
526
+ # first skip_rows entries.
527
+ # TODO: Remove this workaround once
528
+ # https://github.com/rapidsai/cudf/issues/16186
529
+ # is fixed
530
+ nrows = n_rows + skip_rows
531
+ if nrows > -1:
532
+ options.set_num_rows(nrows)
533
+ if with_columns is not None:
534
+ options.set_columns(with_columns)
515
535
  reader = plc.io.parquet.ChunkedParquetReader(
516
- plc.io.SourceInfo(paths),
517
- columns=with_columns,
518
- # We handle skip_rows != 0 by reading from the
519
- # up to n_rows + skip_rows and slicing off the
520
- # first skip_rows entries.
521
- # TODO: Remove this workaround once
522
- # https://github.com/rapidsai/cudf/issues/16186
523
- # is fixed
524
- nrows=n_rows + skip_rows,
525
- skip_rows=0,
536
+ options,
526
537
  chunk_read_limit=parquet_options.get(
527
538
  "chunk_read_limit", cls.PARQUET_DEFAULT_CHUNK_SIZE
528
539
  ),
@@ -568,13 +579,18 @@ class Scan(IR):
568
579
  if predicate is not None and row_index is None:
569
580
  # Can't apply filters during read if we have a row index.
570
581
  filters = to_parquet_filter(predicate.value)
571
- tbl_w_meta = plc.io.parquet.read_parquet(
572
- plc.io.SourceInfo(paths),
573
- columns=with_columns,
574
- filters=filters,
575
- nrows=n_rows,
576
- skip_rows=skip_rows,
577
- )
582
+ options = plc.io.parquet.ParquetReaderOptions.builder(
583
+ plc.io.SourceInfo(paths)
584
+ ).build()
585
+ if n_rows != -1:
586
+ options.set_num_rows(n_rows)
587
+ if skip_rows != 0:
588
+ options.set_skip_rows(skip_rows)
589
+ if with_columns is not None:
590
+ options.set_columns(with_columns)
591
+ if filters is not None:
592
+ options.set_filter(filters)
593
+ tbl_w_meta = plc.io.parquet.read_parquet(options)
578
594
  df = DataFrame.from_table(
579
595
  tbl_w_meta.tbl,
580
596
  # TODO: consider nested column names?
@@ -589,10 +605,12 @@ class Scan(IR):
589
605
  (name, typ, []) for name, typ in schema.items()
590
606
  ]
591
607
  plc_tbl_w_meta = plc.io.json.read_json(
592
- plc.io.SourceInfo(paths),
593
- lines=True,
594
- dtypes=json_schema,
595
- prune_columns=True,
608
+ plc.io.json._setup_json_reader_options(
609
+ plc.io.SourceInfo(paths),
610
+ lines=True,
611
+ dtypes=json_schema,
612
+ prune_columns=True,
613
+ )
596
614
  )
597
615
  # TODO: I don't think cudf-polars supports nested types in general right now
598
616
  # (but when it does, we should pass child column names from nested columns in)
@@ -609,12 +627,7 @@ class Scan(IR):
609
627
  ) # pragma: no cover; post init trips first
610
628
  if row_index is not None:
611
629
  name, offset = row_index
612
- if POLARS_VERSION_GT_112:
613
- # If we sliced away some data from the start, that
614
- # shifts the row index.
615
- # But prior to 1.13, polars had this wrong, so we match behaviour
616
- # https://github.com/pola-rs/polars/issues/19607
617
- offset += skip_rows
630
+ offset += skip_rows
618
631
  dtype = schema[name]
619
632
  step = plc.interop.from_arrow(
620
633
  pa.scalar(1, type=plc.interop.to_arrow(dtype))
@@ -683,27 +696,27 @@ class DataFrameScan(IR):
683
696
  This typically arises from ``q.collect().lazy()``
684
697
  """
685
698
 
686
- __slots__ = ("df", "projection", "predicate")
687
- _non_child = ("schema", "df", "projection", "predicate")
699
+ __slots__ = ("config_options", "df", "projection")
700
+ _non_child = ("schema", "df", "projection", "config_options")
688
701
  df: Any
689
702
  """Polars LazyFrame object."""
690
703
  projection: tuple[str, ...] | None
691
704
  """List of columns to project out."""
692
- predicate: expr.NamedExpr | None
693
- """Mask to apply."""
705
+ config_options: dict[str, Any]
706
+ """GPU-specific configuration options"""
694
707
 
695
708
  def __init__(
696
709
  self,
697
710
  schema: Schema,
698
711
  df: Any,
699
712
  projection: Sequence[str] | None,
700
- predicate: expr.NamedExpr | None,
713
+ config_options: dict[str, Any],
701
714
  ):
702
715
  self.schema = schema
703
716
  self.df = df
704
717
  self.projection = tuple(projection) if projection is not None else None
705
- self.predicate = predicate
706
- self._non_child_args = (schema, df, self.projection, predicate)
718
+ self.config_options = config_options
719
+ self._non_child_args = (schema, df, self.projection)
707
720
  self.children = ()
708
721
 
709
722
  def get_hashable(self) -> Hashable:
@@ -714,7 +727,13 @@ class DataFrameScan(IR):
714
727
  not stable across runs, or repeat instances of the same equal dataframes.
715
728
  """
716
729
  schema_hash = tuple(self.schema.items())
717
- return (type(self), schema_hash, id(self.df), self.projection, self.predicate)
730
+ return (
731
+ type(self),
732
+ schema_hash,
733
+ id(self.df),
734
+ self.projection,
735
+ json.dumps(self.config_options),
736
+ )
718
737
 
719
738
  @classmethod
720
739
  def do_evaluate(
@@ -722,7 +741,6 @@ class DataFrameScan(IR):
722
741
  schema: Schema,
723
742
  df: Any,
724
743
  projection: tuple[str, ...] | None,
725
- predicate: expr.NamedExpr | None,
726
744
  ) -> DataFrame:
727
745
  """Evaluate and return a dataframe."""
728
746
  pdf = pl.DataFrame._from_pydf(df)
@@ -733,11 +751,7 @@ class DataFrameScan(IR):
733
751
  c.obj.type() == dtype
734
752
  for c, dtype in zip(df.columns, schema.values(), strict=True)
735
753
  )
736
- if predicate is not None:
737
- (mask,) = broadcast(predicate.evaluate(df), target_length=df.num_rows)
738
- return df.filter(mask)
739
- else:
740
- return df
754
+ return df
741
755
 
742
756
 
743
757
  class Select(IR):
@@ -814,11 +828,11 @@ class GroupBy(IR):
814
828
  """Perform a groupby."""
815
829
 
816
830
  __slots__ = (
831
+ "agg_infos",
817
832
  "agg_requests",
818
833
  "keys",
819
834
  "maintain_order",
820
835
  "options",
821
- "agg_infos",
822
836
  )
823
837
  _non_child = ("schema", "keys", "agg_requests", "maintain_order", "options")
824
838
  keys: tuple[expr.NamedExpr, ...]
@@ -988,10 +1002,30 @@ class GroupBy(IR):
988
1002
  class ConditionalJoin(IR):
989
1003
  """A conditional inner join of two dataframes on a predicate."""
990
1004
 
991
- __slots__ = ("predicate", "options", "ast_predicate")
1005
+ __slots__ = ("ast_predicate", "options", "predicate")
992
1006
  _non_child = ("schema", "predicate", "options")
993
1007
  predicate: expr.Expr
994
- options: tuple
1008
+ """Expression predicate to join on"""
1009
+ options: tuple[
1010
+ tuple[
1011
+ str,
1012
+ pl_expr.Operator | Iterable[pl_expr.Operator],
1013
+ ],
1014
+ bool,
1015
+ tuple[int, int] | None,
1016
+ str,
1017
+ bool,
1018
+ Literal["none", "left", "right", "left_right", "right_left"],
1019
+ ]
1020
+ """
1021
+ tuple of options:
1022
+ - predicates: tuple of ir join type (eg. ie_join) and (In)Equality conditions
1023
+ - join_nulls: do nulls compare equal?
1024
+ - slice: optional slice to perform after joining.
1025
+ - suffix: string suffix for right columns if names match
1026
+ - coalesce: should key columns be coalesced (only makes sense for outer joins)
1027
+ - maintain_order: which DataFrame row order to preserve, if any
1028
+ """
995
1029
 
996
1030
  def __init__(
997
1031
  self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
@@ -1001,15 +1035,16 @@ class ConditionalJoin(IR):
1001
1035
  self.options = options
1002
1036
  self.children = (left, right)
1003
1037
  self.ast_predicate = to_ast(predicate)
1004
- _, join_nulls, zlice, suffix, coalesce = self.options
1038
+ _, join_nulls, zlice, suffix, coalesce, maintain_order = self.options
1005
1039
  # Preconditions from polars
1006
1040
  assert not join_nulls
1007
1041
  assert not coalesce
1042
+ assert maintain_order == "none"
1008
1043
  if self.ast_predicate is None:
1009
1044
  raise NotImplementedError(
1010
1045
  f"Conditional join with predicate {predicate}"
1011
1046
  ) # pragma: no cover; polars never delivers expressions we can't handle
1012
- self._non_child_args = (self.ast_predicate, zlice, suffix)
1047
+ self._non_child_args = (self.ast_predicate, zlice, suffix, maintain_order)
1013
1048
 
1014
1049
  @classmethod
1015
1050
  def do_evaluate(
@@ -1017,6 +1052,7 @@ class ConditionalJoin(IR):
1017
1052
  predicate: plc.expressions.Expression,
1018
1053
  zlice: tuple[int, int] | None,
1019
1054
  suffix: str,
1055
+ maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
1020
1056
  left: DataFrame,
1021
1057
  right: DataFrame,
1022
1058
  ) -> DataFrame:
@@ -1048,18 +1084,19 @@ class ConditionalJoin(IR):
1048
1084
  class Join(IR):
1049
1085
  """A join of two dataframes."""
1050
1086
 
1051
- __slots__ = ("left_on", "right_on", "options")
1087
+ __slots__ = ("left_on", "options", "right_on")
1052
1088
  _non_child = ("schema", "left_on", "right_on", "options")
1053
1089
  left_on: tuple[expr.NamedExpr, ...]
1054
1090
  """List of expressions used as keys in the left frame."""
1055
1091
  right_on: tuple[expr.NamedExpr, ...]
1056
1092
  """List of expressions used as keys in the right frame."""
1057
1093
  options: tuple[
1058
- Literal["inner", "left", "right", "full", "semi", "anti", "cross"],
1094
+ Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
1059
1095
  bool,
1060
1096
  tuple[int, int] | None,
1061
1097
  str,
1062
1098
  bool,
1099
+ Literal["none", "left", "right", "left_right", "right_left"],
1063
1100
  ]
1064
1101
  """
1065
1102
  tuple of options:
@@ -1068,6 +1105,7 @@ class Join(IR):
1068
1105
  - slice: optional slice to perform after joining.
1069
1106
  - suffix: string suffix for right columns if names match
1070
1107
  - coalesce: should key columns be coalesced (only makes sense for outer joins)
1108
+ - maintain_order: which DataFrame row order to preserve, if any
1071
1109
  """
1072
1110
 
1073
1111
  def __init__(
@@ -1085,50 +1123,48 @@ class Join(IR):
1085
1123
  self.options = options
1086
1124
  self.children = (left, right)
1087
1125
  self._non_child_args = (self.left_on, self.right_on, self.options)
1088
- if any(
1089
- isinstance(e.value, expr.Literal)
1090
- for e in itertools.chain(self.left_on, self.right_on)
1091
- ):
1092
- raise NotImplementedError("Join with literal as join key.")
1126
+ # TODO: Implement maintain_order
1127
+ if options[5] != "none":
1128
+ raise NotImplementedError("maintain_order not implemented yet")
1093
1129
 
1094
1130
  @staticmethod
1095
1131
  @cache
1096
1132
  def _joiners(
1097
- how: Literal["inner", "left", "right", "full", "semi", "anti"],
1133
+ how: Literal["Inner", "Left", "Right", "Full", "Semi", "Anti"],
1098
1134
  ) -> tuple[
1099
1135
  Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None
1100
1136
  ]:
1101
- if how == "inner":
1137
+ if how == "Inner":
1102
1138
  return (
1103
1139
  plc.join.inner_join,
1104
1140
  plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1105
1141
  plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1106
1142
  )
1107
- elif how == "left" or how == "right":
1143
+ elif how == "Left" or how == "Right":
1108
1144
  return (
1109
1145
  plc.join.left_join,
1110
1146
  plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1111
1147
  plc.copying.OutOfBoundsPolicy.NULLIFY,
1112
1148
  )
1113
- elif how == "full":
1149
+ elif how == "Full":
1114
1150
  return (
1115
1151
  plc.join.full_join,
1116
1152
  plc.copying.OutOfBoundsPolicy.NULLIFY,
1117
1153
  plc.copying.OutOfBoundsPolicy.NULLIFY,
1118
1154
  )
1119
- elif how == "semi":
1155
+ elif how == "Semi":
1120
1156
  return (
1121
1157
  plc.join.left_semi_join,
1122
1158
  plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1123
1159
  None,
1124
1160
  )
1125
- elif how == "anti":
1161
+ elif how == "Anti":
1126
1162
  return (
1127
1163
  plc.join.left_anti_join,
1128
1164
  plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1129
1165
  None,
1130
1166
  )
1131
- assert_never(how)
1167
+ assert_never(how) # pragma: no cover
1132
1168
 
1133
1169
  @staticmethod
1134
1170
  def _reorder_maps(
@@ -1189,18 +1225,19 @@ class Join(IR):
1189
1225
  left_on_exprs: Sequence[expr.NamedExpr],
1190
1226
  right_on_exprs: Sequence[expr.NamedExpr],
1191
1227
  options: tuple[
1192
- Literal["inner", "left", "right", "full", "semi", "anti", "cross"],
1228
+ Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
1193
1229
  bool,
1194
1230
  tuple[int, int] | None,
1195
1231
  str,
1196
1232
  bool,
1233
+ Literal["none", "left", "right", "left_right", "right_left"],
1197
1234
  ],
1198
1235
  left: DataFrame,
1199
1236
  right: DataFrame,
1200
1237
  ) -> DataFrame:
1201
1238
  """Evaluate and return a dataframe."""
1202
- how, join_nulls, zlice, suffix, coalesce = options
1203
- if how == "cross":
1239
+ how, join_nulls, zlice, suffix, coalesce, _ = options
1240
+ if how == "Cross":
1204
1241
  # Separate implementation, since cross_join returns the
1205
1242
  # result, not the gather maps
1206
1243
  columns = plc.join.cross_join(left.table, right.table).columns()
@@ -1237,25 +1274,32 @@ class Join(IR):
1237
1274
  table = plc.copying.gather(left.table, lg, left_policy)
1238
1275
  result = DataFrame.from_table(table, left.column_names)
1239
1276
  else:
1240
- if how == "right":
1277
+ if how == "Right":
1241
1278
  # Right join is a left join with the tables swapped
1242
1279
  left, right = right, left
1243
1280
  left_on, right_on = right_on, left_on
1244
1281
  lg, rg = join_fn(left_on.table, right_on.table, null_equality)
1245
- if how == "left" or how == "right":
1282
+ if how == "Left" or how == "Right":
1246
1283
  # Order of left table is preserved
1247
1284
  lg, rg = cls._reorder_maps(
1248
1285
  left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
1249
1286
  )
1250
- if coalesce and how == "inner":
1251
- right = right.discard_columns(right_on.column_names_set)
1287
+ if coalesce:
1288
+ if how == "Full":
1289
+ # In this case, keys must be column references,
1290
+ # possibly with dtype casting. We should use them in
1291
+ # preference to the columns from the original tables.
1292
+ left = left.with_columns(left_on.columns, replace_only=True)
1293
+ right = right.with_columns(right_on.columns, replace_only=True)
1294
+ else:
1295
+ right = right.discard_columns(right_on.column_names_set)
1252
1296
  left = DataFrame.from_table(
1253
1297
  plc.copying.gather(left.table, lg, left_policy), left.column_names
1254
1298
  )
1255
1299
  right = DataFrame.from_table(
1256
1300
  plc.copying.gather(right.table, rg, right_policy), right.column_names
1257
1301
  )
1258
- if coalesce and how != "inner":
1302
+ if coalesce and how == "Full":
1259
1303
  left = left.with_columns(
1260
1304
  (
1261
1305
  Column(
@@ -1271,7 +1315,7 @@ class Join(IR):
1271
1315
  replace_only=True,
1272
1316
  )
1273
1317
  right = right.discard_columns(right_on.column_names_set)
1274
- if how == "right":
1318
+ if how == "Right":
1275
1319
  # Undo the swap for right join before gluing together.
1276
1320
  left, right = right, left
1277
1321
  right = right.rename_columns(
@@ -1316,7 +1360,9 @@ class HStack(IR):
1316
1360
  """Evaluate and return a dataframe."""
1317
1361
  columns = [c.evaluate(df) for c in exprs]
1318
1362
  if should_broadcast:
1319
- columns = broadcast(*columns, target_length=df.num_rows)
1363
+ columns = broadcast(
1364
+ *columns, target_length=df.num_rows if df.num_columns != 0 else None
1365
+ )
1320
1366
  else:
1321
1367
  # Polars ensures this is true, but let's make sure nothing
1322
1368
  # went wrong. In this case, the parent node is a
@@ -1332,7 +1378,7 @@ class HStack(IR):
1332
1378
  class Distinct(IR):
1333
1379
  """Produce a new dataframe with distinct rows."""
1334
1380
 
1335
- __slots__ = ("keep", "subset", "zlice", "stable")
1381
+ __slots__ = ("keep", "stable", "subset", "zlice")
1336
1382
  _non_child = ("schema", "keep", "subset", "zlice", "stable")
1337
1383
  keep: plc.stream_compaction.DuplicateKeepOption
1338
1384
  """Which distinct value to keep."""
@@ -1419,7 +1465,7 @@ class Distinct(IR):
1419
1465
  class Sort(IR):
1420
1466
  """Sort a dataframe."""
1421
1467
 
1422
- __slots__ = ("by", "order", "null_order", "stable", "zlice")
1468
+ __slots__ = ("by", "null_order", "order", "stable", "zlice")
1423
1469
  _non_child = ("schema", "by", "order", "null_order", "stable", "zlice")
1424
1470
  by: tuple[expr.NamedExpr, ...]
1425
1471
  """Sort keys."""
@@ -1500,7 +1546,7 @@ class Sort(IR):
1500
1546
  class Slice(IR):
1501
1547
  """Slice a dataframe."""
1502
1548
 
1503
- __slots__ = ("offset", "length")
1549
+ __slots__ = ("length", "offset")
1504
1550
  _non_child = ("schema", "offset", "length")
1505
1551
  offset: int
1506
1552
  """Start of the slice."""
@@ -1599,13 +1645,15 @@ class MapFunction(IR):
1599
1645
  # polars requires that all to-explode columns have the
1600
1646
  # same sub-shapes
1601
1647
  raise NotImplementedError("Explode with more than one column")
1648
+ self.options = (tuple(to_explode),)
1602
1649
  elif self.name == "rename":
1603
- old, new, _ = self.options
1650
+ old, new, strict = self.options
1604
1651
  # TODO: perhaps polars should validate renaming in the IR?
1605
1652
  if len(new) != len(set(new)) or (
1606
1653
  set(new) & (set(df.schema.keys()) - set(old))
1607
1654
  ):
1608
1655
  raise NotImplementedError("Duplicate new names in rename.")
1656
+ self.options = (tuple(old), tuple(new), strict)
1609
1657
  elif self.name == "unpivot":
1610
1658
  indices, pivotees, variable_name, value_name = self.options
1611
1659
  value_name = "value" if value_name is None else value_name
@@ -1623,13 +1671,15 @@ class MapFunction(IR):
1623
1671
  self.options = (
1624
1672
  tuple(indices),
1625
1673
  tuple(pivotees),
1626
- (variable_name, schema[variable_name]),
1627
- (value_name, schema[value_name]),
1674
+ variable_name,
1675
+ value_name,
1628
1676
  )
1629
- self._non_child_args = (name, self.options)
1677
+ self._non_child_args = (schema, name, self.options)
1630
1678
 
1631
1679
  @classmethod
1632
- def do_evaluate(cls, name: str, options: Any, df: DataFrame) -> DataFrame:
1680
+ def do_evaluate(
1681
+ cls, schema: Schema, name: str, options: Any, df: DataFrame
1682
+ ) -> DataFrame:
1633
1683
  """Evaluate and return a dataframe."""
1634
1684
  if name == "rechunk":
1635
1685
  # No-op in our data model
@@ -1651,8 +1701,8 @@ class MapFunction(IR):
1651
1701
  (
1652
1702
  indices,
1653
1703
  pivotees,
1654
- (variable_name, variable_dtype),
1655
- (value_name, value_dtype),
1704
+ variable_name,
1705
+ value_name,
1656
1706
  ) = options
1657
1707
  npiv = len(pivotees)
1658
1708
  index_columns = [
@@ -1669,7 +1719,7 @@ class MapFunction(IR):
1669
1719
  plc.interop.from_arrow(
1670
1720
  pa.array(
1671
1721
  pivotees,
1672
- type=plc.interop.to_arrow(variable_dtype),
1722
+ type=plc.interop.to_arrow(schema[variable_name]),
1673
1723
  ),
1674
1724
  )
1675
1725
  ]
@@ -1677,7 +1727,10 @@ class MapFunction(IR):
1677
1727
  df.num_rows,
1678
1728
  ).columns()
1679
1729
  value_column = plc.concatenate.concatenate(
1680
- [df.column_map[pivotee].astype(value_dtype).obj for pivotee in pivotees]
1730
+ [
1731
+ df.column_map[pivotee].astype(schema[value_name]).obj
1732
+ for pivotee in pivotees
1733
+ ]
1681
1734
  )
1682
1735
  return DataFrame(
1683
1736
  [
@@ -1704,8 +1757,6 @@ class Union(IR):
1704
1757
  self._non_child_args = (zlice,)
1705
1758
  self.children = children
1706
1759
  schema = self.children[0].schema
1707
- if not all(s.schema == schema for s in self.children[1:]):
1708
- raise NotImplementedError("Schema mismatch")
1709
1760
 
1710
1761
  @classmethod
1711
1762
  def do_evaluate(cls, zlice: tuple[int, int] | None, *dfs: DataFrame) -> DataFrame:
cudf_polars/dsl/to_ast.py CHANGED
@@ -8,8 +8,6 @@ from __future__ import annotations
8
8
  from functools import partial, reduce, singledispatch
9
9
  from typing import TYPE_CHECKING, TypeAlias
10
10
 
11
- from polars.polars import _expr_nodes as pl_expr
12
-
13
11
  import pylibcudf as plc
14
12
  from pylibcudf import expressions as plc_expr
15
13
 
@@ -185,7 +183,7 @@ def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression:
185
183
 
186
184
  @_to_ast.register
187
185
  def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression:
188
- if node.name == pl_expr.BooleanFunction.IsIn:
186
+ if node.name is expr.BooleanFunction.Name.IsIn:
189
187
  needles, haystack = node.children
190
188
  if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16:
191
189
  # 16 is an arbitrary limit
@@ -204,14 +202,14 @@ def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression:
204
202
  raise NotImplementedError(
205
203
  f"Parquet filters don't support {node.name} on columns"
206
204
  )
207
- if node.name == pl_expr.BooleanFunction.IsNull:
205
+ if node.name is expr.BooleanFunction.Name.IsNull:
208
206
  return plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0]))
209
- elif node.name == pl_expr.BooleanFunction.IsNotNull:
207
+ elif node.name is expr.BooleanFunction.Name.IsNotNull:
210
208
  return plc_expr.Operation(
211
209
  plc_expr.ASTOperator.NOT,
212
210
  plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])),
213
211
  )
214
- elif node.name == pl_expr.BooleanFunction.Not:
212
+ elif node.name is expr.BooleanFunction.Name.Not:
215
213
  return plc_expr.Operation(plc_expr.ASTOperator.NOT, self(node.children[0]))
216
214
  raise NotImplementedError(f"AST conversion does not support {node.name}")
217
215