cudf-polars-cu12 24.12.0__py3-none-any.whl → 25.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/__init__.py +1 -1
  3. cudf_polars/callback.py +28 -3
  4. cudf_polars/containers/__init__.py +1 -1
  5. cudf_polars/dsl/expr.py +16 -16
  6. cudf_polars/dsl/expressions/aggregation.py +21 -4
  7. cudf_polars/dsl/expressions/base.py +7 -2
  8. cudf_polars/dsl/expressions/binaryop.py +1 -0
  9. cudf_polars/dsl/expressions/boolean.py +65 -22
  10. cudf_polars/dsl/expressions/datetime.py +82 -20
  11. cudf_polars/dsl/expressions/literal.py +2 -0
  12. cudf_polars/dsl/expressions/rolling.py +3 -1
  13. cudf_polars/dsl/expressions/selection.py +3 -1
  14. cudf_polars/dsl/expressions/sorting.py +2 -0
  15. cudf_polars/dsl/expressions/string.py +118 -39
  16. cudf_polars/dsl/expressions/ternary.py +1 -0
  17. cudf_polars/dsl/expressions/unary.py +11 -1
  18. cudf_polars/dsl/ir.py +173 -122
  19. cudf_polars/dsl/to_ast.py +4 -6
  20. cudf_polars/dsl/translate.py +53 -21
  21. cudf_polars/dsl/traversal.py +10 -10
  22. cudf_polars/experimental/base.py +43 -0
  23. cudf_polars/experimental/dispatch.py +84 -0
  24. cudf_polars/experimental/io.py +325 -0
  25. cudf_polars/experimental/parallel.py +253 -0
  26. cudf_polars/experimental/select.py +36 -0
  27. cudf_polars/testing/asserts.py +14 -5
  28. cudf_polars/testing/plugin.py +64 -4
  29. cudf_polars/typing/__init__.py +5 -5
  30. cudf_polars/utils/dtypes.py +9 -7
  31. cudf_polars/utils/versions.py +4 -7
  32. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.1.dist-info}/METADATA +6 -6
  33. cudf_polars_cu12-25.2.1.dist-info/RECORD +48 -0
  34. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.1.dist-info}/WHEEL +1 -1
  35. cudf_polars_cu12-24.12.0.dist-info/RECORD +0 -43
  36. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.1.dist-info}/LICENSE +0 -0
  37. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.1.dist-info}/top_level.txt +0 -0
cudf_polars/VERSION CHANGED
@@ -1 +1 @@
1
- 24.12.00
1
+ 25.02.01
cudf_polars/__init__.py CHANGED
@@ -21,8 +21,8 @@ _ensure_polars_version()
21
21
  del _ensure_polars_version
22
22
 
23
23
  __all__: list[str] = [
24
- "execute_with_cudf",
25
24
  "Translator",
26
25
  "__git_commit__",
27
26
  "__version__",
27
+ "execute_with_cudf",
28
28
  ]
cudf_polars/callback.py CHANGED
@@ -9,7 +9,7 @@ import contextlib
9
9
  import os
10
10
  import warnings
11
11
  from functools import cache, partial
12
- from typing import TYPE_CHECKING
12
+ from typing import TYPE_CHECKING, Literal
13
13
 
14
14
  import nvtx
15
15
 
@@ -181,6 +181,7 @@ def _callback(
181
181
  *,
182
182
  device: int | None,
183
183
  memory_resource: int | None,
184
+ executor: Literal["pylibcudf", "dask-experimental"] | None,
184
185
  ) -> pl.DataFrame:
185
186
  assert with_columns is None
186
187
  assert pyarrow_predicate is None
@@ -191,7 +192,14 @@ def _callback(
191
192
  set_device(device),
192
193
  set_memory_resource(memory_resource),
193
194
  ):
