cudf-polars-cu13 25.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -0
- cudf_polars/VERSION +1 -0
- cudf_polars/__init__.py +28 -0
- cudf_polars/_version.py +21 -0
- cudf_polars/callback.py +318 -0
- cudf_polars/containers/__init__.py +13 -0
- cudf_polars/containers/column.py +495 -0
- cudf_polars/containers/dataframe.py +361 -0
- cudf_polars/containers/datatype.py +137 -0
- cudf_polars/dsl/__init__.py +8 -0
- cudf_polars/dsl/expr.py +66 -0
- cudf_polars/dsl/expressions/__init__.py +8 -0
- cudf_polars/dsl/expressions/aggregation.py +226 -0
- cudf_polars/dsl/expressions/base.py +272 -0
- cudf_polars/dsl/expressions/binaryop.py +120 -0
- cudf_polars/dsl/expressions/boolean.py +326 -0
- cudf_polars/dsl/expressions/datetime.py +271 -0
- cudf_polars/dsl/expressions/literal.py +97 -0
- cudf_polars/dsl/expressions/rolling.py +643 -0
- cudf_polars/dsl/expressions/selection.py +74 -0
- cudf_polars/dsl/expressions/slicing.py +46 -0
- cudf_polars/dsl/expressions/sorting.py +85 -0
- cudf_polars/dsl/expressions/string.py +1002 -0
- cudf_polars/dsl/expressions/struct.py +137 -0
- cudf_polars/dsl/expressions/ternary.py +49 -0
- cudf_polars/dsl/expressions/unary.py +517 -0
- cudf_polars/dsl/ir.py +2607 -0
- cudf_polars/dsl/nodebase.py +164 -0
- cudf_polars/dsl/to_ast.py +359 -0
- cudf_polars/dsl/tracing.py +16 -0
- cudf_polars/dsl/translate.py +939 -0
- cudf_polars/dsl/traversal.py +224 -0
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +481 -0
- cudf_polars/dsl/utils/groupby.py +98 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +61 -0
- cudf_polars/dsl/utils/reshape.py +74 -0
- cudf_polars/dsl/utils/rolling.py +121 -0
- cudf_polars/dsl/utils/windows.py +192 -0
- cudf_polars/experimental/__init__.py +8 -0
- cudf_polars/experimental/base.py +386 -0
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds.py +220 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
- cudf_polars/experimental/benchmarks/pdsh.py +814 -0
- cudf_polars/experimental/benchmarks/utils.py +832 -0
- cudf_polars/experimental/dask_registers.py +200 -0
- cudf_polars/experimental/dispatch.py +156 -0
- cudf_polars/experimental/distinct.py +197 -0
- cudf_polars/experimental/explain.py +157 -0
- cudf_polars/experimental/expressions.py +590 -0
- cudf_polars/experimental/groupby.py +327 -0
- cudf_polars/experimental/io.py +943 -0
- cudf_polars/experimental/join.py +391 -0
- cudf_polars/experimental/parallel.py +423 -0
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +188 -0
- cudf_polars/experimental/shuffle.py +354 -0
- cudf_polars/experimental/sort.py +609 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/statistics.py +795 -0
- cudf_polars/experimental/utils.py +169 -0
- cudf_polars/py.typed +0 -0
- cudf_polars/testing/__init__.py +8 -0
- cudf_polars/testing/asserts.py +448 -0
- cudf_polars/testing/io.py +122 -0
- cudf_polars/testing/plugin.py +236 -0
- cudf_polars/typing/__init__.py +219 -0
- cudf_polars/utils/__init__.py +8 -0
- cudf_polars/utils/config.py +741 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +118 -0
- cudf_polars/utils/sorting.py +53 -0
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +27 -0
- cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
- cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
- cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
- cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
- cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
cudf_polars/dsl/ir.py
ADDED
|
@@ -0,0 +1,2607 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""
|
|
4
|
+
DSL nodes for the LogicalPlan of polars.
|
|
5
|
+
|
|
6
|
+
An IR node is either a source, normal, or a sink. Respectively they
|
|
7
|
+
can be considered as functions:
|
|
8
|
+
|
|
9
|
+
- source: `IO () -> DataFrame`
|
|
10
|
+
- normal: `DataFrame -> DataFrame`
|
|
11
|
+
- sink: `DataFrame -> IO ()`
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import itertools
|
|
17
|
+
import json
|
|
18
|
+
import random
|
|
19
|
+
import time
|
|
20
|
+
from functools import cache
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
23
|
+
|
|
24
|
+
from typing_extensions import assert_never
|
|
25
|
+
|
|
26
|
+
import polars as pl
|
|
27
|
+
|
|
28
|
+
import pylibcudf as plc
|
|
29
|
+
|
|
30
|
+
import cudf_polars.dsl.expr as expr
|
|
31
|
+
from cudf_polars.containers import Column, DataFrame, DataType
|
|
32
|
+
from cudf_polars.dsl.expressions import rolling, unary
|
|
33
|
+
from cudf_polars.dsl.expressions.base import ExecutionContext
|
|
34
|
+
from cudf_polars.dsl.nodebase import Node
|
|
35
|
+
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
|
|
36
|
+
from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
|
|
37
|
+
from cudf_polars.dsl.utils.reshape import broadcast
|
|
38
|
+
from cudf_polars.dsl.utils.windows import range_window_bounds
|
|
39
|
+
from cudf_polars.utils import dtypes
|
|
40
|
+
from cudf_polars.utils.versions import POLARS_VERSION_LT_131
|
|
41
|
+
|
|
42
|
+
if TYPE_CHECKING:
|
|
43
|
+
from collections.abc import Callable, Hashable, Iterable, Sequence
|
|
44
|
+
from typing import Literal
|
|
45
|
+
|
|
46
|
+
from typing_extensions import Self
|
|
47
|
+
|
|
48
|
+
from polars.polars import _expr_nodes as pl_expr
|
|
49
|
+
|
|
50
|
+
from cudf_polars.containers.dataframe import NamedColumn
|
|
51
|
+
from cudf_polars.typing import CSECache, ClosedInterval, Schema, Slice as Zlice
|
|
52
|
+
from cudf_polars.utils.config import ParquetOptions
|
|
53
|
+
from cudf_polars.utils.timer import Timer
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
__all__ = [
|
|
57
|
+
"IR",
|
|
58
|
+
"Cache",
|
|
59
|
+
"ConditionalJoin",
|
|
60
|
+
"DataFrameScan",
|
|
61
|
+
"Distinct",
|
|
62
|
+
"Empty",
|
|
63
|
+
"ErrorNode",
|
|
64
|
+
"Filter",
|
|
65
|
+
"GroupBy",
|
|
66
|
+
"HConcat",
|
|
67
|
+
"HStack",
|
|
68
|
+
"Join",
|
|
69
|
+
"MapFunction",
|
|
70
|
+
"MergeSorted",
|
|
71
|
+
"Projection",
|
|
72
|
+
"PythonScan",
|
|
73
|
+
"Reduce",
|
|
74
|
+
"Rolling",
|
|
75
|
+
"Scan",
|
|
76
|
+
"Select",
|
|
77
|
+
"Sink",
|
|
78
|
+
"Slice",
|
|
79
|
+
"Sort",
|
|
80
|
+
"Union",
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class IR(Node["IR"]):
|
|
85
|
+
"""Abstract plan node, representing an unevaluated dataframe."""
|
|
86
|
+
|
|
87
|
+
__slots__ = ("_non_child_args", "schema")
|
|
88
|
+
# This annotation is needed because of https://github.com/python/mypy/issues/17981
|
|
89
|
+
_non_child: ClassVar[tuple[str, ...]] = ("schema",)
|
|
90
|
+
# Concrete classes should set this up with the arguments that will
|
|
91
|
+
# be passed to do_evaluate.
|
|
92
|
+
_non_child_args: tuple[Any, ...]
|
|
93
|
+
schema: Schema
|
|
94
|
+
"""Mapping from column names to their data types."""
|
|
95
|
+
|
|
96
|
+
def get_hashable(self) -> Hashable:
|
|
97
|
+
"""
|
|
98
|
+
Hashable representation of node, treating schema dictionary.
|
|
99
|
+
|
|
100
|
+
Since the schema is a dictionary, even though it is morally
|
|
101
|
+
immutable, it is not hashable. We therefore convert it to
|
|
102
|
+
tuples for hashing purposes.
|
|
103
|
+
"""
|
|
104
|
+
# Schema is the first constructor argument
|
|
105
|
+
args = self._ctor_arguments(self.children)[1:]
|
|
106
|
+
schema_hash = tuple(self.schema.items())
|
|
107
|
+
return (type(self), schema_hash, args)
|
|
108
|
+
|
|
109
|
+
# Hacky to avoid type-checking issues, just advertise the
|
|
110
|
+
# signature. Both mypy and pyright complain if we have an abstract
|
|
111
|
+
# method that takes arbitrary *args, but the subclasses have
|
|
112
|
+
# tighter signatures. This complaint is correct because the
|
|
113
|
+
# subclass is not Liskov-substitutable for the superclass.
|
|
114
|
+
# However, we know do_evaluate will only be called with the
|
|
115
|
+
# correct arguments by "construction".
|
|
116
|
+
do_evaluate: Callable[..., DataFrame]
|
|
117
|
+
"""
|
|
118
|
+
Evaluate the node (given its evaluated children), and return a dataframe.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
args
|
|
123
|
+
Non child arguments followed by any evaluated dataframe inputs.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
DataFrame (on device) representing the evaluation of this plan
|
|
128
|
+
node.
|
|
129
|
+
|
|
130
|
+
Raises
|
|
131
|
+
------
|
|
132
|
+
NotImplementedError
|
|
133
|
+
If evaluation fails. Ideally this should not occur, since the
|
|
134
|
+
translation phase should fail earlier.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
|
|
138
|
+
"""
|
|
139
|
+
Evaluate the node (recursively) and return a dataframe.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
cache
|
|
144
|
+
Mapping from cached node ids to constructed DataFrames.
|
|
145
|
+
Used to implement evaluation of the `Cache` node.
|
|
146
|
+
timer
|
|
147
|
+
If not None, a Timer object to record timings for the
|
|
148
|
+
evaluation of the node.
|
|
149
|
+
|
|
150
|
+
Notes
|
|
151
|
+
-----
|
|
152
|
+
Prefer not to override this method. Instead implement
|
|
153
|
+
:meth:`do_evaluate` which doesn't encode a recursion scheme
|
|
154
|
+
and just assumes already evaluated inputs.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
DataFrame (on device) representing the evaluation of this plan
|
|
159
|
+
node (and its children).
|
|
160
|
+
|
|
161
|
+
Raises
|
|
162
|
+
------
|
|
163
|
+
NotImplementedError
|
|
164
|
+
If evaluation fails. Ideally this should not occur, since the
|
|
165
|
+
translation phase should fail earlier.
|
|
166
|
+
"""
|
|
167
|
+
children = [child.evaluate(cache=cache, timer=timer) for child in self.children]
|
|
168
|
+
if timer is not None:
|
|
169
|
+
start = time.monotonic_ns()
|
|
170
|
+
result = self.do_evaluate(*self._non_child_args, *children)
|
|
171
|
+
end = time.monotonic_ns()
|
|
172
|
+
# TODO: Set better names on each class object.
|
|
173
|
+
timer.store(start, end, type(self).__name__)
|
|
174
|
+
return result
|
|
175
|
+
else:
|
|
176
|
+
return self.do_evaluate(*self._non_child_args, *children)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class ErrorNode(IR):
|
|
180
|
+
"""Represents an error translating the IR."""
|
|
181
|
+
|
|
182
|
+
__slots__ = ("error",)
|
|
183
|
+
_non_child = (
|
|
184
|
+
"schema",
|
|
185
|
+
"error",
|
|
186
|
+
)
|
|
187
|
+
error: str
|
|
188
|
+
"""The error."""
|
|
189
|
+
|
|
190
|
+
def __init__(self, schema: Schema, error: str):
|
|
191
|
+
self.schema = schema
|
|
192
|
+
self.error = error
|
|
193
|
+
self.children = ()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class PythonScan(IR):
|
|
197
|
+
"""Representation of input from a python function."""
|
|
198
|
+
|
|
199
|
+
__slots__ = ("options", "predicate")
|
|
200
|
+
_non_child = ("schema", "options", "predicate")
|
|
201
|
+
options: Any
|
|
202
|
+
"""Arbitrary options."""
|
|
203
|
+
predicate: expr.NamedExpr | None
|
|
204
|
+
"""Filter to apply to the constructed dataframe before returning it."""
|
|
205
|
+
|
|
206
|
+
def __init__(self, schema: Schema, options: Any, predicate: expr.NamedExpr | None):
|
|
207
|
+
self.schema = schema
|
|
208
|
+
self.options = options
|
|
209
|
+
self.predicate = predicate
|
|
210
|
+
self._non_child_args = (schema, options, predicate)
|
|
211
|
+
self.children = ()
|
|
212
|
+
raise NotImplementedError("PythonScan not implemented")
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _align_parquet_schema(df: DataFrame, schema: Schema) -> DataFrame:
|
|
216
|
+
# TODO: Alternatively set the schema of the parquet reader to decimal128
|
|
217
|
+
plc_decimals_ids = {
|
|
218
|
+
plc.TypeId.DECIMAL32,
|
|
219
|
+
plc.TypeId.DECIMAL64,
|
|
220
|
+
plc.TypeId.DECIMAL128,
|
|
221
|
+
}
|
|
222
|
+
cast_list = []
|
|
223
|
+
|
|
224
|
+
for name, col in df.column_map.items():
|
|
225
|
+
src = col.obj.type()
|
|
226
|
+
dst = schema[name].plc
|
|
227
|
+
if (
|
|
228
|
+
src.id() in plc_decimals_ids
|
|
229
|
+
and dst.id() in plc_decimals_ids
|
|
230
|
+
and ((src.id() != dst.id()) or (src.scale != dst.scale))
|
|
231
|
+
):
|
|
232
|
+
cast_list.append(
|
|
233
|
+
Column(plc.unary.cast(col.obj, dst), name=name, dtype=schema[name])
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if cast_list:
|
|
237
|
+
df = df.with_columns(cast_list)
|
|
238
|
+
|
|
239
|
+
return df
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class Scan(IR):
|
|
243
|
+
"""Input from files."""
|
|
244
|
+
|
|
245
|
+
__slots__ = (
|
|
246
|
+
"cloud_options",
|
|
247
|
+
"include_file_paths",
|
|
248
|
+
"n_rows",
|
|
249
|
+
"parquet_options",
|
|
250
|
+
"paths",
|
|
251
|
+
"predicate",
|
|
252
|
+
"reader_options",
|
|
253
|
+
"row_index",
|
|
254
|
+
"skip_rows",
|
|
255
|
+
"typ",
|
|
256
|
+
"with_columns",
|
|
257
|
+
)
|
|
258
|
+
_non_child = (
|
|
259
|
+
"schema",
|
|
260
|
+
"typ",
|
|
261
|
+
"reader_options",
|
|
262
|
+
"cloud_options",
|
|
263
|
+
"paths",
|
|
264
|
+
"with_columns",
|
|
265
|
+
"skip_rows",
|
|
266
|
+
"n_rows",
|
|
267
|
+
"row_index",
|
|
268
|
+
"include_file_paths",
|
|
269
|
+
"predicate",
|
|
270
|
+
"parquet_options",
|
|
271
|
+
)
|
|
272
|
+
typ: str
|
|
273
|
+
"""What type of file are we reading? Parquet, CSV, etc..."""
|
|
274
|
+
reader_options: dict[str, Any]
|
|
275
|
+
"""Reader-specific options, as dictionary."""
|
|
276
|
+
cloud_options: dict[str, Any] | None
|
|
277
|
+
"""Cloud-related authentication options, currently ignored."""
|
|
278
|
+
paths: list[str]
|
|
279
|
+
"""List of paths to read from."""
|
|
280
|
+
with_columns: list[str] | None
|
|
281
|
+
"""Projected columns to return."""
|
|
282
|
+
skip_rows: int
|
|
283
|
+
"""Rows to skip at the start when reading."""
|
|
284
|
+
n_rows: int
|
|
285
|
+
"""Number of rows to read after skipping."""
|
|
286
|
+
row_index: tuple[str, int] | None
|
|
287
|
+
"""If not None add an integer index column of the given name."""
|
|
288
|
+
include_file_paths: str | None
|
|
289
|
+
"""Include the path of the source file(s) as a column with this name."""
|
|
290
|
+
predicate: expr.NamedExpr | None
|
|
291
|
+
"""Mask to apply to the read dataframe."""
|
|
292
|
+
parquet_options: ParquetOptions
|
|
293
|
+
"""Parquet-specific options."""
|
|
294
|
+
|
|
295
|
+
PARQUET_DEFAULT_CHUNK_SIZE: int = 0 # unlimited
|
|
296
|
+
PARQUET_DEFAULT_PASS_LIMIT: int = 16 * 1024**3 # 16GiB
|
|
297
|
+
|
|
298
|
+
def __init__(
|
|
299
|
+
self,
|
|
300
|
+
schema: Schema,
|
|
301
|
+
typ: str,
|
|
302
|
+
reader_options: dict[str, Any],
|
|
303
|
+
cloud_options: dict[str, Any] | None,
|
|
304
|
+
paths: list[str],
|
|
305
|
+
with_columns: list[str] | None,
|
|
306
|
+
skip_rows: int,
|
|
307
|
+
n_rows: int,
|
|
308
|
+
row_index: tuple[str, int] | None,
|
|
309
|
+
include_file_paths: str | None,
|
|
310
|
+
predicate: expr.NamedExpr | None,
|
|
311
|
+
parquet_options: ParquetOptions,
|
|
312
|
+
):
|
|
313
|
+
self.schema = schema
|
|
314
|
+
self.typ = typ
|
|
315
|
+
self.reader_options = reader_options
|
|
316
|
+
self.cloud_options = cloud_options
|
|
317
|
+
self.paths = paths
|
|
318
|
+
self.with_columns = with_columns
|
|
319
|
+
self.skip_rows = skip_rows
|
|
320
|
+
self.n_rows = n_rows
|
|
321
|
+
self.row_index = row_index
|
|
322
|
+
self.include_file_paths = include_file_paths
|
|
323
|
+
self.predicate = predicate
|
|
324
|
+
self._non_child_args = (
|
|
325
|
+
schema,
|
|
326
|
+
typ,
|
|
327
|
+
reader_options,
|
|
328
|
+
paths,
|
|
329
|
+
with_columns,
|
|
330
|
+
skip_rows,
|
|
331
|
+
n_rows,
|
|
332
|
+
row_index,
|
|
333
|
+
include_file_paths,
|
|
334
|
+
predicate,
|
|
335
|
+
parquet_options,
|
|
336
|
+
)
|
|
337
|
+
self.children = ()
|
|
338
|
+
self.parquet_options = parquet_options
|
|
339
|
+
if self.typ not in ("csv", "parquet", "ndjson"): # pragma: no cover
|
|
340
|
+
# This line is unhittable ATM since IPC/Anonymous scan raise
|
|
341
|
+
# on the polars side
|
|
342
|
+
raise NotImplementedError(f"Unhandled scan type: {self.typ}")
|
|
343
|
+
if self.typ == "ndjson" and (self.n_rows != -1 or self.skip_rows != 0):
|
|
344
|
+
raise NotImplementedError("row limit in scan for json reader")
|
|
345
|
+
if self.skip_rows < 0:
|
|
346
|
+
# TODO: polars has this implemented for parquet,
|
|
347
|
+
# maybe we can do this too?
|
|
348
|
+
raise NotImplementedError("slice pushdown for negative slices")
|
|
349
|
+
if self.cloud_options is not None and any(
|
|
350
|
+
self.cloud_options.get(k) is not None for k in ("aws", "azure", "gcp")
|
|
351
|
+
):
|
|
352
|
+
raise NotImplementedError(
|
|
353
|
+
"Read from cloud storage"
|
|
354
|
+
) # pragma: no cover; no test yet
|
|
355
|
+
if (
|
|
356
|
+
any(str(p).startswith("https:/") for p in self.paths)
|
|
357
|
+
and POLARS_VERSION_LT_131
|
|
358
|
+
): # pragma: no cover; polars passed us the wrong URI
|
|
359
|
+
# https://github.com/pola-rs/polars/issues/22766
|
|
360
|
+
raise NotImplementedError("Read from https")
|
|
361
|
+
if any(
|
|
362
|
+
str(p).startswith("file:/" if POLARS_VERSION_LT_131 else "file://")
|
|
363
|
+
for p in self.paths
|
|
364
|
+
):
|
|
365
|
+
raise NotImplementedError("Read from file URI")
|
|
366
|
+
if self.typ == "csv":
|
|
367
|
+
if any(
|
|
368
|
+
plc.io.SourceInfo._is_remote_uri(p) for p in self.paths
|
|
369
|
+
): # pragma: no cover; no test yet
|
|
370
|
+
# This works fine when the file has no leading blank lines,
|
|
371
|
+
# but currently we do some file introspection
|
|
372
|
+
# to skip blanks before parsing the header.
|
|
373
|
+
# For remote files we cannot determine if leading blank lines
|
|
374
|
+
# exist, so we're punting on CSV support.
|
|
375
|
+
# TODO: Once the CSV reader supports skipping leading
|
|
376
|
+
# blank lines natively, we can remove this guard.
|
|
377
|
+
raise NotImplementedError(
|
|
378
|
+
"Reading CSV from remote is not yet supported"
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
if self.reader_options["skip_rows_after_header"] != 0:
|
|
382
|
+
raise NotImplementedError("Skipping rows after header in CSV reader")
|
|
383
|
+
parse_options = self.reader_options["parse_options"]
|
|
384
|
+
if (
|
|
385
|
+
null_values := parse_options["null_values"]
|
|
386
|
+
) is not None and "Named" in null_values:
|
|
387
|
+
raise NotImplementedError(
|
|
388
|
+
"Per column null value specification not supported for CSV reader"
|
|
389
|
+
)
|
|
390
|
+
if (
|
|
391
|
+
comment := parse_options["comment_prefix"]
|
|
392
|
+
) is not None and "Multi" in comment:
|
|
393
|
+
raise NotImplementedError(
|
|
394
|
+
"Multi-character comment prefix not supported for CSV reader"
|
|
395
|
+
)
|
|
396
|
+
if not self.reader_options["has_header"]:
|
|
397
|
+
# TODO: To support reading headerless CSV files without requiring new
|
|
398
|
+
# column names, we would need to do file introspection to infer the number
|
|
399
|
+
# of columns so column projection works right.
|
|
400
|
+
reader_schema = self.reader_options.get("schema")
|
|
401
|
+
if not (
|
|
402
|
+
reader_schema
|
|
403
|
+
and isinstance(schema, dict)
|
|
404
|
+
and "fields" in reader_schema
|
|
405
|
+
):
|
|
406
|
+
raise NotImplementedError(
|
|
407
|
+
"Reading CSV without header requires user-provided column names via new_columns"
|
|
408
|
+
)
|
|
409
|
+
elif self.typ == "ndjson":
|
|
410
|
+
# TODO: consider handling the low memory option here
|
|
411
|
+
# (maybe use chunked JSON reader)
|
|
412
|
+
if self.reader_options["ignore_errors"]:
|
|
413
|
+
raise NotImplementedError(
|
|
414
|
+
"ignore_errors is not supported in the JSON reader"
|
|
415
|
+
)
|
|
416
|
+
if include_file_paths is not None:
|
|
417
|
+
# TODO: Need to populate num_rows_per_source in read_json in libcudf
|
|
418
|
+
raise NotImplementedError("Including file paths in a json scan.")
|
|
419
|
+
elif (
|
|
420
|
+
self.typ == "parquet"
|
|
421
|
+
and self.row_index is not None
|
|
422
|
+
and self.with_columns is not None
|
|
423
|
+
and len(self.with_columns) == 0
|
|
424
|
+
):
|
|
425
|
+
raise NotImplementedError(
|
|
426
|
+
"Reading only parquet metadata to produce row index."
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
def get_hashable(self) -> Hashable:
|
|
430
|
+
"""
|
|
431
|
+
Hashable representation of the node.
|
|
432
|
+
|
|
433
|
+
The options dictionaries are serialised for hashing purposes
|
|
434
|
+
as json strings.
|
|
435
|
+
"""
|
|
436
|
+
schema_hash = tuple(self.schema.items())
|
|
437
|
+
return (
|
|
438
|
+
type(self),
|
|
439
|
+
schema_hash,
|
|
440
|
+
self.typ,
|
|
441
|
+
json.dumps(self.reader_options),
|
|
442
|
+
json.dumps(self.cloud_options),
|
|
443
|
+
tuple(self.paths),
|
|
444
|
+
tuple(self.with_columns) if self.with_columns is not None else None,
|
|
445
|
+
self.skip_rows,
|
|
446
|
+
self.n_rows,
|
|
447
|
+
self.row_index,
|
|
448
|
+
self.include_file_paths,
|
|
449
|
+
self.predicate,
|
|
450
|
+
self.parquet_options,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
@staticmethod
|
|
454
|
+
def add_file_paths(
|
|
455
|
+
name: str, paths: list[str], rows_per_path: list[int], df: DataFrame
|
|
456
|
+
) -> DataFrame:
|
|
457
|
+
"""
|
|
458
|
+
Add a Column of file paths to the DataFrame.
|
|
459
|
+
|
|
460
|
+
Each path is repeated according to the number of rows read from it.
|
|
461
|
+
"""
|
|
462
|
+
(filepaths,) = plc.filling.repeat(
|
|
463
|
+
plc.Table([plc.Column.from_arrow(pl.Series(values=map(str, paths)))]),
|
|
464
|
+
plc.Column.from_arrow(
|
|
465
|
+
pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32())
|
|
466
|
+
),
|
|
467
|
+
).columns()
|
|
468
|
+
dtype = DataType(pl.String())
|
|
469
|
+
return df.with_columns([Column(filepaths, name=name, dtype=dtype)])
|
|
470
|
+
|
|
471
|
+
def fast_count(self) -> int: # pragma: no cover
|
|
472
|
+
"""Get the number of rows in a Parquet Scan."""
|
|
473
|
+
meta = plc.io.parquet_metadata.read_parquet_metadata(
|
|
474
|
+
plc.io.SourceInfo(self.paths)
|
|
475
|
+
)
|
|
476
|
+
total_rows = meta.num_rows() - self.skip_rows
|
|
477
|
+
if self.n_rows != -1:
|
|
478
|
+
total_rows = min(total_rows, self.n_rows)
|
|
479
|
+
return max(total_rows, 0)
|
|
480
|
+
|
|
481
|
+
@classmethod
|
|
482
|
+
@nvtx_annotate_cudf_polars(message="Scan")
|
|
483
|
+
def do_evaluate(
|
|
484
|
+
cls,
|
|
485
|
+
schema: Schema,
|
|
486
|
+
typ: str,
|
|
487
|
+
reader_options: dict[str, Any],
|
|
488
|
+
paths: list[str],
|
|
489
|
+
with_columns: list[str] | None,
|
|
490
|
+
skip_rows: int,
|
|
491
|
+
n_rows: int,
|
|
492
|
+
row_index: tuple[str, int] | None,
|
|
493
|
+
include_file_paths: str | None,
|
|
494
|
+
predicate: expr.NamedExpr | None,
|
|
495
|
+
parquet_options: ParquetOptions,
|
|
496
|
+
) -> DataFrame:
|
|
497
|
+
"""Evaluate and return a dataframe."""
|
|
498
|
+
if typ == "csv":
|
|
499
|
+
|
|
500
|
+
def read_csv_header(
|
|
501
|
+
path: Path | str, sep: str
|
|
502
|
+
) -> list[str]: # pragma: no cover
|
|
503
|
+
with Path(path).open() as f:
|
|
504
|
+
for line in f:
|
|
505
|
+
stripped = line.strip()
|
|
506
|
+
if stripped:
|
|
507
|
+
return stripped.split(sep)
|
|
508
|
+
return []
|
|
509
|
+
|
|
510
|
+
parse_options = reader_options["parse_options"]
|
|
511
|
+
sep = chr(parse_options["separator"])
|
|
512
|
+
quote = chr(parse_options["quote_char"])
|
|
513
|
+
eol = chr(parse_options["eol_char"])
|
|
514
|
+
if reader_options["schema"] is not None:
|
|
515
|
+
# Reader schema provides names
|
|
516
|
+
column_names = list(reader_options["schema"]["fields"].keys())
|
|
517
|
+
else:
|
|
518
|
+
# file provides column names
|
|
519
|
+
column_names = None
|
|
520
|
+
usecols = with_columns
|
|
521
|
+
has_header = reader_options["has_header"]
|
|
522
|
+
header = 0 if has_header else -1
|
|
523
|
+
|
|
524
|
+
# polars defaults to no null recognition
|
|
525
|
+
null_values = [""]
|
|
526
|
+
if parse_options["null_values"] is not None:
|
|
527
|
+
((typ, nulls),) = parse_options["null_values"].items()
|
|
528
|
+
if typ == "AllColumnsSingle":
|
|
529
|
+
# Single value
|
|
530
|
+
null_values.append(nulls)
|
|
531
|
+
else:
|
|
532
|
+
# List of values
|
|
533
|
+
null_values.extend(nulls)
|
|
534
|
+
if parse_options["comment_prefix"] is not None:
|
|
535
|
+
comment = chr(parse_options["comment_prefix"]["Single"])
|
|
536
|
+
else:
|
|
537
|
+
comment = None
|
|
538
|
+
decimal = "," if parse_options["decimal_comma"] else "."
|
|
539
|
+
|
|
540
|
+
# polars skips blank lines at the beginning of the file
|
|
541
|
+
pieces = []
|
|
542
|
+
seen_paths = []
|
|
543
|
+
read_partial = n_rows != -1
|
|
544
|
+
for p in paths:
|
|
545
|
+
skiprows = reader_options["skip_rows"]
|
|
546
|
+
path = Path(p)
|
|
547
|
+
with path.open() as f:
|
|
548
|
+
while f.readline() == "\n":
|
|
549
|
+
skiprows += 1
|
|
550
|
+
options = (
|
|
551
|
+
plc.io.csv.CsvReaderOptions.builder(plc.io.SourceInfo([path]))
|
|
552
|
+
.nrows(n_rows)
|
|
553
|
+
.skiprows(skiprows + skip_rows)
|
|
554
|
+
.lineterminator(str(eol))
|
|
555
|
+
.quotechar(str(quote))
|
|
556
|
+
.decimal(decimal)
|
|
557
|
+
.keep_default_na(keep_default_na=False)
|
|
558
|
+
.na_filter(na_filter=True)
|
|
559
|
+
.delimiter(str(sep))
|
|
560
|
+
.build()
|
|
561
|
+
)
|
|
562
|
+
if column_names is not None:
|
|
563
|
+
options.set_names([str(name) for name in column_names])
|
|
564
|
+
else:
|
|
565
|
+
if header > -1 and skip_rows > header: # pragma: no cover
|
|
566
|
+
# We need to read the header otherwise we would skip it
|
|
567
|
+
column_names = read_csv_header(path, str(sep))
|
|
568
|
+
options.set_names(column_names)
|
|
569
|
+
options.set_header(header)
|
|
570
|
+
options.set_dtypes({name: dtype.plc for name, dtype in schema.items()})
|
|
571
|
+
if usecols is not None:
|
|
572
|
+
options.set_use_cols_names([str(name) for name in usecols])
|
|
573
|
+
options.set_na_values(null_values)
|
|
574
|
+
if comment is not None:
|
|
575
|
+
options.set_comment(comment)
|
|
576
|
+
tbl_w_meta = plc.io.csv.read_csv(options)
|
|
577
|
+
pieces.append(tbl_w_meta)
|
|
578
|
+
if include_file_paths is not None:
|
|
579
|
+
seen_paths.append(p)
|
|
580
|
+
if read_partial:
|
|
581
|
+
n_rows -= tbl_w_meta.tbl.num_rows()
|
|
582
|
+
if n_rows <= 0:
|
|
583
|
+
break
|
|
584
|
+
tables, (colnames, *_) = zip(
|
|
585
|
+
*(
|
|
586
|
+
(piece.tbl, piece.column_names(include_children=False))
|
|
587
|
+
for piece in pieces
|
|
588
|
+
),
|
|
589
|
+
strict=True,
|
|
590
|
+
)
|
|
591
|
+
df = DataFrame.from_table(
|
|
592
|
+
plc.concatenate.concatenate(list(tables)),
|
|
593
|
+
colnames,
|
|
594
|
+
[schema[colname] for colname in colnames],
|
|
595
|
+
)
|
|
596
|
+
if include_file_paths is not None:
|
|
597
|
+
df = Scan.add_file_paths(
|
|
598
|
+
include_file_paths,
|
|
599
|
+
seen_paths,
|
|
600
|
+
[t.num_rows() for t in tables],
|
|
601
|
+
df,
|
|
602
|
+
)
|
|
603
|
+
elif typ == "parquet":
|
|
604
|
+
filters = None
|
|
605
|
+
if predicate is not None and row_index is None:
|
|
606
|
+
# Can't apply filters during read if we have a row index.
|
|
607
|
+
filters = to_parquet_filter(predicate.value)
|
|
608
|
+
options = plc.io.parquet.ParquetReaderOptions.builder(
|
|
609
|
+
plc.io.SourceInfo(paths)
|
|
610
|
+
).build()
|
|
611
|
+
if with_columns is not None:
|
|
612
|
+
options.set_columns(with_columns)
|
|
613
|
+
if filters is not None:
|
|
614
|
+
options.set_filter(filters)
|
|
615
|
+
if n_rows != -1:
|
|
616
|
+
options.set_num_rows(n_rows)
|
|
617
|
+
if skip_rows != 0:
|
|
618
|
+
options.set_skip_rows(skip_rows)
|
|
619
|
+
if parquet_options.chunked:
|
|
620
|
+
reader = plc.io.parquet.ChunkedParquetReader(
|
|
621
|
+
options,
|
|
622
|
+
chunk_read_limit=parquet_options.chunk_read_limit,
|
|
623
|
+
pass_read_limit=parquet_options.pass_read_limit,
|
|
624
|
+
)
|
|
625
|
+
chunk = reader.read_chunk()
|
|
626
|
+
tbl = chunk.tbl
|
|
627
|
+
# TODO: Nested column names
|
|
628
|
+
names = chunk.column_names(include_children=False)
|
|
629
|
+
concatenated_columns = tbl.columns()
|
|
630
|
+
while reader.has_next():
|
|
631
|
+
chunk = reader.read_chunk()
|
|
632
|
+
tbl = chunk.tbl
|
|
633
|
+
for i in range(tbl.num_columns()):
|
|
634
|
+
concatenated_columns[i] = plc.concatenate.concatenate(
|
|
635
|
+
[concatenated_columns[i], tbl._columns[i]]
|
|
636
|
+
)
|
|
637
|
+
# Drop residual columns to save memory
|
|
638
|
+
tbl._columns[i] = None
|
|
639
|
+
df = DataFrame.from_table(
|
|
640
|
+
plc.Table(concatenated_columns),
|
|
641
|
+
names=names,
|
|
642
|
+
dtypes=[schema[name] for name in names],
|
|
643
|
+
)
|
|
644
|
+
df = _align_parquet_schema(df, schema)
|
|
645
|
+
if include_file_paths is not None:
|
|
646
|
+
df = Scan.add_file_paths(
|
|
647
|
+
include_file_paths, paths, chunk.num_rows_per_source, df
|
|
648
|
+
)
|
|
649
|
+
else:
|
|
650
|
+
tbl_w_meta = plc.io.parquet.read_parquet(options)
|
|
651
|
+
# TODO: consider nested column names?
|
|
652
|
+
col_names = tbl_w_meta.column_names(include_children=False)
|
|
653
|
+
df = DataFrame.from_table(
|
|
654
|
+
tbl_w_meta.tbl,
|
|
655
|
+
col_names,
|
|
656
|
+
[schema[name] for name in col_names],
|
|
657
|
+
)
|
|
658
|
+
df = _align_parquet_schema(df, schema)
|
|
659
|
+
if include_file_paths is not None:
|
|
660
|
+
df = Scan.add_file_paths(
|
|
661
|
+
include_file_paths, paths, tbl_w_meta.num_rows_per_source, df
|
|
662
|
+
)
|
|
663
|
+
if filters is not None:
|
|
664
|
+
# Mask must have been applied.
|
|
665
|
+
return df
|
|
666
|
+
elif typ == "ndjson":
|
|
667
|
+
json_schema: list[plc.io.json.NameAndType] = [
|
|
668
|
+
(name, typ.plc, []) for name, typ in schema.items()
|
|
669
|
+
]
|
|
670
|
+
plc_tbl_w_meta = plc.io.json.read_json(
|
|
671
|
+
plc.io.json._setup_json_reader_options(
|
|
672
|
+
plc.io.SourceInfo(paths),
|
|
673
|
+
lines=True,
|
|
674
|
+
dtypes=json_schema,
|
|
675
|
+
prune_columns=True,
|
|
676
|
+
)
|
|
677
|
+
)
|
|
678
|
+
# TODO: I don't think cudf-polars supports nested types in general right now
|
|
679
|
+
# (but when it does, we should pass child column names from nested columns in)
|
|
680
|
+
col_names = plc_tbl_w_meta.column_names(include_children=False)
|
|
681
|
+
df = DataFrame.from_table(
|
|
682
|
+
plc_tbl_w_meta.tbl,
|
|
683
|
+
col_names,
|
|
684
|
+
[schema[name] for name in col_names],
|
|
685
|
+
)
|
|
686
|
+
col_order = list(schema.keys())
|
|
687
|
+
if row_index is not None:
|
|
688
|
+
col_order.remove(row_index[0])
|
|
689
|
+
df = df.select(col_order)
|
|
690
|
+
else:
|
|
691
|
+
raise NotImplementedError(
|
|
692
|
+
f"Unhandled scan type: {typ}"
|
|
693
|
+
) # pragma: no cover; post init trips first
|
|
694
|
+
if row_index is not None:
|
|
695
|
+
name, offset = row_index
|
|
696
|
+
offset += skip_rows
|
|
697
|
+
dtype = schema[name]
|
|
698
|
+
step = plc.Scalar.from_py(1, dtype.plc)
|
|
699
|
+
init = plc.Scalar.from_py(offset, dtype.plc)
|
|
700
|
+
index_col = Column(
|
|
701
|
+
plc.filling.sequence(df.num_rows, init, step),
|
|
702
|
+
is_sorted=plc.types.Sorted.YES,
|
|
703
|
+
order=plc.types.Order.ASCENDING,
|
|
704
|
+
null_order=plc.types.NullOrder.AFTER,
|
|
705
|
+
name=name,
|
|
706
|
+
dtype=dtype,
|
|
707
|
+
)
|
|
708
|
+
df = DataFrame([index_col, *df.columns])
|
|
709
|
+
if next(iter(schema)) != name:
|
|
710
|
+
df = df.select(schema)
|
|
711
|
+
assert all(
|
|
712
|
+
c.obj.type() == schema[name].plc for name, c in df.column_map.items()
|
|
713
|
+
)
|
|
714
|
+
if predicate is None:
|
|
715
|
+
return df
|
|
716
|
+
else:
|
|
717
|
+
(mask,) = broadcast(predicate.evaluate(df), target_length=df.num_rows)
|
|
718
|
+
return df.filter(mask)
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
class Sink(IR):
|
|
722
|
+
"""Sink a dataframe to a file."""
|
|
723
|
+
|
|
724
|
+
__slots__ = ("cloud_options", "kind", "options", "parquet_options", "path")
|
|
725
|
+
_non_child = (
|
|
726
|
+
"schema",
|
|
727
|
+
"kind",
|
|
728
|
+
"path",
|
|
729
|
+
"parquet_options",
|
|
730
|
+
"options",
|
|
731
|
+
"cloud_options",
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
kind: str
|
|
735
|
+
"""The type of file to write to. Eg. Parquet, CSV, etc."""
|
|
736
|
+
path: str
|
|
737
|
+
"""The path to write to"""
|
|
738
|
+
parquet_options: ParquetOptions
|
|
739
|
+
"""GPU-specific configuration options"""
|
|
740
|
+
cloud_options: dict[str, Any] | None
|
|
741
|
+
"""Cloud-related authentication options, currently ignored."""
|
|
742
|
+
options: dict[str, Any]
|
|
743
|
+
"""Sink options from Polars"""
|
|
744
|
+
|
|
745
|
+
def __init__(
|
|
746
|
+
self,
|
|
747
|
+
schema: Schema,
|
|
748
|
+
kind: str,
|
|
749
|
+
path: str,
|
|
750
|
+
parquet_options: ParquetOptions,
|
|
751
|
+
options: dict[str, Any],
|
|
752
|
+
cloud_options: dict[str, Any],
|
|
753
|
+
df: IR,
|
|
754
|
+
):
|
|
755
|
+
self.schema = schema
|
|
756
|
+
self.kind = kind
|
|
757
|
+
self.path = path
|
|
758
|
+
self.parquet_options = parquet_options
|
|
759
|
+
self.options = options
|
|
760
|
+
self.cloud_options = cloud_options
|
|
761
|
+
self.children = (df,)
|
|
762
|
+
self._non_child_args = (schema, kind, path, parquet_options, options)
|
|
763
|
+
if self.cloud_options is not None and any(
|
|
764
|
+
self.cloud_options.get(k) is not None
|
|
765
|
+
for k in ("config", "credential_provider")
|
|
766
|
+
):
|
|
767
|
+
raise NotImplementedError(
|
|
768
|
+
"Write to cloud storage"
|
|
769
|
+
) # pragma: no cover; no test yet
|
|
770
|
+
sync_on_close = options.get("sync_on_close")
|
|
771
|
+
if sync_on_close not in {"None", None}:
|
|
772
|
+
raise NotImplementedError(
|
|
773
|
+
f"sync_on_close='{sync_on_close}' is not supported."
|
|
774
|
+
) # pragma: no cover; no test yet
|
|
775
|
+
child_schema = df.schema.values()
|
|
776
|
+
if kind == "Csv":
|
|
777
|
+
if not all(
|
|
778
|
+
plc.io.csv.is_supported_write_csv(dtype.plc) for dtype in child_schema
|
|
779
|
+
):
|
|
780
|
+
# Nested types are unsupported in polars and libcudf
|
|
781
|
+
raise NotImplementedError(
|
|
782
|
+
"Contains unsupported types for CSV writing"
|
|
783
|
+
) # pragma: no cover
|
|
784
|
+
serialize = options["serialize_options"]
|
|
785
|
+
if options["include_bom"]:
|
|
786
|
+
raise NotImplementedError("include_bom is not supported.")
|
|
787
|
+
for key in (
|
|
788
|
+
"date_format",
|
|
789
|
+
"time_format",
|
|
790
|
+
"datetime_format",
|
|
791
|
+
"float_scientific",
|
|
792
|
+
"float_precision",
|
|
793
|
+
):
|
|
794
|
+
if serialize[key] is not None:
|
|
795
|
+
raise NotImplementedError(f"{key} is not supported.")
|
|
796
|
+
if serialize["quote_style"] != "Necessary":
|
|
797
|
+
raise NotImplementedError("Only quote_style='Necessary' is supported.")
|
|
798
|
+
if chr(serialize["quote_char"]) != '"':
|
|
799
|
+
raise NotImplementedError("Only quote_char='\"' is supported.")
|
|
800
|
+
elif kind == "Parquet":
|
|
801
|
+
compression = options["compression"]
|
|
802
|
+
if isinstance(compression, dict):
|
|
803
|
+
if len(compression) != 1:
|
|
804
|
+
raise NotImplementedError(
|
|
805
|
+
"Compression dict with more than one entry."
|
|
806
|
+
) # pragma: no cover
|
|
807
|
+
compression, compression_level = next(iter(compression.items()))
|
|
808
|
+
options["compression"] = compression
|
|
809
|
+
if compression_level is not None:
|
|
810
|
+
raise NotImplementedError(
|
|
811
|
+
"Setting compression_level is not supported."
|
|
812
|
+
)
|
|
813
|
+
if compression == "Lz4Raw":
|
|
814
|
+
compression = "Lz4"
|
|
815
|
+
options["compression"] = compression
|
|
816
|
+
if (
|
|
817
|
+
compression != "Uncompressed"
|
|
818
|
+
and not plc.io.parquet.is_supported_write_parquet(
|
|
819
|
+
getattr(plc.io.types.CompressionType, compression.upper())
|
|
820
|
+
)
|
|
821
|
+
):
|
|
822
|
+
raise NotImplementedError(
|
|
823
|
+
f"Compression type '{compression}' is not supported."
|
|
824
|
+
)
|
|
825
|
+
elif (
|
|
826
|
+
kind == "Json"
|
|
827
|
+
): # pragma: no cover; options are validated on the polars side
|
|
828
|
+
if not all(
|
|
829
|
+
plc.io.json.is_supported_write_json(dtype.plc) for dtype in child_schema
|
|
830
|
+
):
|
|
831
|
+
# Nested types are unsupported in polars and libcudf
|
|
832
|
+
raise NotImplementedError(
|
|
833
|
+
"Contains unsupported types for JSON writing"
|
|
834
|
+
) # pragma: no cover
|
|
835
|
+
shared_writer_options = {"sync_on_close", "maintain_order", "mkdir"}
|
|
836
|
+
if set(options) - shared_writer_options:
|
|
837
|
+
raise NotImplementedError("Unsupported options passed JSON writer.")
|
|
838
|
+
else:
|
|
839
|
+
raise NotImplementedError(
|
|
840
|
+
f"Unhandled sink kind: {kind}"
|
|
841
|
+
) # pragma: no cover
|
|
842
|
+
|
|
843
|
+
def get_hashable(self) -> Hashable:
|
|
844
|
+
"""
|
|
845
|
+
Hashable representation of the node.
|
|
846
|
+
|
|
847
|
+
The option dictionary is serialised for hashing purposes.
|
|
848
|
+
"""
|
|
849
|
+
schema_hash = tuple(self.schema.items()) # pragma: no cover
|
|
850
|
+
return (
|
|
851
|
+
type(self),
|
|
852
|
+
schema_hash,
|
|
853
|
+
self.kind,
|
|
854
|
+
self.path,
|
|
855
|
+
self.parquet_options,
|
|
856
|
+
json.dumps(self.options),
|
|
857
|
+
json.dumps(self.cloud_options),
|
|
858
|
+
) # pragma: no cover
|
|
859
|
+
|
|
860
|
+
@classmethod
|
|
861
|
+
def _write_csv(
|
|
862
|
+
cls, target: plc.io.SinkInfo, options: dict[str, Any], df: DataFrame
|
|
863
|
+
) -> None:
|
|
864
|
+
"""Write CSV data to a sink."""
|
|
865
|
+
serialize = options["serialize_options"]
|
|
866
|
+
options = (
|
|
867
|
+
plc.io.csv.CsvWriterOptions.builder(target, df.table)
|
|
868
|
+
.include_header(options["include_header"])
|
|
869
|
+
.names(df.column_names if options["include_header"] else [])
|
|
870
|
+
.na_rep(serialize["null"])
|
|
871
|
+
.line_terminator(serialize["line_terminator"])
|
|
872
|
+
.inter_column_delimiter(chr(serialize["separator"]))
|
|
873
|
+
.build()
|
|
874
|
+
)
|
|
875
|
+
plc.io.csv.write_csv(options)
|
|
876
|
+
|
|
877
|
+
@classmethod
|
|
878
|
+
def _write_json(cls, target: plc.io.SinkInfo, df: DataFrame) -> None:
|
|
879
|
+
"""Write Json data to a sink."""
|
|
880
|
+
metadata = plc.io.TableWithMetadata(
|
|
881
|
+
df.table, [(col, []) for col in df.column_names]
|
|
882
|
+
)
|
|
883
|
+
options = (
|
|
884
|
+
plc.io.json.JsonWriterOptions.builder(target, df.table)
|
|
885
|
+
.lines(val=True)
|
|
886
|
+
.na_rep("null")
|
|
887
|
+
.include_nulls(val=True)
|
|
888
|
+
.metadata(metadata)
|
|
889
|
+
.utf8_escaped(val=False)
|
|
890
|
+
.build()
|
|
891
|
+
)
|
|
892
|
+
plc.io.json.write_json(options)
|
|
893
|
+
|
|
894
|
+
@staticmethod
|
|
895
|
+
def _make_parquet_metadata(df: DataFrame) -> plc.io.types.TableInputMetadata:
|
|
896
|
+
"""Create TableInputMetadata and set column names."""
|
|
897
|
+
metadata = plc.io.types.TableInputMetadata(df.table)
|
|
898
|
+
for i, name in enumerate(df.column_names):
|
|
899
|
+
metadata.column_metadata[i].set_name(name)
|
|
900
|
+
return metadata
|
|
901
|
+
|
|
902
|
+
@staticmethod
|
|
903
|
+
def _apply_parquet_writer_options(
|
|
904
|
+
builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder
|
|
905
|
+
| plc.io.parquet.ParquetWriterOptionsBuilder,
|
|
906
|
+
options: dict[str, Any],
|
|
907
|
+
) -> (
|
|
908
|
+
plc.io.parquet.ChunkedParquetWriterOptionsBuilder
|
|
909
|
+
| plc.io.parquet.ParquetWriterOptionsBuilder
|
|
910
|
+
):
|
|
911
|
+
"""Apply writer options to the builder."""
|
|
912
|
+
compression = options.get("compression")
|
|
913
|
+
if compression and compression != "Uncompressed":
|
|
914
|
+
compression_type = getattr(
|
|
915
|
+
plc.io.types.CompressionType, compression.upper()
|
|
916
|
+
)
|
|
917
|
+
builder = builder.compression(compression_type)
|
|
918
|
+
|
|
919
|
+
if (data_page_size := options.get("data_page_size")) is not None:
|
|
920
|
+
builder = builder.max_page_size_bytes(data_page_size)
|
|
921
|
+
|
|
922
|
+
if (row_group_size := options.get("row_group_size")) is not None:
|
|
923
|
+
builder = builder.row_group_size_rows(row_group_size)
|
|
924
|
+
|
|
925
|
+
return builder
|
|
926
|
+
|
|
927
|
+
@classmethod
|
|
928
|
+
def _write_parquet(
|
|
929
|
+
cls,
|
|
930
|
+
target: plc.io.SinkInfo,
|
|
931
|
+
parquet_options: ParquetOptions,
|
|
932
|
+
options: dict[str, Any],
|
|
933
|
+
df: DataFrame,
|
|
934
|
+
) -> None:
|
|
935
|
+
metadata: plc.io.types.TableInputMetadata = cls._make_parquet_metadata(df)
|
|
936
|
+
|
|
937
|
+
builder: (
|
|
938
|
+
plc.io.parquet.ChunkedParquetWriterOptionsBuilder
|
|
939
|
+
| plc.io.parquet.ParquetWriterOptionsBuilder
|
|
940
|
+
)
|
|
941
|
+
|
|
942
|
+
if (
|
|
943
|
+
parquet_options.chunked
|
|
944
|
+
and parquet_options.n_output_chunks != 1
|
|
945
|
+
and df.table.num_rows() != 0
|
|
946
|
+
):
|
|
947
|
+
builder = plc.io.parquet.ChunkedParquetWriterOptions.builder(
|
|
948
|
+
target
|
|
949
|
+
).metadata(metadata)
|
|
950
|
+
builder = cls._apply_parquet_writer_options(builder, options)
|
|
951
|
+
writer_options = builder.build()
|
|
952
|
+
writer = plc.io.parquet.ChunkedParquetWriter.from_options(writer_options)
|
|
953
|
+
|
|
954
|
+
# TODO: Can be based on a heuristic that estimates chunk size
|
|
955
|
+
# from the input table size and available GPU memory.
|
|
956
|
+
num_chunks = parquet_options.n_output_chunks
|
|
957
|
+
table_chunks = plc.copying.split(
|
|
958
|
+
df.table,
|
|
959
|
+
[i * df.table.num_rows() // num_chunks for i in range(1, num_chunks)],
|
|
960
|
+
)
|
|
961
|
+
for chunk in table_chunks:
|
|
962
|
+
writer.write(chunk)
|
|
963
|
+
writer.close([])
|
|
964
|
+
|
|
965
|
+
else:
|
|
966
|
+
builder = plc.io.parquet.ParquetWriterOptions.builder(
|
|
967
|
+
target, df.table
|
|
968
|
+
).metadata(metadata)
|
|
969
|
+
builder = cls._apply_parquet_writer_options(builder, options)
|
|
970
|
+
writer_options = builder.build()
|
|
971
|
+
plc.io.parquet.write_parquet(writer_options)
|
|
972
|
+
|
|
973
|
+
@classmethod
|
|
974
|
+
@nvtx_annotate_cudf_polars(message="Sink")
|
|
975
|
+
def do_evaluate(
|
|
976
|
+
cls,
|
|
977
|
+
schema: Schema,
|
|
978
|
+
kind: str,
|
|
979
|
+
path: str,
|
|
980
|
+
parquet_options: ParquetOptions,
|
|
981
|
+
options: dict[str, Any],
|
|
982
|
+
df: DataFrame,
|
|
983
|
+
) -> DataFrame:
|
|
984
|
+
"""Write the dataframe to a file."""
|
|
985
|
+
target = plc.io.SinkInfo([path])
|
|
986
|
+
|
|
987
|
+
if options.get("mkdir", False):
|
|
988
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
989
|
+
if kind == "Csv":
|
|
990
|
+
cls._write_csv(target, options, df)
|
|
991
|
+
elif kind == "Parquet":
|
|
992
|
+
cls._write_parquet(target, parquet_options, options, df)
|
|
993
|
+
elif kind == "Json":
|
|
994
|
+
cls._write_json(target, df)
|
|
995
|
+
|
|
996
|
+
return DataFrame([])
|
|
997
|
+
|
|
998
|
+
|
|
999
|
+
class Cache(IR):
|
|
1000
|
+
"""
|
|
1001
|
+
Return a cached plan node.
|
|
1002
|
+
|
|
1003
|
+
Used for CSE at the plan level.
|
|
1004
|
+
"""
|
|
1005
|
+
|
|
1006
|
+
__slots__ = ("key", "refcount")
|
|
1007
|
+
_non_child = ("schema", "key", "refcount")
|
|
1008
|
+
key: int
|
|
1009
|
+
"""The cache key."""
|
|
1010
|
+
refcount: int | None
|
|
1011
|
+
"""The number of cache hits."""
|
|
1012
|
+
|
|
1013
|
+
def __init__(self, schema: Schema, key: int, refcount: int | None, value: IR):
|
|
1014
|
+
self.schema = schema
|
|
1015
|
+
self.key = key
|
|
1016
|
+
self.refcount = refcount
|
|
1017
|
+
self.children = (value,)
|
|
1018
|
+
self._non_child_args = (key, refcount)
|
|
1019
|
+
|
|
1020
|
+
def get_hashable(self) -> Hashable: # noqa: D102
|
|
1021
|
+
# Polars arranges that the keys are unique across all cache
|
|
1022
|
+
# nodes that reference the same child, so we don't need to
|
|
1023
|
+
# hash the child.
|
|
1024
|
+
return (type(self), self.key, self.refcount)
|
|
1025
|
+
|
|
1026
|
+
def is_equal(self, other: Self) -> bool: # noqa: D102
|
|
1027
|
+
if self.key == other.key and self.refcount == other.refcount:
|
|
1028
|
+
self.children = other.children
|
|
1029
|
+
return True
|
|
1030
|
+
return False
|
|
1031
|
+
|
|
1032
|
+
@classmethod
|
|
1033
|
+
@nvtx_annotate_cudf_polars(message="Cache")
|
|
1034
|
+
def do_evaluate(
|
|
1035
|
+
cls, key: int, refcount: int | None, df: DataFrame
|
|
1036
|
+
) -> DataFrame: # pragma: no cover; basic evaluation never calls this
|
|
1037
|
+
"""Evaluate and return a dataframe."""
|
|
1038
|
+
# Our value has already been computed for us, so let's just
|
|
1039
|
+
# return it.
|
|
1040
|
+
return df
|
|
1041
|
+
|
|
1042
|
+
def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
|
|
1043
|
+
"""Evaluate and return a dataframe."""
|
|
1044
|
+
# We must override the recursion scheme because we don't want
|
|
1045
|
+
# to recurse if we're in the cache.
|
|
1046
|
+
try:
|
|
1047
|
+
(result, hits) = cache[self.key]
|
|
1048
|
+
except KeyError:
|
|
1049
|
+
(value,) = self.children
|
|
1050
|
+
result = value.evaluate(cache=cache, timer=timer)
|
|
1051
|
+
cache[self.key] = (result, 0)
|
|
1052
|
+
return result
|
|
1053
|
+
else:
|
|
1054
|
+
if self.refcount is None:
|
|
1055
|
+
return result
|
|
1056
|
+
|
|
1057
|
+
hits += 1 # pragma: no cover
|
|
1058
|
+
if hits == self.refcount: # pragma: no cover
|
|
1059
|
+
del cache[self.key]
|
|
1060
|
+
else: # pragma: no cover
|
|
1061
|
+
cache[self.key] = (result, hits)
|
|
1062
|
+
return result # pragma: no cover
|
|
1063
|
+
|
|
1064
|
+
|
|
1065
|
+
class DataFrameScan(IR):
|
|
1066
|
+
"""
|
|
1067
|
+
Input from an existing polars DataFrame.
|
|
1068
|
+
|
|
1069
|
+
This typically arises from ``q.collect().lazy()``
|
|
1070
|
+
"""
|
|
1071
|
+
|
|
1072
|
+
__slots__ = ("_id_for_hash", "df", "projection")
|
|
1073
|
+
_non_child = ("schema", "df", "projection")
|
|
1074
|
+
df: Any
|
|
1075
|
+
"""Polars internal PyDataFrame object."""
|
|
1076
|
+
projection: tuple[str, ...] | None
|
|
1077
|
+
"""List of columns to project out."""
|
|
1078
|
+
|
|
1079
|
+
def __init__(
|
|
1080
|
+
self,
|
|
1081
|
+
schema: Schema,
|
|
1082
|
+
df: Any,
|
|
1083
|
+
projection: Sequence[str] | None,
|
|
1084
|
+
):
|
|
1085
|
+
self.schema = schema
|
|
1086
|
+
self.df = df
|
|
1087
|
+
self.projection = tuple(projection) if projection is not None else None
|
|
1088
|
+
self._non_child_args = (
|
|
1089
|
+
schema,
|
|
1090
|
+
pl.DataFrame._from_pydf(df),
|
|
1091
|
+
self.projection,
|
|
1092
|
+
)
|
|
1093
|
+
self.children = ()
|
|
1094
|
+
self._id_for_hash = random.randint(0, 2**64 - 1)
|
|
1095
|
+
|
|
1096
|
+
def get_hashable(self) -> Hashable:
|
|
1097
|
+
"""
|
|
1098
|
+
Hashable representation of the node.
|
|
1099
|
+
|
|
1100
|
+
The (heavy) dataframe object is not hashed. No two instances of
|
|
1101
|
+
``DataFrameScan`` will have the same hash, even if they have the
|
|
1102
|
+
same schema, projection, and config options, and data.
|
|
1103
|
+
"""
|
|
1104
|
+
schema_hash = tuple(self.schema.items())
|
|
1105
|
+
return (
|
|
1106
|
+
type(self),
|
|
1107
|
+
schema_hash,
|
|
1108
|
+
self._id_for_hash,
|
|
1109
|
+
self.projection,
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
@classmethod
|
|
1113
|
+
@nvtx_annotate_cudf_polars(message="DataFrameScan")
|
|
1114
|
+
def do_evaluate(
|
|
1115
|
+
cls,
|
|
1116
|
+
schema: Schema,
|
|
1117
|
+
df: Any,
|
|
1118
|
+
projection: tuple[str, ...] | None,
|
|
1119
|
+
) -> DataFrame:
|
|
1120
|
+
"""Evaluate and return a dataframe."""
|
|
1121
|
+
if projection is not None:
|
|
1122
|
+
df = df.select(projection)
|
|
1123
|
+
df = DataFrame.from_polars(df)
|
|
1124
|
+
assert all(
|
|
1125
|
+
c.obj.type() == dtype.plc
|
|
1126
|
+
for c, dtype in zip(df.columns, schema.values(), strict=True)
|
|
1127
|
+
)
|
|
1128
|
+
return df
|
|
1129
|
+
|
|
1130
|
+
|
|
1131
|
+
class Select(IR):
|
|
1132
|
+
"""Produce a new dataframe selecting given expressions from an input."""
|
|
1133
|
+
|
|
1134
|
+
__slots__ = ("exprs", "should_broadcast")
|
|
1135
|
+
_non_child = ("schema", "exprs", "should_broadcast")
|
|
1136
|
+
exprs: tuple[expr.NamedExpr, ...]
|
|
1137
|
+
"""List of expressions to evaluate to form the new dataframe."""
|
|
1138
|
+
should_broadcast: bool
|
|
1139
|
+
"""Should columns be broadcast?"""
|
|
1140
|
+
|
|
1141
|
+
def __init__(
|
|
1142
|
+
self,
|
|
1143
|
+
schema: Schema,
|
|
1144
|
+
exprs: Sequence[expr.NamedExpr],
|
|
1145
|
+
should_broadcast: bool, # noqa: FBT001
|
|
1146
|
+
df: IR,
|
|
1147
|
+
):
|
|
1148
|
+
self.schema = schema
|
|
1149
|
+
self.exprs = tuple(exprs)
|
|
1150
|
+
self.should_broadcast = should_broadcast
|
|
1151
|
+
self.children = (df,)
|
|
1152
|
+
self._non_child_args = (self.exprs, should_broadcast)
|
|
1153
|
+
if (
|
|
1154
|
+
Select._is_len_expr(self.exprs)
|
|
1155
|
+
and isinstance(df, Scan)
|
|
1156
|
+
and df.typ != "parquet"
|
|
1157
|
+
): # pragma: no cover
|
|
1158
|
+
raise NotImplementedError(f"Unsupported scan type: {df.typ}")
|
|
1159
|
+
|
|
1160
|
+
@staticmethod
|
|
1161
|
+
def _is_len_expr(exprs: tuple[expr.NamedExpr, ...]) -> bool: # pragma: no cover
|
|
1162
|
+
if len(exprs) == 1:
|
|
1163
|
+
expr0 = exprs[0].value
|
|
1164
|
+
return (
|
|
1165
|
+
isinstance(expr0, expr.Cast)
|
|
1166
|
+
and len(expr0.children) == 1
|
|
1167
|
+
and isinstance(expr0.children[0], expr.Len)
|
|
1168
|
+
)
|
|
1169
|
+
return False
|
|
1170
|
+
|
|
1171
|
+
@classmethod
|
|
1172
|
+
@nvtx_annotate_cudf_polars(message="Select")
|
|
1173
|
+
def do_evaluate(
|
|
1174
|
+
cls,
|
|
1175
|
+
exprs: tuple[expr.NamedExpr, ...],
|
|
1176
|
+
should_broadcast: bool, # noqa: FBT001
|
|
1177
|
+
df: DataFrame,
|
|
1178
|
+
) -> DataFrame:
|
|
1179
|
+
"""Evaluate and return a dataframe."""
|
|
1180
|
+
# Handle any broadcasting
|
|
1181
|
+
columns = [e.evaluate(df) for e in exprs]
|
|
1182
|
+
if should_broadcast:
|
|
1183
|
+
columns = broadcast(*columns)
|
|
1184
|
+
return DataFrame(columns)
|
|
1185
|
+
|
|
1186
|
+
def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
|
|
1187
|
+
"""
|
|
1188
|
+
Evaluate the Select node with special handling for fast count queries.
|
|
1189
|
+
|
|
1190
|
+
Parameters
|
|
1191
|
+
----------
|
|
1192
|
+
cache
|
|
1193
|
+
Mapping from cached node ids to constructed DataFrames.
|
|
1194
|
+
Used to implement evaluation of the `Cache` node.
|
|
1195
|
+
timer
|
|
1196
|
+
If not None, a Timer object to record timings for the
|
|
1197
|
+
evaluation of the node.
|
|
1198
|
+
|
|
1199
|
+
Returns
|
|
1200
|
+
-------
|
|
1201
|
+
DataFrame
|
|
1202
|
+
Result of evaluating this Select node. If the expression is a
|
|
1203
|
+
count over a parquet scan, returns a constant row count directly
|
|
1204
|
+
without evaluating the scan.
|
|
1205
|
+
|
|
1206
|
+
Raises
|
|
1207
|
+
------
|
|
1208
|
+
NotImplementedError
|
|
1209
|
+
If evaluation fails. Ideally this should not occur, since the
|
|
1210
|
+
translation phase should fail earlier.
|
|
1211
|
+
"""
|
|
1212
|
+
if (
|
|
1213
|
+
isinstance(self.children[0], Scan)
|
|
1214
|
+
and Select._is_len_expr(self.exprs)
|
|
1215
|
+
and self.children[0].typ == "parquet"
|
|
1216
|
+
and self.children[0].predicate is None
|
|
1217
|
+
):
|
|
1218
|
+
scan = self.children[0] # pragma: no cover
|
|
1219
|
+
effective_rows = scan.fast_count() # pragma: no cover
|
|
1220
|
+
dtype = DataType(pl.UInt32()) # pragma: no cover
|
|
1221
|
+
col = Column(
|
|
1222
|
+
plc.Column.from_scalar(
|
|
1223
|
+
plc.Scalar.from_py(effective_rows, dtype.plc),
|
|
1224
|
+
1,
|
|
1225
|
+
),
|
|
1226
|
+
name=self.exprs[0].name or "len",
|
|
1227
|
+
dtype=dtype,
|
|
1228
|
+
) # pragma: no cover
|
|
1229
|
+
return DataFrame([col]) # pragma: no cover
|
|
1230
|
+
|
|
1231
|
+
return super().evaluate(cache=cache, timer=timer)
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
class Reduce(IR):
|
|
1235
|
+
"""
|
|
1236
|
+
Produce a new dataframe selecting given expressions from an input.
|
|
1237
|
+
|
|
1238
|
+
This is a special case of :class:`Select` where all outputs are a single row.
|
|
1239
|
+
"""
|
|
1240
|
+
|
|
1241
|
+
__slots__ = ("exprs",)
|
|
1242
|
+
_non_child = ("schema", "exprs")
|
|
1243
|
+
exprs: tuple[expr.NamedExpr, ...]
|
|
1244
|
+
"""List of expressions to evaluate to form the new dataframe."""
|
|
1245
|
+
|
|
1246
|
+
def __init__(
|
|
1247
|
+
self, schema: Schema, exprs: Sequence[expr.NamedExpr], df: IR
|
|
1248
|
+
): # pragma: no cover; polars doesn't emit this node yet
|
|
1249
|
+
self.schema = schema
|
|
1250
|
+
self.exprs = tuple(exprs)
|
|
1251
|
+
self.children = (df,)
|
|
1252
|
+
self._non_child_args = (self.exprs,)
|
|
1253
|
+
|
|
1254
|
+
@classmethod
|
|
1255
|
+
@nvtx_annotate_cudf_polars(message="Reduce")
|
|
1256
|
+
def do_evaluate(
|
|
1257
|
+
cls,
|
|
1258
|
+
exprs: tuple[expr.NamedExpr, ...],
|
|
1259
|
+
df: DataFrame,
|
|
1260
|
+
) -> DataFrame: # pragma: no cover; not exposed by polars yet
|
|
1261
|
+
"""Evaluate and return a dataframe."""
|
|
1262
|
+
columns = broadcast(*(e.evaluate(df) for e in exprs))
|
|
1263
|
+
assert all(column.size == 1 for column in columns)
|
|
1264
|
+
return DataFrame(columns)
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
class Rolling(IR):
|
|
1268
|
+
"""Perform a (possibly grouped) rolling aggregation."""
|
|
1269
|
+
|
|
1270
|
+
__slots__ = (
|
|
1271
|
+
"agg_requests",
|
|
1272
|
+
"closed_window",
|
|
1273
|
+
"following",
|
|
1274
|
+
"index",
|
|
1275
|
+
"keys",
|
|
1276
|
+
"preceding",
|
|
1277
|
+
"zlice",
|
|
1278
|
+
)
|
|
1279
|
+
_non_child = (
|
|
1280
|
+
"schema",
|
|
1281
|
+
"index",
|
|
1282
|
+
"preceding",
|
|
1283
|
+
"following",
|
|
1284
|
+
"closed_window",
|
|
1285
|
+
"keys",
|
|
1286
|
+
"agg_requests",
|
|
1287
|
+
"zlice",
|
|
1288
|
+
)
|
|
1289
|
+
index: expr.NamedExpr
|
|
1290
|
+
"""Column being rolled over."""
|
|
1291
|
+
preceding: plc.Scalar
|
|
1292
|
+
"""Preceding window extent defining start of window."""
|
|
1293
|
+
following: plc.Scalar
|
|
1294
|
+
"""Following window extent defining end of window."""
|
|
1295
|
+
closed_window: ClosedInterval
|
|
1296
|
+
"""Treatment of window endpoints."""
|
|
1297
|
+
keys: tuple[expr.NamedExpr, ...]
|
|
1298
|
+
"""Grouping keys."""
|
|
1299
|
+
agg_requests: tuple[expr.NamedExpr, ...]
|
|
1300
|
+
"""Aggregation expressions."""
|
|
1301
|
+
zlice: Zlice | None
|
|
1302
|
+
"""Optional slice"""
|
|
1303
|
+
|
|
1304
|
+
def __init__(
|
|
1305
|
+
self,
|
|
1306
|
+
schema: Schema,
|
|
1307
|
+
index: expr.NamedExpr,
|
|
1308
|
+
preceding: plc.Scalar,
|
|
1309
|
+
following: plc.Scalar,
|
|
1310
|
+
closed_window: ClosedInterval,
|
|
1311
|
+
keys: Sequence[expr.NamedExpr],
|
|
1312
|
+
agg_requests: Sequence[expr.NamedExpr],
|
|
1313
|
+
zlice: Zlice | None,
|
|
1314
|
+
df: IR,
|
|
1315
|
+
):
|
|
1316
|
+
self.schema = schema
|
|
1317
|
+
self.index = index
|
|
1318
|
+
self.preceding = preceding
|
|
1319
|
+
self.following = following
|
|
1320
|
+
self.closed_window = closed_window
|
|
1321
|
+
self.keys = tuple(keys)
|
|
1322
|
+
self.agg_requests = tuple(agg_requests)
|
|
1323
|
+
if not all(
|
|
1324
|
+
plc.rolling.is_valid_rolling_aggregation(
|
|
1325
|
+
agg.value.dtype.plc, agg.value.agg_request
|
|
1326
|
+
)
|
|
1327
|
+
for agg in self.agg_requests
|
|
1328
|
+
):
|
|
1329
|
+
raise NotImplementedError("Unsupported rolling aggregation")
|
|
1330
|
+
if any(
|
|
1331
|
+
agg.value.agg_request.kind() == plc.aggregation.Kind.COLLECT_LIST
|
|
1332
|
+
for agg in self.agg_requests
|
|
1333
|
+
):
|
|
1334
|
+
raise NotImplementedError(
|
|
1335
|
+
"Incorrect handling of empty groups for list collection"
|
|
1336
|
+
)
|
|
1337
|
+
|
|
1338
|
+
self.zlice = zlice
|
|
1339
|
+
self.children = (df,)
|
|
1340
|
+
self._non_child_args = (
|
|
1341
|
+
index,
|
|
1342
|
+
preceding,
|
|
1343
|
+
following,
|
|
1344
|
+
closed_window,
|
|
1345
|
+
keys,
|
|
1346
|
+
agg_requests,
|
|
1347
|
+
zlice,
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
@classmethod
|
|
1351
|
+
@nvtx_annotate_cudf_polars(message="Rolling")
|
|
1352
|
+
def do_evaluate(
|
|
1353
|
+
cls,
|
|
1354
|
+
index: expr.NamedExpr,
|
|
1355
|
+
preceding: plc.Scalar,
|
|
1356
|
+
following: plc.Scalar,
|
|
1357
|
+
closed_window: ClosedInterval,
|
|
1358
|
+
keys_in: Sequence[expr.NamedExpr],
|
|
1359
|
+
aggs: Sequence[expr.NamedExpr],
|
|
1360
|
+
zlice: Zlice | None,
|
|
1361
|
+
df: DataFrame,
|
|
1362
|
+
) -> DataFrame:
|
|
1363
|
+
"""Evaluate and return a dataframe."""
|
|
1364
|
+
keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
|
|
1365
|
+
orderby = index.evaluate(df)
|
|
1366
|
+
# Polars casts integral orderby to int64, but only for calculating window bounds
|
|
1367
|
+
if (
|
|
1368
|
+
plc.traits.is_integral(orderby.obj.type())
|
|
1369
|
+
and orderby.obj.type().id() != plc.TypeId.INT64
|
|
1370
|
+
):
|
|
1371
|
+
orderby_obj = plc.unary.cast(orderby.obj, plc.DataType(plc.TypeId.INT64))
|
|
1372
|
+
else:
|
|
1373
|
+
orderby_obj = orderby.obj
|
|
1374
|
+
preceding_window, following_window = range_window_bounds(
|
|
1375
|
+
preceding, following, closed_window
|
|
1376
|
+
)
|
|
1377
|
+
if orderby.obj.null_count() != 0:
|
|
1378
|
+
raise RuntimeError(
|
|
1379
|
+
f"Index column '{index.name}' in rolling may not contain nulls"
|
|
1380
|
+
)
|
|
1381
|
+
if len(keys_in) > 0:
|
|
1382
|
+
# Must always check sortedness
|
|
1383
|
+
table = plc.Table([*(k.obj for k in keys), orderby_obj])
|
|
1384
|
+
n = table.num_columns()
|
|
1385
|
+
if not plc.sorting.is_sorted(
|
|
1386
|
+
table, [plc.types.Order.ASCENDING] * n, [plc.types.NullOrder.BEFORE] * n
|
|
1387
|
+
):
|
|
1388
|
+
raise RuntimeError("Input for grouped rolling is not sorted")
|
|
1389
|
+
else:
|
|
1390
|
+
if not orderby.check_sorted(
|
|
1391
|
+
order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE
|
|
1392
|
+
):
|
|
1393
|
+
raise RuntimeError(
|
|
1394
|
+
f"Index column '{index.name}' in rolling is not sorted, please sort first"
|
|
1395
|
+
)
|
|
1396
|
+
values = plc.rolling.grouped_range_rolling_window(
|
|
1397
|
+
plc.Table([k.obj for k in keys]),
|
|
1398
|
+
orderby_obj,
|
|
1399
|
+
plc.types.Order.ASCENDING, # Polars requires ascending orderby.
|
|
1400
|
+
plc.types.NullOrder.BEFORE, # Doesn't matter, polars doesn't allow nulls in orderby
|
|
1401
|
+
preceding_window,
|
|
1402
|
+
following_window,
|
|
1403
|
+
[rolling.to_request(request.value, orderby, df) for request in aggs],
|
|
1404
|
+
)
|
|
1405
|
+
return DataFrame(
|
|
1406
|
+
itertools.chain(
|
|
1407
|
+
keys,
|
|
1408
|
+
[orderby],
|
|
1409
|
+
(
|
|
1410
|
+
Column(col, name=request.name, dtype=request.value.dtype)
|
|
1411
|
+
for col, request in zip(values.columns(), aggs, strict=True)
|
|
1412
|
+
),
|
|
1413
|
+
)
|
|
1414
|
+
).slice(zlice)
|
|
1415
|
+
|
|
1416
|
+
|
|
1417
|
+
class GroupBy(IR):
|
|
1418
|
+
"""Perform a groupby."""
|
|
1419
|
+
|
|
1420
|
+
__slots__ = (
|
|
1421
|
+
"agg_requests",
|
|
1422
|
+
"keys",
|
|
1423
|
+
"maintain_order",
|
|
1424
|
+
"zlice",
|
|
1425
|
+
)
|
|
1426
|
+
_non_child = (
|
|
1427
|
+
"schema",
|
|
1428
|
+
"keys",
|
|
1429
|
+
"agg_requests",
|
|
1430
|
+
"maintain_order",
|
|
1431
|
+
"zlice",
|
|
1432
|
+
)
|
|
1433
|
+
keys: tuple[expr.NamedExpr, ...]
|
|
1434
|
+
"""Grouping keys."""
|
|
1435
|
+
agg_requests: tuple[expr.NamedExpr, ...]
|
|
1436
|
+
"""Aggregation expressions."""
|
|
1437
|
+
maintain_order: bool
|
|
1438
|
+
"""Preserve order in groupby."""
|
|
1439
|
+
zlice: Zlice | None
|
|
1440
|
+
"""Optional slice to apply after grouping."""
|
|
1441
|
+
|
|
1442
|
+
def __init__(
|
|
1443
|
+
self,
|
|
1444
|
+
schema: Schema,
|
|
1445
|
+
keys: Sequence[expr.NamedExpr],
|
|
1446
|
+
agg_requests: Sequence[expr.NamedExpr],
|
|
1447
|
+
maintain_order: bool, # noqa: FBT001
|
|
1448
|
+
zlice: Zlice | None,
|
|
1449
|
+
df: IR,
|
|
1450
|
+
):
|
|
1451
|
+
self.schema = schema
|
|
1452
|
+
self.keys = tuple(keys)
|
|
1453
|
+
for request in agg_requests:
|
|
1454
|
+
expr = request.value
|
|
1455
|
+
if isinstance(expr, unary.UnaryFunction) and expr.name == "value_counts":
|
|
1456
|
+
raise NotImplementedError("value_counts is not supported in groupby")
|
|
1457
|
+
if any(
|
|
1458
|
+
isinstance(child, unary.UnaryFunction) and child.name == "value_counts"
|
|
1459
|
+
for child in expr.children
|
|
1460
|
+
):
|
|
1461
|
+
raise NotImplementedError("value_counts is not supported in groupby")
|
|
1462
|
+
self.agg_requests = tuple(agg_requests)
|
|
1463
|
+
self.maintain_order = maintain_order
|
|
1464
|
+
self.zlice = zlice
|
|
1465
|
+
self.children = (df,)
|
|
1466
|
+
self._non_child_args = (
|
|
1467
|
+
schema,
|
|
1468
|
+
self.keys,
|
|
1469
|
+
self.agg_requests,
|
|
1470
|
+
maintain_order,
|
|
1471
|
+
self.zlice,
|
|
1472
|
+
)
|
|
1473
|
+
|
|
1474
|
+
@classmethod
|
|
1475
|
+
@nvtx_annotate_cudf_polars(message="GroupBy")
|
|
1476
|
+
def do_evaluate(
|
|
1477
|
+
cls,
|
|
1478
|
+
schema: Schema,
|
|
1479
|
+
keys_in: Sequence[expr.NamedExpr],
|
|
1480
|
+
agg_requests: Sequence[expr.NamedExpr],
|
|
1481
|
+
maintain_order: bool, # noqa: FBT001
|
|
1482
|
+
zlice: Zlice | None,
|
|
1483
|
+
df: DataFrame,
|
|
1484
|
+
) -> DataFrame:
|
|
1485
|
+
"""Evaluate and return a dataframe."""
|
|
1486
|
+
keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
|
|
1487
|
+
sorted = (
|
|
1488
|
+
plc.types.Sorted.YES
|
|
1489
|
+
if all(k.is_sorted for k in keys)
|
|
1490
|
+
else plc.types.Sorted.NO
|
|
1491
|
+
)
|
|
1492
|
+
grouper = plc.groupby.GroupBy(
|
|
1493
|
+
plc.Table([k.obj for k in keys]),
|
|
1494
|
+
null_handling=plc.types.NullPolicy.INCLUDE,
|
|
1495
|
+
keys_are_sorted=sorted,
|
|
1496
|
+
column_order=[k.order for k in keys],
|
|
1497
|
+
null_precedence=[k.null_order for k in keys],
|
|
1498
|
+
)
|
|
1499
|
+
requests = []
|
|
1500
|
+
names = []
|
|
1501
|
+
for request in agg_requests:
|
|
1502
|
+
name = request.name
|
|
1503
|
+
value = request.value
|
|
1504
|
+
if isinstance(value, expr.Len):
|
|
1505
|
+
# A count aggregation, we need a column so use a key column
|
|
1506
|
+
col = keys[0].obj
|
|
1507
|
+
elif isinstance(value, expr.Agg):
|
|
1508
|
+
if value.name == "quantile":
|
|
1509
|
+
child = value.children[0]
|
|
1510
|
+
else:
|
|
1511
|
+
(child,) = value.children
|
|
1512
|
+
col = child.evaluate(df, context=ExecutionContext.GROUPBY).obj
|
|
1513
|
+
else:
|
|
1514
|
+
# Anything else, we pre-evaluate
|
|
1515
|
+
col = value.evaluate(df, context=ExecutionContext.GROUPBY).obj
|
|
1516
|
+
requests.append(plc.groupby.GroupByRequest(col, [value.agg_request]))
|
|
1517
|
+
names.append(name)
|
|
1518
|
+
group_keys, raw_tables = grouper.aggregate(requests)
|
|
1519
|
+
results = [
|
|
1520
|
+
Column(column, name=name, dtype=schema[name])
|
|
1521
|
+
for name, column, request in zip(
|
|
1522
|
+
names,
|
|
1523
|
+
itertools.chain.from_iterable(t.columns() for t in raw_tables),
|
|
1524
|
+
agg_requests,
|
|
1525
|
+
strict=True,
|
|
1526
|
+
)
|
|
1527
|
+
]
|
|
1528
|
+
result_keys = [
|
|
1529
|
+
Column(grouped_key, name=key.name, dtype=key.dtype)
|
|
1530
|
+
for key, grouped_key in zip(keys, group_keys.columns(), strict=True)
|
|
1531
|
+
]
|
|
1532
|
+
broadcasted = broadcast(*result_keys, *results)
|
|
1533
|
+
# Handle order preservation of groups
|
|
1534
|
+
if maintain_order and not sorted:
|
|
1535
|
+
# The order we want
|
|
1536
|
+
want = plc.stream_compaction.stable_distinct(
|
|
1537
|
+
plc.Table([k.obj for k in keys]),
|
|
1538
|
+
list(range(group_keys.num_columns())),
|
|
1539
|
+
plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
|
|
1540
|
+
plc.types.NullEquality.EQUAL,
|
|
1541
|
+
plc.types.NanEquality.ALL_EQUAL,
|
|
1542
|
+
)
|
|
1543
|
+
# The order we have
|
|
1544
|
+
have = plc.Table([key.obj for key in broadcasted[: len(keys)]])
|
|
1545
|
+
|
|
1546
|
+
# We know an inner join is OK because by construction
|
|
1547
|
+
# want and have are permutations of each other.
|
|
1548
|
+
left_order, right_order = plc.join.inner_join(
|
|
1549
|
+
want, have, plc.types.NullEquality.EQUAL
|
|
1550
|
+
)
|
|
1551
|
+
# Now left_order is an arbitrary permutation of the ordering we
|
|
1552
|
+
# want, and right_order is a matching permutation of the ordering
|
|
1553
|
+
# we have. To get to the original ordering, we need
|
|
1554
|
+
# left_order == iota(nrows), with right_order permuted
|
|
1555
|
+
# appropriately. This can be obtained by sorting
|
|
1556
|
+
# right_order by left_order.
|
|
1557
|
+
(right_order,) = plc.sorting.sort_by_key(
|
|
1558
|
+
plc.Table([right_order]),
|
|
1559
|
+
plc.Table([left_order]),
|
|
1560
|
+
[plc.types.Order.ASCENDING],
|
|
1561
|
+
[plc.types.NullOrder.AFTER],
|
|
1562
|
+
).columns()
|
|
1563
|
+
ordered_table = plc.copying.gather(
|
|
1564
|
+
plc.Table([col.obj for col in broadcasted]),
|
|
1565
|
+
right_order,
|
|
1566
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
1567
|
+
)
|
|
1568
|
+
broadcasted = [
|
|
1569
|
+
Column(reordered, name=old.name, dtype=old.dtype)
|
|
1570
|
+
for reordered, old in zip(
|
|
1571
|
+
ordered_table.columns(), broadcasted, strict=True
|
|
1572
|
+
)
|
|
1573
|
+
]
|
|
1574
|
+
return DataFrame(broadcasted).slice(zlice)
|
|
1575
|
+
|
|
1576
|
+
|
|
1577
|
+
class ConditionalJoin(IR):
|
|
1578
|
+
"""A conditional inner join of two dataframes on a predicate."""
|
|
1579
|
+
|
|
1580
|
+
class Predicate:
|
|
1581
|
+
"""Serializable wrapper for a predicate expression."""
|
|
1582
|
+
|
|
1583
|
+
predicate: expr.Expr
|
|
1584
|
+
ast: plc.expressions.Expression
|
|
1585
|
+
|
|
1586
|
+
def __init__(self, predicate: expr.Expr):
|
|
1587
|
+
self.predicate = predicate
|
|
1588
|
+
self.ast = to_ast(predicate)
|
|
1589
|
+
|
|
1590
|
+
def __reduce__(self) -> tuple[Any, ...]:
|
|
1591
|
+
"""Pickle a Predicate object."""
|
|
1592
|
+
return (type(self), (self.predicate,))
|
|
1593
|
+
|
|
1594
|
+
__slots__ = ("ast_predicate", "options", "predicate")
|
|
1595
|
+
_non_child = ("schema", "predicate", "options")
|
|
1596
|
+
predicate: expr.Expr
|
|
1597
|
+
"""Expression predicate to join on"""
|
|
1598
|
+
options: tuple[
|
|
1599
|
+
tuple[
|
|
1600
|
+
str,
|
|
1601
|
+
pl_expr.Operator | Iterable[pl_expr.Operator],
|
|
1602
|
+
],
|
|
1603
|
+
bool,
|
|
1604
|
+
Zlice | None,
|
|
1605
|
+
str,
|
|
1606
|
+
bool,
|
|
1607
|
+
Literal["none", "left", "right", "left_right", "right_left"],
|
|
1608
|
+
]
|
|
1609
|
+
"""
|
|
1610
|
+
tuple of options:
|
|
1611
|
+
- predicates: tuple of ir join type (eg. ie_join) and (In)Equality conditions
|
|
1612
|
+
- nulls_equal: do nulls compare equal?
|
|
1613
|
+
- slice: optional slice to perform after joining.
|
|
1614
|
+
- suffix: string suffix for right columns if names match
|
|
1615
|
+
- coalesce: should key columns be coalesced (only makes sense for outer joins)
|
|
1616
|
+
- maintain_order: which DataFrame row order to preserve, if any
|
|
1617
|
+
"""
|
|
1618
|
+
|
|
1619
|
+
def __init__(
|
|
1620
|
+
self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
|
|
1621
|
+
) -> None:
|
|
1622
|
+
self.schema = schema
|
|
1623
|
+
self.predicate = predicate
|
|
1624
|
+
self.options = options
|
|
1625
|
+
self.children = (left, right)
|
|
1626
|
+
predicate_wrapper = self.Predicate(predicate)
|
|
1627
|
+
_, nulls_equal, zlice, suffix, coalesce, maintain_order = self.options
|
|
1628
|
+
# Preconditions from polars
|
|
1629
|
+
assert not nulls_equal
|
|
1630
|
+
assert not coalesce
|
|
1631
|
+
assert maintain_order == "none"
|
|
1632
|
+
if predicate_wrapper.ast is None:
|
|
1633
|
+
raise NotImplementedError(
|
|
1634
|
+
f"Conditional join with predicate {predicate}"
|
|
1635
|
+
) # pragma: no cover; polars never delivers expressions we can't handle
|
|
1636
|
+
self._non_child_args = (predicate_wrapper, zlice, suffix, maintain_order)
|
|
1637
|
+
|
|
1638
|
+
@classmethod
|
|
1639
|
+
@nvtx_annotate_cudf_polars(message="ConditionalJoin")
|
|
1640
|
+
def do_evaluate(
|
|
1641
|
+
cls,
|
|
1642
|
+
predicate_wrapper: Predicate,
|
|
1643
|
+
zlice: Zlice | None,
|
|
1644
|
+
suffix: str,
|
|
1645
|
+
maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
|
|
1646
|
+
left: DataFrame,
|
|
1647
|
+
right: DataFrame,
|
|
1648
|
+
) -> DataFrame:
|
|
1649
|
+
"""Evaluate and return a dataframe."""
|
|
1650
|
+
lg, rg = plc.join.conditional_inner_join(
|
|
1651
|
+
left.table,
|
|
1652
|
+
right.table,
|
|
1653
|
+
predicate_wrapper.ast,
|
|
1654
|
+
)
|
|
1655
|
+
left = DataFrame.from_table(
|
|
1656
|
+
plc.copying.gather(
|
|
1657
|
+
left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
|
|
1658
|
+
),
|
|
1659
|
+
left.column_names,
|
|
1660
|
+
left.dtypes,
|
|
1661
|
+
)
|
|
1662
|
+
right = DataFrame.from_table(
|
|
1663
|
+
plc.copying.gather(
|
|
1664
|
+
right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
|
|
1665
|
+
),
|
|
1666
|
+
right.column_names,
|
|
1667
|
+
right.dtypes,
|
|
1668
|
+
)
|
|
1669
|
+
right = right.rename_columns(
|
|
1670
|
+
{
|
|
1671
|
+
name: f"{name}{suffix}"
|
|
1672
|
+
for name in right.column_names
|
|
1673
|
+
if name in left.column_names_set
|
|
1674
|
+
}
|
|
1675
|
+
)
|
|
1676
|
+
result = left.with_columns(right.columns)
|
|
1677
|
+
return result.slice(zlice)
|
|
1678
|
+
|
|
1679
|
+
|
|
1680
|
+
class Join(IR):
|
|
1681
|
+
"""A join of two dataframes."""
|
|
1682
|
+
|
|
1683
|
+
__slots__ = ("left_on", "options", "right_on")
|
|
1684
|
+
_non_child = ("schema", "left_on", "right_on", "options")
|
|
1685
|
+
left_on: tuple[expr.NamedExpr, ...]
|
|
1686
|
+
"""List of expressions used as keys in the left frame."""
|
|
1687
|
+
right_on: tuple[expr.NamedExpr, ...]
|
|
1688
|
+
"""List of expressions used as keys in the right frame."""
|
|
1689
|
+
options: tuple[
|
|
1690
|
+
Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
|
|
1691
|
+
bool,
|
|
1692
|
+
Zlice | None,
|
|
1693
|
+
str,
|
|
1694
|
+
bool,
|
|
1695
|
+
Literal["none", "left", "right", "left_right", "right_left"],
|
|
1696
|
+
]
|
|
1697
|
+
"""
|
|
1698
|
+
tuple of options:
|
|
1699
|
+
- how: join type
|
|
1700
|
+
- nulls_equal: do nulls compare equal?
|
|
1701
|
+
- slice: optional slice to perform after joining.
|
|
1702
|
+
- suffix: string suffix for right columns if names match
|
|
1703
|
+
- coalesce: should key columns be coalesced (only makes sense for outer joins)
|
|
1704
|
+
- maintain_order: which DataFrame row order to preserve, if any
|
|
1705
|
+
"""
|
|
1706
|
+
|
|
1707
|
+
def __init__(
|
|
1708
|
+
self,
|
|
1709
|
+
schema: Schema,
|
|
1710
|
+
left_on: Sequence[expr.NamedExpr],
|
|
1711
|
+
right_on: Sequence[expr.NamedExpr],
|
|
1712
|
+
options: Any,
|
|
1713
|
+
left: IR,
|
|
1714
|
+
right: IR,
|
|
1715
|
+
):
|
|
1716
|
+
self.schema = schema
|
|
1717
|
+
self.left_on = tuple(left_on)
|
|
1718
|
+
self.right_on = tuple(right_on)
|
|
1719
|
+
self.options = options
|
|
1720
|
+
self.children = (left, right)
|
|
1721
|
+
self._non_child_args = (self.left_on, self.right_on, self.options)
|
|
1722
|
+
# TODO: Implement maintain_order
|
|
1723
|
+
if options[5] != "none":
|
|
1724
|
+
raise NotImplementedError("maintain_order not implemented yet")
|
|
1725
|
+
|
|
1726
|
+
@staticmethod
|
|
1727
|
+
@cache
|
|
1728
|
+
def _joiners(
|
|
1729
|
+
how: Literal["Inner", "Left", "Right", "Full", "Semi", "Anti"],
|
|
1730
|
+
) -> tuple[
|
|
1731
|
+
Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None
|
|
1732
|
+
]:
|
|
1733
|
+
if how == "Inner":
|
|
1734
|
+
return (
|
|
1735
|
+
plc.join.inner_join,
|
|
1736
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
1737
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
1738
|
+
)
|
|
1739
|
+
elif how == "Left" or how == "Right":
|
|
1740
|
+
return (
|
|
1741
|
+
plc.join.left_join,
|
|
1742
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
1743
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
1744
|
+
)
|
|
1745
|
+
elif how == "Full":
|
|
1746
|
+
return (
|
|
1747
|
+
plc.join.full_join,
|
|
1748
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
1749
|
+
plc.copying.OutOfBoundsPolicy.NULLIFY,
|
|
1750
|
+
)
|
|
1751
|
+
elif how == "Semi":
|
|
1752
|
+
return (
|
|
1753
|
+
plc.join.left_semi_join,
|
|
1754
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
1755
|
+
None,
|
|
1756
|
+
)
|
|
1757
|
+
elif how == "Anti":
|
|
1758
|
+
return (
|
|
1759
|
+
plc.join.left_anti_join,
|
|
1760
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
1761
|
+
None,
|
|
1762
|
+
)
|
|
1763
|
+
assert_never(how) # pragma: no cover
|
|
1764
|
+
|
|
1765
|
+
@staticmethod
|
|
1766
|
+
def _reorder_maps(
|
|
1767
|
+
left_rows: int,
|
|
1768
|
+
lg: plc.Column,
|
|
1769
|
+
left_policy: plc.copying.OutOfBoundsPolicy,
|
|
1770
|
+
right_rows: int,
|
|
1771
|
+
rg: plc.Column,
|
|
1772
|
+
right_policy: plc.copying.OutOfBoundsPolicy,
|
|
1773
|
+
) -> list[plc.Column]:
|
|
1774
|
+
"""
|
|
1775
|
+
Reorder gather maps to satisfy polars join order restrictions.
|
|
1776
|
+
|
|
1777
|
+
Parameters
|
|
1778
|
+
----------
|
|
1779
|
+
left_rows
|
|
1780
|
+
Number of rows in left table
|
|
1781
|
+
lg
|
|
1782
|
+
Left gather map
|
|
1783
|
+
left_policy
|
|
1784
|
+
Nullify policy for left map
|
|
1785
|
+
right_rows
|
|
1786
|
+
Number of rows in right table
|
|
1787
|
+
rg
|
|
1788
|
+
Right gather map
|
|
1789
|
+
right_policy
|
|
1790
|
+
Nullify policy for right map
|
|
1791
|
+
|
|
1792
|
+
Returns
|
|
1793
|
+
-------
|
|
1794
|
+
list of reordered left and right gather maps.
|
|
1795
|
+
|
|
1796
|
+
Notes
|
|
1797
|
+
-----
|
|
1798
|
+
For a left join, the polars result preserves the order of the
|
|
1799
|
+
left keys, and is stable wrt the right keys. For all other
|
|
1800
|
+
joins, there is no order obligation.
|
|
1801
|
+
"""
|
|
1802
|
+
init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE)
|
|
1803
|
+
step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE)
|
|
1804
|
+
left_order = plc.copying.gather(
|
|
1805
|
+
plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
|
|
1806
|
+
)
|
|
1807
|
+
right_order = plc.copying.gather(
|
|
1808
|
+
plc.Table([plc.filling.sequence(right_rows, init, step)]), rg, right_policy
|
|
1809
|
+
)
|
|
1810
|
+
return plc.sorting.stable_sort_by_key(
|
|
1811
|
+
plc.Table([lg, rg]),
|
|
1812
|
+
plc.Table([*left_order.columns(), *right_order.columns()]),
|
|
1813
|
+
[plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
|
|
1814
|
+
[plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
|
|
1815
|
+
).columns()
|
|
1816
|
+
|
|
1817
|
+
@staticmethod
|
|
1818
|
+
def _build_columns(
|
|
1819
|
+
columns: Iterable[plc.Column],
|
|
1820
|
+
template: Iterable[NamedColumn],
|
|
1821
|
+
*,
|
|
1822
|
+
left: bool = True,
|
|
1823
|
+
empty: bool = False,
|
|
1824
|
+
rename: Callable[[str], str] = lambda name: name,
|
|
1825
|
+
) -> list[Column]:
|
|
1826
|
+
if empty:
|
|
1827
|
+
return [
|
|
1828
|
+
Column(
|
|
1829
|
+
plc.column_factories.make_empty_column(col.dtype.plc),
|
|
1830
|
+
col.dtype,
|
|
1831
|
+
name=rename(col.name),
|
|
1832
|
+
)
|
|
1833
|
+
for col in template
|
|
1834
|
+
]
|
|
1835
|
+
|
|
1836
|
+
columns = [
|
|
1837
|
+
Column(new, col.dtype, name=rename(col.name))
|
|
1838
|
+
for new, col in zip(columns, template, strict=True)
|
|
1839
|
+
]
|
|
1840
|
+
|
|
1841
|
+
if left:
|
|
1842
|
+
columns = [
|
|
1843
|
+
col.sorted_like(orig)
|
|
1844
|
+
for col, orig in zip(columns, template, strict=True)
|
|
1845
|
+
]
|
|
1846
|
+
|
|
1847
|
+
return columns
|
|
1848
|
+
|
|
1849
|
+
@classmethod
|
|
1850
|
+
@nvtx_annotate_cudf_polars(message="Join")
|
|
1851
|
+
def do_evaluate(
|
|
1852
|
+
cls,
|
|
1853
|
+
left_on_exprs: Sequence[expr.NamedExpr],
|
|
1854
|
+
right_on_exprs: Sequence[expr.NamedExpr],
|
|
1855
|
+
options: tuple[
|
|
1856
|
+
Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
|
|
1857
|
+
bool,
|
|
1858
|
+
Zlice | None,
|
|
1859
|
+
str,
|
|
1860
|
+
bool,
|
|
1861
|
+
Literal["none", "left", "right", "left_right", "right_left"],
|
|
1862
|
+
],
|
|
1863
|
+
left: DataFrame,
|
|
1864
|
+
right: DataFrame,
|
|
1865
|
+
) -> DataFrame:
|
|
1866
|
+
"""Evaluate and return a dataframe."""
|
|
1867
|
+
how, nulls_equal, zlice, suffix, coalesce, _ = options
|
|
1868
|
+
if how == "Cross":
|
|
1869
|
+
# Separate implementation, since cross_join returns the
|
|
1870
|
+
# result, not the gather maps
|
|
1871
|
+
if right.num_rows == 0:
|
|
1872
|
+
left_cols = Join._build_columns([], left.columns, empty=True)
|
|
1873
|
+
right_cols = Join._build_columns(
|
|
1874
|
+
[],
|
|
1875
|
+
right.columns,
|
|
1876
|
+
left=False,
|
|
1877
|
+
empty=True,
|
|
1878
|
+
rename=lambda name: name
|
|
1879
|
+
if name not in left.column_names_set
|
|
1880
|
+
else f"{name}{suffix}",
|
|
1881
|
+
)
|
|
1882
|
+
return DataFrame([*left_cols, *right_cols])
|
|
1883
|
+
|
|
1884
|
+
columns = plc.join.cross_join(left.table, right.table).columns()
|
|
1885
|
+
left_cols = Join._build_columns(
|
|
1886
|
+
columns[: left.num_columns],
|
|
1887
|
+
left.columns,
|
|
1888
|
+
)
|
|
1889
|
+
right_cols = Join._build_columns(
|
|
1890
|
+
columns[left.num_columns :],
|
|
1891
|
+
right.columns,
|
|
1892
|
+
rename=lambda name: name
|
|
1893
|
+
if name not in left.column_names_set
|
|
1894
|
+
else f"{name}{suffix}",
|
|
1895
|
+
left=False,
|
|
1896
|
+
)
|
|
1897
|
+
return DataFrame([*left_cols, *right_cols]).slice(zlice)
|
|
1898
|
+
# TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
|
|
1899
|
+
left_on = DataFrame(broadcast(*(e.evaluate(left) for e in left_on_exprs)))
|
|
1900
|
+
right_on = DataFrame(broadcast(*(e.evaluate(right) for e in right_on_exprs)))
|
|
1901
|
+
null_equality = (
|
|
1902
|
+
plc.types.NullEquality.EQUAL
|
|
1903
|
+
if nulls_equal
|
|
1904
|
+
else plc.types.NullEquality.UNEQUAL
|
|
1905
|
+
)
|
|
1906
|
+
join_fn, left_policy, right_policy = cls._joiners(how)
|
|
1907
|
+
if right_policy is None:
|
|
1908
|
+
# Semi join
|
|
1909
|
+
lg = join_fn(left_on.table, right_on.table, null_equality)
|
|
1910
|
+
table = plc.copying.gather(left.table, lg, left_policy)
|
|
1911
|
+
result = DataFrame.from_table(table, left.column_names, left.dtypes)
|
|
1912
|
+
else:
|
|
1913
|
+
if how == "Right":
|
|
1914
|
+
# Right join is a left join with the tables swapped
|
|
1915
|
+
left, right = right, left
|
|
1916
|
+
left_on, right_on = right_on, left_on
|
|
1917
|
+
lg, rg = join_fn(left_on.table, right_on.table, null_equality)
|
|
1918
|
+
if how == "Left" or how == "Right":
|
|
1919
|
+
# Order of left table is preserved
|
|
1920
|
+
lg, rg = cls._reorder_maps(
|
|
1921
|
+
left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
|
|
1922
|
+
)
|
|
1923
|
+
if coalesce:
|
|
1924
|
+
if how == "Full":
|
|
1925
|
+
# In this case, keys must be column references,
|
|
1926
|
+
# possibly with dtype casting. We should use them in
|
|
1927
|
+
# preference to the columns from the original tables.
|
|
1928
|
+
left = left.with_columns(left_on.columns, replace_only=True)
|
|
1929
|
+
right = right.with_columns(right_on.columns, replace_only=True)
|
|
1930
|
+
else:
|
|
1931
|
+
right = right.discard_columns(right_on.column_names_set)
|
|
1932
|
+
left = DataFrame.from_table(
|
|
1933
|
+
plc.copying.gather(left.table, lg, left_policy),
|
|
1934
|
+
left.column_names,
|
|
1935
|
+
left.dtypes,
|
|
1936
|
+
)
|
|
1937
|
+
right = DataFrame.from_table(
|
|
1938
|
+
plc.copying.gather(right.table, rg, right_policy),
|
|
1939
|
+
right.column_names,
|
|
1940
|
+
right.dtypes,
|
|
1941
|
+
)
|
|
1942
|
+
if coalesce and how == "Full":
|
|
1943
|
+
left = left.with_columns(
|
|
1944
|
+
(
|
|
1945
|
+
Column(
|
|
1946
|
+
plc.replace.replace_nulls(left_col.obj, right_col.obj),
|
|
1947
|
+
name=left_col.name,
|
|
1948
|
+
dtype=left_col.dtype,
|
|
1949
|
+
)
|
|
1950
|
+
for left_col, right_col in zip(
|
|
1951
|
+
left.select_columns(left_on.column_names_set),
|
|
1952
|
+
right.select_columns(right_on.column_names_set),
|
|
1953
|
+
strict=True,
|
|
1954
|
+
)
|
|
1955
|
+
),
|
|
1956
|
+
replace_only=True,
|
|
1957
|
+
)
|
|
1958
|
+
right = right.discard_columns(right_on.column_names_set)
|
|
1959
|
+
if how == "Right":
|
|
1960
|
+
# Undo the swap for right join before gluing together.
|
|
1961
|
+
left, right = right, left
|
|
1962
|
+
right = right.rename_columns(
|
|
1963
|
+
{
|
|
1964
|
+
name: f"{name}{suffix}"
|
|
1965
|
+
for name in right.column_names
|
|
1966
|
+
if name in left.column_names_set
|
|
1967
|
+
}
|
|
1968
|
+
)
|
|
1969
|
+
result = left.with_columns(right.columns)
|
|
1970
|
+
return result.slice(zlice)
|
|
1971
|
+
|
|
1972
|
+
|
|
1973
|
+
class HStack(IR):
|
|
1974
|
+
"""Add new columns to a dataframe."""
|
|
1975
|
+
|
|
1976
|
+
__slots__ = ("columns", "should_broadcast")
|
|
1977
|
+
_non_child = ("schema", "columns", "should_broadcast")
|
|
1978
|
+
should_broadcast: bool
|
|
1979
|
+
"""Should the resulting evaluated columns be broadcast to the same length."""
|
|
1980
|
+
|
|
1981
|
+
def __init__(
|
|
1982
|
+
self,
|
|
1983
|
+
schema: Schema,
|
|
1984
|
+
columns: Sequence[expr.NamedExpr],
|
|
1985
|
+
should_broadcast: bool, # noqa: FBT001
|
|
1986
|
+
df: IR,
|
|
1987
|
+
):
|
|
1988
|
+
self.schema = schema
|
|
1989
|
+
self.columns = tuple(columns)
|
|
1990
|
+
self.should_broadcast = should_broadcast
|
|
1991
|
+
self._non_child_args = (self.columns, self.should_broadcast)
|
|
1992
|
+
self.children = (df,)
|
|
1993
|
+
|
|
1994
|
+
@classmethod
|
|
1995
|
+
@nvtx_annotate_cudf_polars(message="HStack")
|
|
1996
|
+
def do_evaluate(
|
|
1997
|
+
cls,
|
|
1998
|
+
exprs: Sequence[expr.NamedExpr],
|
|
1999
|
+
should_broadcast: bool, # noqa: FBT001
|
|
2000
|
+
df: DataFrame,
|
|
2001
|
+
) -> DataFrame:
|
|
2002
|
+
"""Evaluate and return a dataframe."""
|
|
2003
|
+
columns = [c.evaluate(df) for c in exprs]
|
|
2004
|
+
if should_broadcast:
|
|
2005
|
+
columns = broadcast(
|
|
2006
|
+
*columns, target_length=df.num_rows if df.num_columns != 0 else None
|
|
2007
|
+
)
|
|
2008
|
+
else:
|
|
2009
|
+
# Polars ensures this is true, but let's make sure nothing
|
|
2010
|
+
# went wrong. In this case, the parent node is a
|
|
2011
|
+
# guaranteed to be a Select which will take care of making
|
|
2012
|
+
# sure that everything is the same length. The result
|
|
2013
|
+
# table that might have mismatching column lengths will
|
|
2014
|
+
# never be turned into a pylibcudf Table with all columns
|
|
2015
|
+
# by the Select, which is why this is safe.
|
|
2016
|
+
assert all(e.name.startswith("__POLARS_CSER_0x") for e in exprs)
|
|
2017
|
+
return df.with_columns(columns)
|
|
2018
|
+
|
|
2019
|
+
|
|
2020
|
+
class Distinct(IR):
|
|
2021
|
+
"""Produce a new dataframe with distinct rows."""
|
|
2022
|
+
|
|
2023
|
+
__slots__ = ("keep", "stable", "subset", "zlice")
|
|
2024
|
+
_non_child = ("schema", "keep", "subset", "zlice", "stable")
|
|
2025
|
+
keep: plc.stream_compaction.DuplicateKeepOption
|
|
2026
|
+
"""Which distinct value to keep."""
|
|
2027
|
+
subset: frozenset[str] | None
|
|
2028
|
+
"""Which columns should be used to define distinctness. If None,
|
|
2029
|
+
then all columns are used."""
|
|
2030
|
+
zlice: Zlice | None
|
|
2031
|
+
"""Optional slice to apply to the result."""
|
|
2032
|
+
stable: bool
|
|
2033
|
+
"""Should the result maintain ordering."""
|
|
2034
|
+
|
|
2035
|
+
def __init__(
|
|
2036
|
+
self,
|
|
2037
|
+
schema: Schema,
|
|
2038
|
+
keep: plc.stream_compaction.DuplicateKeepOption,
|
|
2039
|
+
subset: frozenset[str] | None,
|
|
2040
|
+
zlice: Zlice | None,
|
|
2041
|
+
stable: bool, # noqa: FBT001
|
|
2042
|
+
df: IR,
|
|
2043
|
+
):
|
|
2044
|
+
self.schema = schema
|
|
2045
|
+
self.keep = keep
|
|
2046
|
+
self.subset = subset
|
|
2047
|
+
self.zlice = zlice
|
|
2048
|
+
self.stable = stable
|
|
2049
|
+
self._non_child_args = (keep, subset, zlice, stable)
|
|
2050
|
+
self.children = (df,)
|
|
2051
|
+
|
|
2052
|
+
_KEEP_MAP: ClassVar[dict[str, plc.stream_compaction.DuplicateKeepOption]] = {
|
|
2053
|
+
"first": plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
|
|
2054
|
+
"last": plc.stream_compaction.DuplicateKeepOption.KEEP_LAST,
|
|
2055
|
+
"none": plc.stream_compaction.DuplicateKeepOption.KEEP_NONE,
|
|
2056
|
+
"any": plc.stream_compaction.DuplicateKeepOption.KEEP_ANY,
|
|
2057
|
+
}
|
|
2058
|
+
|
|
2059
|
+
@classmethod
|
|
2060
|
+
@nvtx_annotate_cudf_polars(message="Distinct")
|
|
2061
|
+
def do_evaluate(
|
|
2062
|
+
cls,
|
|
2063
|
+
keep: plc.stream_compaction.DuplicateKeepOption,
|
|
2064
|
+
subset: frozenset[str] | None,
|
|
2065
|
+
zlice: Zlice | None,
|
|
2066
|
+
stable: bool, # noqa: FBT001
|
|
2067
|
+
df: DataFrame,
|
|
2068
|
+
) -> DataFrame:
|
|
2069
|
+
"""Evaluate and return a dataframe."""
|
|
2070
|
+
if subset is None:
|
|
2071
|
+
indices = list(range(df.num_columns))
|
|
2072
|
+
keys_sorted = all(c.is_sorted for c in df.column_map.values())
|
|
2073
|
+
else:
|
|
2074
|
+
indices = [i for i, k in enumerate(df.column_names) if k in subset]
|
|
2075
|
+
keys_sorted = all(df.column_map[name].is_sorted for name in subset)
|
|
2076
|
+
if keys_sorted:
|
|
2077
|
+
table = plc.stream_compaction.unique(
|
|
2078
|
+
df.table,
|
|
2079
|
+
indices,
|
|
2080
|
+
keep,
|
|
2081
|
+
plc.types.NullEquality.EQUAL,
|
|
2082
|
+
)
|
|
2083
|
+
else:
|
|
2084
|
+
distinct = (
|
|
2085
|
+
plc.stream_compaction.stable_distinct
|
|
2086
|
+
if stable
|
|
2087
|
+
else plc.stream_compaction.distinct
|
|
2088
|
+
)
|
|
2089
|
+
table = distinct(
|
|
2090
|
+
df.table,
|
|
2091
|
+
indices,
|
|
2092
|
+
keep,
|
|
2093
|
+
plc.types.NullEquality.EQUAL,
|
|
2094
|
+
plc.types.NanEquality.ALL_EQUAL,
|
|
2095
|
+
)
|
|
2096
|
+
# TODO: Is this sortedness setting correct
|
|
2097
|
+
result = DataFrame(
|
|
2098
|
+
[
|
|
2099
|
+
Column(new, name=old.name, dtype=old.dtype).sorted_like(old)
|
|
2100
|
+
for new, old in zip(table.columns(), df.columns, strict=True)
|
|
2101
|
+
]
|
|
2102
|
+
)
|
|
2103
|
+
if keys_sorted or stable:
|
|
2104
|
+
result = result.sorted_like(df)
|
|
2105
|
+
return result.slice(zlice)
|
|
2106
|
+
|
|
2107
|
+
|
|
2108
|
+
class Sort(IR):
|
|
2109
|
+
"""Sort a dataframe."""
|
|
2110
|
+
|
|
2111
|
+
__slots__ = ("by", "null_order", "order", "stable", "zlice")
|
|
2112
|
+
_non_child = ("schema", "by", "order", "null_order", "stable", "zlice")
|
|
2113
|
+
by: tuple[expr.NamedExpr, ...]
|
|
2114
|
+
"""Sort keys."""
|
|
2115
|
+
order: tuple[plc.types.Order, ...]
|
|
2116
|
+
"""Sort order for each sort key."""
|
|
2117
|
+
null_order: tuple[plc.types.NullOrder, ...]
|
|
2118
|
+
"""Null sorting location for each sort key."""
|
|
2119
|
+
stable: bool
|
|
2120
|
+
"""Should the sort be stable?"""
|
|
2121
|
+
zlice: Zlice | None
|
|
2122
|
+
"""Optional slice to apply to the result."""
|
|
2123
|
+
|
|
2124
|
+
def __init__(
|
|
2125
|
+
self,
|
|
2126
|
+
schema: Schema,
|
|
2127
|
+
by: Sequence[expr.NamedExpr],
|
|
2128
|
+
order: Sequence[plc.types.Order],
|
|
2129
|
+
null_order: Sequence[plc.types.NullOrder],
|
|
2130
|
+
stable: bool, # noqa: FBT001
|
|
2131
|
+
zlice: Zlice | None,
|
|
2132
|
+
df: IR,
|
|
2133
|
+
):
|
|
2134
|
+
self.schema = schema
|
|
2135
|
+
self.by = tuple(by)
|
|
2136
|
+
self.order = tuple(order)
|
|
2137
|
+
self.null_order = tuple(null_order)
|
|
2138
|
+
self.stable = stable
|
|
2139
|
+
self.zlice = zlice
|
|
2140
|
+
self._non_child_args = (
|
|
2141
|
+
self.by,
|
|
2142
|
+
self.order,
|
|
2143
|
+
self.null_order,
|
|
2144
|
+
self.stable,
|
|
2145
|
+
self.zlice,
|
|
2146
|
+
)
|
|
2147
|
+
self.children = (df,)
|
|
2148
|
+
|
|
2149
|
+
@classmethod
|
|
2150
|
+
@nvtx_annotate_cudf_polars(message="Sort")
|
|
2151
|
+
def do_evaluate(
|
|
2152
|
+
cls,
|
|
2153
|
+
by: Sequence[expr.NamedExpr],
|
|
2154
|
+
order: Sequence[plc.types.Order],
|
|
2155
|
+
null_order: Sequence[plc.types.NullOrder],
|
|
2156
|
+
stable: bool, # noqa: FBT001
|
|
2157
|
+
zlice: Zlice | None,
|
|
2158
|
+
df: DataFrame,
|
|
2159
|
+
) -> DataFrame:
|
|
2160
|
+
"""Evaluate and return a dataframe."""
|
|
2161
|
+
sort_keys = broadcast(*(k.evaluate(df) for k in by), target_length=df.num_rows)
|
|
2162
|
+
do_sort = plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
|
|
2163
|
+
table = do_sort(
|
|
2164
|
+
df.table,
|
|
2165
|
+
plc.Table([k.obj for k in sort_keys]),
|
|
2166
|
+
list(order),
|
|
2167
|
+
list(null_order),
|
|
2168
|
+
)
|
|
2169
|
+
result = DataFrame.from_table(table, df.column_names, df.dtypes)
|
|
2170
|
+
first_key = sort_keys[0]
|
|
2171
|
+
name = by[0].name
|
|
2172
|
+
first_key_in_result = (
|
|
2173
|
+
name in df.column_map and first_key.obj is df.column_map[name].obj
|
|
2174
|
+
)
|
|
2175
|
+
if first_key_in_result:
|
|
2176
|
+
result.column_map[name].set_sorted(
|
|
2177
|
+
is_sorted=plc.types.Sorted.YES, order=order[0], null_order=null_order[0]
|
|
2178
|
+
)
|
|
2179
|
+
return result.slice(zlice)
|
|
2180
|
+
|
|
2181
|
+
|
|
2182
|
+
class Slice(IR):
|
|
2183
|
+
"""Slice a dataframe."""
|
|
2184
|
+
|
|
2185
|
+
__slots__ = ("length", "offset")
|
|
2186
|
+
_non_child = ("schema", "offset", "length")
|
|
2187
|
+
offset: int
|
|
2188
|
+
"""Start of the slice."""
|
|
2189
|
+
length: int | None
|
|
2190
|
+
"""Length of the slice."""
|
|
2191
|
+
|
|
2192
|
+
def __init__(self, schema: Schema, offset: int, length: int | None, df: IR):
|
|
2193
|
+
self.schema = schema
|
|
2194
|
+
self.offset = offset
|
|
2195
|
+
self.length = length
|
|
2196
|
+
self._non_child_args = (offset, length)
|
|
2197
|
+
self.children = (df,)
|
|
2198
|
+
|
|
2199
|
+
@classmethod
|
|
2200
|
+
@nvtx_annotate_cudf_polars(message="Slice")
|
|
2201
|
+
def do_evaluate(cls, offset: int, length: int, df: DataFrame) -> DataFrame:
|
|
2202
|
+
"""Evaluate and return a dataframe."""
|
|
2203
|
+
return df.slice((offset, length))
|
|
2204
|
+
|
|
2205
|
+
|
|
2206
|
+
class Filter(IR):
|
|
2207
|
+
"""Filter a dataframe with a boolean mask."""
|
|
2208
|
+
|
|
2209
|
+
__slots__ = ("mask",)
|
|
2210
|
+
_non_child = ("schema", "mask")
|
|
2211
|
+
mask: expr.NamedExpr
|
|
2212
|
+
"""Expression to produce the filter mask."""
|
|
2213
|
+
|
|
2214
|
+
def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR):
|
|
2215
|
+
self.schema = schema
|
|
2216
|
+
self.mask = mask
|
|
2217
|
+
self._non_child_args = (mask,)
|
|
2218
|
+
self.children = (df,)
|
|
2219
|
+
|
|
2220
|
+
@classmethod
|
|
2221
|
+
@nvtx_annotate_cudf_polars(message="Filter")
|
|
2222
|
+
def do_evaluate(cls, mask_expr: expr.NamedExpr, df: DataFrame) -> DataFrame:
|
|
2223
|
+
"""Evaluate and return a dataframe."""
|
|
2224
|
+
(mask,) = broadcast(mask_expr.evaluate(df), target_length=df.num_rows)
|
|
2225
|
+
return df.filter(mask)
|
|
2226
|
+
|
|
2227
|
+
|
|
2228
|
+
class Projection(IR):
|
|
2229
|
+
"""Select a subset of columns from a dataframe."""
|
|
2230
|
+
|
|
2231
|
+
__slots__ = ()
|
|
2232
|
+
_non_child = ("schema",)
|
|
2233
|
+
|
|
2234
|
+
def __init__(self, schema: Schema, df: IR):
|
|
2235
|
+
self.schema = schema
|
|
2236
|
+
self._non_child_args = (schema,)
|
|
2237
|
+
self.children = (df,)
|
|
2238
|
+
|
|
2239
|
+
@classmethod
|
|
2240
|
+
@nvtx_annotate_cudf_polars(message="Projection")
|
|
2241
|
+
def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame:
|
|
2242
|
+
"""Evaluate and return a dataframe."""
|
|
2243
|
+
# This can reorder things.
|
|
2244
|
+
columns = broadcast(
|
|
2245
|
+
*(df.column_map[name] for name in schema), target_length=df.num_rows
|
|
2246
|
+
)
|
|
2247
|
+
return DataFrame(columns)
|
|
2248
|
+
|
|
2249
|
+
|
|
2250
|
+
class MergeSorted(IR):
|
|
2251
|
+
"""Merge sorted operation."""
|
|
2252
|
+
|
|
2253
|
+
__slots__ = ("key",)
|
|
2254
|
+
_non_child = ("schema", "key")
|
|
2255
|
+
key: str
|
|
2256
|
+
"""Key that is sorted."""
|
|
2257
|
+
|
|
2258
|
+
def __init__(self, schema: Schema, key: str, left: IR, right: IR):
|
|
2259
|
+
# Children must be Sort or Repartition(Sort).
|
|
2260
|
+
# The Repartition(Sort) case happens during fallback.
|
|
2261
|
+
left_sort_child = left if isinstance(left, Sort) else left.children[0]
|
|
2262
|
+
right_sort_child = right if isinstance(right, Sort) else right.children[0]
|
|
2263
|
+
assert isinstance(left_sort_child, Sort)
|
|
2264
|
+
assert isinstance(right_sort_child, Sort)
|
|
2265
|
+
assert left_sort_child.order == right_sort_child.order
|
|
2266
|
+
assert len(left.schema.keys()) <= len(right.schema.keys())
|
|
2267
|
+
self.schema = schema
|
|
2268
|
+
self.key = key
|
|
2269
|
+
self.children = (left, right)
|
|
2270
|
+
self._non_child_args = (key,)
|
|
2271
|
+
|
|
2272
|
+
@classmethod
|
|
2273
|
+
@nvtx_annotate_cudf_polars(message="MergeSorted")
|
|
2274
|
+
def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
|
|
2275
|
+
"""Evaluate and return a dataframe."""
|
|
2276
|
+
left, right = dfs
|
|
2277
|
+
right = right.discard_columns(right.column_names_set - left.column_names_set)
|
|
2278
|
+
on_col_left = left.select_columns({key})[0]
|
|
2279
|
+
on_col_right = right.select_columns({key})[0]
|
|
2280
|
+
return DataFrame.from_table(
|
|
2281
|
+
plc.merge.merge(
|
|
2282
|
+
[right.table, left.table],
|
|
2283
|
+
[left.column_names.index(key), right.column_names.index(key)],
|
|
2284
|
+
[on_col_left.order, on_col_right.order],
|
|
2285
|
+
[on_col_left.null_order, on_col_right.null_order],
|
|
2286
|
+
),
|
|
2287
|
+
left.column_names,
|
|
2288
|
+
left.dtypes,
|
|
2289
|
+
)
|
|
2290
|
+
|
|
2291
|
+
|
|
2292
|
+
class MapFunction(IR):
|
|
2293
|
+
"""Apply some function to a dataframe."""
|
|
2294
|
+
|
|
2295
|
+
__slots__ = ("name", "options")
|
|
2296
|
+
_non_child = ("schema", "name", "options")
|
|
2297
|
+
name: str
|
|
2298
|
+
"""Name of the function to apply"""
|
|
2299
|
+
options: Any
|
|
2300
|
+
"""Arbitrary name-specific options"""
|
|
2301
|
+
|
|
2302
|
+
_NAMES: ClassVar[frozenset[str]] = frozenset(
|
|
2303
|
+
[
|
|
2304
|
+
"rechunk",
|
|
2305
|
+
"rename",
|
|
2306
|
+
"explode",
|
|
2307
|
+
"unpivot",
|
|
2308
|
+
"row_index",
|
|
2309
|
+
"fast_count",
|
|
2310
|
+
]
|
|
2311
|
+
)
|
|
2312
|
+
|
|
2313
|
+
def __init__(self, schema: Schema, name: str, options: Any, df: IR):
|
|
2314
|
+
self.schema = schema
|
|
2315
|
+
self.name = name
|
|
2316
|
+
self.options = options
|
|
2317
|
+
self.children = (df,)
|
|
2318
|
+
if (
|
|
2319
|
+
self.name not in MapFunction._NAMES
|
|
2320
|
+
): # pragma: no cover; need more polars rust functions
|
|
2321
|
+
raise NotImplementedError(
|
|
2322
|
+
f"Unhandled map function {self.name}"
|
|
2323
|
+
) # pragma: no cover
|
|
2324
|
+
if self.name == "explode":
|
|
2325
|
+
(to_explode,) = self.options
|
|
2326
|
+
if len(to_explode) > 1:
|
|
2327
|
+
# TODO: straightforward, but need to error check
|
|
2328
|
+
# polars requires that all to-explode columns have the
|
|
2329
|
+
# same sub-shapes
|
|
2330
|
+
raise NotImplementedError("Explode with more than one column")
|
|
2331
|
+
self.options = (tuple(to_explode),)
|
|
2332
|
+
elif POLARS_VERSION_LT_131 and self.name == "rename": # pragma: no cover
|
|
2333
|
+
# As of 1.31, polars validates renaming in the IR
|
|
2334
|
+
old, new, strict = self.options
|
|
2335
|
+
if len(new) != len(set(new)) or (
|
|
2336
|
+
set(new) & (set(df.schema.keys()) - set(old))
|
|
2337
|
+
):
|
|
2338
|
+
raise NotImplementedError(
|
|
2339
|
+
"Duplicate new names in rename."
|
|
2340
|
+
) # pragma: no cover
|
|
2341
|
+
self.options = (tuple(old), tuple(new), strict)
|
|
2342
|
+
elif self.name == "unpivot":
|
|
2343
|
+
indices, pivotees, variable_name, value_name = self.options
|
|
2344
|
+
value_name = "value" if value_name is None else value_name
|
|
2345
|
+
variable_name = "variable" if variable_name is None else variable_name
|
|
2346
|
+
if len(pivotees) == 0:
|
|
2347
|
+
index = frozenset(indices)
|
|
2348
|
+
pivotees = [name for name in df.schema if name not in index]
|
|
2349
|
+
if not all(
|
|
2350
|
+
dtypes.can_cast(df.schema[p].plc, self.schema[value_name].plc)
|
|
2351
|
+
for p in pivotees
|
|
2352
|
+
):
|
|
2353
|
+
raise NotImplementedError(
|
|
2354
|
+
"Unpivot cannot cast all input columns to "
|
|
2355
|
+
f"{self.schema[value_name].id()}"
|
|
2356
|
+
) # pragma: no cover
|
|
2357
|
+
self.options = (
|
|
2358
|
+
tuple(indices),
|
|
2359
|
+
tuple(pivotees),
|
|
2360
|
+
variable_name,
|
|
2361
|
+
value_name,
|
|
2362
|
+
)
|
|
2363
|
+
elif self.name == "row_index":
|
|
2364
|
+
col_name, offset = options
|
|
2365
|
+
self.options = (col_name, offset)
|
|
2366
|
+
elif self.name == "fast_count":
|
|
2367
|
+
# TODO: Remove this once all scan types support projections
|
|
2368
|
+
# using Select + Len. Currently, CSV is the only format that
|
|
2369
|
+
# uses the legacy MapFunction(FastCount) path because it is
|
|
2370
|
+
# faster than the new-streaming path for large files.
|
|
2371
|
+
# See https://github.com/pola-rs/polars/pull/22363#issue-3010224808
|
|
2372
|
+
raise NotImplementedError(
|
|
2373
|
+
"Fast count unsupported for CSV scans"
|
|
2374
|
+
) # pragma: no cover
|
|
2375
|
+
self._non_child_args = (schema, name, self.options)
|
|
2376
|
+
|
|
2377
|
+
def get_hashable(self) -> Hashable:
|
|
2378
|
+
"""
|
|
2379
|
+
Hashable representation of the node.
|
|
2380
|
+
|
|
2381
|
+
The options dictionaries are serialised for hashing purposes
|
|
2382
|
+
as json strings.
|
|
2383
|
+
"""
|
|
2384
|
+
return (
|
|
2385
|
+
type(self),
|
|
2386
|
+
self.name,
|
|
2387
|
+
json.dumps(self.options),
|
|
2388
|
+
tuple(self.schema.items()),
|
|
2389
|
+
self._ctor_arguments(self.children)[1:],
|
|
2390
|
+
)
|
|
2391
|
+
|
|
2392
|
+
@classmethod
|
|
2393
|
+
@nvtx_annotate_cudf_polars(message="MapFunction")
|
|
2394
|
+
def do_evaluate(
|
|
2395
|
+
cls, schema: Schema, name: str, options: Any, df: DataFrame
|
|
2396
|
+
) -> DataFrame:
|
|
2397
|
+
"""Evaluate and return a dataframe."""
|
|
2398
|
+
if name == "rechunk":
|
|
2399
|
+
# No-op in our data model
|
|
2400
|
+
# Don't think this appears in a plan tree from python
|
|
2401
|
+
return df # pragma: no cover
|
|
2402
|
+
elif POLARS_VERSION_LT_131 and name == "rename": # pragma: no cover
|
|
2403
|
+
# final tag is "swapping" which is useful for the
|
|
2404
|
+
# optimiser (it blocks some pushdown operations)
|
|
2405
|
+
old, new, _ = options
|
|
2406
|
+
return df.rename_columns(dict(zip(old, new, strict=True)))
|
|
2407
|
+
elif name == "explode":
|
|
2408
|
+
((to_explode,),) = options
|
|
2409
|
+
index = df.column_names.index(to_explode)
|
|
2410
|
+
subset = df.column_names_set - {to_explode}
|
|
2411
|
+
return DataFrame.from_table(
|
|
2412
|
+
plc.lists.explode_outer(df.table, index), df.column_names, df.dtypes
|
|
2413
|
+
).sorted_like(df, subset=subset)
|
|
2414
|
+
elif name == "unpivot":
|
|
2415
|
+
(
|
|
2416
|
+
indices,
|
|
2417
|
+
pivotees,
|
|
2418
|
+
variable_name,
|
|
2419
|
+
value_name,
|
|
2420
|
+
) = options
|
|
2421
|
+
npiv = len(pivotees)
|
|
2422
|
+
selected = df.select(indices)
|
|
2423
|
+
index_columns = [
|
|
2424
|
+
Column(tiled, name=name, dtype=old.dtype)
|
|
2425
|
+
for tiled, name, old in zip(
|
|
2426
|
+
plc.reshape.tile(selected.table, npiv).columns(),
|
|
2427
|
+
indices,
|
|
2428
|
+
selected.columns,
|
|
2429
|
+
strict=True,
|
|
2430
|
+
)
|
|
2431
|
+
]
|
|
2432
|
+
(variable_column,) = plc.filling.repeat(
|
|
2433
|
+
plc.Table(
|
|
2434
|
+
[
|
|
2435
|
+
plc.Column.from_arrow(
|
|
2436
|
+
pl.Series(
|
|
2437
|
+
values=pivotees, dtype=schema[variable_name].polars
|
|
2438
|
+
)
|
|
2439
|
+
)
|
|
2440
|
+
]
|
|
2441
|
+
),
|
|
2442
|
+
df.num_rows,
|
|
2443
|
+
).columns()
|
|
2444
|
+
value_column = plc.concatenate.concatenate(
|
|
2445
|
+
[
|
|
2446
|
+
df.column_map[pivotee].astype(schema[value_name]).obj
|
|
2447
|
+
for pivotee in pivotees
|
|
2448
|
+
]
|
|
2449
|
+
)
|
|
2450
|
+
return DataFrame(
|
|
2451
|
+
[
|
|
2452
|
+
*index_columns,
|
|
2453
|
+
Column(
|
|
2454
|
+
variable_column, name=variable_name, dtype=schema[variable_name]
|
|
2455
|
+
),
|
|
2456
|
+
Column(value_column, name=value_name, dtype=schema[value_name]),
|
|
2457
|
+
]
|
|
2458
|
+
)
|
|
2459
|
+
elif name == "row_index":
|
|
2460
|
+
col_name, offset = options
|
|
2461
|
+
dtype = schema[col_name]
|
|
2462
|
+
step = plc.Scalar.from_py(1, dtype.plc)
|
|
2463
|
+
init = plc.Scalar.from_py(offset, dtype.plc)
|
|
2464
|
+
index_col = Column(
|
|
2465
|
+
plc.filling.sequence(df.num_rows, init, step),
|
|
2466
|
+
is_sorted=plc.types.Sorted.YES,
|
|
2467
|
+
order=plc.types.Order.ASCENDING,
|
|
2468
|
+
null_order=plc.types.NullOrder.AFTER,
|
|
2469
|
+
name=col_name,
|
|
2470
|
+
dtype=dtype,
|
|
2471
|
+
)
|
|
2472
|
+
return DataFrame([index_col, *df.columns])
|
|
2473
|
+
else:
|
|
2474
|
+
raise AssertionError("Should never be reached") # pragma: no cover
|
|
2475
|
+
|
|
2476
|
+
|
|
2477
|
+
class Union(IR):
|
|
2478
|
+
"""Concatenate dataframes vertically."""
|
|
2479
|
+
|
|
2480
|
+
__slots__ = ("zlice",)
|
|
2481
|
+
_non_child = ("schema", "zlice")
|
|
2482
|
+
zlice: Zlice | None
|
|
2483
|
+
"""Optional slice to apply to the result."""
|
|
2484
|
+
|
|
2485
|
+
def __init__(self, schema: Schema, zlice: Zlice | None, *children: IR):
|
|
2486
|
+
self.schema = schema
|
|
2487
|
+
self.zlice = zlice
|
|
2488
|
+
self._non_child_args = (zlice,)
|
|
2489
|
+
self.children = children
|
|
2490
|
+
schema = self.children[0].schema
|
|
2491
|
+
|
|
2492
|
+
@classmethod
|
|
2493
|
+
@nvtx_annotate_cudf_polars(message="Union")
|
|
2494
|
+
def do_evaluate(cls, zlice: Zlice | None, *dfs: DataFrame) -> DataFrame:
|
|
2495
|
+
"""Evaluate and return a dataframe."""
|
|
2496
|
+
# TODO: only evaluate what we need if we have a slice?
|
|
2497
|
+
return DataFrame.from_table(
|
|
2498
|
+
plc.concatenate.concatenate([df.table for df in dfs]),
|
|
2499
|
+
dfs[0].column_names,
|
|
2500
|
+
dfs[0].dtypes,
|
|
2501
|
+
).slice(zlice)
|
|
2502
|
+
|
|
2503
|
+
|
|
2504
|
+
class HConcat(IR):
|
|
2505
|
+
"""Concatenate dataframes horizontally."""
|
|
2506
|
+
|
|
2507
|
+
__slots__ = ("should_broadcast",)
|
|
2508
|
+
_non_child = ("schema", "should_broadcast")
|
|
2509
|
+
|
|
2510
|
+
def __init__(
|
|
2511
|
+
self,
|
|
2512
|
+
schema: Schema,
|
|
2513
|
+
should_broadcast: bool, # noqa: FBT001
|
|
2514
|
+
*children: IR,
|
|
2515
|
+
):
|
|
2516
|
+
self.schema = schema
|
|
2517
|
+
self.should_broadcast = should_broadcast
|
|
2518
|
+
self._non_child_args = (should_broadcast,)
|
|
2519
|
+
self.children = children
|
|
2520
|
+
|
|
2521
|
+
@staticmethod
|
|
2522
|
+
def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table:
|
|
2523
|
+
"""
|
|
2524
|
+
Extend a table with nulls.
|
|
2525
|
+
|
|
2526
|
+
Parameters
|
|
2527
|
+
----------
|
|
2528
|
+
table
|
|
2529
|
+
Table to extend
|
|
2530
|
+
nrows
|
|
2531
|
+
Number of additional rows
|
|
2532
|
+
|
|
2533
|
+
Returns
|
|
2534
|
+
-------
|
|
2535
|
+
New pylibcudf table.
|
|
2536
|
+
"""
|
|
2537
|
+
return plc.concatenate.concatenate(
|
|
2538
|
+
[
|
|
2539
|
+
table,
|
|
2540
|
+
plc.Table(
|
|
2541
|
+
[
|
|
2542
|
+
plc.Column.all_null_like(column, nrows)
|
|
2543
|
+
for column in table.columns()
|
|
2544
|
+
]
|
|
2545
|
+
),
|
|
2546
|
+
]
|
|
2547
|
+
)
|
|
2548
|
+
|
|
2549
|
+
@classmethod
|
|
2550
|
+
@nvtx_annotate_cudf_polars(message="HConcat")
|
|
2551
|
+
def do_evaluate(
|
|
2552
|
+
cls,
|
|
2553
|
+
should_broadcast: bool, # noqa: FBT001
|
|
2554
|
+
*dfs: DataFrame,
|
|
2555
|
+
) -> DataFrame:
|
|
2556
|
+
"""Evaluate and return a dataframe."""
|
|
2557
|
+
# Special should_broadcast case.
|
|
2558
|
+
# Used to recombine decomposed expressions
|
|
2559
|
+
if should_broadcast:
|
|
2560
|
+
return DataFrame(
|
|
2561
|
+
broadcast(*itertools.chain.from_iterable(df.columns for df in dfs))
|
|
2562
|
+
)
|
|
2563
|
+
|
|
2564
|
+
max_rows = max(df.num_rows for df in dfs)
|
|
2565
|
+
# Horizontal concatenation extends shorter tables with nulls
|
|
2566
|
+
return DataFrame(
|
|
2567
|
+
itertools.chain.from_iterable(
|
|
2568
|
+
df.columns
|
|
2569
|
+
for df in (
|
|
2570
|
+
df
|
|
2571
|
+
if df.num_rows == max_rows
|
|
2572
|
+
else DataFrame.from_table(
|
|
2573
|
+
cls._extend_with_nulls(df.table, nrows=max_rows - df.num_rows),
|
|
2574
|
+
df.column_names,
|
|
2575
|
+
df.dtypes,
|
|
2576
|
+
)
|
|
2577
|
+
for df in dfs
|
|
2578
|
+
)
|
|
2579
|
+
)
|
|
2580
|
+
)
|
|
2581
|
+
|
|
2582
|
+
|
|
2583
|
+
class Empty(IR):
|
|
2584
|
+
"""Represents an empty DataFrame with a known schema."""
|
|
2585
|
+
|
|
2586
|
+
__slots__ = ("schema",)
|
|
2587
|
+
_non_child = ("schema",)
|
|
2588
|
+
|
|
2589
|
+
def __init__(self, schema: Schema):
|
|
2590
|
+
self.schema = schema
|
|
2591
|
+
self._non_child_args = (schema,)
|
|
2592
|
+
self.children = ()
|
|
2593
|
+
|
|
2594
|
+
@classmethod
|
|
2595
|
+
@nvtx_annotate_cudf_polars(message="Empty")
|
|
2596
|
+
def do_evaluate(cls, schema: Schema) -> DataFrame: # pragma: no cover
|
|
2597
|
+
"""Evaluate and return a dataframe."""
|
|
2598
|
+
return DataFrame(
|
|
2599
|
+
[
|
|
2600
|
+
Column(
|
|
2601
|
+
plc.column_factories.make_empty_column(dtype.plc),
|
|
2602
|
+
dtype=dtype,
|
|
2603
|
+
name=name,
|
|
2604
|
+
)
|
|
2605
|
+
for name, dtype in schema.items()
|
|
2606
|
+
]
|
|
2607
|
+
)
|