cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,226 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # TODO: remove need for this
4
+ # ruff: noqa: D101
5
+ """DSL nodes for aggregations."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from functools import partial
10
+ from typing import TYPE_CHECKING, Any, ClassVar
11
+
12
+ import pylibcudf as plc
13
+
14
+ from cudf_polars.containers import Column
15
+ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
16
+ from cudf_polars.dsl.expressions.literal import Literal
17
+
18
+ if TYPE_CHECKING:
19
+ from cudf_polars.containers import DataFrame, DataType
20
+
21
+ __all__ = ["Agg"]
22
+
23
+
24
+ class Agg(Expr):
25
+ __slots__ = ("name", "op", "options", "request")
26
+ _non_child = ("dtype", "name", "options")
27
+
28
+ def __init__(
29
+ self, dtype: DataType, name: str, options: Any, *children: Expr
30
+ ) -> None:
31
+ self.dtype = dtype
32
+ self.name = name
33
+ self.options = options
34
+ self.is_pointwise = False
35
+ self.children = children
36
+ if name not in Agg._SUPPORTED:
37
+ raise NotImplementedError(
38
+ f"Unsupported aggregation {name=}"
39
+ ) # pragma: no cover; all valid aggs are supported
40
+ # TODO: nan handling in groupby case
41
+ if name == "min":
42
+ req = plc.aggregation.min()
43
+ elif name == "max":
44
+ req = plc.aggregation.max()
45
+ elif name == "median":
46
+ req = plc.aggregation.median()
47
+ elif name == "n_unique":
48
+ # TODO: datatype of result
49
+ req = plc.aggregation.nunique(null_handling=plc.types.NullPolicy.INCLUDE)
50
+ elif name == "first" or name == "last":
51
+ req = None
52
+ elif name == "mean":
53
+ req = plc.aggregation.mean()
54
+ elif name == "sum":
55
+ req = plc.aggregation.sum()
56
+ elif name == "std":
57
+ # TODO: handle nans
58
+ req = plc.aggregation.std(ddof=options)
59
+ elif name == "var":
60
+ # TODO: handle nans
61
+ req = plc.aggregation.variance(ddof=options)
62
+ elif name == "count":
63
+ req = plc.aggregation.count(
64
+ null_handling=plc.types.NullPolicy.EXCLUDE
65
+ if not options
66
+ else plc.types.NullPolicy.INCLUDE
67
+ )
68
+ elif name == "quantile":
69
+ child, quantile = self.children
70
+ if not isinstance(quantile, Literal):
71
+ raise NotImplementedError("Only support literal quantile values")
72
+ if options == "equiprobable":
73
+ raise NotImplementedError("Quantile with equiprobable interpolation")
74
+ if plc.traits.is_duration(child.dtype.plc):
75
+ raise NotImplementedError("Quantile with duration data type")
76
+ req = plc.aggregation.quantile(
77
+ quantiles=[quantile.value], interp=Agg.interp_mapping[options]
78
+ )
79
+ else:
80
+ raise NotImplementedError(
81
+ f"Unreachable, {name=} is incorrectly listed in _SUPPORTED"
82
+ ) # pragma: no cover
83
+ self.request = req
84
+ op = getattr(self, f"_{name}", None)
85
+ if op is None:
86
+ op = partial(self._reduce, request=req)
87
+ elif name in {"min", "max"}:
88
+ op = partial(op, propagate_nans=options)
89
+ elif name == "count":
90
+ op = partial(op, include_nulls=options)
91
+ elif name in {"sum", "first", "last"}:
92
+ pass
93
+ else:
94
+ raise NotImplementedError(
95
+ f"Unreachable, supported agg {name=} has no implementation"
96
+ ) # pragma: no cover
97
+ self.op = op
98
+
99
+ _SUPPORTED: ClassVar[frozenset[str]] = frozenset(
100
+ [
101
+ "min",
102
+ "max",
103
+ "median",
104
+ "n_unique",
105
+ "first",
106
+ "last",
107
+ "mean",
108
+ "sum",
109
+ "count",
110
+ "std",
111
+ "var",
112
+ "quantile",
113
+ ]
114
+ )
115
+
116
+ interp_mapping: ClassVar[dict[str, plc.types.Interpolation]] = {
117
+ "nearest": plc.types.Interpolation.NEAREST,
118
+ "higher": plc.types.Interpolation.HIGHER,
119
+ "lower": plc.types.Interpolation.LOWER,
120
+ "midpoint": plc.types.Interpolation.MIDPOINT,
121
+ "linear": plc.types.Interpolation.LINEAR,
122
+ }
123
+
124
+ @property
125
+ def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102
126
+ if self.name == "first":
127
+ return plc.aggregation.nth_element(
128
+ 0, null_handling=plc.types.NullPolicy.INCLUDE
129
+ )
130
+ elif self.name == "last":
131
+ return plc.aggregation.nth_element(
132
+ -1, null_handling=plc.types.NullPolicy.INCLUDE
133
+ )
134
+ else:
135
+ assert self.request is not None, "Init should have raised"
136
+ return self.request
137
+
138
+ def _reduce(
139
+ self, column: Column, *, request: plc.aggregation.Aggregation
140
+ ) -> Column:
141
+ return Column(
142
+ plc.Column.from_scalar(
143
+ plc.reduce.reduce(column.obj, request, self.dtype.plc),
144
+ 1,
145
+ ),
146
+ name=column.name,
147
+ dtype=self.dtype,
148
+ )
149
+
150
+ def _count(self, column: Column, *, include_nulls: bool) -> Column:
151
+ null_count = column.null_count if not include_nulls else 0
152
+ return Column(
153
+ plc.Column.from_scalar(
154
+ plc.Scalar.from_py(column.size - null_count, self.dtype.plc),
155
+ 1,
156
+ ),
157
+ name=column.name,
158
+ dtype=self.dtype,
159
+ )
160
+
161
+ def _sum(self, column: Column) -> Column:
162
+ if column.size == 0 or column.null_count == column.size:
163
+ return Column(
164
+ plc.Column.from_scalar(
165
+ plc.Scalar.from_py(0, self.dtype.plc),
166
+ 1,
167
+ ),
168
+ name=column.name,
169
+ dtype=self.dtype,
170
+ )
171
+ return self._reduce(column, request=plc.aggregation.sum())
172
+
173
+ def _min(self, column: Column, *, propagate_nans: bool) -> Column:
174
+ if propagate_nans and column.nan_count > 0:
175
+ return Column(
176
+ plc.Column.from_scalar(
177
+ plc.Scalar.from_py(float("nan"), self.dtype.plc),
178
+ 1,
179
+ ),
180
+ name=column.name,
181
+ dtype=self.dtype,
182
+ )
183
+ if column.nan_count > 0:
184
+ column = column.mask_nans()
185
+ return self._reduce(column, request=plc.aggregation.min())
186
+
187
+ def _max(self, column: Column, *, propagate_nans: bool) -> Column:
188
+ if propagate_nans and column.nan_count > 0:
189
+ return Column(
190
+ plc.Column.from_scalar(
191
+ plc.Scalar.from_py(float("nan"), self.dtype.plc),
192
+ 1,
193
+ ),
194
+ name=column.name,
195
+ dtype=self.dtype,
196
+ )
197
+ if column.nan_count > 0:
198
+ column = column.mask_nans()
199
+ return self._reduce(column, request=plc.aggregation.max())
200
+
201
+ def _first(self, column: Column) -> Column:
202
+ return Column(
203
+ plc.copying.slice(column.obj, [0, 1])[0], name=column.name, dtype=self.dtype
204
+ )
205
+
206
+ def _last(self, column: Column) -> Column:
207
+ n = column.size
208
+ return Column(
209
+ plc.copying.slice(column.obj, [n - 1, n])[0],
210
+ name=column.name,
211
+ dtype=self.dtype,
212
+ )
213
+
214
+ def do_evaluate(
215
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
216
+ ) -> Column:
217
+ """Evaluate this expression given a dataframe for context."""
218
+ if context is not ExecutionContext.FRAME:
219
+ raise NotImplementedError(
220
+ f"Agg in context {context}"
221
+ ) # pragma: no cover; unreachable
222
+
223
+ # Aggregations like quantiles may have additional children that were
224
+ # preprocessed into pylibcudf requests.
225
+ child = self.children[0]
226
+ return self.op(child.evaluate(df, context=context))
@@ -0,0 +1,272 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # TODO: remove need for this
4
+ # ruff: noqa: D101
5
+ """Base and common classes for expression DSL nodes."""
6
+
7
+ from __future__ import annotations
8
+
9
+ import enum
10
+ from enum import IntEnum
11
+ from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple
12
+
13
+ import pylibcudf as plc
14
+
15
+ from cudf_polars.containers import Column
16
+ from cudf_polars.dsl.nodebase import Node
17
+
18
+ if TYPE_CHECKING:
19
+ from typing_extensions import Self
20
+
21
+ from cudf_polars.containers import Column, DataFrame, DataType
22
+
23
+ __all__ = ["AggInfo", "Col", "ColRef", "ExecutionContext", "Expr", "NamedExpr"]
24
+
25
+
26
+ class AggInfo(NamedTuple):
27
+ requests: list[tuple[Expr | None, plc.aggregation.Aggregation, Expr]]
28
+
29
+
30
+ class ExecutionContext(IntEnum):
31
+ FRAME = enum.auto()
32
+ GROUPBY = enum.auto()
33
+ ROLLING = enum.auto()
34
+ # Follows GROUPBY semantics but useful
35
+ # to differentiate from GROUPBY so we can
36
+ # implement agg/per-row ops independently
37
+ WINDOW = enum.auto()
38
+
39
+
40
+ class Expr(Node["Expr"]):
41
+ """An abstract expression object."""
42
+
43
+ __slots__ = ("dtype", "is_pointwise")
44
+ dtype: DataType
45
+ """Data type of the expression."""
46
+ is_pointwise: bool
47
+ """Whether this expression acts pointwise on its inputs."""
48
+ # This annotation is needed because of https://github.com/python/mypy/issues/17981
49
+ _non_child: ClassVar[tuple[str, ...]] = ("dtype",)
50
+ """Names of non-child data (not Exprs) for reconstruction."""
51
+
52
+ def do_evaluate(
53
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
54
+ ) -> Column:
55
+ """
56
+ Evaluate this expression given a dataframe for context.
57
+
58
+ Parameters
59
+ ----------
60
+ df
61
+ DataFrame that will provide columns.
62
+ context
63
+ What context are we performing this evaluation in?
64
+
65
+ Notes
66
+ -----
67
+ Do not call this function directly, but rather :meth:`evaluate`.
68
+
69
+ Returns
70
+ -------
71
+ Column representing the evaluation of the expression.
72
+
73
+ Raises
74
+ ------
75
+ NotImplementedError
76
+ If we couldn't evaluate the expression. Ideally all these
77
+ are returned during translation to the IR, but for now we
78
+ are not perfect.
79
+ """
80
+ raise NotImplementedError(
81
+ f"Evaluation of expression {type(self).__name__}"
82
+ ) # pragma: no cover; translation of unimplemented nodes trips first
83
+
84
+ def evaluate(
85
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
86
+ ) -> Column:
87
+ """
88
+ Evaluate this expression given a dataframe for context.
89
+
90
+ Parameters
91
+ ----------
92
+ df
93
+ DataFrame that will provide columns.
94
+ context
95
+ What context are we performing this evaluation in?
96
+
97
+ Notes
98
+ -----
99
+ Individual subclasses should implement :meth:`do_evaluate`,
100
+ this method provides logic to handle lookups in the
101
+ substitution mapping.
102
+
103
+ Returns
104
+ -------
105
+ Column representing the evaluation of the expression.
106
+
107
+ Raises
108
+ ------
109
+ NotImplementedError
110
+ If we couldn't evaluate the expression. Ideally all these
111
+ are returned during translation to the IR, but for now we
112
+ are not perfect.
113
+ """
114
+ return self.do_evaluate(df, context=context)
115
+
116
+ @property
117
+ def agg_request(self) -> plc.aggregation.Aggregation:
118
+ """
119
+ The aggregation for this expression in a grouped aggregation.
120
+
121
+ Returns
122
+ -------
123
+ Aggregation request. Default is to collect the expression.
124
+
125
+ Notes
126
+ -----
127
+ This presumes that the IR translation has decomposed groupby
128
+ reductions only into cases we can handle.
129
+
130
+ Raises
131
+ ------
132
+ NotImplementedError
133
+ If requesting an aggregation from an unexpected expression.
134
+ """
135
+ return plc.aggregation.collect_list()
136
+
137
+
138
+ class ErrorExpr(Expr):
139
+ __slots__ = ("error",)
140
+ _non_child = ("dtype", "error")
141
+ error: str
142
+
143
+ def __init__(self, dtype: DataType, error: str) -> None:
144
+ self.dtype = dtype
145
+ self.error = error
146
+ self.children = ()
147
+ self.is_pointwise = False
148
+
149
+
150
+ class NamedExpr:
151
+ # NamedExpr does not inherit from Expr since it does not appear
152
+ # when evaluating expressions themselves, only when constructing
153
+ # named return values in dataframe (IR) nodes.
154
+ __slots__ = ("name", "value")
155
+ value: Expr
156
+ name: str
157
+
158
+ def __init__(self, name: str, value: Expr) -> None:
159
+ self.name = name
160
+ self.value = value
161
+
162
+ def __hash__(self) -> int:
163
+ """Hash of the expression."""
164
+ return hash((type(self), self.name, self.value))
165
+
166
+ def __repr__(self) -> str:
167
+ """Repr of the expression."""
168
+ return f"NamedExpr({self.name}, {self.value})"
169
+
170
+ def __eq__(self, other: Any) -> bool:
171
+ """Equality of two expressions."""
172
+ return (
173
+ type(self) is type(other)
174
+ and self.name == other.name
175
+ and self.value == other.value
176
+ )
177
+
178
+ def __ne__(self, other: Any) -> bool:
179
+ """Inequality of expressions."""
180
+ return not self.__eq__(other)
181
+
182
+ def evaluate(
183
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
184
+ ) -> Column:
185
+ """
186
+ Evaluate this expression given a dataframe for context.
187
+
188
+ Parameters
189
+ ----------
190
+ df
191
+ DataFrame providing context
192
+ context
193
+ Execution context
194
+
195
+ Returns
196
+ -------
197
+ Evaluated Column with name attached.
198
+
199
+ See Also
200
+ --------
201
+ :meth:`Expr.evaluate` for details, this function just adds the
202
+ name to a column produced from an expression.
203
+ """
204
+ return self.value.evaluate(df, context=context).rename(self.name)
205
+
206
+ def reconstruct(self, expr: Expr) -> Self:
207
+ """
208
+ Rebuild with a new `Expr` value.
209
+
210
+ Parameters
211
+ ----------
212
+ expr
213
+ New `Expr` value
214
+
215
+ Returns
216
+ -------
217
+ New `NamedExpr` with `expr` as the underlying expression.
218
+ The name of the original `NamedExpr` is preserved.
219
+ """
220
+ if expr is self.value:
221
+ return self
222
+ return type(self)(self.name, expr)
223
+
224
+
225
+ class Col(Expr):
226
+ __slots__ = ("name",)
227
+ _non_child = ("dtype", "name")
228
+ name: str
229
+
230
+ def __init__(self, dtype: DataType, name: str) -> None:
231
+ self.dtype = dtype
232
+ self.name = name
233
+ self.is_pointwise = True
234
+ self.children = ()
235
+
236
+ def do_evaluate(
237
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
238
+ ) -> Column:
239
+ """Evaluate this expression given a dataframe for context."""
240
+ # Deliberately remove the name here so that we guarantee
241
+ # evaluation of the IR produces names.
242
+ return df.column_map[self.name].rename(None)
243
+
244
+
245
+ class ColRef(Expr):
246
+ __slots__ = ("index", "table_ref")
247
+ _non_child = ("dtype", "index", "table_ref")
248
+ index: int
249
+ table_ref: plc.expressions.TableReference
250
+
251
+ def __init__(
252
+ self,
253
+ dtype: DataType,
254
+ index: int,
255
+ table_ref: plc.expressions.TableReference,
256
+ column: Expr,
257
+ ) -> None:
258
+ if not isinstance(column, Col):
259
+ raise TypeError("Column reference should only apply to columns")
260
+ self.dtype = dtype
261
+ self.index = index
262
+ self.table_ref = table_ref
263
+ self.is_pointwise = True
264
+ self.children = (column,)
265
+
266
+ def do_evaluate(
267
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
268
+ ) -> Column:
269
+ """Evaluate this expression given a dataframe for context."""
270
+ raise NotImplementedError(
271
+ "Only expect this node as part of an expression translated to libcudf AST."
272
+ )
@@ -0,0 +1,120 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # TODO: remove need for this
4
+ # ruff: noqa: D101
5
+ """BinaryOp DSL nodes."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, ClassVar
10
+
11
+ from polars.polars import _expr_nodes as pl_expr
12
+
13
+ import pylibcudf as plc
14
+
15
+ from cudf_polars.containers import Column
16
+ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
17
+
18
+ if TYPE_CHECKING:
19
+ from cudf_polars.containers import DataFrame, DataType
20
+
21
+ __all__ = ["BinOp"]
22
+
23
+
24
+ class BinOp(Expr):
25
+ __slots__ = ("op",)
26
+ _non_child = ("dtype", "op")
27
+
28
+ def __init__(
29
+ self,
30
+ dtype: DataType,
31
+ op: plc.binaryop.BinaryOperator,
32
+ left: Expr,
33
+ right: Expr,
34
+ ) -> None:
35
+ self.dtype = dtype
36
+ if plc.traits.is_boolean(self.dtype.plc):
37
+ # For boolean output types, bitand and bitor implement
38
+ # boolean logic, so translate. bitxor also does, but the
39
+ # default behaviour is correct.
40
+ op = BinOp._BOOL_KLEENE_MAPPING.get(op, op)
41
+ self.op = op
42
+ self.children = (left, right)
43
+ self.is_pointwise = True
44
+ if not plc.binaryop.is_supported_operation(
45
+ self.dtype.plc, left.dtype.plc, right.dtype.plc, op
46
+ ):
47
+ raise NotImplementedError(
48
+ f"Operation {op.name} not supported "
49
+ f"for types {left.dtype.id().name} and {right.dtype.id().name} "
50
+ f"with output type {self.dtype.id().name}"
51
+ )
52
+
53
+ _BOOL_KLEENE_MAPPING: ClassVar[
54
+ dict[plc.binaryop.BinaryOperator, plc.binaryop.BinaryOperator]
55
+ ] = {
56
+ plc.binaryop.BinaryOperator.BITWISE_AND: plc.binaryop.BinaryOperator.NULL_LOGICAL_AND,
57
+ plc.binaryop.BinaryOperator.BITWISE_OR: plc.binaryop.BinaryOperator.NULL_LOGICAL_OR,
58
+ plc.binaryop.BinaryOperator.LOGICAL_AND: plc.binaryop.BinaryOperator.NULL_LOGICAL_AND,
59
+ plc.binaryop.BinaryOperator.LOGICAL_OR: plc.binaryop.BinaryOperator.NULL_LOGICAL_OR,
60
+ }
61
+
62
+ _MAPPING: ClassVar[dict[pl_expr.Operator, plc.binaryop.BinaryOperator]] = {
63
+ pl_expr.Operator.Eq: plc.binaryop.BinaryOperator.EQUAL,
64
+ pl_expr.Operator.EqValidity: plc.binaryop.BinaryOperator.NULL_EQUALS,
65
+ pl_expr.Operator.NotEq: plc.binaryop.BinaryOperator.NOT_EQUAL,
66
+ pl_expr.Operator.NotEqValidity: plc.binaryop.BinaryOperator.NULL_NOT_EQUALS,
67
+ pl_expr.Operator.Lt: plc.binaryop.BinaryOperator.LESS,
68
+ pl_expr.Operator.LtEq: plc.binaryop.BinaryOperator.LESS_EQUAL,
69
+ pl_expr.Operator.Gt: plc.binaryop.BinaryOperator.GREATER,
70
+ pl_expr.Operator.GtEq: plc.binaryop.BinaryOperator.GREATER_EQUAL,
71
+ pl_expr.Operator.Plus: plc.binaryop.BinaryOperator.ADD,
72
+ pl_expr.Operator.Minus: plc.binaryop.BinaryOperator.SUB,
73
+ pl_expr.Operator.Multiply: plc.binaryop.BinaryOperator.MUL,
74
+ pl_expr.Operator.Divide: plc.binaryop.BinaryOperator.DIV,
75
+ pl_expr.Operator.TrueDivide: plc.binaryop.BinaryOperator.TRUE_DIV,
76
+ pl_expr.Operator.FloorDivide: plc.binaryop.BinaryOperator.FLOOR_DIV,
77
+ pl_expr.Operator.Modulus: plc.binaryop.BinaryOperator.PYMOD,
78
+ pl_expr.Operator.And: plc.binaryop.BinaryOperator.BITWISE_AND,
79
+ pl_expr.Operator.Or: plc.binaryop.BinaryOperator.BITWISE_OR,
80
+ pl_expr.Operator.Xor: plc.binaryop.BinaryOperator.BITWISE_XOR,
81
+ pl_expr.Operator.LogicalAnd: plc.binaryop.BinaryOperator.LOGICAL_AND,
82
+ pl_expr.Operator.LogicalOr: plc.binaryop.BinaryOperator.LOGICAL_OR,
83
+ }
84
+
85
+ def do_evaluate(
86
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
87
+ ) -> Column:
88
+ """Evaluate this expression given a dataframe for context."""
89
+ left, right = (child.evaluate(df, context=context) for child in self.children)
90
+ lop = left.obj
91
+ rop = right.obj
92
+ if left.size != right.size:
93
+ if left.is_scalar:
94
+ lop = left.obj_scalar
95
+ elif right.is_scalar:
96
+ rop = right.obj_scalar
97
+ if plc.traits.is_integral_not_bool(self.dtype.plc) and self.op in {
98
+ plc.binaryop.BinaryOperator.FLOOR_DIV,
99
+ plc.binaryop.BinaryOperator.PYMOD,
100
+ }:
101
+ if right.obj.size() == 1 and right.obj.to_scalar().to_py() == 0:
102
+ return Column(
103
+ plc.Column.all_null_like(left.obj, left.obj.size()),
104
+ dtype=self.dtype,
105
+ )
106
+
107
+ if right.obj.size() > 1:
108
+ rop = plc.replace.find_and_replace_all(
109
+ right.obj,
110
+ plc.Column.from_scalar(
111
+ plc.Scalar.from_py(0, dtype=self.dtype.plc), 1
112
+ ),
113
+ plc.Column.from_scalar(
114
+ plc.Scalar.from_py(None, dtype=self.dtype.plc), 1
115
+ ),
116
+ )
117
+ return Column(
118
+ plc.binaryop.binary_operation(lop, rop, self.op, self.dtype.plc),
119
+ dtype=self.dtype,
120
+ )