194
- return ir.evaluate(cache={}).to_polars()
195
+ if executor is None or executor == "pylibcudf":
196
+ return ir.evaluate(cache={}).to_polars()
197
+ elif executor == "dask-experimental":
198
+ from cudf_polars.experimental.parallel import evaluate_dask
199
+
200
+ return evaluate_dask(ir).to_polars()
201
+ else:
202
+ raise ValueError(f"Unknown executor '{executor}'")
195
203
 
196
204
 
197
205
  def validate_config_options(config: dict) -> None:
@@ -208,7 +216,10 @@ def validate_config_options(config: dict) -> None:
208
216
  ValueError
209
217
  If the configuration contains unsupported options.
210
218
  """
211
- if unsupported := (config.keys() - {"raise_on_fail", "parquet_options"}):
219
+ if unsupported := (
220
+ config.keys()
221
+ - {"raise_on_fail", "parquet_options", "executor", "executor_options"}
222
+ ):
212
223
  raise ValueError(
213
224
  f"Engine configuration contains unsupported settings: {unsupported}"
214
225
  )
@@ -216,6 +227,18 @@ def validate_config_options(config: dict) -> None:
216
227
  config.get("parquet_options", {})
217
228
  )
218
229
 
230
+ # Validate executor_options
231
+ executor = config.get("executor", "pylibcudf")
232
+ if executor == "dask-experimental":
233
+ unsupported = config.get("executor_options", {}).keys() - {
234
+ "max_rows_per_partition",
235
+ "parquet_blocksize",
236
+ }
237
+ else:
238
+ unsupported = config.get("executor_options", {}).keys()
239
+ if unsupported:
240
+ raise ValueError(f"Unsupported executor_options for {executor}: {unsupported}")
241
+
219
242
 
220
243
  def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None:
221
244
  """
@@ -243,6 +266,7 @@ def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None:
243
266
  device = config.device
244
267
  memory_resource = config.memory_resource
245
268
  raise_on_fail = config.config.get("raise_on_fail", False)
269
+ executor = config.config.get("executor", None)
246
270
  validate_config_options(config.config)
247
271
 
248
272
  with nvtx.annotate(message="ConvertIR", domain="cudf_polars"):
@@ -272,5 +296,6 @@ def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None:
272
296
  ir,
273
297
  device=device,
274
298
  memory_resource=memory_resource,
299
+ executor=executor,
275
300
  )
276
301
  )
@@ -5,7 +5,7 @@
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- __all__: list[str] = ["DataFrame", "Column"]
8
+ __all__: list[str] = ["Column", "DataFrame"]
9
9
 
10
10
  from cudf_polars.containers.column import Column
11
11
  from cudf_polars.containers.dataframe import DataFrame
cudf_polars/dsl/expr.py CHANGED
@@ -36,27 +36,27 @@ from cudf_polars.dsl.expressions.ternary import Ternary
36
36
  from cudf_polars.dsl.expressions.unary import Cast, Len, UnaryFunction
37
37
 
38
38
  __all__ = [
39
- "Expr",
39
+ "Agg",
40
+ "AggInfo",
41
+ "BinOp",
42
+ "BooleanFunction",
43
+ "Cast",
44
+ "Col",
45
+ "ColRef",
40
46
  "ErrorExpr",
41
- "NamedExpr",
47
+ "Expr",
48
+ "Filter",
49
+ "Gather",
50
+ "GroupedRollingWindow",
51
+ "Len",
42
52
  "Literal",
43
53
  "LiteralColumn",
44
- "Len",
45
- "Col",
46
- "ColRef",
47
- "BooleanFunction",
48
- "StringFunction",
49
- "TemporalFunction",
54
+ "NamedExpr",
55
+ "RollingWindow",
50
56
  "Sort",
51
57
  "SortBy",
52
- "Gather",
53
- "Filter",
54
- "RollingWindow",
55
- "GroupedRollingWindow",
56
- "Cast",
57
- "Agg",
58
- "AggInfo",
58
+ "StringFunction",
59
+ "TemporalFunction",
59
60
  "Ternary",
60
- "BinOp",
61
61
  "UnaryFunction",
62
62
  ]
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  # TODO: remove need for this
4
4
  # ruff: noqa: D101
