cudf-polars-cu12 24.8.0a281__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/VERSION +1 -0
- cudf_polars/__init__.py +22 -0
- cudf_polars/_version.py +21 -0
- cudf_polars/callback.py +66 -0
- cudf_polars/containers/__init__.py +11 -0
- cudf_polars/containers/column.py +189 -0
- cudf_polars/containers/dataframe.py +226 -0
- cudf_polars/dsl/__init__.py +8 -0
- cudf_polars/dsl/expr.py +1422 -0
- cudf_polars/dsl/ir.py +1053 -0
- cudf_polars/dsl/translate.py +535 -0
- cudf_polars/py.typed +0 -0
- cudf_polars/testing/__init__.py +8 -0
- cudf_polars/testing/asserts.py +118 -0
- cudf_polars/typing/__init__.py +106 -0
- cudf_polars/utils/__init__.py +8 -0
- cudf_polars/utils/dtypes.py +159 -0
- cudf_polars/utils/sorting.py +53 -0
- cudf_polars_cu12-24.8.0a281.dist-info/LICENSE +201 -0
- cudf_polars_cu12-24.8.0a281.dist-info/METADATA +126 -0
- cudf_polars_cu12-24.8.0a281.dist-info/RECORD +23 -0
- cudf_polars_cu12-24.8.0a281.dist-info/WHEEL +5 -0
- cudf_polars_cu12-24.8.0a281.dist-info/top_level.txt +1 -0
cudf_polars/dsl/ir.py
ADDED
|
@@ -0,0 +1,1053 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""
|
|
4
|
+
DSL nodes for the LogicalPlan of polars.
|
|
5
|
+
|
|
6
|
+
An IR node is either a source, normal, or a sink. Respectively they
|
|
7
|
+
can be considered as functions:
|
|
8
|
+
|
|
9
|
+
- source: `IO () -> DataFrame`
|
|
10
|
+
- normal: `DataFrame -> DataFrame`
|
|
11
|
+
- sink: `DataFrame -> IO ()`
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
import itertools
|
|
18
|
+
import types
|
|
19
|
+
from functools import cache
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
|
22
|
+
|
|
23
|
+
import pyarrow as pa
|
|
24
|
+
from typing_extensions import assert_never
|
|
25
|
+
|
|
26
|
+
import polars as pl
|
|
27
|
+
|
|
28
|
+
import cudf
|
|
29
|
+
import cudf._lib.pylibcudf as plc
|
|
30
|
+
|
|
31
|
+
import cudf_polars.dsl.expr as expr
|
|
32
|
+
from cudf_polars.containers import DataFrame, NamedColumn
|
|
33
|
+
from cudf_polars.utils import dtypes, sorting
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from collections.abc import MutableMapping
|
|
37
|
+
from typing import Literal
|
|
38
|
+
|
|
39
|
+
from cudf_polars.typing import Schema
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
__all__ = [
|
|
43
|
+
"IR",
|
|
44
|
+
"PythonScan",
|
|
45
|
+
"Scan",
|
|
46
|
+
"Cache",
|
|
47
|
+
"DataFrameScan",
|
|
48
|
+
"Select",
|
|
49
|
+
"GroupBy",
|
|
50
|
+
"Join",
|
|
51
|
+
"HStack",
|
|
52
|
+
"Distinct",
|
|
53
|
+
"Sort",
|
|
54
|
+
"Slice",
|
|
55
|
+
"Filter",
|
|
56
|
+
"Projection",
|
|
57
|
+
"MapFunction",
|
|
58
|
+
"Union",
|
|
59
|
+
"HConcat",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def broadcast(
|
|
64
|
+
*columns: NamedColumn, target_length: int | None = None
|
|
65
|
+
) -> list[NamedColumn]:
|
|
66
|
+
"""
|
|
67
|
+
Broadcast a sequence of columns to a common length.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
columns
|
|
72
|
+
Columns to broadcast.
|
|
73
|
+
target_length
|
|
74
|
+
Optional length to broadcast to. If not provided, uses the
|
|
75
|
+
non-unit length of existing columns.
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
List of broadcasted columns all of the same length.
|
|
80
|
+
|
|
81
|
+
Raises
|
|
82
|
+
------
|
|
83
|
+
RuntimeError
|
|
84
|
+
If broadcasting is not possible.
|
|
85
|
+
|
|
86
|
+
Notes
|
|
87
|
+
-----
|
|
88
|
+
In evaluation of a set of expressions, polars type-puns length-1
|
|
89
|
+
columns with scalars. When we insert these into a DataFrame
|
|
90
|
+
object, we need to ensure they are of equal length. This function
|
|
91
|
+
takes some columns, some of which may be length-1 and ensures that
|
|
92
|
+
all length-1 columns are broadcast to the length of the others.
|
|
93
|
+
|
|
94
|
+
Broadcasting is only possible if the set of lengths of the input
|
|
95
|
+
columns is a subset of ``{1, n}`` for some (fixed) ``n``. If
|
|
96
|
+
``target_length`` is provided and not all columns are length-1
|
|
97
|
+
(i.e. ``n != 1``), then ``target_length`` must be equal to ``n``.
|
|
98
|
+
"""
|
|
99
|
+
if len(columns) == 0:
|
|
100
|
+
return []
|
|
101
|
+
lengths: set[int] = {column.obj.size() for column in columns}
|
|
102
|
+
if lengths == {1}:
|
|
103
|
+
if target_length is None:
|
|
104
|
+
return list(columns)
|
|
105
|
+
nrows = target_length
|
|
106
|
+
else:
|
|
107
|
+
try:
|
|
108
|
+
(nrows,) = lengths.difference([1])
|
|
109
|
+
except ValueError as e:
|
|
110
|
+
raise RuntimeError("Mismatching column lengths") from e
|
|
111
|
+
if target_length is not None and nrows != target_length:
|
|
112
|
+
raise RuntimeError(
|
|
113
|
+
f"Cannot broadcast columns of length {nrows=} to {target_length=}"
|
|
114
|
+
)
|
|
115
|
+
return [
|
|
116
|
+
column
|
|
117
|
+
if column.obj.size() != 1
|
|
118
|
+
else NamedColumn(
|
|
119
|
+
plc.Column.from_scalar(column.obj_scalar, nrows),
|
|
120
|
+
column.name,
|
|
121
|
+
is_sorted=plc.types.Sorted.YES,
|
|
122
|
+
order=plc.types.Order.ASCENDING,
|
|
123
|
+
null_order=plc.types.NullOrder.BEFORE,
|
|
124
|
+
)
|
|
125
|
+
for column in columns
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclasses.dataclass
|
|
130
|
+
class IR:
|
|
131
|
+
"""Abstract plan node, representing an unevaluated dataframe."""
|
|
132
|
+
|
|
133
|
+
schema: Schema
|
|
134
|
+
"""Mapping from column names to their data types."""
|
|
135
|
+
|
|
136
|
+
def __post_init__(self):
|
|
137
|
+
"""Validate preconditions."""
|
|
138
|
+
if any(dtype.id() == plc.TypeId.EMPTY for dtype in self.schema.values()):
|
|
139
|
+
raise NotImplementedError("Cannot make empty columns.")
|
|
140
|
+
|
|
141
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
142
|
+
"""
|
|
143
|
+
Evaluate the node and return a dataframe.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
cache
|
|
148
|
+
Mapping from cached node ids to constructed DataFrames.
|
|
149
|
+
Used to implement evaluation of the `Cache` node.
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
DataFrame (on device) representing the evaluation of this plan
|
|
154
|
+
node.
|
|
155
|
+
|
|
156
|
+
Raises
|
|
157
|
+
------
|
|
158
|
+
NotImplementedError
|
|
159
|
+
If we couldn't evaluate things. Ideally this should not occur,
|
|
160
|
+
since the translation phase should pick up things that we
|
|
161
|
+
cannot handle.
|
|
162
|
+
"""
|
|
163
|
+
raise NotImplementedError(
|
|
164
|
+
f"Evaluation of plan {type(self).__name__}"
|
|
165
|
+
) # pragma: no cover
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@dataclasses.dataclass
|
|
169
|
+
class PythonScan(IR):
|
|
170
|
+
"""Representation of input from a python function."""
|
|
171
|
+
|
|
172
|
+
options: Any
|
|
173
|
+
"""Arbitrary options."""
|
|
174
|
+
predicate: expr.NamedExpr | None
|
|
175
|
+
"""Filter to apply to the constructed dataframe before returning it."""
|
|
176
|
+
|
|
177
|
+
def __post_init__(self):
|
|
178
|
+
"""Validate preconditions."""
|
|
179
|
+
raise NotImplementedError("PythonScan not implemented")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@dataclasses.dataclass
|
|
183
|
+
class Scan(IR):
|
|
184
|
+
"""Input from files."""
|
|
185
|
+
|
|
186
|
+
typ: str
|
|
187
|
+
"""What type of file are we reading? Parquet, CSV, etc..."""
|
|
188
|
+
reader_options: dict[str, Any]
|
|
189
|
+
"""Reader-specific options, as dictionary."""
|
|
190
|
+
cloud_options: dict[str, Any] | None
|
|
191
|
+
"""Cloud-related authentication options, currently ignored."""
|
|
192
|
+
paths: list[str]
|
|
193
|
+
"""List of paths to read from."""
|
|
194
|
+
file_options: Any
|
|
195
|
+
"""Options for reading the file.
|
|
196
|
+
|
|
197
|
+
Attributes are:
|
|
198
|
+
- ``with_columns: list[str]`` of projected columns to return.
|
|
199
|
+
- ``n_rows: int``: Number of rows to read.
|
|
200
|
+
- ``row_index: tuple[name, offset] | None``: Add an integer index
|
|
201
|
+
column with given name.
|
|
202
|
+
"""
|
|
203
|
+
predicate: expr.NamedExpr | None
|
|
204
|
+
"""Mask to apply to the read dataframe."""
|
|
205
|
+
|
|
206
|
+
def __post_init__(self) -> None:
|
|
207
|
+
"""Validate preconditions."""
|
|
208
|
+
if self.file_options.n_rows is not None:
|
|
209
|
+
raise NotImplementedError("row limit in scan")
|
|
210
|
+
if self.typ not in ("csv", "parquet"):
|
|
211
|
+
raise NotImplementedError(f"Unhandled scan type: {self.typ}")
|
|
212
|
+
if self.cloud_options is not None and any(
|
|
213
|
+
self.cloud_options[k] is not None for k in ("aws", "azure", "gcp")
|
|
214
|
+
):
|
|
215
|
+
raise NotImplementedError(
|
|
216
|
+
"Read from cloud storage"
|
|
217
|
+
) # pragma: no cover; no test yet
|
|
218
|
+
if self.typ == "csv":
|
|
219
|
+
if self.reader_options["skip_rows_after_header"] != 0:
|
|
220
|
+
raise NotImplementedError("Skipping rows after header in CSV reader")
|
|
221
|
+
parse_options = self.reader_options["parse_options"]
|
|
222
|
+
if (
|
|
223
|
+
null_values := parse_options["null_values"]
|
|
224
|
+
) is not None and "Named" in null_values:
|
|
225
|
+
raise NotImplementedError(
|
|
226
|
+
"Per column null value specification not supported for CSV reader"
|
|
227
|
+
)
|
|
228
|
+
if (
|
|
229
|
+
comment := parse_options["comment_prefix"]
|
|
230
|
+
) is not None and "Multi" in comment:
|
|
231
|
+
raise NotImplementedError(
|
|
232
|
+
"Multi-character comment prefix not supported for CSV reader"
|
|
233
|
+
)
|
|
234
|
+
if not self.reader_options["has_header"]:
|
|
235
|
+
# Need to do some file introspection to get the number
|
|
236
|
+
# of columns so that column projection works right.
|
|
237
|
+
raise NotImplementedError("Reading CSV without header")
|
|
238
|
+
|
|
239
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
240
|
+
"""Evaluate and return a dataframe."""
|
|
241
|
+
options = self.file_options
|
|
242
|
+
with_columns = options.with_columns
|
|
243
|
+
row_index = options.row_index
|
|
244
|
+
if self.typ == "csv":
|
|
245
|
+
dtype_map = {
|
|
246
|
+
name: cudf._lib.types.PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES[typ.id()]
|
|
247
|
+
for name, typ in self.schema.items()
|
|
248
|
+
}
|
|
249
|
+
parse_options = self.reader_options["parse_options"]
|
|
250
|
+
sep = chr(parse_options["separator"])
|
|
251
|
+
quote = chr(parse_options["quote_char"])
|
|
252
|
+
eol = chr(parse_options["eol_char"])
|
|
253
|
+
if self.reader_options["schema"] is not None:
|
|
254
|
+
# Reader schema provides names
|
|
255
|
+
column_names = list(self.reader_options["schema"]["inner"].keys())
|
|
256
|
+
else:
|
|
257
|
+
# file provides column names
|
|
258
|
+
column_names = None
|
|
259
|
+
usecols = with_columns
|
|
260
|
+
# TODO: support has_header=False
|
|
261
|
+
header = 0
|
|
262
|
+
|
|
263
|
+
# polars defaults to no null recognition
|
|
264
|
+
null_values = [""]
|
|
265
|
+
if parse_options["null_values"] is not None:
|
|
266
|
+
((typ, nulls),) = parse_options["null_values"].items()
|
|
267
|
+
if typ == "AllColumnsSingle":
|
|
268
|
+
# Single value
|
|
269
|
+
null_values.append(nulls)
|
|
270
|
+
else:
|
|
271
|
+
# List of values
|
|
272
|
+
null_values.extend(nulls)
|
|
273
|
+
if parse_options["comment_prefix"] is not None:
|
|
274
|
+
comment = chr(parse_options["comment_prefix"]["Single"])
|
|
275
|
+
else:
|
|
276
|
+
comment = None
|
|
277
|
+
decimal = "," if parse_options["decimal_comma"] else "."
|
|
278
|
+
|
|
279
|
+
# polars skips blank lines at the beginning of the file
|
|
280
|
+
pieces = []
|
|
281
|
+
for p in self.paths:
|
|
282
|
+
skiprows = self.reader_options["skip_rows"]
|
|
283
|
+
# TODO: read_csv expands globs which we should not do,
|
|
284
|
+
# because polars will already have handled them.
|
|
285
|
+
path = Path(p)
|
|
286
|
+
with path.open() as f:
|
|
287
|
+
while f.readline() == "\n":
|
|
288
|
+
skiprows += 1
|
|
289
|
+
pieces.append(
|
|
290
|
+
cudf.read_csv(
|
|
291
|
+
path,
|
|
292
|
+
sep=sep,
|
|
293
|
+
quotechar=quote,
|
|
294
|
+
lineterminator=eol,
|
|
295
|
+
names=column_names,
|
|
296
|
+
header=header,
|
|
297
|
+
usecols=usecols,
|
|
298
|
+
na_filter=True,
|
|
299
|
+
na_values=null_values,
|
|
300
|
+
keep_default_na=False,
|
|
301
|
+
skiprows=skiprows,
|
|
302
|
+
comment=comment,
|
|
303
|
+
decimal=decimal,
|
|
304
|
+
dtype=dtype_map,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
df = DataFrame.from_cudf(cudf.concat(pieces))
|
|
308
|
+
elif self.typ == "parquet":
|
|
309
|
+
cdf = cudf.read_parquet(self.paths, columns=with_columns)
|
|
310
|
+
assert isinstance(cdf, cudf.DataFrame)
|
|
311
|
+
df = DataFrame.from_cudf(cdf)
|
|
312
|
+
else:
|
|
313
|
+
raise NotImplementedError(
|
|
314
|
+
f"Unhandled scan type: {self.typ}"
|
|
315
|
+
) # pragma: no cover; post init trips first
|
|
316
|
+
if row_index is not None:
|
|
317
|
+
name, offset = row_index
|
|
318
|
+
dtype = self.schema[name]
|
|
319
|
+
step = plc.interop.from_arrow(
|
|
320
|
+
pa.scalar(1, type=plc.interop.to_arrow(dtype))
|
|
321
|
+
)
|
|
322
|
+
init = plc.interop.from_arrow(
|
|
323
|
+
pa.scalar(offset, type=plc.interop.to_arrow(dtype))
|
|
324
|
+
)
|
|
325
|
+
index = NamedColumn(
|
|
326
|
+
plc.filling.sequence(df.num_rows, init, step),
|
|
327
|
+
name,
|
|
328
|
+
is_sorted=plc.types.Sorted.YES,
|
|
329
|
+
order=plc.types.Order.ASCENDING,
|
|
330
|
+
null_order=plc.types.NullOrder.AFTER,
|
|
331
|
+
)
|
|
332
|
+
df = DataFrame([index, *df.columns])
|
|
333
|
+
# TODO: should be true, but not the case until we get
|
|
334
|
+
# cudf-classic out of the loop for IO since it converts date32
|
|
335
|
+
# to datetime.
|
|
336
|
+
# assert all(
|
|
337
|
+
# c.obj.type() == dtype
|
|
338
|
+
# for c, dtype in zip(df.columns, self.schema.values())
|
|
339
|
+
# )
|
|
340
|
+
if self.predicate is None:
|
|
341
|
+
return df
|
|
342
|
+
else:
|
|
343
|
+
(mask,) = broadcast(self.predicate.evaluate(df), target_length=df.num_rows)
|
|
344
|
+
return df.filter(mask)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@dataclasses.dataclass
|
|
348
|
+
class Cache(IR):
|
|
349
|
+
"""
|
|
350
|
+
Return a cached plan node.
|
|
351
|
+
|
|
352
|
+
Used for CSE at the plan level.
|
|
353
|
+
"""
|
|
354
|
+
|
|
355
|
+
key: int
|
|
356
|
+
"""The cache key."""
|
|
357
|
+
value: IR
|
|
358
|
+
"""The unevaluated node to cache."""
|
|
359
|
+
|
|
360
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
361
|
+
"""Evaluate and return a dataframe."""
|
|
362
|
+
try:
|
|
363
|
+
return cache[self.key]
|
|
364
|
+
except KeyError:
|
|
365
|
+
return cache.setdefault(self.key, self.value.evaluate(cache=cache))
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
@dataclasses.dataclass
|
|
369
|
+
class DataFrameScan(IR):
|
|
370
|
+
"""
|
|
371
|
+
Input from an existing polars DataFrame.
|
|
372
|
+
|
|
373
|
+
This typically arises from ``q.collect().lazy()``
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
df: Any
|
|
377
|
+
"""Polars LazyFrame object."""
|
|
378
|
+
projection: list[str]
|
|
379
|
+
"""List of columns to project out."""
|
|
380
|
+
predicate: expr.NamedExpr | None
|
|
381
|
+
"""Mask to apply."""
|
|
382
|
+
|
|
383
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
384
|
+
"""Evaluate and return a dataframe."""
|
|
385
|
+
pdf = pl.DataFrame._from_pydf(self.df)
|
|
386
|
+
if self.projection is not None:
|
|
387
|
+
pdf = pdf.select(self.projection)
|
|
388
|
+
table = pdf.to_arrow()
|
|
389
|
+
schema = table.schema
|
|
390
|
+
for i, field in enumerate(schema):
|
|
391
|
+
schema = schema.set(
|
|
392
|
+
i, pa.field(field.name, dtypes.downcast_arrow_lists(field.type))
|
|
393
|
+
)
|
|
394
|
+
# No-op if the schema is unchanged.
|
|
395
|
+
table = table.cast(schema)
|
|
396
|
+
df = DataFrame.from_table(
|
|
397
|
+
plc.interop.from_arrow(table), list(self.schema.keys())
|
|
398
|
+
)
|
|
399
|
+
assert all(
|
|
400
|
+
c.obj.type() == dtype for c, dtype in zip(df.columns, self.schema.values())
|
|
401
|
+
)
|
|
402
|
+
if self.predicate is not None:
|
|
403
|
+
(mask,) = broadcast(self.predicate.evaluate(df), target_length=df.num_rows)
|
|
404
|
+
return df.filter(mask)
|
|
405
|
+
else:
|
|
406
|
+
return df
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
@dataclasses.dataclass
|
|
410
|
+
class Select(IR):
|
|
411
|
+
"""Produce a new dataframe selecting given expressions from an input."""
|
|
412
|
+
|
|
413
|
+
df: IR
|
|
414
|
+
"""Input dataframe."""
|
|
415
|
+
expr: list[expr.NamedExpr]
|
|
416
|
+
"""List of expressions to evaluate to form the new dataframe."""
|
|
417
|
+
should_broadcast: bool
|
|
418
|
+
"""Should columns be broadcast?"""
|
|
419
|
+
|
|
420
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
421
|
+
"""Evaluate and return a dataframe."""
|
|
422
|
+
df = self.df.evaluate(cache=cache)
|
|
423
|
+
# Handle any broadcasting
|
|
424
|
+
columns = [e.evaluate(df) for e in self.expr]
|
|
425
|
+
if self.should_broadcast:
|
|
426
|
+
columns = broadcast(*columns)
|
|
427
|
+
return DataFrame(columns)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
@dataclasses.dataclass
|
|
431
|
+
class Reduce(IR):
|
|
432
|
+
"""
|
|
433
|
+
Produce a new dataframe selecting given expressions from an input.
|
|
434
|
+
|
|
435
|
+
This is a special case of :class:`Select` where all outputs are a single row.
|
|
436
|
+
"""
|
|
437
|
+
|
|
438
|
+
df: IR
|
|
439
|
+
"""Input dataframe."""
|
|
440
|
+
expr: list[expr.NamedExpr]
|
|
441
|
+
"""List of expressions to evaluate to form the new dataframe."""
|
|
442
|
+
|
|
443
|
+
def evaluate(
|
|
444
|
+
self, *, cache: MutableMapping[int, DataFrame]
|
|
445
|
+
) -> DataFrame: # pragma: no cover; polars doesn't emit this node yet
|
|
446
|
+
"""Evaluate and return a dataframe."""
|
|
447
|
+
df = self.df.evaluate(cache=cache)
|
|
448
|
+
columns = broadcast(*(e.evaluate(df) for e in self.expr))
|
|
449
|
+
assert all(column.obj.size() == 1 for column in columns)
|
|
450
|
+
return DataFrame(columns)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def placeholder_column(n: int) -> plc.Column:
|
|
454
|
+
"""
|
|
455
|
+
Produce a placeholder pylibcudf column with NO BACKING DATA.
|
|
456
|
+
|
|
457
|
+
Parameters
|
|
458
|
+
----------
|
|
459
|
+
n
|
|
460
|
+
Number of rows the column will advertise
|
|
461
|
+
|
|
462
|
+
Returns
|
|
463
|
+
-------
|
|
464
|
+
pylibcudf Column that is almost unusable. DO NOT ACCESS THE DATA BUFFER.
|
|
465
|
+
|
|
466
|
+
Notes
|
|
467
|
+
-----
|
|
468
|
+
This is used to avoid allocating data for count aggregations.
|
|
469
|
+
"""
|
|
470
|
+
return plc.Column(
|
|
471
|
+
plc.DataType(plc.TypeId.INT8),
|
|
472
|
+
n,
|
|
473
|
+
plc.gpumemoryview(
|
|
474
|
+
types.SimpleNamespace(__cuda_array_interface__={"data": (1, True)})
|
|
475
|
+
),
|
|
476
|
+
None,
|
|
477
|
+
0,
|
|
478
|
+
0,
|
|
479
|
+
[],
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@dataclasses.dataclass
|
|
484
|
+
class GroupBy(IR):
|
|
485
|
+
"""Perform a groupby."""
|
|
486
|
+
|
|
487
|
+
df: IR
|
|
488
|
+
"""Input dataframe."""
|
|
489
|
+
agg_requests: list[expr.NamedExpr]
|
|
490
|
+
"""List of expressions to evaluate groupwise."""
|
|
491
|
+
keys: list[expr.NamedExpr]
|
|
492
|
+
"""List of expressions forming the keys."""
|
|
493
|
+
maintain_order: bool
|
|
494
|
+
"""Should the order of the input dataframe be maintained?"""
|
|
495
|
+
options: Any
|
|
496
|
+
"""Options controlling style of groupby."""
|
|
497
|
+
agg_infos: list[expr.AggInfo] = dataclasses.field(init=False)
|
|
498
|
+
|
|
499
|
+
@staticmethod
|
|
500
|
+
def check_agg(agg: expr.Expr) -> int:
|
|
501
|
+
"""
|
|
502
|
+
Determine if we can handle an aggregation expression.
|
|
503
|
+
|
|
504
|
+
Parameters
|
|
505
|
+
----------
|
|
506
|
+
agg
|
|
507
|
+
Expression to check
|
|
508
|
+
|
|
509
|
+
Returns
|
|
510
|
+
-------
|
|
511
|
+
depth of nesting
|
|
512
|
+
|
|
513
|
+
Raises
|
|
514
|
+
------
|
|
515
|
+
NotImplementedError
|
|
516
|
+
For unsupported expression nodes.
|
|
517
|
+
"""
|
|
518
|
+
if isinstance(agg, (expr.BinOp, expr.Cast, expr.UnaryFunction)):
|
|
519
|
+
return max(GroupBy.check_agg(child) for child in agg.children)
|
|
520
|
+
elif isinstance(agg, expr.Agg):
|
|
521
|
+
return 1 + max(GroupBy.check_agg(child) for child in agg.children)
|
|
522
|
+
elif isinstance(agg, (expr.Len, expr.Col, expr.Literal)):
|
|
523
|
+
return 0
|
|
524
|
+
else:
|
|
525
|
+
raise NotImplementedError(f"No handler for {agg=}")
|
|
526
|
+
|
|
527
|
+
def __post_init__(self) -> None:
|
|
528
|
+
"""Check whether all the aggregations are implemented."""
|
|
529
|
+
if self.options.rolling is None and self.maintain_order:
|
|
530
|
+
raise NotImplementedError("Maintaining order in groupby")
|
|
531
|
+
if self.options.rolling:
|
|
532
|
+
raise NotImplementedError(
|
|
533
|
+
"rolling window/groupby"
|
|
534
|
+
) # pragma: no cover; rollingwindow constructor has already raised
|
|
535
|
+
if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests):
|
|
536
|
+
raise NotImplementedError("Nested aggregations in groupby")
|
|
537
|
+
self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests]
|
|
538
|
+
|
|
539
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
540
|
+
"""Evaluate and return a dataframe."""
|
|
541
|
+
df = self.df.evaluate(cache=cache)
|
|
542
|
+
keys = broadcast(
|
|
543
|
+
*(k.evaluate(df) for k in self.keys), target_length=df.num_rows
|
|
544
|
+
)
|
|
545
|
+
# TODO: use sorted information, need to expose column_order
|
|
546
|
+
# and null_precedence in pylibcudf groupby constructor
|
|
547
|
+
# sorted = (
|
|
548
|
+
# plc.types.Sorted.YES
|
|
549
|
+
# if all(k.is_sorted for k in keys)
|
|
550
|
+
# else plc.types.Sorted.NO
|
|
551
|
+
# )
|
|
552
|
+
grouper = plc.groupby.GroupBy(
|
|
553
|
+
plc.Table([k.obj for k in keys]),
|
|
554
|
+
null_handling=plc.types.NullPolicy.INCLUDE,
|
|
555
|
+
)
|
|
556
|
+
# TODO: uniquify
|
|
557
|
+
requests = []
|
|
558
|
+
replacements: list[expr.Expr] = []
|
|
559
|
+
for info in self.agg_infos:
|
|
560
|
+
for pre_eval, req, rep in info.requests:
|
|
561
|
+
if pre_eval is None:
|
|
562
|
+
col = placeholder_column(df.num_rows)
|
|
563
|
+
else:
|
|
564
|
+
col = pre_eval.evaluate(df).obj
|
|
565
|
+
requests.append(plc.groupby.GroupByRequest(col, [req]))
|
|
566
|
+
replacements.append(rep)
|
|
567
|
+
group_keys, raw_tables = grouper.aggregate(requests)
|
|
568
|
+
# TODO: names
|
|
569
|
+
raw_columns: list[NamedColumn] = []
|
|
570
|
+
for i, table in enumerate(raw_tables):
|
|
571
|
+
(column,) = table.columns()
|
|
572
|
+
raw_columns.append(NamedColumn(column, f"tmp{i}"))
|
|
573
|
+
mapping = dict(zip(replacements, raw_columns))
|
|
574
|
+
result_keys = [
|
|
575
|
+
NamedColumn(gk, k.name) for gk, k in zip(group_keys.columns(), keys)
|
|
576
|
+
]
|
|
577
|
+
result_subs = DataFrame(raw_columns)
|
|
578
|
+
results = [
|
|
579
|
+
req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests
|
|
580
|
+
]
|
|
581
|
+
return DataFrame([*result_keys, *results]).slice(self.options.slice)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
@dataclasses.dataclass
|
|
585
|
+
class Join(IR):
|
|
586
|
+
"""A join of two dataframes."""
|
|
587
|
+
|
|
588
|
+
left: IR
|
|
589
|
+
"""Left frame."""
|
|
590
|
+
right: IR
|
|
591
|
+
"""Right frame."""
|
|
592
|
+
left_on: list[expr.NamedExpr]
|
|
593
|
+
"""List of expressions used as keys in the left frame."""
|
|
594
|
+
right_on: list[expr.NamedExpr]
|
|
595
|
+
"""List of expressions used as keys in the right frame."""
|
|
596
|
+
options: tuple[
|
|
597
|
+
Literal["inner", "left", "full", "leftsemi", "leftanti", "cross"],
|
|
598
|
+
bool,
|
|
599
|
+
tuple[int, int] | None,
|
|
600
|
+
str | None,
|
|
601
|
+
bool,
|
|
602
|
+
]
|
|
603
|
+
"""
|
|
604
|
+
tuple of options:
|
|
605
|
+
- how: join type
|
|
606
|
+
- join_nulls: do nulls compare equal?
|
|
607
|
+
- slice: optional slice to perform after joining.
|
|
608
|
+
- suffix: string suffix for right columns if names match
|
|
609
|
+
- coalesce: should key columns be coalesced (only makes sense for outer joins)
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
def __post_init__(self) -> None:
|
|
613
|
+
"""Validate preconditions."""
|
|
614
|
+
if any(
|
|
615
|
+
isinstance(e.value, expr.Literal)
|
|
616
|
+
for e in itertools.chain(self.left_on, self.right_on)
|
|
617
|
+
):
|
|
618
|
+
raise NotImplementedError("Join with literal as join key.")
|
|
619
|
+
|
|
620
|
+
@staticmethod
|
|
621
|
+
@cache
|
|
622
|
+
def _joiners(
|
|
623
|
+
how: Literal["inner", "left", "full", "leftsemi", "leftanti"],
|
|
624
|
+
) -> tuple[
|
|
625
|
+
Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None
|
|
626
|
+
]:
|
|
627
|
+
if how == "inner":
|
|
628
|
+
return (
|
|
629
|
+
plc.join.inner_join,
|
|
630
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
631
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
632
|
+
)
|
|
633
|
+
elif how == "left":
|
|
634
|
+
return (
|
|
635
|
+
plc.join.left_join,
|
|
636
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
637
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
638
|
+
)
|
|
639
|
+
elif how == "full":
|
|
640
|
+
return (
|
|
641
|
+
plc.join.full_join,
|
|
642
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
643
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
644
|
+
)
|
|
645
|
+
elif how == "leftsemi":
|
|
646
|
+
return (
|
|
647
|
+
plc.join.left_semi_join,
|
|
648
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
649
|
+
None,
|
|
650
|
+
)
|
|
651
|
+
elif how == "leftanti":
|
|
652
|
+
return (
|
|
653
|
+
plc.join.left_anti_join,
|
|
654
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
655
|
+
None,
|
|
656
|
+
)
|
|
657
|
+
else:
|
|
658
|
+
assert_never(how)
|
|
659
|
+
|
|
660
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
661
|
+
"""Evaluate and return a dataframe."""
|
|
662
|
+
left = self.left.evaluate(cache=cache)
|
|
663
|
+
right = self.right.evaluate(cache=cache)
|
|
664
|
+
how, join_nulls, zlice, suffix, coalesce = self.options
|
|
665
|
+
suffix = "_right" if suffix is None else suffix
|
|
666
|
+
if how == "cross":
|
|
667
|
+
# Separate implementation, since cross_join returns the
|
|
668
|
+
# result, not the gather maps
|
|
669
|
+
columns = plc.join.cross_join(left.table, right.table).columns()
|
|
670
|
+
left_cols = [
|
|
671
|
+
NamedColumn(new, old.name).sorted_like(old)
|
|
672
|
+
for new, old in zip(columns[: left.num_columns], left.columns)
|
|
673
|
+
]
|
|
674
|
+
right_cols = [
|
|
675
|
+
NamedColumn(
|
|
676
|
+
new,
|
|
677
|
+
old.name
|
|
678
|
+
if old.name not in left.column_names_set
|
|
679
|
+
else f"{old.name}{suffix}",
|
|
680
|
+
)
|
|
681
|
+
for new, old in zip(columns[left.num_columns :], right.columns)
|
|
682
|
+
]
|
|
683
|
+
return DataFrame([*left_cols, *right_cols])
|
|
684
|
+
# TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
|
|
685
|
+
left_on = DataFrame(broadcast(*(e.evaluate(left) for e in self.left_on)))
|
|
686
|
+
right_on = DataFrame(broadcast(*(e.evaluate(right) for e in self.right_on)))
|
|
687
|
+
null_equality = (
|
|
688
|
+
plc.types.NullEquality.EQUAL
|
|
689
|
+
if join_nulls
|
|
690
|
+
else plc.types.NullEquality.UNEQUAL
|
|
691
|
+
)
|
|
692
|
+
join_fn, left_policy, right_policy = Join._joiners(how)
|
|
693
|
+
if right_policy is None:
|
|
694
|
+
# Semi join
|
|
695
|
+
lg = join_fn(left_on.table, right_on.table, null_equality)
|
|
696
|
+
table = plc.copying.gather(left.table, lg, left_policy)
|
|
697
|
+
result = DataFrame.from_table(table, left.column_names)
|
|
698
|
+
else:
|
|
699
|
+
lg, rg = join_fn(left_on.table, right_on.table, null_equality)
|
|
700
|
+
if coalesce and how == "inner":
|
|
701
|
+
right = right.discard_columns(right_on.column_names_set)
|
|
702
|
+
left = DataFrame.from_table(
|
|
703
|
+
plc.copying.gather(left.table, lg, left_policy), left.column_names
|
|
704
|
+
)
|
|
705
|
+
right = DataFrame.from_table(
|
|
706
|
+
plc.copying.gather(right.table, rg, right_policy), right.column_names
|
|
707
|
+
)
|
|
708
|
+
if coalesce and how != "inner":
|
|
709
|
+
left = left.replace_columns(
|
|
710
|
+
*(
|
|
711
|
+
NamedColumn(
|
|
712
|
+
plc.replace.replace_nulls(left_col.obj, right_col.obj),
|
|
713
|
+
left_col.name,
|
|
714
|
+
)
|
|
715
|
+
for left_col, right_col in zip(
|
|
716
|
+
left.select_columns(left_on.column_names_set),
|
|
717
|
+
right.select_columns(right_on.column_names_set),
|
|
718
|
+
)
|
|
719
|
+
)
|
|
720
|
+
)
|
|
721
|
+
right = right.discard_columns(right_on.column_names_set)
|
|
722
|
+
right = right.rename_columns(
|
|
723
|
+
{
|
|
724
|
+
name: f"{name}{suffix}"
|
|
725
|
+
for name in right.column_names
|
|
726
|
+
if name in left.column_names_set
|
|
727
|
+
}
|
|
728
|
+
)
|
|
729
|
+
result = left.with_columns(right.columns)
|
|
730
|
+
return result.slice(zlice)
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
@dataclasses.dataclass
|
|
734
|
+
class HStack(IR):
|
|
735
|
+
"""Add new columns to a dataframe."""
|
|
736
|
+
|
|
737
|
+
df: IR
|
|
738
|
+
"""Input dataframe."""
|
|
739
|
+
columns: list[expr.NamedExpr]
|
|
740
|
+
"""List of expressions to produce new columns."""
|
|
741
|
+
should_broadcast: bool
|
|
742
|
+
"""Should columns be broadcast?"""
|
|
743
|
+
|
|
744
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
745
|
+
"""Evaluate and return a dataframe."""
|
|
746
|
+
df = self.df.evaluate(cache=cache)
|
|
747
|
+
columns = [c.evaluate(df) for c in self.columns]
|
|
748
|
+
if self.should_broadcast:
|
|
749
|
+
columns = broadcast(*columns, target_length=df.num_rows)
|
|
750
|
+
else:
|
|
751
|
+
# Polars ensures this is true, but let's make sure nothing
|
|
752
|
+
# went wrong. In this case, the parent node is a
|
|
753
|
+
# guaranteed to be a Select which will take care of making
|
|
754
|
+
# sure that everything is the same length. The result
|
|
755
|
+
# table that might have mismatching column lengths will
|
|
756
|
+
# never be turned into a pylibcudf Table with all columns
|
|
757
|
+
# by the Select, which is why this is safe.
|
|
758
|
+
assert all(e.name.startswith("__POLARS_CSER_0x") for e in self.columns)
|
|
759
|
+
return df.with_columns(columns)
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
@dataclasses.dataclass
|
|
763
|
+
class Distinct(IR):
|
|
764
|
+
"""Produce a new dataframe with distinct rows."""
|
|
765
|
+
|
|
766
|
+
df: IR
|
|
767
|
+
"""Input dataframe."""
|
|
768
|
+
keep: plc.stream_compaction.DuplicateKeepOption
|
|
769
|
+
"""Which rows to keep."""
|
|
770
|
+
subset: set[str] | None
|
|
771
|
+
"""Which columns to inspect when computing distinct rows."""
|
|
772
|
+
zlice: tuple[int, int] | None
|
|
773
|
+
"""Optional slice to perform after compaction."""
|
|
774
|
+
stable: bool
|
|
775
|
+
"""Should order be preserved?"""
|
|
776
|
+
|
|
777
|
+
_KEEP_MAP: ClassVar[dict[str, plc.stream_compaction.DuplicateKeepOption]] = {
|
|
778
|
+
"first": plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
|
|
779
|
+
"last": plc.stream_compaction.DuplicateKeepOption.KEEP_LAST,
|
|
780
|
+
"none": plc.stream_compaction.DuplicateKeepOption.KEEP_NONE,
|
|
781
|
+
"any": plc.stream_compaction.DuplicateKeepOption.KEEP_ANY,
|
|
782
|
+
}
|
|
783
|
+
|
|
784
|
+
def __init__(self, schema: Schema, df: IR, options: Any) -> None:
|
|
785
|
+
self.schema = schema
|
|
786
|
+
self.df = df
|
|
787
|
+
(keep, subset, maintain_order, zlice) = options
|
|
788
|
+
self.keep = Distinct._KEEP_MAP[keep]
|
|
789
|
+
self.subset = set(subset) if subset is not None else None
|
|
790
|
+
self.stable = maintain_order
|
|
791
|
+
self.zlice = zlice
|
|
792
|
+
|
|
793
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
794
|
+
"""Evaluate and return a dataframe."""
|
|
795
|
+
df = self.df.evaluate(cache=cache)
|
|
796
|
+
if self.subset is None:
|
|
797
|
+
indices = list(range(df.num_columns))
|
|
798
|
+
else:
|
|
799
|
+
indices = [i for i, k in enumerate(df.column_names) if k in self.subset]
|
|
800
|
+
keys_sorted = all(df.columns[i].is_sorted for i in indices)
|
|
801
|
+
if keys_sorted:
|
|
802
|
+
table = plc.stream_compaction.unique(
|
|
803
|
+
df.table,
|
|
804
|
+
indices,
|
|
805
|
+
self.keep,
|
|
806
|
+
plc.types.NullEquality.EQUAL,
|
|
807
|
+
)
|
|
808
|
+
else:
|
|
809
|
+
distinct = (
|
|
810
|
+
plc.stream_compaction.stable_distinct
|
|
811
|
+
if self.stable
|
|
812
|
+
else plc.stream_compaction.distinct
|
|
813
|
+
)
|
|
814
|
+
table = distinct(
|
|
815
|
+
df.table,
|
|
816
|
+
indices,
|
|
817
|
+
self.keep,
|
|
818
|
+
plc.types.NullEquality.EQUAL,
|
|
819
|
+
plc.types.NanEquality.ALL_EQUAL,
|
|
820
|
+
)
|
|
821
|
+
result = DataFrame(
|
|
822
|
+
[
|
|
823
|
+
NamedColumn(c, old.name).sorted_like(old)
|
|
824
|
+
for c, old in zip(table.columns(), df.columns)
|
|
825
|
+
]
|
|
826
|
+
)
|
|
827
|
+
if keys_sorted or self.stable:
|
|
828
|
+
result = result.sorted_like(df)
|
|
829
|
+
return result.slice(self.zlice)
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
@dataclasses.dataclass
|
|
833
|
+
class Sort(IR):
|
|
834
|
+
"""Sort a dataframe."""
|
|
835
|
+
|
|
836
|
+
df: IR
|
|
837
|
+
"""Input."""
|
|
838
|
+
by: list[expr.NamedExpr]
|
|
839
|
+
"""List of expressions to produce sort keys."""
|
|
840
|
+
do_sort: Callable[..., plc.Table]
|
|
841
|
+
"""pylibcudf sorting function."""
|
|
842
|
+
zlice: tuple[int, int] | None
|
|
843
|
+
"""Optional slice to apply after sorting."""
|
|
844
|
+
order: list[plc.types.Order]
|
|
845
|
+
"""Order keys should be sorted in."""
|
|
846
|
+
null_order: list[plc.types.NullOrder]
|
|
847
|
+
"""Where nulls sort to."""
|
|
848
|
+
|
|
849
|
+
def __init__(
|
|
850
|
+
self,
|
|
851
|
+
schema: Schema,
|
|
852
|
+
df: IR,
|
|
853
|
+
by: list[expr.NamedExpr],
|
|
854
|
+
options: Any,
|
|
855
|
+
zlice: tuple[int, int] | None,
|
|
856
|
+
) -> None:
|
|
857
|
+
self.schema = schema
|
|
858
|
+
self.df = df
|
|
859
|
+
self.by = by
|
|
860
|
+
self.zlice = zlice
|
|
861
|
+
stable, nulls_last, descending = options
|
|
862
|
+
self.order, self.null_order = sorting.sort_order(
|
|
863
|
+
descending, nulls_last=nulls_last, num_keys=len(by)
|
|
864
|
+
)
|
|
865
|
+
self.do_sort = (
|
|
866
|
+
plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
870
|
+
"""Evaluate and return a dataframe."""
|
|
871
|
+
df = self.df.evaluate(cache=cache)
|
|
872
|
+
sort_keys = broadcast(
|
|
873
|
+
*(k.evaluate(df) for k in self.by), target_length=df.num_rows
|
|
874
|
+
)
|
|
875
|
+
names = {c.name: i for i, c in enumerate(df.columns)}
|
|
876
|
+
# TODO: More robust identification here.
|
|
877
|
+
keys_in_result = [
|
|
878
|
+
i
|
|
879
|
+
for k in sort_keys
|
|
880
|
+
if (i := names.get(k.name)) is not None and k.obj is df.columns[i].obj
|
|
881
|
+
]
|
|
882
|
+
table = self.do_sort(
|
|
883
|
+
df.table,
|
|
884
|
+
plc.Table([k.obj for k in sort_keys]),
|
|
885
|
+
self.order,
|
|
886
|
+
self.null_order,
|
|
887
|
+
)
|
|
888
|
+
columns = [
|
|
889
|
+
NamedColumn(c, old.name) for c, old in zip(table.columns(), df.columns)
|
|
890
|
+
]
|
|
891
|
+
# If a sort key is in the result table, set the sortedness property
|
|
892
|
+
for k, i in enumerate(keys_in_result):
|
|
893
|
+
columns[i] = columns[i].set_sorted(
|
|
894
|
+
is_sorted=plc.types.Sorted.YES,
|
|
895
|
+
order=self.order[k],
|
|
896
|
+
null_order=self.null_order[k],
|
|
897
|
+
)
|
|
898
|
+
return DataFrame(columns).slice(self.zlice)
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
@dataclasses.dataclass
|
|
902
|
+
class Slice(IR):
|
|
903
|
+
"""Slice a dataframe."""
|
|
904
|
+
|
|
905
|
+
df: IR
|
|
906
|
+
"""Input."""
|
|
907
|
+
offset: int
|
|
908
|
+
"""Start of the slice."""
|
|
909
|
+
length: int
|
|
910
|
+
"""Length of the slice."""
|
|
911
|
+
|
|
912
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
913
|
+
"""Evaluate and return a dataframe."""
|
|
914
|
+
df = self.df.evaluate(cache=cache)
|
|
915
|
+
return df.slice((self.offset, self.length))
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
@dataclasses.dataclass
|
|
919
|
+
class Filter(IR):
|
|
920
|
+
"""Filter a dataframe with a boolean mask."""
|
|
921
|
+
|
|
922
|
+
df: IR
|
|
923
|
+
"""Input."""
|
|
924
|
+
mask: expr.NamedExpr
|
|
925
|
+
"""Expression evaluating to a mask."""
|
|
926
|
+
|
|
927
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
928
|
+
"""Evaluate and return a dataframe."""
|
|
929
|
+
df = self.df.evaluate(cache=cache)
|
|
930
|
+
(mask,) = broadcast(self.mask.evaluate(df), target_length=df.num_rows)
|
|
931
|
+
return df.filter(mask)
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
@dataclasses.dataclass
|
|
935
|
+
class Projection(IR):
|
|
936
|
+
"""Select a subset of columns from a dataframe."""
|
|
937
|
+
|
|
938
|
+
df: IR
|
|
939
|
+
"""Input."""
|
|
940
|
+
|
|
941
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
942
|
+
"""Evaluate and return a dataframe."""
|
|
943
|
+
df = self.df.evaluate(cache=cache)
|
|
944
|
+
# This can reorder things.
|
|
945
|
+
columns = broadcast(
|
|
946
|
+
*df.select(list(self.schema.keys())).columns, target_length=df.num_rows
|
|
947
|
+
)
|
|
948
|
+
return DataFrame(columns)
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
@dataclasses.dataclass
|
|
952
|
+
class MapFunction(IR):
|
|
953
|
+
"""Apply some function to a dataframe."""
|
|
954
|
+
|
|
955
|
+
df: IR
|
|
956
|
+
"""Input."""
|
|
957
|
+
name: str
|
|
958
|
+
"""Function name."""
|
|
959
|
+
options: Any
|
|
960
|
+
"""Arbitrary options, interpreted per function."""
|
|
961
|
+
|
|
962
|
+
_NAMES: ClassVar[frozenset[str]] = frozenset(
|
|
963
|
+
[
|
|
964
|
+
"rechunk",
|
|
965
|
+
# libcudf merge is not stable wrt order of inputs, since
|
|
966
|
+
# it uses a priority queue to manage the tables it produces.
|
|
967
|
+
# See: https://github.com/rapidsai/cudf/issues/16010
|
|
968
|
+
# "merge_sorted",
|
|
969
|
+
"rename",
|
|
970
|
+
"explode",
|
|
971
|
+
]
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
def __post_init__(self) -> None:
|
|
975
|
+
"""Validate preconditions."""
|
|
976
|
+
if self.name not in MapFunction._NAMES:
|
|
977
|
+
raise NotImplementedError(f"Unhandled map function {self.name}")
|
|
978
|
+
if self.name == "explode":
|
|
979
|
+
(to_explode,) = self.options
|
|
980
|
+
if len(to_explode) > 1:
|
|
981
|
+
# TODO: straightforward, but need to error check
|
|
982
|
+
# polars requires that all to-explode columns have the
|
|
983
|
+
# same sub-shapes
|
|
984
|
+
raise NotImplementedError("Explode with more than one column")
|
|
985
|
+
elif self.name == "rename":
|
|
986
|
+
old, new, _ = self.options
|
|
987
|
+
# TODO: perhaps polars should validate renaming in the IR?
|
|
988
|
+
if len(new) != len(set(new)) or (
|
|
989
|
+
set(new) & (set(self.df.schema.keys() - set(old)))
|
|
990
|
+
):
|
|
991
|
+
raise NotImplementedError("Duplicate new names in rename.")
|
|
992
|
+
|
|
993
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
994
|
+
"""Evaluate and return a dataframe."""
|
|
995
|
+
if self.name == "rechunk":
|
|
996
|
+
# No-op in our data model
|
|
997
|
+
# Don't think this appears in a plan tree from python
|
|
998
|
+
return self.df.evaluate(cache=cache) # pragma: no cover
|
|
999
|
+
elif self.name == "rename":
|
|
1000
|
+
df = self.df.evaluate(cache=cache)
|
|
1001
|
+
# final tag is "swapping" which is useful for the
|
|
1002
|
+
# optimiser (it blocks some pushdown operations)
|
|
1003
|
+
old, new, _ = self.options
|
|
1004
|
+
return df.rename_columns(dict(zip(old, new)))
|
|
1005
|
+
elif self.name == "explode":
|
|
1006
|
+
df = self.df.evaluate(cache=cache)
|
|
1007
|
+
((to_explode,),) = self.options
|
|
1008
|
+
index = df.column_names.index(to_explode)
|
|
1009
|
+
subset = df.column_names_set - {to_explode}
|
|
1010
|
+
return DataFrame.from_table(
|
|
1011
|
+
plc.lists.explode_outer(df.table, index), df.column_names
|
|
1012
|
+
).sorted_like(df, subset=subset)
|
|
1013
|
+
else:
|
|
1014
|
+
raise AssertionError("Should never be reached") # pragma: no cover
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
@dataclasses.dataclass
|
|
1018
|
+
class Union(IR):
|
|
1019
|
+
"""Concatenate dataframes vertically."""
|
|
1020
|
+
|
|
1021
|
+
dfs: list[IR]
|
|
1022
|
+
"""List of inputs."""
|
|
1023
|
+
zlice: tuple[int, int] | None
|
|
1024
|
+
"""Optional slice to apply after concatenation."""
|
|
1025
|
+
|
|
1026
|
+
def __post_init__(self) -> None:
|
|
1027
|
+
"""Validate preconditions."""
|
|
1028
|
+
schema = self.dfs[0].schema
|
|
1029
|
+
if not all(s.schema == schema for s in self.dfs[1:]):
|
|
1030
|
+
raise NotImplementedError("Schema mismatch")
|
|
1031
|
+
|
|
1032
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
1033
|
+
"""Evaluate and return a dataframe."""
|
|
1034
|
+
# TODO: only evaluate what we need if we have a slice
|
|
1035
|
+
dfs = [df.evaluate(cache=cache) for df in self.dfs]
|
|
1036
|
+
return DataFrame.from_table(
|
|
1037
|
+
plc.concatenate.concatenate([df.table for df in dfs]), dfs[0].column_names
|
|
1038
|
+
).slice(self.zlice)
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
@dataclasses.dataclass
|
|
1042
|
+
class HConcat(IR):
|
|
1043
|
+
"""Concatenate dataframes horizontally."""
|
|
1044
|
+
|
|
1045
|
+
dfs: list[IR]
|
|
1046
|
+
"""List of inputs."""
|
|
1047
|
+
|
|
1048
|
+
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
|
|
1049
|
+
"""Evaluate and return a dataframe."""
|
|
1050
|
+
dfs = [df.evaluate(cache=cache) for df in self.dfs]
|
|
1051
|
+
return DataFrame(
|
|
1052
|
+
list(itertools.chain.from_iterable(df.columns for df in dfs)),
|
|
1053
|
+
)
|