cudf-polars-cu12 25.4.0__py3-none-any.whl → 25.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +55 -61
- cudf_polars/containers/__init__.py +4 -2
- cudf_polars/containers/column.py +123 -40
- cudf_polars/containers/dataframe.py +70 -35
- cudf_polars/containers/datatype.py +135 -0
- cudf_polars/dsl/expr.py +2 -0
- cudf_polars/dsl/expressions/aggregation.py +51 -71
- cudf_polars/dsl/expressions/base.py +45 -77
- cudf_polars/dsl/expressions/binaryop.py +29 -44
- cudf_polars/dsl/expressions/boolean.py +64 -71
- cudf_polars/dsl/expressions/datetime.py +70 -34
- cudf_polars/dsl/expressions/literal.py +45 -33
- cudf_polars/dsl/expressions/rolling.py +133 -10
- cudf_polars/dsl/expressions/selection.py +13 -31
- cudf_polars/dsl/expressions/slicing.py +6 -13
- cudf_polars/dsl/expressions/sorting.py +9 -21
- cudf_polars/dsl/expressions/string.py +470 -84
- cudf_polars/dsl/expressions/struct.py +138 -0
- cudf_polars/dsl/expressions/ternary.py +9 -13
- cudf_polars/dsl/expressions/unary.py +151 -90
- cudf_polars/dsl/ir.py +798 -331
- cudf_polars/dsl/nodebase.py +11 -4
- cudf_polars/dsl/to_ast.py +61 -20
- cudf_polars/dsl/tracing.py +16 -0
- cudf_polars/dsl/translate.py +279 -167
- cudf_polars/dsl/traversal.py +64 -15
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +301 -0
- cudf_polars/dsl/utils/groupby.py +93 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +61 -0
- cudf_polars/dsl/utils/reshape.py +74 -0
- cudf_polars/dsl/utils/rolling.py +115 -0
- cudf_polars/dsl/utils/windows.py +186 -0
- cudf_polars/experimental/base.py +112 -8
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds.py +216 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
- cudf_polars/experimental/benchmarks/pdsh.py +812 -0
- cudf_polars/experimental/benchmarks/utils.py +725 -0
- cudf_polars/experimental/dask_registers.py +200 -0
- cudf_polars/experimental/dispatch.py +22 -7
- cudf_polars/experimental/distinct.py +194 -0
- cudf_polars/experimental/explain.py +127 -0
- cudf_polars/experimental/expressions.py +547 -0
- cudf_polars/experimental/groupby.py +174 -196
- cudf_polars/experimental/io.py +626 -51
- cudf_polars/experimental/join.py +104 -33
- cudf_polars/experimental/parallel.py +219 -133
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +132 -7
- cudf_polars/experimental/shuffle.py +126 -18
- cudf_polars/experimental/sort.py +45 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/utils.py +112 -0
- cudf_polars/testing/asserts.py +213 -14
- cudf_polars/testing/io.py +72 -0
- cudf_polars/testing/plugin.py +77 -67
- cudf_polars/typing/__init__.py +63 -22
- cudf_polars/utils/config.py +584 -117
- cudf_polars/utils/dtypes.py +4 -117
- cudf_polars/utils/timer.py +1 -1
- cudf_polars/utils/versions.py +7 -5
- {cudf_polars_cu12-25.4.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/METADATA +13 -18
- cudf_polars_cu12-25.8.0.dist-info/RECORD +81 -0
- {cudf_polars_cu12-25.4.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/WHEEL +1 -1
- cudf_polars/experimental/dask_serialize.py +0 -73
- cudf_polars_cu12-25.4.0.dist-info/RECORD +0 -55
- {cudf_polars_cu12-25.4.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu12-25.4.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/top_level.txt +0 -0
|
@@ -8,26 +8,55 @@ from __future__ import annotations
|
|
|
8
8
|
from functools import cached_property
|
|
9
9
|
from typing import TYPE_CHECKING, cast
|
|
10
10
|
|
|
11
|
-
import pyarrow as pa
|
|
12
|
-
|
|
13
11
|
import polars as pl
|
|
14
12
|
|
|
15
13
|
import pylibcudf as plc
|
|
16
14
|
|
|
17
|
-
from cudf_polars.containers import Column
|
|
18
|
-
from cudf_polars.utils import conversion
|
|
15
|
+
from cudf_polars.containers import Column, DataType
|
|
16
|
+
from cudf_polars.utils import conversion
|
|
19
17
|
|
|
20
18
|
if TYPE_CHECKING:
|
|
21
19
|
from collections.abc import Iterable, Mapping, Sequence, Set
|
|
22
20
|
|
|
23
|
-
from typing_extensions import Self
|
|
21
|
+
from typing_extensions import Any, CapsuleType, Self
|
|
24
22
|
|
|
25
|
-
from cudf_polars.typing import ColumnOptions, DataFrameHeader, Slice
|
|
23
|
+
from cudf_polars.typing import ColumnOptions, DataFrameHeader, PolarsDataType, Slice
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
__all__: list[str] = ["DataFrame"]
|
|
29
27
|
|
|
30
28
|
|
|
29
|
+
def _create_polars_column_metadata(
|
|
30
|
+
name: str, dtype: PolarsDataType
|
|
31
|
+
) -> plc.interop.ColumnMetadata:
|
|
32
|
+
"""Create ColumnMetadata preserving pl.Struct field names."""
|
|
33
|
+
if isinstance(dtype, pl.Struct):
|
|
34
|
+
children_meta = [
|
|
35
|
+
_create_polars_column_metadata(field.name, field.dtype)
|
|
36
|
+
for field in dtype.fields
|
|
37
|
+
]
|
|
38
|
+
else:
|
|
39
|
+
children_meta = []
|
|
40
|
+
timezone = dtype.time_zone if isinstance(dtype, pl.Datetime) else None
|
|
41
|
+
return plc.interop.ColumnMetadata(
|
|
42
|
+
name=name, timezone=timezone or "", children_meta=children_meta
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# This is also defined in pylibcudf.interop
|
|
47
|
+
class _ObjectWithArrowMetadata:
|
|
48
|
+
def __init__(
|
|
49
|
+
self, obj: plc.Table, metadata: list[plc.interop.ColumnMetadata]
|
|
50
|
+
) -> None:
|
|
51
|
+
self.obj = obj
|
|
52
|
+
self.metadata = metadata
|
|
53
|
+
|
|
54
|
+
def __arrow_c_array__(
|
|
55
|
+
self, requested_schema: None = None
|
|
56
|
+
) -> tuple[CapsuleType, CapsuleType]:
|
|
57
|
+
return self.obj._to_schema(self.metadata), self.obj._to_host_array()
|
|
58
|
+
|
|
59
|
+
|
|
31
60
|
# Pacify the type checker. DataFrame init asserts that all the columns
|
|
32
61
|
# have a string name, so let's narrow the type.
|
|
33
62
|
class NamedColumn(Column):
|
|
@@ -46,6 +75,7 @@ class DataFrame:
|
|
|
46
75
|
if any(c.name is None for c in columns):
|
|
47
76
|
raise ValueError("All columns must have a name")
|
|
48
77
|
self.columns = [cast(NamedColumn, c) for c in columns]
|
|
78
|
+
self.dtypes = [c.dtype for c in self.columns]
|
|
49
79
|
self.column_map = {c.name: c for c in self.columns}
|
|
50
80
|
self.table = plc.Table([c.obj for c in self.columns])
|
|
51
81
|
|
|
@@ -62,11 +92,12 @@ class DataFrame:
|
|
|
62
92
|
# To guarantee we produce correct names, we therefore
|
|
63
93
|
# serialise with names we control and rename with that map.
|
|
64
94
|
name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)}
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
95
|
+
metadata = [
|
|
96
|
+
_create_polars_column_metadata(name, dtype.polars)
|
|
97
|
+
for name, dtype in zip(name_map, self.dtypes, strict=True)
|
|
98
|
+
]
|
|
99
|
+
table_with_metadata = _ObjectWithArrowMetadata(self.table, metadata)
|
|
100
|
+
df = pl.DataFrame(table_with_metadata)
|
|
70
101
|
return df.rename(name_map).with_columns(
|
|
71
102
|
pl.col(c.name).set_sorted(descending=c.order == plc.types.Order.DESCENDING)
|
|
72
103
|
if c.is_sorted
|
|
@@ -108,21 +139,18 @@ class DataFrame:
|
|
|
108
139
|
-------
|
|
109
140
|
New dataframe representing the input.
|
|
110
141
|
"""
|
|
111
|
-
|
|
112
|
-
schema = table.schema
|
|
113
|
-
for i, field in enumerate(schema):
|
|
114
|
-
schema = schema.set(
|
|
115
|
-
i, pa.field(field.name, dtypes.downcast_arrow_lists(field.type))
|
|
116
|
-
)
|
|
117
|
-
# No-op if the schema is unchanged.
|
|
118
|
-
d_table = plc.interop.from_arrow(table.cast(schema))
|
|
142
|
+
plc_table = plc.Table.from_arrow(df)
|
|
119
143
|
return cls(
|
|
120
|
-
Column(
|
|
121
|
-
for
|
|
144
|
+
Column(d_col, name=name, dtype=DataType(h_col.dtype)).copy_metadata(h_col)
|
|
145
|
+
for d_col, h_col, name in zip(
|
|
146
|
+
plc_table.columns(), df.iter_columns(), df.columns, strict=True
|
|
147
|
+
)
|
|
122
148
|
)
|
|
123
149
|
|
|
124
150
|
@classmethod
|
|
125
|
-
def from_table(
|
|
151
|
+
def from_table(
|
|
152
|
+
cls, table: plc.Table, names: Sequence[str], dtypes: Sequence[DataType]
|
|
153
|
+
) -> Self:
|
|
126
154
|
"""
|
|
127
155
|
Create from a pylibcudf table.
|
|
128
156
|
|
|
@@ -132,6 +160,8 @@ class DataFrame:
|
|
|
132
160
|
Pylibcudf table to obtain columns from
|
|
133
161
|
names
|
|
134
162
|
Names for the columns
|
|
163
|
+
dtypes
|
|
164
|
+
Dtypes for the columns
|
|
135
165
|
|
|
136
166
|
Returns
|
|
137
167
|
-------
|
|
@@ -146,7 +176,8 @@ class DataFrame:
|
|
|
146
176
|
if table.num_columns() != len(names):
|
|
147
177
|
raise ValueError("Mismatching name and table length.")
|
|
148
178
|
return cls(
|
|
149
|
-
Column(c, name=name
|
|
179
|
+
Column(c, name=name, dtype=dtype)
|
|
180
|
+
for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
|
|
150
181
|
)
|
|
151
182
|
|
|
152
183
|
@classmethod
|
|
@@ -173,7 +204,7 @@ class DataFrame:
|
|
|
173
204
|
packed_metadata, packed_gpu_data
|
|
174
205
|
)
|
|
175
206
|
return cls(
|
|
176
|
-
Column(c, **kw)
|
|
207
|
+
Column(c, **Column.deserialize_ctor_kwargs(kw))
|
|
177
208
|
for c, kw in zip(table.columns(), header["columns_kwargs"], strict=True)
|
|
178
209
|
)
|
|
179
210
|
|
|
@@ -202,13 +233,7 @@ class DataFrame:
|
|
|
202
233
|
|
|
203
234
|
# Keyword arguments for `Column.__init__`.
|
|
204
235
|
columns_kwargs: list[ColumnOptions] = [
|
|
205
|
-
|
|
206
|
-
"is_sorted": col.is_sorted,
|
|
207
|
-
"order": col.order,
|
|
208
|
-
"null_order": col.null_order,
|
|
209
|
-
"name": col.name,
|
|
210
|
-
}
|
|
211
|
-
for col in self.columns
|
|
236
|
+
col.serialize_ctor_kwargs() for col in self.columns
|
|
212
237
|
]
|
|
213
238
|
header: DataFrameHeader = {
|
|
214
239
|
"columns_kwargs": columns_kwargs,
|
|
@@ -246,7 +271,9 @@ class DataFrame:
|
|
|
246
271
|
for c, other in zip(self.columns, like.columns, strict=True)
|
|
247
272
|
)
|
|
248
273
|
|
|
249
|
-
def with_columns(
|
|
274
|
+
def with_columns(
|
|
275
|
+
self, columns: Iterable[Column], *, replace_only: bool = False
|
|
276
|
+
) -> Self:
|
|
250
277
|
"""
|
|
251
278
|
Return a new dataframe with extra columns.
|
|
252
279
|
|
|
@@ -275,7 +302,7 @@ class DataFrame:
|
|
|
275
302
|
"""Drop columns by name."""
|
|
276
303
|
return type(self)(column for column in self.columns if column.name not in names)
|
|
277
304
|
|
|
278
|
-
def select(self, names: Sequence[str]) -> Self:
|
|
305
|
+
def select(self, names: Sequence[str] | Mapping[str, Any]) -> Self:
|
|
279
306
|
"""Select columns by name returning DataFrame."""
|
|
280
307
|
try:
|
|
281
308
|
return type(self)(self.column_map[name] for name in names)
|
|
@@ -293,7 +320,11 @@ class DataFrame:
|
|
|
293
320
|
def filter(self, mask: Column) -> Self:
|
|
294
321
|
"""Return a filtered table given a mask."""
|
|
295
322
|
table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
|
|
296
|
-
return
|
|
323
|
+
return (
|
|
324
|
+
type(self)
|
|
325
|
+
.from_table(table, self.column_names, self.dtypes)
|
|
326
|
+
.sorted_like(self)
|
|
327
|
+
)
|
|
297
328
|
|
|
298
329
|
def slice(self, zlice: Slice | None) -> Self:
|
|
299
330
|
"""
|
|
@@ -314,4 +345,8 @@ class DataFrame:
|
|
|
314
345
|
(table,) = plc.copying.slice(
|
|
315
346
|
self.table, conversion.from_polars_slice(zlice, num_rows=self.num_rows)
|
|
316
347
|
)
|
|
317
|
-
return
|
|
348
|
+
return (
|
|
349
|
+
type(self)
|
|
350
|
+
.from_table(table, self.column_names, self.dtypes)
|
|
351
|
+
.sorted_like(self)
|
|
352
|
+
)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""A datatype, preserving polars metadata."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from functools import cache
|
|
9
|
+
|
|
10
|
+
from typing_extensions import assert_never
|
|
11
|
+
|
|
12
|
+
import polars as pl
|
|
13
|
+
|
|
14
|
+
import pylibcudf as plc
|
|
15
|
+
|
|
16
|
+
__all__ = ["DataType"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@cache
|
|
20
|
+
def _from_polars(dtype: pl.DataType) -> plc.DataType:
|
|
21
|
+
"""
|
|
22
|
+
Convert a polars datatype to a pylibcudf one.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
dtype
|
|
27
|
+
Polars dtype to convert
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
Matching pylibcudf DataType object.
|
|
32
|
+
|
|
33
|
+
Raises
|
|
34
|
+
------
|
|
35
|
+
NotImplementedError
|
|
36
|
+
For unsupported conversions.
|
|
37
|
+
"""
|
|
38
|
+
if isinstance(dtype, pl.Boolean):
|
|
39
|
+
return plc.DataType(plc.TypeId.BOOL8)
|
|
40
|
+
elif isinstance(dtype, pl.Int8):
|
|
41
|
+
return plc.DataType(plc.TypeId.INT8)
|
|
42
|
+
elif isinstance(dtype, pl.Int16):
|
|
43
|
+
return plc.DataType(plc.TypeId.INT16)
|
|
44
|
+
elif isinstance(dtype, pl.Int32):
|
|
45
|
+
return plc.DataType(plc.TypeId.INT32)
|
|
46
|
+
elif isinstance(dtype, pl.Int64):
|
|
47
|
+
return plc.DataType(plc.TypeId.INT64)
|
|
48
|
+
if isinstance(dtype, pl.UInt8):
|
|
49
|
+
return plc.DataType(plc.TypeId.UINT8)
|
|
50
|
+
elif isinstance(dtype, pl.UInt16):
|
|
51
|
+
return plc.DataType(plc.TypeId.UINT16)
|
|
52
|
+
elif isinstance(dtype, pl.UInt32):
|
|
53
|
+
return plc.DataType(plc.TypeId.UINT32)
|
|
54
|
+
elif isinstance(dtype, pl.UInt64):
|
|
55
|
+
return plc.DataType(plc.TypeId.UINT64)
|
|
56
|
+
elif isinstance(dtype, pl.Float32):
|
|
57
|
+
return plc.DataType(plc.TypeId.FLOAT32)
|
|
58
|
+
elif isinstance(dtype, pl.Float64):
|
|
59
|
+
return plc.DataType(plc.TypeId.FLOAT64)
|
|
60
|
+
elif isinstance(dtype, pl.Date):
|
|
61
|
+
return plc.DataType(plc.TypeId.TIMESTAMP_DAYS)
|
|
62
|
+
elif isinstance(dtype, pl.Time):
|
|
63
|
+
raise NotImplementedError("Time of day dtype not implemented")
|
|
64
|
+
elif isinstance(dtype, pl.Datetime):
|
|
65
|
+
if dtype.time_unit == "ms":
|
|
66
|
+
return plc.DataType(plc.TypeId.TIMESTAMP_MILLISECONDS)
|
|
67
|
+
elif dtype.time_unit == "us":
|
|
68
|
+
return plc.DataType(plc.TypeId.TIMESTAMP_MICROSECONDS)
|
|
69
|
+
elif dtype.time_unit == "ns":
|
|
70
|
+
return plc.DataType(plc.TypeId.TIMESTAMP_NANOSECONDS)
|
|
71
|
+
assert dtype.time_unit is not None # pragma: no cover
|
|
72
|
+
assert_never(dtype.time_unit)
|
|
73
|
+
elif isinstance(dtype, pl.Duration):
|
|
74
|
+
if dtype.time_unit == "ms":
|
|
75
|
+
return plc.DataType(plc.TypeId.DURATION_MILLISECONDS)
|
|
76
|
+
elif dtype.time_unit == "us":
|
|
77
|
+
return plc.DataType(plc.TypeId.DURATION_MICROSECONDS)
|
|
78
|
+
elif dtype.time_unit == "ns":
|
|
79
|
+
return plc.DataType(plc.TypeId.DURATION_NANOSECONDS)
|
|
80
|
+
assert dtype.time_unit is not None # pragma: no cover
|
|
81
|
+
assert_never(dtype.time_unit)
|
|
82
|
+
elif isinstance(dtype, pl.String):
|
|
83
|
+
return plc.DataType(plc.TypeId.STRING)
|
|
84
|
+
elif isinstance(dtype, pl.Null):
|
|
85
|
+
# TODO: Hopefully
|
|
86
|
+
return plc.DataType(plc.TypeId.EMPTY)
|
|
87
|
+
elif isinstance(dtype, pl.List):
|
|
88
|
+
# Recurse to catch unsupported inner types
|
|
89
|
+
_ = _from_polars(dtype.inner)
|
|
90
|
+
return plc.DataType(plc.TypeId.LIST)
|
|
91
|
+
elif isinstance(dtype, pl.Struct):
|
|
92
|
+
# Recurse to catch unsupported field types
|
|
93
|
+
for field in dtype.fields:
|
|
94
|
+
_ = _from_polars(field.dtype)
|
|
95
|
+
return plc.DataType(plc.TypeId.STRUCT)
|
|
96
|
+
else:
|
|
97
|
+
raise NotImplementedError(f"{dtype=} conversion not supported")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class DataType:
|
|
101
|
+
"""A datatype, preserving polars metadata."""
|
|
102
|
+
|
|
103
|
+
polars: pl.datatypes.DataType
|
|
104
|
+
plc: plc.DataType
|
|
105
|
+
|
|
106
|
+
def __init__(self, polars_dtype: pl.DataType) -> None:
|
|
107
|
+
self.polars = polars_dtype
|
|
108
|
+
self.plc = _from_polars(polars_dtype)
|
|
109
|
+
|
|
110
|
+
def id(self) -> plc.TypeId:
|
|
111
|
+
"""The pylibcudf.TypeId of this DataType."""
|
|
112
|
+
return self.plc.id()
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def children(self) -> list[DataType]:
|
|
116
|
+
"""The children types of this DataType."""
|
|
117
|
+
if self.plc.id() == plc.TypeId.STRUCT:
|
|
118
|
+
return [DataType(field.dtype) for field in self.polars.fields]
|
|
119
|
+
elif self.plc.id() == plc.TypeId.LIST:
|
|
120
|
+
return [DataType(self.polars.inner)]
|
|
121
|
+
return []
|
|
122
|
+
|
|
123
|
+
def __eq__(self, other: object) -> bool:
|
|
124
|
+
"""Equality of DataTypes."""
|
|
125
|
+
if not isinstance(other, DataType):
|
|
126
|
+
return False
|
|
127
|
+
return self.polars == other.polars
|
|
128
|
+
|
|
129
|
+
def __hash__(self) -> int:
|
|
130
|
+
"""Hash of the DataType."""
|
|
131
|
+
return hash(self.polars)
|
|
132
|
+
|
|
133
|
+
def __repr__(self) -> str:
|
|
134
|
+
"""Representation of the DataType."""
|
|
135
|
+
return f"<DataType(polars={self.polars}, plc={self.id()!r})>"
|
cudf_polars/dsl/expr.py
CHANGED
|
@@ -33,6 +33,7 @@ from cudf_polars.dsl.expressions.selection import Filter, Gather
|
|
|
33
33
|
from cudf_polars.dsl.expressions.slicing import Slice
|
|
34
34
|
from cudf_polars.dsl.expressions.sorting import Sort, SortBy
|
|
35
35
|
from cudf_polars.dsl.expressions.string import StringFunction
|
|
36
|
+
from cudf_polars.dsl.expressions.struct import StructFunction
|
|
36
37
|
from cudf_polars.dsl.expressions.ternary import Ternary
|
|
37
38
|
from cudf_polars.dsl.expressions.unary import Cast, Len, UnaryFunction
|
|
38
39
|
|
|
@@ -58,6 +59,7 @@ __all__ = [
|
|
|
58
59
|
"Sort",
|
|
59
60
|
"SortBy",
|
|
60
61
|
"StringFunction",
|
|
62
|
+
"StructFunction",
|
|
61
63
|
"TemporalFunction",
|
|
62
64
|
"Ternary",
|
|
63
65
|
"UnaryFunction",
|
|
@@ -9,23 +9,14 @@ from __future__ import annotations
|
|
|
9
9
|
from functools import partial
|
|
10
10
|
from typing import TYPE_CHECKING, Any, ClassVar
|
|
11
11
|
|
|
12
|
-
import pyarrow as pa
|
|
13
|
-
|
|
14
12
|
import pylibcudf as plc
|
|
15
13
|
|
|
16
14
|
from cudf_polars.containers import Column
|
|
17
|
-
from cudf_polars.dsl.expressions.base import
|
|
18
|
-
AggInfo,
|
|
19
|
-
ExecutionContext,
|
|
20
|
-
Expr,
|
|
21
|
-
)
|
|
15
|
+
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
|
|
22
16
|
from cudf_polars.dsl.expressions.literal import Literal
|
|
23
|
-
from cudf_polars.dsl.expressions.unary import UnaryFunction
|
|
24
17
|
|
|
25
18
|
if TYPE_CHECKING:
|
|
26
|
-
from
|
|
27
|
-
|
|
28
|
-
from cudf_polars.containers import DataFrame
|
|
19
|
+
from cudf_polars.containers import DataFrame, DataType
|
|
29
20
|
|
|
30
21
|
__all__ = ["Agg"]
|
|
31
22
|
|
|
@@ -35,7 +26,7 @@ class Agg(Expr):
|
|
|
35
26
|
_non_child = ("dtype", "name", "options")
|
|
36
27
|
|
|
37
28
|
def __init__(
|
|
38
|
-
self, dtype:
|
|
29
|
+
self, dtype: DataType, name: str, options: Any, *children: Expr
|
|
39
30
|
) -> None:
|
|
40
31
|
self.dtype = dtype
|
|
41
32
|
self.name = name
|
|
@@ -75,11 +66,15 @@ class Agg(Expr):
|
|
|
75
66
|
else plc.types.NullPolicy.INCLUDE
|
|
76
67
|
)
|
|
77
68
|
elif name == "quantile":
|
|
78
|
-
|
|
69
|
+
child, quantile = self.children
|
|
79
70
|
if not isinstance(quantile, Literal):
|
|
80
71
|
raise NotImplementedError("Only support literal quantile values")
|
|
72
|
+
if options == "equiprobable":
|
|
73
|
+
raise NotImplementedError("Quantile with equiprobable interpolation")
|
|
74
|
+
if plc.traits.is_duration(child.dtype.plc):
|
|
75
|
+
raise NotImplementedError("Quantile with duration data type")
|
|
81
76
|
req = plc.aggregation.quantile(
|
|
82
|
-
quantiles=[quantile.value
|
|
77
|
+
quantiles=[quantile.value], interp=Agg.interp_mapping[options]
|
|
83
78
|
)
|
|
84
79
|
else:
|
|
85
80
|
raise NotImplementedError(
|
|
@@ -91,7 +86,9 @@ class Agg(Expr):
|
|
|
91
86
|
op = partial(self._reduce, request=req)
|
|
92
87
|
elif name in {"min", "max"}:
|
|
93
88
|
op = partial(op, propagate_nans=options)
|
|
94
|
-
elif name
|
|
89
|
+
elif name == "count":
|
|
90
|
+
op = partial(op, include_nulls=options)
|
|
91
|
+
elif name in {"sum", "first", "last"}:
|
|
95
92
|
pass
|
|
96
93
|
else:
|
|
97
94
|
raise NotImplementedError(
|
|
@@ -124,71 +121,52 @@ class Agg(Expr):
|
|
|
124
121
|
"linear": plc.types.Interpolation.LINEAR,
|
|
125
122
|
}
|
|
126
123
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
if depth >= 1:
|
|
130
|
-
raise NotImplementedError(
|
|
131
|
-
"Nested aggregations in groupby"
|
|
132
|
-
) # pragma: no cover; check_agg trips first
|
|
133
|
-
if (isminmax := self.name in {"min", "max"}) and self.options:
|
|
134
|
-
raise NotImplementedError("Nan propagation in groupby for min/max")
|
|
135
|
-
(child,) = self.children
|
|
136
|
-
((expr, _, _),) = child.collect_agg(depth=depth + 1).requests
|
|
137
|
-
request = self.request
|
|
138
|
-
# These are handled specially here because we don't set up the
|
|
139
|
-
# request for the whole-frame agg because we can avoid a
|
|
140
|
-
# reduce for these.
|
|
124
|
+
@property
|
|
125
|
+
def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102
|
|
141
126
|
if self.name == "first":
|
|
142
|
-
|
|
127
|
+
return plc.aggregation.nth_element(
|
|
143
128
|
0, null_handling=plc.types.NullPolicy.INCLUDE
|
|
144
129
|
)
|
|
145
130
|
elif self.name == "last":
|
|
146
|
-
|
|
131
|
+
return plc.aggregation.nth_element(
|
|
147
132
|
-1, null_handling=plc.types.NullPolicy.INCLUDE
|
|
148
133
|
)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
) # pragma: no cover; __init__ trips first
|
|
153
|
-
if isminmax and plc.traits.is_floating_point(self.dtype):
|
|
154
|
-
assert expr is not None
|
|
155
|
-
# Ignore nans in these groupby aggs, do this by masking
|
|
156
|
-
# nans in the input
|
|
157
|
-
expr = UnaryFunction(self.dtype, "mask_nans", (), expr)
|
|
158
|
-
return AggInfo([(expr, request, self)])
|
|
134
|
+
else:
|
|
135
|
+
assert self.request is not None, "Init should have raised"
|
|
136
|
+
return self.request
|
|
159
137
|
|
|
160
138
|
def _reduce(
|
|
161
139
|
self, column: Column, *, request: plc.aggregation.Aggregation
|
|
162
140
|
) -> Column:
|
|
163
141
|
return Column(
|
|
164
142
|
plc.Column.from_scalar(
|
|
165
|
-
plc.reduce.reduce(column.obj, request, self.dtype),
|
|
143
|
+
plc.reduce.reduce(column.obj, request, self.dtype.plc),
|
|
166
144
|
1,
|
|
167
|
-
)
|
|
145
|
+
),
|
|
146
|
+
name=column.name,
|
|
147
|
+
dtype=self.dtype,
|
|
168
148
|
)
|
|
169
149
|
|
|
170
|
-
def _count(self, column: Column) -> Column:
|
|
150
|
+
def _count(self, column: Column, *, include_nulls: bool) -> Column:
|
|
151
|
+
null_count = column.null_count if not include_nulls else 0
|
|
171
152
|
return Column(
|
|
172
153
|
plc.Column.from_scalar(
|
|
173
|
-
plc.
|
|
174
|
-
pa.scalar(
|
|
175
|
-
column.size - column.null_count,
|
|
176
|
-
type=plc.interop.to_arrow(self.dtype),
|
|
177
|
-
),
|
|
178
|
-
),
|
|
154
|
+
plc.Scalar.from_py(column.size - null_count, self.dtype.plc),
|
|
179
155
|
1,
|
|
180
|
-
)
|
|
156
|
+
),
|
|
157
|
+
name=column.name,
|
|
158
|
+
dtype=self.dtype,
|
|
181
159
|
)
|
|
182
160
|
|
|
183
161
|
def _sum(self, column: Column) -> Column:
|
|
184
162
|
if column.size == 0 or column.null_count == column.size:
|
|
185
163
|
return Column(
|
|
186
164
|
plc.Column.from_scalar(
|
|
187
|
-
plc.
|
|
188
|
-
pa.scalar(0, type=plc.interop.to_arrow(self.dtype))
|
|
189
|
-
),
|
|
165
|
+
plc.Scalar.from_py(0, self.dtype.plc),
|
|
190
166
|
1,
|
|
191
|
-
)
|
|
167
|
+
),
|
|
168
|
+
name=column.name,
|
|
169
|
+
dtype=self.dtype,
|
|
192
170
|
)
|
|
193
171
|
return self._reduce(column, request=plc.aggregation.sum())
|
|
194
172
|
|
|
@@ -196,11 +174,11 @@ class Agg(Expr):
|
|
|
196
174
|
if propagate_nans and column.nan_count > 0:
|
|
197
175
|
return Column(
|
|
198
176
|
plc.Column.from_scalar(
|
|
199
|
-
plc.
|
|
200
|
-
pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
|
|
201
|
-
),
|
|
177
|
+
plc.Scalar.from_py(float("nan"), self.dtype.plc),
|
|
202
178
|
1,
|
|
203
|
-
)
|
|
179
|
+
),
|
|
180
|
+
name=column.name,
|
|
181
|
+
dtype=self.dtype,
|
|
204
182
|
)
|
|
205
183
|
if column.nan_count > 0:
|
|
206
184
|
column = column.mask_nans()
|
|
@@ -210,29 +188,31 @@ class Agg(Expr):
|
|
|
210
188
|
if propagate_nans and column.nan_count > 0:
|
|
211
189
|
return Column(
|
|
212
190
|
plc.Column.from_scalar(
|
|
213
|
-
plc.
|
|
214
|
-
pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
|
|
215
|
-
),
|
|
191
|
+
plc.Scalar.from_py(float("nan"), self.dtype.plc),
|
|
216
192
|
1,
|
|
217
|
-
)
|
|
193
|
+
),
|
|
194
|
+
name=column.name,
|
|
195
|
+
dtype=self.dtype,
|
|
218
196
|
)
|
|
219
197
|
if column.nan_count > 0:
|
|
220
198
|
column = column.mask_nans()
|
|
221
199
|
return self._reduce(column, request=plc.aggregation.max())
|
|
222
200
|
|
|
223
201
|
def _first(self, column: Column) -> Column:
|
|
224
|
-
return Column(
|
|
202
|
+
return Column(
|
|
203
|
+
plc.copying.slice(column.obj, [0, 1])[0], name=column.name, dtype=self.dtype
|
|
204
|
+
)
|
|
225
205
|
|
|
226
206
|
def _last(self, column: Column) -> Column:
|
|
227
207
|
n = column.size
|
|
228
|
-
return Column(
|
|
208
|
+
return Column(
|
|
209
|
+
plc.copying.slice(column.obj, [n - 1, n])[0],
|
|
210
|
+
name=column.name,
|
|
211
|
+
dtype=self.dtype,
|
|
212
|
+
)
|
|
229
213
|
|
|
230
214
|
def do_evaluate(
|
|
231
|
-
self,
|
|
232
|
-
df: DataFrame,
|
|
233
|
-
*,
|
|
234
|
-
context: ExecutionContext = ExecutionContext.FRAME,
|
|
235
|
-
mapping: Mapping[Expr, Column] | None = None,
|
|
215
|
+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
|
|
236
216
|
) -> Column:
|
|
237
217
|
"""Evaluate this expression given a dataframe for context."""
|
|
238
218
|
if context is not ExecutionContext.FRAME:
|
|
@@ -243,4 +223,4 @@ class Agg(Expr):
|
|
|
243
223
|
# Aggregations like quantiles may have additional children that were
|
|
244
224
|
# preprocessed into pylibcudf requests.
|
|
245
225
|
child = self.children[0]
|
|
246
|
-
return self.op(child.evaluate(df, context=context
|
|
226
|
+
return self.op(child.evaluate(df, context=context))
|