@@ -31,7 +31,7 @@ __all__ = ["Agg"]
31
31
 
32
32
 
33
33
  class Agg(Expr):
34
- __slots__ = ("name", "options", "op", "request")
34
+ __slots__ = ("name", "op", "options", "request")
35
35
  _non_child = ("dtype", "name", "options")
36
36
 
37
37
  def __init__(
@@ -40,6 +40,7 @@ class Agg(Expr):
40
40
  self.dtype = dtype
41
41
  self.name = name
42
42
  self.options = options
43
+ self.is_pointwise = False
43
44
  self.children = children
44
45
  if name not in Agg._SUPPORTED:
45
46
  raise NotImplementedError(
@@ -68,7 +69,11 @@ class Agg(Expr):
68
69
  # TODO: handle nans
69
70
  req = plc.aggregation.variance(ddof=options)
70
71
  elif name == "count":
71
- req = plc.aggregation.count(null_handling=plc.types.NullPolicy.EXCLUDE)
72
+ req = plc.aggregation.count(
73
+ null_handling=plc.types.NullPolicy.EXCLUDE
74
+ if not options
75
+ else plc.types.NullPolicy.INCLUDE
76
+ )
72
77
  elif name == "quantile":
73
78
  _, quantile = self.children
74
79
  if not isinstance(quantile, Literal):
@@ -86,7 +91,7 @@ class Agg(Expr):
86
91
  op = partial(self._reduce, request=req)
87
92
  elif name in {"min", "max"}:
88
93
  op = partial(op, propagate_nans=options)
89
- elif name in {"count", "first", "last"}:
94
+ elif name in {"count", "sum", "first", "last"}:
90
95
  pass
91
96
  else:
92
97
  raise NotImplementedError(
@@ -175,6 +180,18 @@ class Agg(Expr):
175
180
  )
176
181
  )
177
182
 
183
+ def _sum(self, column: Column) -> Column:
184
+ if column.obj.size() == 0:
185
+ return Column(
186
+ plc.Column.from_scalar(
187
+ plc.interop.from_arrow(
188
+ pa.scalar(0, type=plc.interop.to_arrow(self.dtype))
189
+ ),
190
+ 1,
191
+ )
192
+ )
193
+ return self._reduce(column, request=plc.aggregation.sum())
194
+
178
195
  def _min(self, column: Column, *, propagate_nans: bool) -> Column:
179
196
  if propagate_nans and column.nan_count > 0:
180
197
  return Column(
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
20
20
 
21
21
  from cudf_polars.containers import Column, DataFrame
22
22
 
23
- __all__ = ["Expr", "NamedExpr", "Col", "AggInfo", "ExecutionContext", "ColRef"]
23
+ __all__ = ["AggInfo", "Col", "ColRef", "ExecutionContext", "Expr", "NamedExpr"]
24
24
 
25
25
 
26
26
  class AggInfo(NamedTuple):
@@ -36,9 +36,11 @@ class ExecutionContext(IntEnum):
36
36
  class Expr(Node["Expr"]):
37
37
  """An abstract expression object."""
38
38
 
39
- __slots__ = ("dtype",)
39
+ __slots__ = ("dtype", "is_pointwise")
40
40
  dtype: plc.DataType
41
41
  """Data type of the expression."""
42
+ is_pointwise: bool
43
+ """Whether this expression acts pointwise on its inputs."""
42
44
  # This annotation is needed because of https://github.com/python/mypy/issues/17981
43
45
  _non_child: ClassVar[tuple[str, ...]] = ("dtype",)
44
46
  """Names of non-child data (not Exprs) for reconstruction."""
@@ -164,6 +166,7 @@ class ErrorExpr(Expr):
164
166
  self.dtype = dtype
165
167
  self.error = error
166
168
  self.children = ()
169
+ self.is_pointwise = True
167
170
 
168
171
 
169
172
  class NamedExpr:
@@ -243,6 +246,7 @@ class Col(Expr):
243
246
  def __init__(self, dtype: plc.DataType, name: str) -> None:
244
247
  self.dtype = dtype
245
248
  self.name = name
249
+ self.is_pointwise = True
246
250
  self.children = ()
247
251
 
248
252
  def do_evaluate(
@@ -280,6 +284,7 @@ class ColRef(Expr):
280
284
  self.dtype = dtype
281
285
  self.index = index
282
286
  self.table_ref = table_ref
287
+ self.is_pointwise = True
283
288
  self.children = (column,)
284
289
 
285
290
  def do_evaluate(
@@ -42,6 +42,7 @@ class BinOp(Expr):
42
42
  op = BinOp._BOOL_KLEENE_MAPPING.get(op, op)
43
43
  self.op = op
44
44
  self.children = (left, right)
45
+ self.is_pointwise = True
45
46
  if not plc.binaryop.is_supported_operation(
46
47
  self.dtype, left.dtype, right.dtype, op
47
48
  ):
@@ -6,13 +6,12 @@
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ from enum import IntEnum, auto
9
10
  from functools import partial, reduce
10
11
  from typing import TYPE_CHECKING, Any, ClassVar
11
12
 
12
13
  import pyarrow as pa
13
14
 
14
- from polars.polars import _expr_nodes as pl_expr
15
-
16
15
  import pylibcudf as plc
17
16
 
18
17
  from cudf_polars.containers import Column
@@ -24,7 +23,10 @@ from cudf_polars.dsl.expressions.base import (
24
23
  if TYPE_CHECKING:
25
24
  from collections.abc import Mapping
26
25
 
26
+ from typing_extensions import Self
27
+
27
28
  import polars.type_aliases as pl_types
29
+ from polars.polars import _expr_nodes as pl_expr
28
30
 
29
31
  from cudf_polars.containers import DataFrame
30
32
 
@@ -32,13 +34,46 @@ __all__ = ["BooleanFunction"]
32
34
 
33
35
 
34
36
  class BooleanFunction(Expr):
37
+ class Name(IntEnum):
38
+ """Internal and picklable representation of polars' `BooleanFunction`."""
39
+
40
+ All = auto()
41
+ AllHorizontal = auto()
42
+ Any = auto()
43
+ AnyHorizontal = auto()
44
+ IsBetween = auto()
45
+ IsDuplicated = auto()
46
+ IsFinite = auto()
47
+ IsFirstDistinct = auto()
48
+ IsIn = auto()
49
+ IsInfinite = auto()
50
+ IsLastDistinct = auto()
51
+ IsNan = auto()
52
+ IsNotNan = auto()
53
+ IsNotNull = auto()
54
+ IsNull = auto()
55
+ IsUnique = auto()
56
+ Not = auto()
57
+
58
+ @classmethod
59
+ def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self:
60
+ """Convert from polars' `BooleanFunction`."""
61
+ try:
62
+ function, name = str(obj).split(".", maxsplit=1)
63
+ except ValueError:
64
+ # Failed to unpack string
65
+ function = None
66
+ if function != "BooleanFunction":
67
+ raise ValueError("BooleanFunction required")
68
+ return getattr(cls, name)
69
+
35
70
  __slots__ = ("name", "options")
36
71
  _non_child = ("dtype", "name", "options")
37
72
 
38
73
  def __init__(
39
74
  self,
40
75
  dtype: plc.DataType,
41
- name: pl_expr.BooleanFunction,
76
+ name: BooleanFunction.Name,
42
77
  options: tuple[Any, ...],
43
78
  *children: Expr,
44
79
  ) -> None:
@@ -46,7 +81,15 @@ class BooleanFunction(Expr):
46
81
  self.options = options
47
82
  self.name = name
48
83
  self.children = children
49
- if self.name == pl_expr.BooleanFunction.IsIn and not all(
84
+ self.is_pointwise = self.name not in (
85
+ BooleanFunction.Name.All,
86
+ BooleanFunction.Name.Any,
87
+ BooleanFunction.Name.IsDuplicated,
88
+ BooleanFunction.Name.IsFirstDistinct,
89
+ BooleanFunction.Name.IsLastDistinct,
90
+ BooleanFunction.Name.IsUnique,
91
+ )
92
+ if self.name is BooleanFunction.Name.IsIn and not all(
50
93
  c.dtype == self.children[0].dtype for c in self.children
51
94
  ):
52
95
  # TODO: If polars IR doesn't put the casts in, we need to
@@ -110,12 +153,12 @@ class BooleanFunction(Expr):
110
153
  ) -> Column:
111
154
  """Evaluate this expression given a dataframe for context."""
112
155
  if self.name in (
113
- pl_expr.BooleanFunction.IsFinite,
114
- pl_expr.BooleanFunction.IsInfinite,
156
+ BooleanFunction.Name.IsFinite,
157
+ BooleanFunction.Name.IsInfinite,
115
158
  ):
116
159
  # Avoid evaluating the child if the dtype tells us it's unnecessary.
117
160
  (child,) = self.children
118
- is_finite = self.name == pl_expr.BooleanFunction.IsFinite
161
+ is_finite = self.name is BooleanFunction.Name.IsFinite
119
162
  if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64):
120
163
  value = plc.interop.from_arrow(
121
164
  pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype))
@@ -142,10 +185,10 @@ class BooleanFunction(Expr):
142
185
  ]
143
186
  # Kleene logic for Any (OR) and All (AND) if ignore_nulls is
144
187
  # False
145
- if self.name in (pl_expr.BooleanFunction.Any, pl_expr.BooleanFunction.All):
188
+ if self.name in (BooleanFunction.Name.Any, BooleanFunction.Name.All):
146
189
  (ignore_nulls,) = self.options
147
190
  (column,) = columns
148
- is_any = self.name == pl_expr.BooleanFunction.Any
191
+ is_any = self.name is BooleanFunction.Name.Any
149
192
  agg = plc.aggregation.any() if is_any else plc.aggregation.all()
150
193
  result = plc.reduce.reduce(column.obj, agg, self.dtype)
151
194
  if not ignore_nulls and column.obj.null_count() > 0:
@@ -160,32 +203,32 @@ class BooleanFunction(Expr):
160
203
  # If the input null count was non-zero, we must
161
204
  # post-process the result to insert the correct value.
162
205
  h_result = plc.interop.to_arrow(result).as_py()
163
- if is_any and not h_result or not is_any and h_result:
206
+ if (is_any and not h_result) or (not is_any and h_result):
164
207
  # Any All
165
208
  # False || Null => Null True && Null => Null
166
209
  return Column(plc.Column.all_null_like(column.obj, 1))
167
210
  return Column(plc.Column.from_scalar(result, 1))
168
- if self.name == pl_expr.BooleanFunction.IsNull:
211
+ if self.name is BooleanFunction.Name.IsNull:
169
212
  (column,) = columns
170
213
  return Column(plc.unary.is_null(column.obj))
171
- elif self.name == pl_expr.BooleanFunction.IsNotNull:
214
+ elif self.name is BooleanFunction.Name.IsNotNull:
172
215
  (column,) = columns
173
216
  return Column(plc.unary.is_valid(column.obj))
174
- elif self.name == pl_expr.BooleanFunction.IsNan:
217
+ elif self.name is BooleanFunction.Name.IsNan:
175
218
  (column,) = columns
176
219
  return Column(
177
220
  plc.unary.is_nan(column.obj).with_mask(
178
221
  column.obj.null_mask(), column.obj.null_count()
179
222
  )
180
223
  )
181
- elif self.name == pl_expr.BooleanFunction.IsNotNan:
224
+ elif self.name is BooleanFunction.Name.IsNotNan:
182
225
  (column,) = columns
183
226
  return Column(
184
227
  plc.unary.is_not_nan(column.obj).with_mask(
185
228
  column.obj.null_mask(), column.obj.null_count()
186
229
  )
187
230
  )
188
- elif self.name == pl_expr.BooleanFunction.IsFirstDistinct:
231
+ elif self.name is BooleanFunction.Name.IsFirstDistinct:
189
232
  (column,) = columns
190
233
  return self._distinct(
191
234
  column,
@@ -197,7 +240,7 @@ class BooleanFunction(Expr):
197
240
  pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
198
241
  ),
199
242
  )
200
- elif self.name == pl_expr.BooleanFunction.IsLastDistinct:
243
+ elif self.name is BooleanFunction.Name.IsLastDistinct:
201
244
  (column,) = columns
202
245
  return self._distinct(
203
246
  column,
@@ -209,7 +252,7 @@ class BooleanFunction(Expr):
209
252
  pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
210
253
  ),
211
254
  )
