cudf-polars-cu12 25.4.0__py3-none-any.whl → 25.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/callback.py +55 -61
  3. cudf_polars/containers/__init__.py +4 -2
  4. cudf_polars/containers/column.py +123 -40
  5. cudf_polars/containers/dataframe.py +70 -35
  6. cudf_polars/containers/datatype.py +135 -0
  7. cudf_polars/dsl/expr.py +2 -0
  8. cudf_polars/dsl/expressions/aggregation.py +51 -71
  9. cudf_polars/dsl/expressions/base.py +45 -77
  10. cudf_polars/dsl/expressions/binaryop.py +29 -44
  11. cudf_polars/dsl/expressions/boolean.py +64 -71
  12. cudf_polars/dsl/expressions/datetime.py +70 -34
  13. cudf_polars/dsl/expressions/literal.py +45 -33
  14. cudf_polars/dsl/expressions/rolling.py +133 -10
  15. cudf_polars/dsl/expressions/selection.py +13 -31
  16. cudf_polars/dsl/expressions/slicing.py +6 -13
  17. cudf_polars/dsl/expressions/sorting.py +9 -21
  18. cudf_polars/dsl/expressions/string.py +470 -84
  19. cudf_polars/dsl/expressions/struct.py +138 -0
  20. cudf_polars/dsl/expressions/ternary.py +9 -13
  21. cudf_polars/dsl/expressions/unary.py +151 -90
  22. cudf_polars/dsl/ir.py +798 -331
  23. cudf_polars/dsl/nodebase.py +11 -4
  24. cudf_polars/dsl/to_ast.py +61 -20
  25. cudf_polars/dsl/tracing.py +16 -0
  26. cudf_polars/dsl/translate.py +279 -167
  27. cudf_polars/dsl/traversal.py +64 -15
  28. cudf_polars/dsl/utils/__init__.py +8 -0
  29. cudf_polars/dsl/utils/aggregations.py +301 -0
  30. cudf_polars/dsl/utils/groupby.py +93 -0
  31. cudf_polars/dsl/utils/naming.py +34 -0
  32. cudf_polars/dsl/utils/replace.py +61 -0
  33. cudf_polars/dsl/utils/reshape.py +74 -0
  34. cudf_polars/dsl/utils/rolling.py +115 -0
  35. cudf_polars/dsl/utils/windows.py +186 -0
  36. cudf_polars/experimental/base.py +112 -8
  37. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  38. cudf_polars/experimental/benchmarks/pdsds.py +216 -0
  39. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  40. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  41. cudf_polars/experimental/benchmarks/pdsh.py +812 -0
  42. cudf_polars/experimental/benchmarks/utils.py +725 -0
  43. cudf_polars/experimental/dask_registers.py +200 -0
  44. cudf_polars/experimental/dispatch.py +22 -7
  45. cudf_polars/experimental/distinct.py +194 -0
  46. cudf_polars/experimental/explain.py +127 -0
  47. cudf_polars/experimental/expressions.py +547 -0
  48. cudf_polars/experimental/groupby.py +174 -196
  49. cudf_polars/experimental/io.py +626 -51
  50. cudf_polars/experimental/join.py +104 -33
  51. cudf_polars/experimental/parallel.py +219 -133
  52. cudf_polars/experimental/repartition.py +69 -0
  53. cudf_polars/experimental/scheduler.py +155 -0
  54. cudf_polars/experimental/select.py +132 -7
  55. cudf_polars/experimental/shuffle.py +126 -18
  56. cudf_polars/experimental/sort.py +45 -0
  57. cudf_polars/experimental/spilling.py +151 -0
  58. cudf_polars/experimental/utils.py +112 -0
  59. cudf_polars/testing/asserts.py +213 -14
  60. cudf_polars/testing/io.py +72 -0
  61. cudf_polars/testing/plugin.py +77 -67
  62. cudf_polars/typing/__init__.py +63 -22
  63. cudf_polars/utils/config.py +584 -117
  64. cudf_polars/utils/dtypes.py +4 -117
  65. cudf_polars/utils/timer.py +1 -1
  66. cudf_polars/utils/versions.py +7 -5
  67. {cudf_polars_cu12-25.4.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/METADATA +13 -18
  68. cudf_polars_cu12-25.8.0.dist-info/RECORD +81 -0
  69. {cudf_polars_cu12-25.4.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/WHEEL +1 -1
  70. cudf_polars/experimental/dask_serialize.py +0 -73
  71. cudf_polars_cu12-25.4.0.dist-info/RECORD +0 -55
  72. {cudf_polars_cu12-25.4.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/licenses/LICENSE +0 -0
  73. {cudf_polars_cu12-25.4.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/top_level.txt +0 -0
@@ -8,26 +8,55 @@ from __future__ import annotations
8
8
  from functools import cached_property
9
9
  from typing import TYPE_CHECKING, cast
10
10
 
11
- import pyarrow as pa
12
-
13
11
  import polars as pl
14
12
 
15
13
  import pylibcudf as plc
16
14
 
17
- from cudf_polars.containers import Column
18
- from cudf_polars.utils import conversion, dtypes
15
+ from cudf_polars.containers import Column, DataType
16
+ from cudf_polars.utils import conversion
19
17
 
20
18
  if TYPE_CHECKING:
21
19
  from collections.abc import Iterable, Mapping, Sequence, Set
22
20
 
23
- from typing_extensions import Self
21
+ from typing_extensions import Any, CapsuleType, Self
24
22
 
25
- from cudf_polars.typing import ColumnOptions, DataFrameHeader, Slice
23
+ from cudf_polars.typing import ColumnOptions, DataFrameHeader, PolarsDataType, Slice
26
24
 
27
25
 
28
26
  __all__: list[str] = ["DataFrame"]
29
27
 
30
28
 
29
+ def _create_polars_column_metadata(
30
+ name: str, dtype: PolarsDataType
31
+ ) -> plc.interop.ColumnMetadata:
32
+ """Create ColumnMetadata preserving pl.Struct field names."""
33
+ if isinstance(dtype, pl.Struct):
34
+ children_meta = [
35
+ _create_polars_column_metadata(field.name, field.dtype)
36
+ for field in dtype.fields
37
+ ]
38
+ else:
39
+ children_meta = []
40
+ timezone = dtype.time_zone if isinstance(dtype, pl.Datetime) else None
41
+ return plc.interop.ColumnMetadata(
42
+ name=name, timezone=timezone or "", children_meta=children_meta
43
+ )
44
+
45
+
46
+ # This is also defined in pylibcudf.interop
47
+ class _ObjectWithArrowMetadata:
48
+ def __init__(
49
+ self, obj: plc.Table, metadata: list[plc.interop.ColumnMetadata]
50
+ ) -> None:
51
+ self.obj = obj
52
+ self.metadata = metadata
53
+
54
+ def __arrow_c_array__(
55
+ self, requested_schema: None = None
56
+ ) -> tuple[CapsuleType, CapsuleType]:
57
+ return self.obj._to_schema(self.metadata), self.obj._to_host_array()
58
+
59
+
31
60
  # Pacify the type checker. DataFrame init asserts that all the columns
32
61
  # have a string name, so let's narrow the type.
33
62
  class NamedColumn(Column):
@@ -46,6 +75,7 @@ class DataFrame:
46
75
  if any(c.name is None for c in columns):
47
76
  raise ValueError("All columns must have a name")
48
77
  self.columns = [cast(NamedColumn, c) for c in columns]
78
+ self.dtypes = [c.dtype for c in self.columns]
49
79
  self.column_map = {c.name: c for c in self.columns}
50
80
  self.table = plc.Table([c.obj for c in self.columns])
51
81
 
@@ -62,11 +92,12 @@ class DataFrame:
62
92
  # To guarantee we produce correct names, we therefore
63
93
  # serialise with names we control and rename with that map.
64
94
  name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)}
65
- table = plc.interop.to_arrow(
66
- self.table,
67
- [plc.interop.ColumnMetadata(name=name) for name in name_map],
68
- )
69
- df: pl.DataFrame = pl.from_arrow(table)
95
+ metadata = [
96
+ _create_polars_column_metadata(name, dtype.polars)
97
+ for name, dtype in zip(name_map, self.dtypes, strict=True)
98
+ ]
99
+ table_with_metadata = _ObjectWithArrowMetadata(self.table, metadata)
100
+ df = pl.DataFrame(table_with_metadata)
70
101
  return df.rename(name_map).with_columns(
71
102
  pl.col(c.name).set_sorted(descending=c.order == plc.types.Order.DESCENDING)
72
103
  if c.is_sorted
@@ -108,21 +139,18 @@ class DataFrame:
108
139
  -------
109
140
  New dataframe representing the input.
110
141
  """
111
- table = df.to_arrow()
112
- schema = table.schema
113
- for i, field in enumerate(schema):
114
- schema = schema.set(
115
- i, pa.field(field.name, dtypes.downcast_arrow_lists(field.type))
116
- )
117
- # No-op if the schema is unchanged.
118
- d_table = plc.interop.from_arrow(table.cast(schema))
142
+ plc_table = plc.Table.from_arrow(df)
119
143
  return cls(
120
- Column(column).copy_metadata(h_col)
121
- for column, h_col in zip(d_table.columns(), df.iter_columns(), strict=True)
144
+ Column(d_col, name=name, dtype=DataType(h_col.dtype)).copy_metadata(h_col)
145
+ for d_col, h_col, name in zip(
146
+ plc_table.columns(), df.iter_columns(), df.columns, strict=True
147
+ )
122
148
  )
123
149
 
124
150
  @classmethod
125
- def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self:
151
+ def from_table(
152
+ cls, table: plc.Table, names: Sequence[str], dtypes: Sequence[DataType]
153
+ ) -> Self:
126
154
  """
127
155
  Create from a pylibcudf table.
128
156
 
@@ -132,6 +160,8 @@ class DataFrame:
132
160
  Pylibcudf table to obtain columns from
133
161
  names
134
162
  Names for the columns
163
+ dtypes
164
+ Dtypes for the columns
135
165
 
136
166
  Returns
137
167
  -------
@@ -146,7 +176,8 @@ class DataFrame:
146
176
  if table.num_columns() != len(names):
147
177
  raise ValueError("Mismatching name and table length.")
148
178
  return cls(
149
- Column(c, name=name) for c, name in zip(table.columns(), names, strict=True)
179
+ Column(c, name=name, dtype=dtype)
180
+ for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
150
181
  )
151
182
 
152
183
  @classmethod
@@ -173,7 +204,7 @@ class DataFrame:
173
204
  packed_metadata, packed_gpu_data
174
205
  )
175
206
  return cls(
176
- Column(c, **kw)
207
+ Column(c, **Column.deserialize_ctor_kwargs(kw))
177
208
  for c, kw in zip(table.columns(), header["columns_kwargs"], strict=True)
178
209
  )
179
210
 
@@ -202,13 +233,7 @@ class DataFrame:
202
233
 
203
234
  # Keyword arguments for `Column.__init__`.
204
235
  columns_kwargs: list[ColumnOptions] = [
205
- {
206
- "is_sorted": col.is_sorted,
207
- "order": col.order,
208
- "null_order": col.null_order,
209
- "name": col.name,
210
- }
211
- for col in self.columns
236
+ col.serialize_ctor_kwargs() for col in self.columns
212
237
  ]
213
238
  header: DataFrameHeader = {
214
239
  "columns_kwargs": columns_kwargs,
@@ -246,7 +271,9 @@ class DataFrame:
246
271
  for c, other in zip(self.columns, like.columns, strict=True)
247
272
  )
248
273
 
249
- def with_columns(self, columns: Iterable[Column], *, replace_only=False) -> Self:
274
+ def with_columns(
275
+ self, columns: Iterable[Column], *, replace_only: bool = False
276
+ ) -> Self:
250
277
  """
251
278
  Return a new dataframe with extra columns.
252
279
 
@@ -275,7 +302,7 @@ class DataFrame:
275
302
  """Drop columns by name."""
276
303
  return type(self)(column for column in self.columns if column.name not in names)
277
304
 
278
- def select(self, names: Sequence[str]) -> Self:
305
+ def select(self, names: Sequence[str] | Mapping[str, Any]) -> Self:
279
306
  """Select columns by name returning DataFrame."""
280
307
  try:
281
308
  return type(self)(self.column_map[name] for name in names)
@@ -293,7 +320,11 @@ class DataFrame:
293
320
  def filter(self, mask: Column) -> Self:
294
321
  """Return a filtered table given a mask."""
295
322
  table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
296
- return type(self).from_table(table, self.column_names).sorted_like(self)
323
+ return (
324
+ type(self)
325
+ .from_table(table, self.column_names, self.dtypes)
326
+ .sorted_like(self)
327
+ )
297
328
 
298
329
  def slice(self, zlice: Slice | None) -> Self:
299
330
  """
@@ -314,4 +345,8 @@ class DataFrame:
314
345
  (table,) = plc.copying.slice(
315
346
  self.table, conversion.from_polars_slice(zlice, num_rows=self.num_rows)
316
347
  )
317
- return type(self).from_table(table, self.column_names).sorted_like(self)
348
+ return (
349
+ type(self)
350
+ .from_table(table, self.column_names, self.dtypes)
351
+ .sorted_like(self)
352
+ )
@@ -0,0 +1,135 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """A datatype, preserving polars metadata."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from functools import cache
9
+
10
+ from typing_extensions import assert_never
11
+
12
+ import polars as pl
13
+
14
+ import pylibcudf as plc
15
+
16
+ __all__ = ["DataType"]
17
+
18
+
19
+ @cache
20
+ def _from_polars(dtype: pl.DataType) -> plc.DataType:
21
+ """
22
+ Convert a polars datatype to a pylibcudf one.
23
+
24
+ Parameters
25
+ ----------
26
+ dtype
27
+ Polars dtype to convert
28
+
29
+ Returns
30
+ -------
31
+ Matching pylibcudf DataType object.
32
+
33
+ Raises
34
+ ------
35
+ NotImplementedError
36
+ For unsupported conversions.
37
+ """
38
+ if isinstance(dtype, pl.Boolean):
39
+ return plc.DataType(plc.TypeId.BOOL8)
40
+ elif isinstance(dtype, pl.Int8):
41
+ return plc.DataType(plc.TypeId.INT8)
42
+ elif isinstance(dtype, pl.Int16):
43
+ return plc.DataType(plc.TypeId.INT16)
44
+ elif isinstance(dtype, pl.Int32):
45
+ return plc.DataType(plc.TypeId.INT32)
46
+ elif isinstance(dtype, pl.Int64):
47
+ return plc.DataType(plc.TypeId.INT64)
48
+ if isinstance(dtype, pl.UInt8):
49
+ return plc.DataType(plc.TypeId.UINT8)
50
+ elif isinstance(dtype, pl.UInt16):
51
+ return plc.DataType(plc.TypeId.UINT16)
52
+ elif isinstance(dtype, pl.UInt32):
53
+ return plc.DataType(plc.TypeId.UINT32)
54
+ elif isinstance(dtype, pl.UInt64):
55
+ return plc.DataType(plc.TypeId.UINT64)
56
+ elif isinstance(dtype, pl.Float32):
57
+ return plc.DataType(plc.TypeId.FLOAT32)
58
+ elif isinstance(dtype, pl.Float64):
59
+ return plc.DataType(plc.TypeId.FLOAT64)
60
+ elif isinstance(dtype, pl.Date):
61
+ return plc.DataType(plc.TypeId.TIMESTAMP_DAYS)
62
+ elif isinstance(dtype, pl.Time):
63
+ raise NotImplementedError("Time of day dtype not implemented")
64
+ elif isinstance(dtype, pl.Datetime):
65
+ if dtype.time_unit == "ms":
66
+ return plc.DataType(plc.TypeId.TIMESTAMP_MILLISECONDS)
67
+ elif dtype.time_unit == "us":
68
+ return plc.DataType(plc.TypeId.TIMESTAMP_MICROSECONDS)
69
+ elif dtype.time_unit == "ns":
70
+ return plc.DataType(plc.TypeId.TIMESTAMP_NANOSECONDS)
71
+ assert dtype.time_unit is not None # pragma: no cover
72
+ assert_never(dtype.time_unit)
73
+ elif isinstance(dtype, pl.Duration):
74
+ if dtype.time_unit == "ms":
75
+ return plc.DataType(plc.TypeId.DURATION_MILLISECONDS)
76
+ elif dtype.time_unit == "us":
77
+ return plc.DataType(plc.TypeId.DURATION_MICROSECONDS)
78
+ elif dtype.time_unit == "ns":
79
+ return plc.DataType(plc.TypeId.DURATION_NANOSECONDS)
80
+ assert dtype.time_unit is not None # pragma: no cover
81
+ assert_never(dtype.time_unit)
82
+ elif isinstance(dtype, pl.String):
83
+ return plc.DataType(plc.TypeId.STRING)
84
+ elif isinstance(dtype, pl.Null):
85
+ # TODO: Hopefully
86
+ return plc.DataType(plc.TypeId.EMPTY)
87
+ elif isinstance(dtype, pl.List):
88
+ # Recurse to catch unsupported inner types
89
+ _ = _from_polars(dtype.inner)
90
+ return plc.DataType(plc.TypeId.LIST)
91
+ elif isinstance(dtype, pl.Struct):
92
+ # Recurse to catch unsupported field types
93
+ for field in dtype.fields:
94
+ _ = _from_polars(field.dtype)
95
+ return plc.DataType(plc.TypeId.STRUCT)
96
+ else:
97
+ raise NotImplementedError(f"{dtype=} conversion not supported")
98
+
99
+
100
+ class DataType:
101
+ """A datatype, preserving polars metadata."""
102
+
103
+ polars: pl.datatypes.DataType
104
+ plc: plc.DataType
105
+
106
+ def __init__(self, polars_dtype: pl.DataType) -> None:
107
+ self.polars = polars_dtype
108
+ self.plc = _from_polars(polars_dtype)
109
+
110
+ def id(self) -> plc.TypeId:
111
+ """The pylibcudf.TypeId of this DataType."""
112
+ return self.plc.id()
113
+
114
+ @property
115
+ def children(self) -> list[DataType]:
116
+ """The children types of this DataType."""
117
+ if self.plc.id() == plc.TypeId.STRUCT:
118
+ return [DataType(field.dtype) for field in self.polars.fields]
119
+ elif self.plc.id() == plc.TypeId.LIST:
120
+ return [DataType(self.polars.inner)]
121
+ return []
122
+
123
+ def __eq__(self, other: object) -> bool:
124
+ """Equality of DataTypes."""
125
+ if not isinstance(other, DataType):
126
+ return False
127
+ return self.polars == other.polars
128
+
129
+ def __hash__(self) -> int:
130
+ """Hash of the DataType."""
131
+ return hash(self.polars)
132
+
133
+ def __repr__(self) -> str:
134
+ """Representation of the DataType."""
135
+ return f"<DataType(polars={self.polars}, plc={self.id()!r})>"
cudf_polars/dsl/expr.py CHANGED
@@ -33,6 +33,7 @@ from cudf_polars.dsl.expressions.selection import Filter, Gather
33
33
  from cudf_polars.dsl.expressions.slicing import Slice
34
34
  from cudf_polars.dsl.expressions.sorting import Sort, SortBy
35
35
  from cudf_polars.dsl.expressions.string import StringFunction
36
+ from cudf_polars.dsl.expressions.struct import StructFunction
36
37
  from cudf_polars.dsl.expressions.ternary import Ternary
37
38
  from cudf_polars.dsl.expressions.unary import Cast, Len, UnaryFunction
38
39
 
@@ -58,6 +59,7 @@ __all__ = [
58
59
  "Sort",
59
60
  "SortBy",
60
61
  "StringFunction",
62
+ "StructFunction",
61
63
  "TemporalFunction",
62
64
  "Ternary",
63
65
  "UnaryFunction",
@@ -9,23 +9,14 @@ from __future__ import annotations
9
9
  from functools import partial
10
10
  from typing import TYPE_CHECKING, Any, ClassVar
11
11
 
12
- import pyarrow as pa
13
-
14
12
  import pylibcudf as plc
15
13
 
16
14
  from cudf_polars.containers import Column
17
- from cudf_polars.dsl.expressions.base import (
18
- AggInfo,
19
- ExecutionContext,
20
- Expr,
21
- )
15
+ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
22
16
  from cudf_polars.dsl.expressions.literal import Literal
23
- from cudf_polars.dsl.expressions.unary import UnaryFunction
24
17
 
25
18
  if TYPE_CHECKING:
26
- from collections.abc import Mapping
27
-
28
- from cudf_polars.containers import DataFrame
19
+ from cudf_polars.containers import DataFrame, DataType
29
20
 
30
21
  __all__ = ["Agg"]
31
22
 
@@ -35,7 +26,7 @@ class Agg(Expr):
35
26
  _non_child = ("dtype", "name", "options")
36
27
 
37
28
  def __init__(
38
- self, dtype: plc.DataType, name: str, options: Any, *children: Expr
29
+ self, dtype: DataType, name: str, options: Any, *children: Expr
39
30
  ) -> None:
40
31
  self.dtype = dtype
41
32
  self.name = name
@@ -75,11 +66,15 @@ class Agg(Expr):
75
66
  else plc.types.NullPolicy.INCLUDE
76
67
  )
77
68
  elif name == "quantile":
78
- _, quantile = self.children
69
+ child, quantile = self.children
79
70
  if not isinstance(quantile, Literal):
80
71
  raise NotImplementedError("Only support literal quantile values")
72
+ if options == "equiprobable":
73
+ raise NotImplementedError("Quantile with equiprobable interpolation")
74
+ if plc.traits.is_duration(child.dtype.plc):
75
+ raise NotImplementedError("Quantile with duration data type")
81
76
  req = plc.aggregation.quantile(
82
- quantiles=[quantile.value.as_py()], interp=Agg.interp_mapping[options]
77
+ quantiles=[quantile.value], interp=Agg.interp_mapping[options]
83
78
  )
84
79
  else:
85
80
  raise NotImplementedError(
@@ -91,7 +86,9 @@ class Agg(Expr):
91
86
  op = partial(self._reduce, request=req)
92
87
  elif name in {"min", "max"}:
93
88
  op = partial(op, propagate_nans=options)
94
- elif name in {"count", "sum", "first", "last"}:
89
+ elif name == "count":
90
+ op = partial(op, include_nulls=options)
91
+ elif name in {"sum", "first", "last"}:
95
92
  pass
96
93
  else:
97
94
  raise NotImplementedError(
@@ -124,71 +121,52 @@ class Agg(Expr):
124
121
  "linear": plc.types.Interpolation.LINEAR,
125
122
  }
126
123
 
127
- def collect_agg(self, *, depth: int) -> AggInfo:
128
- """Collect information about aggregations in groupbys."""
129
- if depth >= 1:
130
- raise NotImplementedError(
131
- "Nested aggregations in groupby"
132
- ) # pragma: no cover; check_agg trips first
133
- if (isminmax := self.name in {"min", "max"}) and self.options:
134
- raise NotImplementedError("Nan propagation in groupby for min/max")
135
- (child,) = self.children
136
- ((expr, _, _),) = child.collect_agg(depth=depth + 1).requests
137
- request = self.request
138
- # These are handled specially here because we don't set up the
139
- # request for the whole-frame agg because we can avoid a
140
- # reduce for these.
124
+ @property
125
+ def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102
141
126
  if self.name == "first":
142
- request = plc.aggregation.nth_element(
127
+ return plc.aggregation.nth_element(
143
128
  0, null_handling=plc.types.NullPolicy.INCLUDE
144
129
  )
145
130
  elif self.name == "last":
146
- request = plc.aggregation.nth_element(
131
+ return plc.aggregation.nth_element(
147
132
  -1, null_handling=plc.types.NullPolicy.INCLUDE
148
133
  )
149
- if request is None:
150
- raise NotImplementedError(
151
- f"Aggregation {self.name} in groupby"
152
- ) # pragma: no cover; __init__ trips first
153
- if isminmax and plc.traits.is_floating_point(self.dtype):
154
- assert expr is not None
155
- # Ignore nans in these groupby aggs, do this by masking
156
- # nans in the input
157
- expr = UnaryFunction(self.dtype, "mask_nans", (), expr)
158
- return AggInfo([(expr, request, self)])
134
+ else:
135
+ assert self.request is not None, "Init should have raised"
136
+ return self.request
159
137
 
160
138
  def _reduce(
161
139
  self, column: Column, *, request: plc.aggregation.Aggregation
162
140
  ) -> Column:
163
141
  return Column(
164
142
  plc.Column.from_scalar(
165
- plc.reduce.reduce(column.obj, request, self.dtype),
143
+ plc.reduce.reduce(column.obj, request, self.dtype.plc),
166
144
  1,
167
- )
145
+ ),
146
+ name=column.name,
147
+ dtype=self.dtype,
168
148
  )
169
149
 
170
- def _count(self, column: Column) -> Column:
150
+ def _count(self, column: Column, *, include_nulls: bool) -> Column:
151
+ null_count = column.null_count if not include_nulls else 0
171
152
  return Column(
172
153
  plc.Column.from_scalar(
173
- plc.interop.from_arrow(
174
- pa.scalar(
175
- column.size - column.null_count,
176
- type=plc.interop.to_arrow(self.dtype),
177
- ),
178
- ),
154
+ plc.Scalar.from_py(column.size - null_count, self.dtype.plc),
179
155
  1,
180
- )
156
+ ),
157
+ name=column.name,
158
+ dtype=self.dtype,
181
159
  )
182
160
 
183
161
  def _sum(self, column: Column) -> Column:
184
162
  if column.size == 0 or column.null_count == column.size:
185
163
  return Column(
186
164
  plc.Column.from_scalar(
187
- plc.interop.from_arrow(
188
- pa.scalar(0, type=plc.interop.to_arrow(self.dtype))
189
- ),
165
+ plc.Scalar.from_py(0, self.dtype.plc),
190
166
  1,
191
- )
167
+ ),
168
+ name=column.name,
169
+ dtype=self.dtype,
192
170
  )
193
171
  return self._reduce(column, request=plc.aggregation.sum())
194
172
 
@@ -196,11 +174,11 @@ class Agg(Expr):
196
174
  if propagate_nans and column.nan_count > 0:
197
175
  return Column(
198
176
  plc.Column.from_scalar(
199
- plc.interop.from_arrow(
200
- pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
201
- ),
177
+ plc.Scalar.from_py(float("nan"), self.dtype.plc),
202
178
  1,
203
- )
179
+ ),
180
+ name=column.name,
181
+ dtype=self.dtype,
204
182
  )
205
183
  if column.nan_count > 0:
206
184
  column = column.mask_nans()
@@ -210,29 +188,31 @@ class Agg(Expr):
210
188
  if propagate_nans and column.nan_count > 0:
211
189
  return Column(
212
190
  plc.Column.from_scalar(
213
- plc.interop.from_arrow(
214
- pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
215
- ),
191
+ plc.Scalar.from_py(float("nan"), self.dtype.plc),
216
192
  1,
217
- )
193
+ ),
194
+ name=column.name,
195
+ dtype=self.dtype,
218
196
  )
219
197
  if column.nan_count > 0:
220
198
  column = column.mask_nans()
221
199
  return self._reduce(column, request=plc.aggregation.max())
222
200
 
223
201
  def _first(self, column: Column) -> Column:
224
- return Column(plc.copying.slice(column.obj, [0, 1])[0])
202
+ return Column(
203
+ plc.copying.slice(column.obj, [0, 1])[0], name=column.name, dtype=self.dtype
204
+ )
225
205
 
226
206
  def _last(self, column: Column) -> Column:
227
207
  n = column.size
228
- return Column(plc.copying.slice(column.obj, [n - 1, n])[0])
208
+ return Column(
209
+ plc.copying.slice(column.obj, [n - 1, n])[0],
210
+ name=column.name,
211
+ dtype=self.dtype,
212
+ )
229
213
 
230
214
  def do_evaluate(
231
- self,
232
- df: DataFrame,
233
- *,
234
- context: ExecutionContext = ExecutionContext.FRAME,
235
- mapping: Mapping[Expr, Column] | None = None,
215
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
236
216
  ) -> Column:
237
217
  """Evaluate this expression given a dataframe for context."""
238
218
  if context is not ExecutionContext.FRAME:
@@ -243,4 +223,4 @@ class Agg(Expr):
243
223
  # Aggregations like quantiles may have additional children that were
244
224
  # preprocessed into pylibcudf requests.
245
225
  child = self.children[0]
246
- return self.op(child.evaluate(df, context=context, mapping=mapping))
226
+ return self.op(child.evaluate(df, context=context))