212
- elif self.name == pl_expr.BooleanFunction.IsUnique:
255
+ elif self.name is BooleanFunction.Name.IsUnique:
213
256
  (column,) = columns
214
257
  return self._distinct(
215
258
  column,
@@ -221,7 +264,7 @@ class BooleanFunction(Expr):
221
264
  pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
222
265
  ),
223
266
  )
224
- elif self.name == pl_expr.BooleanFunction.IsDuplicated:
267
+ elif self.name is BooleanFunction.Name.IsDuplicated:
225
268
  (column,) = columns
226
269
  return self._distinct(
227
270
  column,
@@ -233,7 +276,7 @@ class BooleanFunction(Expr):
233
276
  pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype))
234
277
  ),
235
278
  )
236
- elif self.name == pl_expr.BooleanFunction.AllHorizontal:
279
+ elif self.name is BooleanFunction.Name.AllHorizontal:
237
280
  return Column(
238
281
  reduce(
239
282
  partial(
@@ -244,7 +287,7 @@ class BooleanFunction(Expr):
244
287
  (c.obj for c in columns),
245
288
  )
246
289
  )
247
- elif self.name == pl_expr.BooleanFunction.AnyHorizontal:
290
+ elif self.name is BooleanFunction.Name.AnyHorizontal:
248
291
  return Column(
249
292
  reduce(
250
293
  partial(
@@ -255,10 +298,10 @@ class BooleanFunction(Expr):
255
298
  (c.obj for c in columns),
256
299
  )
257
300
  )
258
- elif self.name == pl_expr.BooleanFunction.IsIn:
301
+ elif self.name is BooleanFunction.Name.IsIn:
259
302
  needles, haystack = columns
260
303
  return Column(plc.search.contains(haystack.obj, needles.obj))
261
- elif self.name == pl_expr.BooleanFunction.Not:
304
+ elif self.name is BooleanFunction.Name.Not:
262
305
  (column,) = columns
263
306
  return Column(
264
307
  plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT)
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  # TODO: remove need for this
4
4
  # ruff: noqa: D101
@@ -6,12 +6,11 @@
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ from enum import IntEnum, auto
9
10
  from typing import TYPE_CHECKING, Any, ClassVar
10
11
 
11
12
  import pyarrow as pa
12
13
 
13
- from polars.polars import _expr_nodes as pl_expr
14
-
15
14
  import pylibcudf as plc
16
15
 
17
16
  from cudf_polars.containers import Column
@@ -20,33 +19,95 @@ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
20
19
  if TYPE_CHECKING:
21
20
  from collections.abc import Mapping
22
21
 
22
+ from typing_extensions import Self
23
+
24
+ from polars.polars import _expr_nodes as pl_expr
25
+
23
26
  from cudf_polars.containers import DataFrame
24
27
 
25
28
  __all__ = ["TemporalFunction"]
26
29
 
27
30
 
28
31
  class TemporalFunction(Expr):
32
+ class Name(IntEnum):
33
+ """Internal and picklable representation of polars' `TemporalFunction`."""
34
+
35
+ BaseUtcOffset = auto()
36
+ CastTimeUnit = auto()
37
+ Century = auto()
38
+ Combine = auto()
39
+ ConvertTimeZone = auto()
40
+ DSTOffset = auto()
41
+ Date = auto()
42
+ Datetime = auto()
43
+ DatetimeFunction = auto()
44
+ Day = auto()
45
+ Duration = auto()
46
+ Hour = auto()
47
+ IsLeapYear = auto()
48
+ IsoYear = auto()
49
+ Microsecond = auto()
50
+ Millennium = auto()
51
+ Millisecond = auto()
52
+ Minute = auto()
53
+ Month = auto()
54
+ MonthEnd = auto()
55
+ MonthStart = auto()
56
+ Nanosecond = auto()
57
+ OffsetBy = auto()
58
+ OrdinalDay = auto()
59
+ Quarter = auto()
60
+ Replace = auto()
61
+ ReplaceTimeZone = auto()
62
+ Round = auto()
63
+ Second = auto()
64
+ Time = auto()
65
+ TimeStamp = auto()
66
+ ToString = auto()
67
+ TotalDays = auto()
68
+ TotalHours = auto()
69
+ TotalMicroseconds = auto()
70
+ TotalMilliseconds = auto()
71
+ TotalMinutes = auto()
72
+ TotalNanoseconds = auto()
73
+ TotalSeconds = auto()
74
+ Truncate = auto()
75
+ Week = auto()
76
+ WeekDay = auto()
77
+ WithTimeUnit = auto()
78
+ Year = auto()
79
+
80
+ @classmethod
81
+ def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self:
82
+ """Convert from polars' `TemporalFunction`."""
83
+ try:
84
+ function, name = str(obj).split(".", maxsplit=1)
85
+ except ValueError:
86
+ # Failed to unpack string
87
+ function = None
88
+ if function != "TemporalFunction":
89
+ raise ValueError("TemporalFunction required")
90
+ return getattr(cls, name)
91
+
29
92
  __slots__ = ("name", "options")
30
- _COMPONENT_MAP: ClassVar[
31
- dict[pl_expr.TemporalFunction, plc.datetime.DatetimeComponent]
32
- ] = {
33
- pl_expr.TemporalFunction.Year: plc.datetime.DatetimeComponent.YEAR,
34
- pl_expr.TemporalFunction.Month: plc.datetime.DatetimeComponent.MONTH,
35
- pl_expr.TemporalFunction.Day: plc.datetime.DatetimeComponent.DAY,
36
- pl_expr.TemporalFunction.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY,
37
- pl_expr.TemporalFunction.Hour: plc.datetime.DatetimeComponent.HOUR,
38
- pl_expr.TemporalFunction.Minute: plc.datetime.DatetimeComponent.MINUTE,
39
- pl_expr.TemporalFunction.Second: plc.datetime.DatetimeComponent.SECOND,
40
- pl_expr.TemporalFunction.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND,
41
- pl_expr.TemporalFunction.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND,
42
- pl_expr.TemporalFunction.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND,
43
- }
44
93
  _non_child = ("dtype", "name", "options")
94
+ _COMPONENT_MAP: ClassVar[dict[Name, plc.datetime.DatetimeComponent]] = {
95
+ Name.Year: plc.datetime.DatetimeComponent.YEAR,
96
+ Name.Month: plc.datetime.DatetimeComponent.MONTH,
97
+ Name.Day: plc.datetime.DatetimeComponent.DAY,
98
+ Name.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY,
99
+ Name.Hour: plc.datetime.DatetimeComponent.HOUR,
100
+ Name.Minute: plc.datetime.DatetimeComponent.MINUTE,
101
+ Name.Second: plc.datetime.DatetimeComponent.SECOND,
102
+ Name.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND,
103
+ Name.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND,
104
+ Name.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND,
105
+ }
45
106
 
46
107
  def __init__(
47
108
  self,
48
109
  dtype: plc.DataType,
49
- name: pl_expr.TemporalFunction,
110
+ name: TemporalFunction.Name,
50
111
  options: tuple[Any, ...],
51
112
  *children: Expr,
52
113
  ) -> None:
@@ -54,6 +115,7 @@ class TemporalFunction(Expr):
54
115
  self.options = options
55
116
  self.name = name
56
117
  self.children = children
118
+ self.is_pointwise = True
57
119
  if self.name not in self._COMPONENT_MAP:
58
120
  raise NotImplementedError(f"Temporal function {self.name}")
59
121
 
@@ -70,7 +132,7 @@ class TemporalFunction(Expr):
70
132
  for child in self.children
71
133
  ]
72
134
  (column,) = columns
73
- if self.name == pl_expr.TemporalFunction.Microsecond:
135
+ if self.name is TemporalFunction.Name.Microsecond:
74
136
  millis = plc.datetime.extract_datetime_component(
75
137
  column.obj, plc.datetime.DatetimeComponent.MILLISECOND
76
138
  )
@@ -90,7 +152,7 @@ class TemporalFunction(Expr):
90
152
  plc.types.DataType(plc.types.TypeId.INT32),
91
153
  )
92
154
  return Column(total_micros)
93
- elif self.name == pl_expr.TemporalFunction.Nanosecond:
155
+ elif self.name is TemporalFunction.Name.Nanosecond:
94
156
  millis = plc.datetime.extract_datetime_component(
95
157
  column.obj, plc.datetime.DatetimeComponent.MILLISECOND
96
158
  )
@@ -38,6 +38,7 @@ class Literal(Expr):
38
38
  assert value.type == plc.interop.to_arrow(dtype)
39
39
  self.value = value
40
40
  self.children = ()
41
+ self.is_pointwise = True
41
42
 
42
43
  def do_evaluate(
43
44
  self,
@@ -65,6 +66,7 @@ class LiteralColumn(Expr):
65
66
  data = value.to_arrow()
66
67
  self.value = data.cast(dtypes.downcast_arrow_lists(data.type))
67
68
  self.children = ()
69
+ self.is_pointwise = True
68
70
 
69
71
  def get_hashable(self) -> Hashable:
70
72
  """Compute a hash of the column."""
@@ -13,7 +13,7 @@ from cudf_polars.dsl.expressions.base import Expr
13
13
  if TYPE_CHECKING:
14
14
  import pylibcudf as plc
15
15
 
16
- __all__ = ["RollingWindow", "GroupedRollingWindow"]
16
+ __all__ = ["GroupedRollingWindow", "RollingWindow"]
17
17
 
18
18
 
19
19
  class RollingWindow(Expr):
@@ -24,6 +24,7 @@ class RollingWindow(Expr):
24
24
  self.dtype = dtype
25
25
  self.options = options
26
26
  self.children = (agg,)
27
+ self.is_pointwise = False
27
28
  raise NotImplementedError("Rolling window not implemented")
28
29
 
29
30
 
@@ -35,4 +36,5 @@ class GroupedRollingWindow(Expr):
35
36
  self.dtype = dtype
36
37
  self.options = options
37
38
  self.children = (agg, *by)
39
+ self.is_pointwise = False
38
40
  raise NotImplementedError("Grouped rolling window not implemented")