cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +82 -65
- cudf_polars/containers/column.py +138 -7
- cudf_polars/containers/dataframe.py +26 -39
- cudf_polars/dsl/expr.py +3 -1
- cudf_polars/dsl/expressions/aggregation.py +27 -63
- cudf_polars/dsl/expressions/base.py +40 -72
- cudf_polars/dsl/expressions/binaryop.py +5 -41
- cudf_polars/dsl/expressions/boolean.py +25 -53
- cudf_polars/dsl/expressions/datetime.py +97 -17
- cudf_polars/dsl/expressions/literal.py +27 -33
- cudf_polars/dsl/expressions/rolling.py +110 -9
- cudf_polars/dsl/expressions/selection.py +8 -26
- cudf_polars/dsl/expressions/slicing.py +47 -0
- cudf_polars/dsl/expressions/sorting.py +5 -18
- cudf_polars/dsl/expressions/string.py +33 -36
- cudf_polars/dsl/expressions/ternary.py +3 -10
- cudf_polars/dsl/expressions/unary.py +35 -75
- cudf_polars/dsl/ir.py +749 -212
- cudf_polars/dsl/nodebase.py +8 -1
- cudf_polars/dsl/to_ast.py +5 -3
- cudf_polars/dsl/translate.py +319 -171
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +292 -0
- cudf_polars/dsl/utils/groupby.py +97 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +46 -0
- cudf_polars/dsl/utils/rolling.py +113 -0
- cudf_polars/dsl/utils/windows.py +186 -0
- cudf_polars/experimental/base.py +17 -19
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsh.py +1279 -0
- cudf_polars/experimental/dask_registers.py +196 -0
- cudf_polars/experimental/distinct.py +174 -0
- cudf_polars/experimental/explain.py +127 -0
- cudf_polars/experimental/expressions.py +521 -0
- cudf_polars/experimental/groupby.py +288 -0
- cudf_polars/experimental/io.py +58 -29
- cudf_polars/experimental/join.py +353 -0
- cudf_polars/experimental/parallel.py +166 -93
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +92 -7
- cudf_polars/experimental/shuffle.py +294 -0
- cudf_polars/experimental/sort.py +45 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/utils.py +100 -0
- cudf_polars/testing/asserts.py +146 -6
- cudf_polars/testing/io.py +72 -0
- cudf_polars/testing/plugin.py +78 -76
- cudf_polars/typing/__init__.py +59 -6
- cudf_polars/utils/config.py +353 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +22 -5
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +5 -4
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/METADATA +10 -7
- cudf_polars_cu12-25.6.0.dist-info/RECORD +73 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/WHEEL +1 -1
- cudf_polars/experimental/dask_serialize.py +0 -59
- cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info/licenses}/LICENSE +0 -0
- {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/ir.py
CHANGED
|
@@ -15,6 +15,8 @@ from __future__ import annotations
|
|
|
15
15
|
|
|
16
16
|
import itertools
|
|
17
17
|
import json
|
|
18
|
+
import random
|
|
19
|
+
import time
|
|
18
20
|
from functools import cache
|
|
19
21
|
from pathlib import Path
|
|
20
22
|
from typing import TYPE_CHECKING, Any, ClassVar
|
|
@@ -28,17 +30,25 @@ import pylibcudf as plc
|
|
|
28
30
|
|
|
29
31
|
import cudf_polars.dsl.expr as expr
|
|
30
32
|
from cudf_polars.containers import Column, DataFrame
|
|
33
|
+
from cudf_polars.dsl.expressions import rolling
|
|
34
|
+
from cudf_polars.dsl.expressions.base import ExecutionContext
|
|
31
35
|
from cudf_polars.dsl.nodebase import Node
|
|
32
36
|
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
|
|
37
|
+
from cudf_polars.dsl.utils.windows import range_window_bounds
|
|
33
38
|
from cudf_polars.utils import dtypes
|
|
39
|
+
from cudf_polars.utils.versions import POLARS_VERSION_LT_128
|
|
34
40
|
|
|
35
41
|
if TYPE_CHECKING:
|
|
36
|
-
from collections.abc import Callable, Hashable, Iterable,
|
|
42
|
+
from collections.abc import Callable, Hashable, Iterable, Sequence
|
|
37
43
|
from typing import Literal
|
|
38
44
|
|
|
45
|
+
from typing_extensions import Self
|
|
46
|
+
|
|
39
47
|
from polars.polars import _expr_nodes as pl_expr
|
|
40
48
|
|
|
41
|
-
from cudf_polars.typing import Schema
|
|
49
|
+
from cudf_polars.typing import CSECache, ClosedInterval, Schema, Slice as Zlice
|
|
50
|
+
from cudf_polars.utils.config import ConfigOptions
|
|
51
|
+
from cudf_polars.utils.timer import Timer
|
|
42
52
|
|
|
43
53
|
|
|
44
54
|
__all__ = [
|
|
@@ -47,6 +57,7 @@ __all__ = [
|
|
|
47
57
|
"ConditionalJoin",
|
|
48
58
|
"DataFrameScan",
|
|
49
59
|
"Distinct",
|
|
60
|
+
"Empty",
|
|
50
61
|
"ErrorNode",
|
|
51
62
|
"Filter",
|
|
52
63
|
"GroupBy",
|
|
@@ -54,10 +65,14 @@ __all__ = [
|
|
|
54
65
|
"HStack",
|
|
55
66
|
"Join",
|
|
56
67
|
"MapFunction",
|
|
68
|
+
"MergeSorted",
|
|
57
69
|
"Projection",
|
|
58
70
|
"PythonScan",
|
|
71
|
+
"Reduce",
|
|
72
|
+
"Rolling",
|
|
59
73
|
"Scan",
|
|
60
74
|
"Select",
|
|
75
|
+
"Sink",
|
|
61
76
|
"Slice",
|
|
62
77
|
"Sort",
|
|
63
78
|
"Union",
|
|
@@ -100,7 +115,7 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
|
|
|
100
115
|
"""
|
|
101
116
|
if len(columns) == 0:
|
|
102
117
|
return []
|
|
103
|
-
lengths: set[int] = {column.
|
|
118
|
+
lengths: set[int] = {column.size for column in columns}
|
|
104
119
|
if lengths == {1}:
|
|
105
120
|
if target_length is None:
|
|
106
121
|
return list(columns)
|
|
@@ -116,7 +131,7 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
|
|
|
116
131
|
)
|
|
117
132
|
return [
|
|
118
133
|
column
|
|
119
|
-
if column.
|
|
134
|
+
if column.size != 1
|
|
120
135
|
else Column(
|
|
121
136
|
plc.Column.from_scalar(column.obj_scalar, nrows),
|
|
122
137
|
is_sorted=plc.types.Sorted.YES,
|
|
@@ -181,7 +196,7 @@ class IR(Node["IR"]):
|
|
|
181
196
|
translation phase should fail earlier.
|
|
182
197
|
"""
|
|
183
198
|
|
|
184
|
-
def evaluate(self, *, cache:
|
|
199
|
+
def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
|
|
185
200
|
"""
|
|
186
201
|
Evaluate the node (recursively) and return a dataframe.
|
|
187
202
|
|
|
@@ -190,6 +205,9 @@ class IR(Node["IR"]):
|
|
|
190
205
|
cache
|
|
191
206
|
Mapping from cached node ids to constructed DataFrames.
|
|
192
207
|
Used to implement evaluation of the `Cache` node.
|
|
208
|
+
timer
|
|
209
|
+
If not None, a Timer object to record timings for the
|
|
210
|
+
evaluation of the node.
|
|
193
211
|
|
|
194
212
|
Notes
|
|
195
213
|
-----
|
|
@@ -208,10 +226,16 @@ class IR(Node["IR"]):
|
|
|
208
226
|
If evaluation fails. Ideally this should not occur, since the
|
|
209
227
|
translation phase should fail earlier.
|
|
210
228
|
"""
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
229
|
+
children = [child.evaluate(cache=cache, timer=timer) for child in self.children]
|
|
230
|
+
if timer is not None:
|
|
231
|
+
start = time.monotonic_ns()
|
|
232
|
+
result = self.do_evaluate(*self._non_child_args, *children)
|
|
233
|
+
end = time.monotonic_ns()
|
|
234
|
+
# TODO: Set better names on each class object.
|
|
235
|
+
timer.store(start, end, type(self).__name__)
|
|
236
|
+
return result
|
|
237
|
+
else:
|
|
238
|
+
return self.do_evaluate(*self._non_child_args, *children)
|
|
215
239
|
|
|
216
240
|
|
|
217
241
|
class ErrorNode(IR):
|
|
@@ -256,6 +280,7 @@ class Scan(IR):
|
|
|
256
280
|
__slots__ = (
|
|
257
281
|
"cloud_options",
|
|
258
282
|
"config_options",
|
|
283
|
+
"include_file_paths",
|
|
259
284
|
"n_rows",
|
|
260
285
|
"paths",
|
|
261
286
|
"predicate",
|
|
@@ -276,6 +301,7 @@ class Scan(IR):
|
|
|
276
301
|
"skip_rows",
|
|
277
302
|
"n_rows",
|
|
278
303
|
"row_index",
|
|
304
|
+
"include_file_paths",
|
|
279
305
|
"predicate",
|
|
280
306
|
)
|
|
281
307
|
typ: str
|
|
@@ -284,7 +310,7 @@ class Scan(IR):
|
|
|
284
310
|
"""Reader-specific options, as dictionary."""
|
|
285
311
|
cloud_options: dict[str, Any] | None
|
|
286
312
|
"""Cloud-related authentication options, currently ignored."""
|
|
287
|
-
config_options:
|
|
313
|
+
config_options: ConfigOptions
|
|
288
314
|
"""GPU-specific configuration options"""
|
|
289
315
|
paths: list[str]
|
|
290
316
|
"""List of paths to read from."""
|
|
@@ -296,6 +322,8 @@ class Scan(IR):
|
|
|
296
322
|
"""Number of rows to read after skipping."""
|
|
297
323
|
row_index: tuple[str, int] | None
|
|
298
324
|
"""If not None add an integer index column of the given name."""
|
|
325
|
+
include_file_paths: str | None
|
|
326
|
+
"""Include the path of the source file(s) as a column with this name."""
|
|
299
327
|
predicate: expr.NamedExpr | None
|
|
300
328
|
"""Mask to apply to the read dataframe."""
|
|
301
329
|
|
|
@@ -308,12 +336,13 @@ class Scan(IR):
|
|
|
308
336
|
typ: str,
|
|
309
337
|
reader_options: dict[str, Any],
|
|
310
338
|
cloud_options: dict[str, Any] | None,
|
|
311
|
-
config_options:
|
|
339
|
+
config_options: ConfigOptions,
|
|
312
340
|
paths: list[str],
|
|
313
341
|
with_columns: list[str] | None,
|
|
314
342
|
skip_rows: int,
|
|
315
343
|
n_rows: int,
|
|
316
344
|
row_index: tuple[str, int] | None,
|
|
345
|
+
include_file_paths: str | None,
|
|
317
346
|
predicate: expr.NamedExpr | None,
|
|
318
347
|
):
|
|
319
348
|
self.schema = schema
|
|
@@ -326,6 +355,7 @@ class Scan(IR):
|
|
|
326
355
|
self.skip_rows = skip_rows
|
|
327
356
|
self.n_rows = n_rows
|
|
328
357
|
self.row_index = row_index
|
|
358
|
+
self.include_file_paths = include_file_paths
|
|
329
359
|
self.predicate = predicate
|
|
330
360
|
self._non_child_args = (
|
|
331
361
|
schema,
|
|
@@ -337,6 +367,7 @@ class Scan(IR):
|
|
|
337
367
|
skip_rows,
|
|
338
368
|
n_rows,
|
|
339
369
|
row_index,
|
|
370
|
+
include_file_paths,
|
|
340
371
|
predicate,
|
|
341
372
|
)
|
|
342
373
|
self.children = ()
|
|
@@ -350,7 +381,9 @@ class Scan(IR):
|
|
|
350
381
|
# TODO: polars has this implemented for parquet,
|
|
351
382
|
# maybe we can do this too?
|
|
352
383
|
raise NotImplementedError("slice pushdown for negative slices")
|
|
353
|
-
if
|
|
384
|
+
if (
|
|
385
|
+
POLARS_VERSION_LT_128 and self.typ in {"csv"} and self.skip_rows != 0
|
|
386
|
+
): # pragma: no cover
|
|
354
387
|
# This comes from slice pushdown, but that
|
|
355
388
|
# optimization doesn't happen right now
|
|
356
389
|
raise NotImplementedError("skipping rows in CSV reader")
|
|
@@ -360,7 +393,7 @@ class Scan(IR):
|
|
|
360
393
|
raise NotImplementedError(
|
|
361
394
|
"Read from cloud storage"
|
|
362
395
|
) # pragma: no cover; no test yet
|
|
363
|
-
if any(p.startswith("https
|
|
396
|
+
if any(str(p).startswith("https:/") for p in self.paths):
|
|
364
397
|
raise NotImplementedError("Read from https")
|
|
365
398
|
if self.typ == "csv":
|
|
366
399
|
if self.reader_options["skip_rows_after_header"] != 0:
|
|
@@ -379,9 +412,18 @@ class Scan(IR):
|
|
|
379
412
|
"Multi-character comment prefix not supported for CSV reader"
|
|
380
413
|
)
|
|
381
414
|
if not self.reader_options["has_header"]:
|
|
382
|
-
#
|
|
383
|
-
#
|
|
384
|
-
|
|
415
|
+
# TODO: To support reading headerless CSV files without requiring new
|
|
416
|
+
# column names, we would need to do file introspection to infer the number
|
|
417
|
+
# of columns so column projection works right.
|
|
418
|
+
reader_schema = self.reader_options.get("schema")
|
|
419
|
+
if not (
|
|
420
|
+
reader_schema
|
|
421
|
+
and isinstance(schema, dict)
|
|
422
|
+
and "fields" in reader_schema
|
|
423
|
+
):
|
|
424
|
+
raise NotImplementedError(
|
|
425
|
+
"Reading CSV without header requires user-provided column names via new_columns"
|
|
426
|
+
)
|
|
385
427
|
elif self.typ == "ndjson":
|
|
386
428
|
# TODO: consider handling the low memory option here
|
|
387
429
|
# (maybe use chunked JSON reader)
|
|
@@ -389,6 +431,9 @@ class Scan(IR):
|
|
|
389
431
|
raise NotImplementedError(
|
|
390
432
|
"ignore_errors is not supported in the JSON reader"
|
|
391
433
|
)
|
|
434
|
+
if include_file_paths is not None:
|
|
435
|
+
# TODO: Need to populate num_rows_per_source in read_json in libcudf
|
|
436
|
+
raise NotImplementedError("Including file paths in a json scan.")
|
|
392
437
|
elif (
|
|
393
438
|
self.typ == "parquet"
|
|
394
439
|
and self.row_index is not None
|
|
@@ -413,31 +458,60 @@ class Scan(IR):
|
|
|
413
458
|
self.typ,
|
|
414
459
|
json.dumps(self.reader_options),
|
|
415
460
|
json.dumps(self.cloud_options),
|
|
416
|
-
|
|
461
|
+
self.config_options,
|
|
417
462
|
tuple(self.paths),
|
|
418
463
|
tuple(self.with_columns) if self.with_columns is not None else None,
|
|
419
464
|
self.skip_rows,
|
|
420
465
|
self.n_rows,
|
|
421
466
|
self.row_index,
|
|
467
|
+
self.include_file_paths,
|
|
422
468
|
self.predicate,
|
|
423
469
|
)
|
|
424
470
|
|
|
471
|
+
@staticmethod
|
|
472
|
+
def add_file_paths(
|
|
473
|
+
name: str, paths: list[str], rows_per_path: list[int], df: DataFrame
|
|
474
|
+
) -> DataFrame:
|
|
475
|
+
"""
|
|
476
|
+
Add a Column of file paths to the DataFrame.
|
|
477
|
+
|
|
478
|
+
Each path is repeated according to the number of rows read from it.
|
|
479
|
+
"""
|
|
480
|
+
(filepaths,) = plc.filling.repeat(
|
|
481
|
+
# TODO: Remove call from_arrow when we support python list to Column
|
|
482
|
+
plc.Table([plc.interop.from_arrow(pa.array(map(str, paths)))]),
|
|
483
|
+
plc.interop.from_arrow(pa.array(rows_per_path, type=pa.int32())),
|
|
484
|
+
).columns()
|
|
485
|
+
return df.with_columns([Column(filepaths, name=name)])
|
|
486
|
+
|
|
425
487
|
@classmethod
|
|
426
488
|
def do_evaluate(
|
|
427
489
|
cls,
|
|
428
490
|
schema: Schema,
|
|
429
491
|
typ: str,
|
|
430
492
|
reader_options: dict[str, Any],
|
|
431
|
-
config_options:
|
|
493
|
+
config_options: ConfigOptions,
|
|
432
494
|
paths: list[str],
|
|
433
495
|
with_columns: list[str] | None,
|
|
434
496
|
skip_rows: int,
|
|
435
497
|
n_rows: int,
|
|
436
498
|
row_index: tuple[str, int] | None,
|
|
499
|
+
include_file_paths: str | None,
|
|
437
500
|
predicate: expr.NamedExpr | None,
|
|
438
|
-
):
|
|
501
|
+
) -> DataFrame:
|
|
439
502
|
"""Evaluate and return a dataframe."""
|
|
440
503
|
if typ == "csv":
|
|
504
|
+
|
|
505
|
+
def read_csv_header(
|
|
506
|
+
path: Path | str, sep: str
|
|
507
|
+
) -> list[str]: # pragma: no cover
|
|
508
|
+
with Path(path).open() as f:
|
|
509
|
+
for line in f:
|
|
510
|
+
stripped = line.strip()
|
|
511
|
+
if stripped:
|
|
512
|
+
return stripped.split(sep)
|
|
513
|
+
return []
|
|
514
|
+
|
|
441
515
|
parse_options = reader_options["parse_options"]
|
|
442
516
|
sep = chr(parse_options["separator"])
|
|
443
517
|
quote = chr(parse_options["quote_char"])
|
|
@@ -449,8 +523,8 @@ class Scan(IR):
|
|
|
449
523
|
# file provides column names
|
|
450
524
|
column_names = None
|
|
451
525
|
usecols = with_columns
|
|
452
|
-
|
|
453
|
-
header = 0
|
|
526
|
+
has_header = reader_options["has_header"]
|
|
527
|
+
header = 0 if has_header else -1
|
|
454
528
|
|
|
455
529
|
# polars defaults to no null recognition
|
|
456
530
|
null_values = [""]
|
|
@@ -470,6 +544,7 @@ class Scan(IR):
|
|
|
470
544
|
|
|
471
545
|
# polars skips blank lines at the beginning of the file
|
|
472
546
|
pieces = []
|
|
547
|
+
seen_paths = []
|
|
473
548
|
read_partial = n_rows != -1
|
|
474
549
|
for p in paths:
|
|
475
550
|
skiprows = reader_options["skip_rows"]
|
|
@@ -480,7 +555,9 @@ class Scan(IR):
|
|
|
480
555
|
options = (
|
|
481
556
|
plc.io.csv.CsvReaderOptions.builder(plc.io.SourceInfo([path]))
|
|
482
557
|
.nrows(n_rows)
|
|
483
|
-
.skiprows(
|
|
558
|
+
.skiprows(
|
|
559
|
+
skiprows if POLARS_VERSION_LT_128 else skiprows + skip_rows
|
|
560
|
+
) # pragma: no cover
|
|
484
561
|
.lineterminator(str(eol))
|
|
485
562
|
.quotechar(str(quote))
|
|
486
563
|
.decimal(decimal)
|
|
@@ -491,6 +568,13 @@ class Scan(IR):
|
|
|
491
568
|
options.set_delimiter(str(sep))
|
|
492
569
|
if column_names is not None:
|
|
493
570
|
options.set_names([str(name) for name in column_names])
|
|
571
|
+
else:
|
|
572
|
+
if (
|
|
573
|
+
not POLARS_VERSION_LT_128 and header > -1 and skip_rows > header
|
|
574
|
+
): # pragma: no cover
|
|
575
|
+
# We need to read the header otherwise we would skip it
|
|
576
|
+
column_names = read_csv_header(path, str(sep))
|
|
577
|
+
options.set_names(column_names)
|
|
494
578
|
options.set_header(header)
|
|
495
579
|
options.set_dtypes(schema)
|
|
496
580
|
if usecols is not None:
|
|
@@ -500,6 +584,8 @@ class Scan(IR):
|
|
|
500
584
|
options.set_comment(comment)
|
|
501
585
|
tbl_w_meta = plc.io.csv.read_csv(options)
|
|
502
586
|
pieces.append(tbl_w_meta)
|
|
587
|
+
if include_file_paths is not None:
|
|
588
|
+
seen_paths.append(p)
|
|
503
589
|
if read_partial:
|
|
504
590
|
n_rows -= tbl_w_meta.tbl.num_rows()
|
|
505
591
|
if n_rows <= 0:
|
|
@@ -515,12 +601,26 @@ class Scan(IR):
|
|
|
515
601
|
plc.concatenate.concatenate(list(tables)),
|
|
516
602
|
colnames[0],
|
|
517
603
|
)
|
|
604
|
+
if include_file_paths is not None:
|
|
605
|
+
df = Scan.add_file_paths(
|
|
606
|
+
include_file_paths,
|
|
607
|
+
seen_paths,
|
|
608
|
+
[t.num_rows() for t in tables],
|
|
609
|
+
df,
|
|
610
|
+
)
|
|
518
611
|
elif typ == "parquet":
|
|
519
|
-
|
|
520
|
-
if
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
612
|
+
filters = None
|
|
613
|
+
if predicate is not None and row_index is None:
|
|
614
|
+
# Can't apply filters during read if we have a row index.
|
|
615
|
+
filters = to_parquet_filter(predicate.value)
|
|
616
|
+
options = plc.io.parquet.ParquetReaderOptions.builder(
|
|
617
|
+
plc.io.SourceInfo(paths)
|
|
618
|
+
).build()
|
|
619
|
+
if with_columns is not None:
|
|
620
|
+
options.set_columns(with_columns)
|
|
621
|
+
if filters is not None:
|
|
622
|
+
options.set_filter(filters)
|
|
623
|
+
if config_options.parquet_options.chunked:
|
|
524
624
|
# We handle skip_rows != 0 by reading from the
|
|
525
625
|
# up to n_rows + skip_rows and slicing off the
|
|
526
626
|
# first skip_rows entries.
|
|
@@ -530,21 +630,15 @@ class Scan(IR):
|
|
|
530
630
|
nrows = n_rows + skip_rows
|
|
531
631
|
if nrows > -1:
|
|
532
632
|
options.set_num_rows(nrows)
|
|
533
|
-
if with_columns is not None:
|
|
534
|
-
options.set_columns(with_columns)
|
|
535
633
|
reader = plc.io.parquet.ChunkedParquetReader(
|
|
536
634
|
options,
|
|
537
|
-
chunk_read_limit=parquet_options.
|
|
538
|
-
|
|
539
|
-
),
|
|
540
|
-
pass_read_limit=parquet_options.get(
|
|
541
|
-
"pass_read_limit", cls.PARQUET_DEFAULT_PASS_LIMIT
|
|
542
|
-
),
|
|
635
|
+
chunk_read_limit=config_options.parquet_options.chunk_read_limit,
|
|
636
|
+
pass_read_limit=config_options.parquet_options.pass_read_limit,
|
|
543
637
|
)
|
|
544
|
-
|
|
638
|
+
chunk = reader.read_chunk()
|
|
545
639
|
rows_left_to_skip = skip_rows
|
|
546
640
|
|
|
547
|
-
def slice_skip(tbl: plc.Table):
|
|
641
|
+
def slice_skip(tbl: plc.Table) -> plc.Table:
|
|
548
642
|
nonlocal rows_left_to_skip
|
|
549
643
|
if rows_left_to_skip > 0:
|
|
550
644
|
table_rows = tbl.num_rows()
|
|
@@ -556,12 +650,13 @@ class Scan(IR):
|
|
|
556
650
|
rows_left_to_skip -= chunk_skip
|
|
557
651
|
return tbl
|
|
558
652
|
|
|
559
|
-
tbl = slice_skip(
|
|
653
|
+
tbl = slice_skip(chunk.tbl)
|
|
560
654
|
# TODO: Nested column names
|
|
561
|
-
names =
|
|
655
|
+
names = chunk.column_names(include_children=False)
|
|
562
656
|
concatenated_columns = tbl.columns()
|
|
563
657
|
while reader.has_next():
|
|
564
|
-
|
|
658
|
+
chunk = reader.read_chunk()
|
|
659
|
+
tbl = slice_skip(chunk.tbl)
|
|
565
660
|
|
|
566
661
|
for i in range(tbl.num_columns()):
|
|
567
662
|
concatenated_columns[i] = plc.concatenate.concatenate(
|
|
@@ -574,31 +669,28 @@ class Scan(IR):
|
|
|
574
669
|
plc.Table(concatenated_columns),
|
|
575
670
|
names=names,
|
|
576
671
|
)
|
|
672
|
+
if include_file_paths is not None:
|
|
673
|
+
df = Scan.add_file_paths(
|
|
674
|
+
include_file_paths, paths, chunk.num_rows_per_source, df
|
|
675
|
+
)
|
|
577
676
|
else:
|
|
578
|
-
filters = None
|
|
579
|
-
if predicate is not None and row_index is None:
|
|
580
|
-
# Can't apply filters during read if we have a row index.
|
|
581
|
-
filters = to_parquet_filter(predicate.value)
|
|
582
|
-
options = plc.io.parquet.ParquetReaderOptions.builder(
|
|
583
|
-
plc.io.SourceInfo(paths)
|
|
584
|
-
).build()
|
|
585
677
|
if n_rows != -1:
|
|
586
678
|
options.set_num_rows(n_rows)
|
|
587
679
|
if skip_rows != 0:
|
|
588
680
|
options.set_skip_rows(skip_rows)
|
|
589
|
-
if with_columns is not None:
|
|
590
|
-
options.set_columns(with_columns)
|
|
591
|
-
if filters is not None:
|
|
592
|
-
options.set_filter(filters)
|
|
593
681
|
tbl_w_meta = plc.io.parquet.read_parquet(options)
|
|
594
682
|
df = DataFrame.from_table(
|
|
595
683
|
tbl_w_meta.tbl,
|
|
596
684
|
# TODO: consider nested column names?
|
|
597
685
|
tbl_w_meta.column_names(include_children=False),
|
|
598
686
|
)
|
|
599
|
-
if
|
|
600
|
-
|
|
601
|
-
|
|
687
|
+
if include_file_paths is not None:
|
|
688
|
+
df = Scan.add_file_paths(
|
|
689
|
+
include_file_paths, paths, tbl_w_meta.num_rows_per_source, df
|
|
690
|
+
)
|
|
691
|
+
if filters is not None:
|
|
692
|
+
# Mask must have been applied.
|
|
693
|
+
return df
|
|
602
694
|
|
|
603
695
|
elif typ == "ndjson":
|
|
604
696
|
json_schema: list[plc.io.json.NameAndType] = [
|
|
@@ -629,20 +721,18 @@ class Scan(IR):
|
|
|
629
721
|
name, offset = row_index
|
|
630
722
|
offset += skip_rows
|
|
631
723
|
dtype = schema[name]
|
|
632
|
-
step = plc.
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
init = plc.interop.from_arrow(
|
|
636
|
-
pa.scalar(offset, type=plc.interop.to_arrow(dtype))
|
|
637
|
-
)
|
|
638
|
-
index = Column(
|
|
724
|
+
step = plc.Scalar.from_py(1, dtype)
|
|
725
|
+
init = plc.Scalar.from_py(offset, dtype)
|
|
726
|
+
index_col = Column(
|
|
639
727
|
plc.filling.sequence(df.num_rows, init, step),
|
|
640
728
|
is_sorted=plc.types.Sorted.YES,
|
|
641
729
|
order=plc.types.Order.ASCENDING,
|
|
642
730
|
null_order=plc.types.NullOrder.AFTER,
|
|
643
731
|
name=name,
|
|
644
732
|
)
|
|
645
|
-
df = DataFrame([
|
|
733
|
+
df = DataFrame([index_col, *df.columns])
|
|
734
|
+
if next(iter(schema)) != name:
|
|
735
|
+
df = df.select(schema)
|
|
646
736
|
assert all(c.obj.type() == schema[name] for name, c in df.column_map.items())
|
|
647
737
|
if predicate is None:
|
|
648
738
|
return df
|
|
@@ -651,6 +741,193 @@ class Scan(IR):
|
|
|
651
741
|
return df.filter(mask)
|
|
652
742
|
|
|
653
743
|
|
|
744
|
+
class Sink(IR):
|
|
745
|
+
"""Sink a dataframe to a file."""
|
|
746
|
+
|
|
747
|
+
__slots__ = ("cloud_options", "kind", "options", "path")
|
|
748
|
+
_non_child = ("schema", "kind", "path", "options", "cloud_options")
|
|
749
|
+
|
|
750
|
+
kind: str
|
|
751
|
+
path: str
|
|
752
|
+
options: dict[str, Any]
|
|
753
|
+
|
|
754
|
+
def __init__(
|
|
755
|
+
self,
|
|
756
|
+
schema: Schema,
|
|
757
|
+
kind: str,
|
|
758
|
+
path: str,
|
|
759
|
+
options: dict[str, Any],
|
|
760
|
+
cloud_options: dict[str, Any],
|
|
761
|
+
df: IR,
|
|
762
|
+
):
|
|
763
|
+
self.schema = schema
|
|
764
|
+
self.kind = kind
|
|
765
|
+
self.path = path
|
|
766
|
+
self.options = options
|
|
767
|
+
self.cloud_options = cloud_options
|
|
768
|
+
self.children = (df,)
|
|
769
|
+
self._non_child_args = (schema, kind, path, options)
|
|
770
|
+
if self.cloud_options is not None and any(
|
|
771
|
+
self.cloud_options.get(k) is not None
|
|
772
|
+
for k in ("config", "credential_provider")
|
|
773
|
+
):
|
|
774
|
+
raise NotImplementedError(
|
|
775
|
+
"Write to cloud storage"
|
|
776
|
+
) # pragma: no cover; no test yet
|
|
777
|
+
sync_on_close = options.get("sync_on_close")
|
|
778
|
+
if sync_on_close not in {"None", None}:
|
|
779
|
+
raise NotImplementedError(
|
|
780
|
+
f"sync_on_close='{sync_on_close}' is not supported."
|
|
781
|
+
) # pragma: no cover; no test yet
|
|
782
|
+
child_schema = df.schema.values()
|
|
783
|
+
if kind == "Csv":
|
|
784
|
+
if not all(
|
|
785
|
+
plc.io.csv.is_supported_write_csv(dtype) for dtype in child_schema
|
|
786
|
+
):
|
|
787
|
+
# Nested types are unsupported in polars and libcudf
|
|
788
|
+
raise NotImplementedError(
|
|
789
|
+
"Contains unsupported types for CSV writing"
|
|
790
|
+
) # pragma: no cover
|
|
791
|
+
serialize = options["serialize_options"]
|
|
792
|
+
if options["include_bom"]:
|
|
793
|
+
raise NotImplementedError("include_bom is not supported.")
|
|
794
|
+
for key in (
|
|
795
|
+
"date_format",
|
|
796
|
+
"time_format",
|
|
797
|
+
"datetime_format",
|
|
798
|
+
"float_scientific",
|
|
799
|
+
"float_precision",
|
|
800
|
+
):
|
|
801
|
+
if serialize[key] is not None:
|
|
802
|
+
raise NotImplementedError(f"{key} is not supported.")
|
|
803
|
+
if serialize["quote_style"] != "Necessary":
|
|
804
|
+
raise NotImplementedError("Only quote_style='Necessary' is supported.")
|
|
805
|
+
if chr(serialize["quote_char"]) != '"':
|
|
806
|
+
raise NotImplementedError("Only quote_char='\"' is supported.")
|
|
807
|
+
elif kind == "Parquet":
|
|
808
|
+
compression = options["compression"]
|
|
809
|
+
if isinstance(compression, dict):
|
|
810
|
+
if len(compression) != 1:
|
|
811
|
+
raise NotImplementedError(
|
|
812
|
+
"Compression dict with more than one entry."
|
|
813
|
+
) # pragma: no cover
|
|
814
|
+
compression, compression_level = next(iter(compression.items()))
|
|
815
|
+
options["compression"] = compression
|
|
816
|
+
if compression_level is not None:
|
|
817
|
+
raise NotImplementedError(
|
|
818
|
+
"Setting compression_level is not supported."
|
|
819
|
+
)
|
|
820
|
+
if compression == "Lz4Raw":
|
|
821
|
+
compression = "Lz4"
|
|
822
|
+
options["compression"] = compression
|
|
823
|
+
if (
|
|
824
|
+
compression != "Uncompressed"
|
|
825
|
+
and not plc.io.parquet.is_supported_write_parquet(
|
|
826
|
+
getattr(plc.io.types.CompressionType, compression.upper())
|
|
827
|
+
)
|
|
828
|
+
):
|
|
829
|
+
raise NotImplementedError(
|
|
830
|
+
f"Compression type '{compression}' is not supported."
|
|
831
|
+
)
|
|
832
|
+
elif (
|
|
833
|
+
kind == "Json"
|
|
834
|
+
): # pragma: no cover; options are validated on the polars side
|
|
835
|
+
if not all(
|
|
836
|
+
plc.io.json.is_supported_write_json(dtype) for dtype in child_schema
|
|
837
|
+
):
|
|
838
|
+
# Nested types are unsupported in polars and libcudf
|
|
839
|
+
raise NotImplementedError(
|
|
840
|
+
"Contains unsupported types for JSON writing"
|
|
841
|
+
) # pragma: no cover
|
|
842
|
+
shared_writer_options = {"sync_on_close", "maintain_order", "mkdir"}
|
|
843
|
+
if set(options) - shared_writer_options:
|
|
844
|
+
raise NotImplementedError("Unsupported options passed JSON writer.")
|
|
845
|
+
else:
|
|
846
|
+
raise NotImplementedError(
|
|
847
|
+
f"Unhandled sink kind: {kind}"
|
|
848
|
+
) # pragma: no cover
|
|
849
|
+
|
|
850
|
+
def get_hashable(self) -> Hashable:
|
|
851
|
+
"""
|
|
852
|
+
Hashable representation of the node.
|
|
853
|
+
|
|
854
|
+
The option dictionary is serialised for hashing purposes.
|
|
855
|
+
"""
|
|
856
|
+
schema_hash = tuple(self.schema.items()) # pragma: no cover
|
|
857
|
+
return (
|
|
858
|
+
type(self),
|
|
859
|
+
schema_hash,
|
|
860
|
+
self.kind,
|
|
861
|
+
self.path,
|
|
862
|
+
json.dumps(self.options),
|
|
863
|
+
json.dumps(self.cloud_options),
|
|
864
|
+
) # pragma: no cover
|
|
865
|
+
|
|
866
|
+
@classmethod
|
|
867
|
+
def do_evaluate(
|
|
868
|
+
cls,
|
|
869
|
+
schema: Schema,
|
|
870
|
+
kind: str,
|
|
871
|
+
path: str,
|
|
872
|
+
options: dict[str, Any],
|
|
873
|
+
df: DataFrame,
|
|
874
|
+
) -> DataFrame:
|
|
875
|
+
"""Write the dataframe to a file."""
|
|
876
|
+
target = plc.io.SinkInfo([path])
|
|
877
|
+
|
|
878
|
+
if options.get("mkdir", False):
|
|
879
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
880
|
+
if kind == "Csv":
|
|
881
|
+
serialize = options["serialize_options"]
|
|
882
|
+
options = (
|
|
883
|
+
plc.io.csv.CsvWriterOptions.builder(target, df.table)
|
|
884
|
+
.include_header(options["include_header"])
|
|
885
|
+
.names(df.column_names if options["include_header"] else [])
|
|
886
|
+
.na_rep(serialize["null"])
|
|
887
|
+
.line_terminator(serialize["line_terminator"])
|
|
888
|
+
.inter_column_delimiter(chr(serialize["separator"]))
|
|
889
|
+
.build()
|
|
890
|
+
)
|
|
891
|
+
plc.io.csv.write_csv(options)
|
|
892
|
+
|
|
893
|
+
elif kind == "Parquet":
|
|
894
|
+
metadata = plc.io.types.TableInputMetadata(df.table)
|
|
895
|
+
for i, name in enumerate(df.column_names):
|
|
896
|
+
metadata.column_metadata[i].set_name(name)
|
|
897
|
+
|
|
898
|
+
builder = plc.io.parquet.ParquetWriterOptions.builder(target, df.table)
|
|
899
|
+
compression = options["compression"]
|
|
900
|
+
if compression != "Uncompressed":
|
|
901
|
+
builder.compression(
|
|
902
|
+
getattr(plc.io.types.CompressionType, compression.upper())
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
writer_options = builder.metadata(metadata).build()
|
|
906
|
+
if options["data_page_size"] is not None:
|
|
907
|
+
writer_options.set_max_page_size_bytes(options["data_page_size"])
|
|
908
|
+
if options["row_group_size"] is not None:
|
|
909
|
+
writer_options.set_row_group_size_rows(options["row_group_size"])
|
|
910
|
+
|
|
911
|
+
plc.io.parquet.write_parquet(writer_options)
|
|
912
|
+
|
|
913
|
+
elif kind == "Json":
|
|
914
|
+
metadata = plc.io.TableWithMetadata(
|
|
915
|
+
df.table, [(col, []) for col in df.column_names]
|
|
916
|
+
)
|
|
917
|
+
options = (
|
|
918
|
+
plc.io.json.JsonWriterOptions.builder(target, df.table)
|
|
919
|
+
.lines(val=True)
|
|
920
|
+
.na_rep("null")
|
|
921
|
+
.include_nulls(val=True)
|
|
922
|
+
.metadata(metadata)
|
|
923
|
+
.utf8_escaped(val=False)
|
|
924
|
+
.build()
|
|
925
|
+
)
|
|
926
|
+
plc.io.json.write_json(options)
|
|
927
|
+
|
|
928
|
+
return DataFrame([])
|
|
929
|
+
|
|
930
|
+
|
|
654
931
|
class Cache(IR):
|
|
655
932
|
"""
|
|
656
933
|
Return a cached plan node.
|
|
@@ -658,35 +935,59 @@ class Cache(IR):
|
|
|
658
935
|
Used for CSE at the plan level.
|
|
659
936
|
"""
|
|
660
937
|
|
|
661
|
-
__slots__ = ("key",)
|
|
662
|
-
_non_child = ("schema", "key")
|
|
938
|
+
__slots__ = ("key", "refcount")
|
|
939
|
+
_non_child = ("schema", "key", "refcount")
|
|
663
940
|
key: int
|
|
664
941
|
"""The cache key."""
|
|
942
|
+
refcount: int
|
|
943
|
+
"""The number of cache hits."""
|
|
665
944
|
|
|
666
|
-
def __init__(self, schema: Schema, key: int, value: IR):
|
|
945
|
+
def __init__(self, schema: Schema, key: int, refcount: int, value: IR):
|
|
667
946
|
self.schema = schema
|
|
668
947
|
self.key = key
|
|
948
|
+
self.refcount = refcount
|
|
669
949
|
self.children = (value,)
|
|
670
|
-
self._non_child_args = (key,)
|
|
950
|
+
self._non_child_args = (key, refcount)
|
|
951
|
+
|
|
952
|
+
def get_hashable(self) -> Hashable: # noqa: D102
|
|
953
|
+
# Polars arranges that the keys are unique across all cache
|
|
954
|
+
# nodes that reference the same child, so we don't need to
|
|
955
|
+
# hash the child.
|
|
956
|
+
return (type(self), self.key, self.refcount)
|
|
957
|
+
|
|
958
|
+
def is_equal(self, other: Self) -> bool: # noqa: D102
|
|
959
|
+
if self.key == other.key and self.refcount == other.refcount:
|
|
960
|
+
self.children = other.children
|
|
961
|
+
return True
|
|
962
|
+
return False
|
|
671
963
|
|
|
672
964
|
@classmethod
|
|
673
965
|
def do_evaluate(
|
|
674
|
-
cls, key: int, df: DataFrame
|
|
966
|
+
cls, key: int, refcount: int, df: DataFrame
|
|
675
967
|
) -> DataFrame: # pragma: no cover; basic evaluation never calls this
|
|
676
968
|
"""Evaluate and return a dataframe."""
|
|
677
969
|
# Our value has already been computed for us, so let's just
|
|
678
970
|
# return it.
|
|
679
971
|
return df
|
|
680
972
|
|
|
681
|
-
def evaluate(self, *, cache:
|
|
973
|
+
def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
|
|
682
974
|
"""Evaluate and return a dataframe."""
|
|
683
975
|
# We must override the recursion scheme because we don't want
|
|
684
976
|
# to recurse if we're in the cache.
|
|
685
977
|
try:
|
|
686
|
-
|
|
978
|
+
(result, hits) = cache[self.key]
|
|
687
979
|
except KeyError:
|
|
688
980
|
(value,) = self.children
|
|
689
|
-
|
|
981
|
+
result = value.evaluate(cache=cache, timer=timer)
|
|
982
|
+
cache[self.key] = (result, 0)
|
|
983
|
+
return result
|
|
984
|
+
else:
|
|
985
|
+
hits += 1
|
|
986
|
+
if hits == self.refcount:
|
|
987
|
+
del cache[self.key]
|
|
988
|
+
else:
|
|
989
|
+
cache[self.key] = (result, hits)
|
|
990
|
+
return result
|
|
690
991
|
|
|
691
992
|
|
|
692
993
|
class DataFrameScan(IR):
|
|
@@ -696,13 +997,13 @@ class DataFrameScan(IR):
|
|
|
696
997
|
This typically arises from ``q.collect().lazy()``
|
|
697
998
|
"""
|
|
698
999
|
|
|
699
|
-
__slots__ = ("config_options", "df", "projection")
|
|
1000
|
+
__slots__ = ("_id_for_hash", "config_options", "df", "projection")
|
|
700
1001
|
_non_child = ("schema", "df", "projection", "config_options")
|
|
701
1002
|
df: Any
|
|
702
|
-
"""Polars
|
|
1003
|
+
"""Polars internal PyDataFrame object."""
|
|
703
1004
|
projection: tuple[str, ...] | None
|
|
704
1005
|
"""List of columns to project out."""
|
|
705
|
-
config_options:
|
|
1006
|
+
config_options: ConfigOptions
|
|
706
1007
|
"""GPU-specific configuration options"""
|
|
707
1008
|
|
|
708
1009
|
def __init__(
|
|
@@ -710,29 +1011,35 @@ class DataFrameScan(IR):
|
|
|
710
1011
|
schema: Schema,
|
|
711
1012
|
df: Any,
|
|
712
1013
|
projection: Sequence[str] | None,
|
|
713
|
-
config_options:
|
|
1014
|
+
config_options: ConfigOptions,
|
|
714
1015
|
):
|
|
715
1016
|
self.schema = schema
|
|
716
1017
|
self.df = df
|
|
717
1018
|
self.projection = tuple(projection) if projection is not None else None
|
|
718
1019
|
self.config_options = config_options
|
|
719
|
-
self._non_child_args = (
|
|
1020
|
+
self._non_child_args = (
|
|
1021
|
+
schema,
|
|
1022
|
+
pl.DataFrame._from_pydf(df),
|
|
1023
|
+
self.projection,
|
|
1024
|
+
)
|
|
720
1025
|
self.children = ()
|
|
1026
|
+
self._id_for_hash = random.randint(0, 2**64 - 1)
|
|
721
1027
|
|
|
722
1028
|
def get_hashable(self) -> Hashable:
|
|
723
1029
|
"""
|
|
724
1030
|
Hashable representation of the node.
|
|
725
1031
|
|
|
726
|
-
The (heavy) dataframe object is hashed
|
|
727
|
-
|
|
1032
|
+
The (heavy) dataframe object is not hashed. No two instances of
|
|
1033
|
+
``DataFrameScan`` will have the same hash, even if they have the
|
|
1034
|
+
same schema, projection, and config options, and data.
|
|
728
1035
|
"""
|
|
729
1036
|
schema_hash = tuple(self.schema.items())
|
|
730
1037
|
return (
|
|
731
1038
|
type(self),
|
|
732
1039
|
schema_hash,
|
|
733
|
-
|
|
1040
|
+
self._id_for_hash,
|
|
734
1041
|
self.projection,
|
|
735
|
-
|
|
1042
|
+
self.config_options,
|
|
736
1043
|
)
|
|
737
1044
|
|
|
738
1045
|
@classmethod
|
|
@@ -743,10 +1050,9 @@ class DataFrameScan(IR):
|
|
|
743
1050
|
projection: tuple[str, ...] | None,
|
|
744
1051
|
) -> DataFrame:
|
|
745
1052
|
"""Evaluate and return a dataframe."""
|
|
746
|
-
pdf = pl.DataFrame._from_pydf(df)
|
|
747
1053
|
if projection is not None:
|
|
748
|
-
|
|
749
|
-
df = DataFrame.from_polars(
|
|
1054
|
+
df = df.select(projection)
|
|
1055
|
+
df = DataFrame.from_polars(df)
|
|
750
1056
|
assert all(
|
|
751
1057
|
c.obj.type() == dtype
|
|
752
1058
|
for c, dtype in zip(df.columns, schema.values(), strict=True)
|
|
@@ -820,29 +1126,191 @@ class Reduce(IR):
|
|
|
820
1126
|
) -> DataFrame: # pragma: no cover; not exposed by polars yet
|
|
821
1127
|
"""Evaluate and return a dataframe."""
|
|
822
1128
|
columns = broadcast(*(e.evaluate(df) for e in exprs))
|
|
823
|
-
assert all(column.
|
|
1129
|
+
assert all(column.size == 1 for column in columns)
|
|
824
1130
|
return DataFrame(columns)
|
|
825
1131
|
|
|
826
1132
|
|
|
1133
|
+
class Rolling(IR):
|
|
1134
|
+
"""Perform a (possibly grouped) rolling aggregation."""
|
|
1135
|
+
|
|
1136
|
+
__slots__ = (
|
|
1137
|
+
"agg_requests",
|
|
1138
|
+
"closed_window",
|
|
1139
|
+
"following",
|
|
1140
|
+
"index",
|
|
1141
|
+
"keys",
|
|
1142
|
+
"preceding",
|
|
1143
|
+
"zlice",
|
|
1144
|
+
)
|
|
1145
|
+
_non_child = (
|
|
1146
|
+
"schema",
|
|
1147
|
+
"index",
|
|
1148
|
+
"preceding",
|
|
1149
|
+
"following",
|
|
1150
|
+
"closed_window",
|
|
1151
|
+
"keys",
|
|
1152
|
+
"agg_requests",
|
|
1153
|
+
"zlice",
|
|
1154
|
+
)
|
|
1155
|
+
index: expr.NamedExpr
|
|
1156
|
+
"""Column being rolled over."""
|
|
1157
|
+
preceding: plc.Scalar
|
|
1158
|
+
"""Preceding window extent defining start of window."""
|
|
1159
|
+
following: plc.Scalar
|
|
1160
|
+
"""Following window extent defining end of window."""
|
|
1161
|
+
closed_window: ClosedInterval
|
|
1162
|
+
"""Treatment of window endpoints."""
|
|
1163
|
+
keys: tuple[expr.NamedExpr, ...]
|
|
1164
|
+
"""Grouping keys."""
|
|
1165
|
+
agg_requests: tuple[expr.NamedExpr, ...]
|
|
1166
|
+
"""Aggregation expressions."""
|
|
1167
|
+
zlice: Zlice | None
|
|
1168
|
+
"""Optional slice"""
|
|
1169
|
+
|
|
1170
|
+
def __init__(
|
|
1171
|
+
self,
|
|
1172
|
+
schema: Schema,
|
|
1173
|
+
index: expr.NamedExpr,
|
|
1174
|
+
preceding: plc.Scalar,
|
|
1175
|
+
following: plc.Scalar,
|
|
1176
|
+
closed_window: ClosedInterval,
|
|
1177
|
+
keys: Sequence[expr.NamedExpr],
|
|
1178
|
+
agg_requests: Sequence[expr.NamedExpr],
|
|
1179
|
+
zlice: Zlice | None,
|
|
1180
|
+
df: IR,
|
|
1181
|
+
):
|
|
1182
|
+
self.schema = schema
|
|
1183
|
+
self.index = index
|
|
1184
|
+
self.preceding = preceding
|
|
1185
|
+
self.following = following
|
|
1186
|
+
self.closed_window = closed_window
|
|
1187
|
+
self.keys = tuple(keys)
|
|
1188
|
+
self.agg_requests = tuple(agg_requests)
|
|
1189
|
+
if not all(
|
|
1190
|
+
plc.rolling.is_valid_rolling_aggregation(
|
|
1191
|
+
agg.value.dtype, agg.value.agg_request
|
|
1192
|
+
)
|
|
1193
|
+
for agg in self.agg_requests
|
|
1194
|
+
):
|
|
1195
|
+
raise NotImplementedError("Unsupported rolling aggregation")
|
|
1196
|
+
if any(
|
|
1197
|
+
agg.value.agg_request.kind() == plc.aggregation.Kind.COLLECT_LIST
|
|
1198
|
+
for agg in self.agg_requests
|
|
1199
|
+
):
|
|
1200
|
+
raise NotImplementedError(
|
|
1201
|
+
"Incorrect handling of empty groups for list collection"
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
self.zlice = zlice
|
|
1205
|
+
self.children = (df,)
|
|
1206
|
+
self._non_child_args = (
|
|
1207
|
+
index,
|
|
1208
|
+
preceding,
|
|
1209
|
+
following,
|
|
1210
|
+
closed_window,
|
|
1211
|
+
keys,
|
|
1212
|
+
agg_requests,
|
|
1213
|
+
zlice,
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
@classmethod
|
|
1217
|
+
def do_evaluate(
|
|
1218
|
+
cls,
|
|
1219
|
+
index: expr.NamedExpr,
|
|
1220
|
+
preceding: plc.Scalar,
|
|
1221
|
+
following: plc.Scalar,
|
|
1222
|
+
closed_window: ClosedInterval,
|
|
1223
|
+
keys_in: Sequence[expr.NamedExpr],
|
|
1224
|
+
aggs: Sequence[expr.NamedExpr],
|
|
1225
|
+
zlice: Zlice | None,
|
|
1226
|
+
df: DataFrame,
|
|
1227
|
+
) -> DataFrame:
|
|
1228
|
+
"""Evaluate and return a dataframe."""
|
|
1229
|
+
keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
|
|
1230
|
+
orderby = index.evaluate(df)
|
|
1231
|
+
# Polars casts integral orderby to int64, but only for calculating window bounds
|
|
1232
|
+
if (
|
|
1233
|
+
plc.traits.is_integral(orderby.obj.type())
|
|
1234
|
+
and orderby.obj.type().id() != plc.TypeId.INT64
|
|
1235
|
+
):
|
|
1236
|
+
orderby_obj = plc.unary.cast(orderby.obj, plc.DataType(plc.TypeId.INT64))
|
|
1237
|
+
else:
|
|
1238
|
+
orderby_obj = orderby.obj
|
|
1239
|
+
preceding_window, following_window = range_window_bounds(
|
|
1240
|
+
preceding, following, closed_window
|
|
1241
|
+
)
|
|
1242
|
+
if orderby.obj.null_count() != 0:
|
|
1243
|
+
raise RuntimeError(
|
|
1244
|
+
f"Index column '{index.name}' in rolling may not contain nulls"
|
|
1245
|
+
)
|
|
1246
|
+
if len(keys_in) > 0:
|
|
1247
|
+
# Must always check sortedness
|
|
1248
|
+
table = plc.Table([*(k.obj for k in keys), orderby_obj])
|
|
1249
|
+
n = table.num_columns()
|
|
1250
|
+
if not plc.sorting.is_sorted(
|
|
1251
|
+
table, [plc.types.Order.ASCENDING] * n, [plc.types.NullOrder.BEFORE] * n
|
|
1252
|
+
):
|
|
1253
|
+
raise RuntimeError("Input for grouped rolling is not sorted")
|
|
1254
|
+
else:
|
|
1255
|
+
if not orderby.check_sorted(
|
|
1256
|
+
order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE
|
|
1257
|
+
):
|
|
1258
|
+
raise RuntimeError(
|
|
1259
|
+
f"Index column '{index.name}' in rolling is not sorted, please sort first"
|
|
1260
|
+
)
|
|
1261
|
+
values = plc.rolling.grouped_range_rolling_window(
|
|
1262
|
+
plc.Table([k.obj for k in keys]),
|
|
1263
|
+
orderby_obj,
|
|
1264
|
+
plc.types.Order.ASCENDING, # Polars requires ascending orderby.
|
|
1265
|
+
plc.types.NullOrder.BEFORE, # Doesn't matter, polars doesn't allow nulls in orderby
|
|
1266
|
+
preceding_window,
|
|
1267
|
+
following_window,
|
|
1268
|
+
[rolling.to_request(request.value, orderby, df) for request in aggs],
|
|
1269
|
+
)
|
|
1270
|
+
return DataFrame(
|
|
1271
|
+
itertools.chain(
|
|
1272
|
+
keys,
|
|
1273
|
+
[orderby],
|
|
1274
|
+
(
|
|
1275
|
+
Column(col, name=name)
|
|
1276
|
+
for col, name in zip(
|
|
1277
|
+
values.columns(),
|
|
1278
|
+
(request.name for request in aggs),
|
|
1279
|
+
strict=True,
|
|
1280
|
+
)
|
|
1281
|
+
),
|
|
1282
|
+
)
|
|
1283
|
+
).slice(zlice)
|
|
1284
|
+
|
|
1285
|
+
|
|
827
1286
|
class GroupBy(IR):
|
|
828
1287
|
"""Perform a groupby."""
|
|
829
1288
|
|
|
830
1289
|
__slots__ = (
|
|
831
|
-
"agg_infos",
|
|
832
1290
|
"agg_requests",
|
|
1291
|
+
"config_options",
|
|
833
1292
|
"keys",
|
|
834
1293
|
"maintain_order",
|
|
835
|
-
"
|
|
1294
|
+
"zlice",
|
|
1295
|
+
)
|
|
1296
|
+
_non_child = (
|
|
1297
|
+
"schema",
|
|
1298
|
+
"keys",
|
|
1299
|
+
"agg_requests",
|
|
1300
|
+
"maintain_order",
|
|
1301
|
+
"zlice",
|
|
1302
|
+
"config_options",
|
|
836
1303
|
)
|
|
837
|
-
_non_child = ("schema", "keys", "agg_requests", "maintain_order", "options")
|
|
838
1304
|
keys: tuple[expr.NamedExpr, ...]
|
|
839
1305
|
"""Grouping keys."""
|
|
840
1306
|
agg_requests: tuple[expr.NamedExpr, ...]
|
|
841
1307
|
"""Aggregation expressions."""
|
|
842
1308
|
maintain_order: bool
|
|
843
1309
|
"""Preserve order in groupby."""
|
|
844
|
-
|
|
845
|
-
"""
|
|
1310
|
+
zlice: Zlice | None
|
|
1311
|
+
"""Optional slice to apply after grouping."""
|
|
1312
|
+
config_options: ConfigOptions
|
|
1313
|
+
"""GPU-specific configuration options"""
|
|
846
1314
|
|
|
847
1315
|
def __init__(
|
|
848
1316
|
self,
|
|
@@ -850,70 +1318,33 @@ class GroupBy(IR):
|
|
|
850
1318
|
keys: Sequence[expr.NamedExpr],
|
|
851
1319
|
agg_requests: Sequence[expr.NamedExpr],
|
|
852
1320
|
maintain_order: bool, # noqa: FBT001
|
|
853
|
-
|
|
1321
|
+
zlice: Zlice | None,
|
|
1322
|
+
config_options: ConfigOptions,
|
|
854
1323
|
df: IR,
|
|
855
1324
|
):
|
|
856
1325
|
self.schema = schema
|
|
857
1326
|
self.keys = tuple(keys)
|
|
858
1327
|
self.agg_requests = tuple(agg_requests)
|
|
859
1328
|
self.maintain_order = maintain_order
|
|
860
|
-
self.
|
|
1329
|
+
self.zlice = zlice
|
|
1330
|
+
self.config_options = config_options
|
|
861
1331
|
self.children = (df,)
|
|
862
|
-
if self.options.rolling:
|
|
863
|
-
raise NotImplementedError(
|
|
864
|
-
"rolling window/groupby"
|
|
865
|
-
) # pragma: no cover; rollingwindow constructor has already raised
|
|
866
|
-
if self.options.dynamic:
|
|
867
|
-
raise NotImplementedError("dynamic group by")
|
|
868
|
-
if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests):
|
|
869
|
-
raise NotImplementedError("Nested aggregations in groupby")
|
|
870
|
-
self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests]
|
|
871
1332
|
self._non_child_args = (
|
|
872
1333
|
self.keys,
|
|
873
1334
|
self.agg_requests,
|
|
874
1335
|
maintain_order,
|
|
875
|
-
|
|
876
|
-
self.agg_infos,
|
|
1336
|
+
self.zlice,
|
|
877
1337
|
)
|
|
878
1338
|
|
|
879
|
-
@staticmethod
|
|
880
|
-
def check_agg(agg: expr.Expr) -> int:
|
|
881
|
-
"""
|
|
882
|
-
Determine if we can handle an aggregation expression.
|
|
883
|
-
|
|
884
|
-
Parameters
|
|
885
|
-
----------
|
|
886
|
-
agg
|
|
887
|
-
Expression to check
|
|
888
|
-
|
|
889
|
-
Returns
|
|
890
|
-
-------
|
|
891
|
-
depth of nesting
|
|
892
|
-
|
|
893
|
-
Raises
|
|
894
|
-
------
|
|
895
|
-
NotImplementedError
|
|
896
|
-
For unsupported expression nodes.
|
|
897
|
-
"""
|
|
898
|
-
if isinstance(agg, (expr.BinOp, expr.Cast, expr.UnaryFunction)):
|
|
899
|
-
return max(GroupBy.check_agg(child) for child in agg.children)
|
|
900
|
-
elif isinstance(agg, expr.Agg):
|
|
901
|
-
return 1 + max(GroupBy.check_agg(child) for child in agg.children)
|
|
902
|
-
elif isinstance(agg, (expr.Len, expr.Col, expr.Literal, expr.LiteralColumn)):
|
|
903
|
-
return 0
|
|
904
|
-
else:
|
|
905
|
-
raise NotImplementedError(f"No handler for {agg=}")
|
|
906
|
-
|
|
907
1339
|
@classmethod
|
|
908
1340
|
def do_evaluate(
|
|
909
1341
|
cls,
|
|
910
1342
|
keys_in: Sequence[expr.NamedExpr],
|
|
911
1343
|
agg_requests: Sequence[expr.NamedExpr],
|
|
912
1344
|
maintain_order: bool, # noqa: FBT001
|
|
913
|
-
|
|
914
|
-
agg_infos: Sequence[expr.AggInfo],
|
|
1345
|
+
zlice: Zlice | None,
|
|
915
1346
|
df: DataFrame,
|
|
916
|
-
):
|
|
1347
|
+
) -> DataFrame:
|
|
917
1348
|
"""Evaluate and return a dataframe."""
|
|
918
1349
|
keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
|
|
919
1350
|
sorted = (
|
|
@@ -928,32 +1359,38 @@ class GroupBy(IR):
|
|
|
928
1359
|
column_order=[k.order for k in keys],
|
|
929
1360
|
null_precedence=[k.null_order for k in keys],
|
|
930
1361
|
)
|
|
931
|
-
# TODO: uniquify
|
|
932
1362
|
requests = []
|
|
933
|
-
|
|
934
|
-
for
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
1363
|
+
names = []
|
|
1364
|
+
for request in agg_requests:
|
|
1365
|
+
name = request.name
|
|
1366
|
+
value = request.value
|
|
1367
|
+
if isinstance(value, expr.Len):
|
|
1368
|
+
# A count aggregation, we need a column so use a key column
|
|
1369
|
+
col = keys[0].obj
|
|
1370
|
+
elif isinstance(value, expr.Agg):
|
|
1371
|
+
if value.name == "quantile":
|
|
1372
|
+
child = value.children[0]
|
|
941
1373
|
else:
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
1374
|
+
(child,) = value.children
|
|
1375
|
+
col = child.evaluate(df, context=ExecutionContext.GROUPBY).obj
|
|
1376
|
+
else:
|
|
1377
|
+
# Anything else, we pre-evaluate
|
|
1378
|
+
col = value.evaluate(df, context=ExecutionContext.GROUPBY).obj
|
|
1379
|
+
requests.append(plc.groupby.GroupByRequest(col, [value.agg_request]))
|
|
1380
|
+
names.append(name)
|
|
945
1381
|
group_keys, raw_tables = grouper.aggregate(requests)
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
1382
|
+
results = [
|
|
1383
|
+
Column(column, name=name)
|
|
1384
|
+
for name, column in zip(
|
|
1385
|
+
names,
|
|
1386
|
+
itertools.chain.from_iterable(t.columns() for t in raw_tables),
|
|
1387
|
+
strict=True,
|
|
1388
|
+
)
|
|
1389
|
+
]
|
|
951
1390
|
result_keys = [
|
|
952
1391
|
Column(grouped_key, name=key.name)
|
|
953
1392
|
for key, grouped_key in zip(keys, group_keys.columns(), strict=True)
|
|
954
1393
|
]
|
|
955
|
-
result_subs = DataFrame(raw_columns)
|
|
956
|
-
results = [req.evaluate(result_subs, mapping=mapping) for req in agg_requests]
|
|
957
1394
|
broadcasted = broadcast(*result_keys, *results)
|
|
958
1395
|
# Handle order preservation of groups
|
|
959
1396
|
if maintain_order and not sorted:
|
|
@@ -996,12 +1433,26 @@ class GroupBy(IR):
|
|
|
996
1433
|
ordered_table.columns(), broadcasted, strict=True
|
|
997
1434
|
)
|
|
998
1435
|
]
|
|
999
|
-
return DataFrame(broadcasted).slice(
|
|
1436
|
+
return DataFrame(broadcasted).slice(zlice)
|
|
1000
1437
|
|
|
1001
1438
|
|
|
1002
1439
|
class ConditionalJoin(IR):
|
|
1003
1440
|
"""A conditional inner join of two dataframes on a predicate."""
|
|
1004
1441
|
|
|
1442
|
+
class Predicate:
|
|
1443
|
+
"""Serializable wrapper for a predicate expression."""
|
|
1444
|
+
|
|
1445
|
+
predicate: expr.Expr
|
|
1446
|
+
ast: plc.expressions.Expression
|
|
1447
|
+
|
|
1448
|
+
def __init__(self, predicate: expr.Expr):
|
|
1449
|
+
self.predicate = predicate
|
|
1450
|
+
self.ast = to_ast(predicate)
|
|
1451
|
+
|
|
1452
|
+
def __reduce__(self) -> tuple[Any, ...]:
|
|
1453
|
+
"""Pickle a Predicate object."""
|
|
1454
|
+
return (type(self), (self.predicate,))
|
|
1455
|
+
|
|
1005
1456
|
__slots__ = ("ast_predicate", "options", "predicate")
|
|
1006
1457
|
_non_child = ("schema", "predicate", "options")
|
|
1007
1458
|
predicate: expr.Expr
|
|
@@ -1012,7 +1463,7 @@ class ConditionalJoin(IR):
|
|
|
1012
1463
|
pl_expr.Operator | Iterable[pl_expr.Operator],
|
|
1013
1464
|
],
|
|
1014
1465
|
bool,
|
|
1015
|
-
|
|
1466
|
+
Zlice | None,
|
|
1016
1467
|
str,
|
|
1017
1468
|
bool,
|
|
1018
1469
|
Literal["none", "left", "right", "left_right", "right_left"],
|
|
@@ -1020,7 +1471,7 @@ class ConditionalJoin(IR):
|
|
|
1020
1471
|
"""
|
|
1021
1472
|
tuple of options:
|
|
1022
1473
|
- predicates: tuple of ir join type (eg. ie_join) and (In)Equality conditions
|
|
1023
|
-
-
|
|
1474
|
+
- nulls_equal: do nulls compare equal?
|
|
1024
1475
|
- slice: optional slice to perform after joining.
|
|
1025
1476
|
- suffix: string suffix for right columns if names match
|
|
1026
1477
|
- coalesce: should key columns be coalesced (only makes sense for outer joins)
|
|
@@ -1034,30 +1485,34 @@ class ConditionalJoin(IR):
|
|
|
1034
1485
|
self.predicate = predicate
|
|
1035
1486
|
self.options = options
|
|
1036
1487
|
self.children = (left, right)
|
|
1037
|
-
|
|
1038
|
-
_,
|
|
1488
|
+
predicate_wrapper = self.Predicate(predicate)
|
|
1489
|
+
_, nulls_equal, zlice, suffix, coalesce, maintain_order = self.options
|
|
1039
1490
|
# Preconditions from polars
|
|
1040
|
-
assert not
|
|
1491
|
+
assert not nulls_equal
|
|
1041
1492
|
assert not coalesce
|
|
1042
1493
|
assert maintain_order == "none"
|
|
1043
|
-
if
|
|
1494
|
+
if predicate_wrapper.ast is None:
|
|
1044
1495
|
raise NotImplementedError(
|
|
1045
1496
|
f"Conditional join with predicate {predicate}"
|
|
1046
1497
|
) # pragma: no cover; polars never delivers expressions we can't handle
|
|
1047
|
-
self._non_child_args = (
|
|
1498
|
+
self._non_child_args = (predicate_wrapper, zlice, suffix, maintain_order)
|
|
1048
1499
|
|
|
1049
1500
|
@classmethod
|
|
1050
1501
|
def do_evaluate(
|
|
1051
1502
|
cls,
|
|
1052
|
-
|
|
1053
|
-
zlice:
|
|
1503
|
+
predicate_wrapper: Predicate,
|
|
1504
|
+
zlice: Zlice | None,
|
|
1054
1505
|
suffix: str,
|
|
1055
1506
|
maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
|
|
1056
1507
|
left: DataFrame,
|
|
1057
1508
|
right: DataFrame,
|
|
1058
1509
|
) -> DataFrame:
|
|
1059
1510
|
"""Evaluate and return a dataframe."""
|
|
1060
|
-
lg, rg = plc.join.conditional_inner_join(
|
|
1511
|
+
lg, rg = plc.join.conditional_inner_join(
|
|
1512
|
+
left.table,
|
|
1513
|
+
right.table,
|
|
1514
|
+
predicate_wrapper.ast,
|
|
1515
|
+
)
|
|
1061
1516
|
left = DataFrame.from_table(
|
|
1062
1517
|
plc.copying.gather(
|
|
1063
1518
|
left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
|
|
@@ -1084,8 +1539,8 @@ class ConditionalJoin(IR):
|
|
|
1084
1539
|
class Join(IR):
|
|
1085
1540
|
"""A join of two dataframes."""
|
|
1086
1541
|
|
|
1087
|
-
__slots__ = ("left_on", "options", "right_on")
|
|
1088
|
-
_non_child = ("schema", "left_on", "right_on", "options")
|
|
1542
|
+
__slots__ = ("config_options", "left_on", "options", "right_on")
|
|
1543
|
+
_non_child = ("schema", "left_on", "right_on", "options", "config_options")
|
|
1089
1544
|
left_on: tuple[expr.NamedExpr, ...]
|
|
1090
1545
|
"""List of expressions used as keys in the left frame."""
|
|
1091
1546
|
right_on: tuple[expr.NamedExpr, ...]
|
|
@@ -1093,7 +1548,7 @@ class Join(IR):
|
|
|
1093
1548
|
options: tuple[
|
|
1094
1549
|
Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
|
|
1095
1550
|
bool,
|
|
1096
|
-
|
|
1551
|
+
Zlice | None,
|
|
1097
1552
|
str,
|
|
1098
1553
|
bool,
|
|
1099
1554
|
Literal["none", "left", "right", "left_right", "right_left"],
|
|
@@ -1101,12 +1556,14 @@ class Join(IR):
|
|
|
1101
1556
|
"""
|
|
1102
1557
|
tuple of options:
|
|
1103
1558
|
- how: join type
|
|
1104
|
-
-
|
|
1559
|
+
- nulls_equal: do nulls compare equal?
|
|
1105
1560
|
- slice: optional slice to perform after joining.
|
|
1106
1561
|
- suffix: string suffix for right columns if names match
|
|
1107
1562
|
- coalesce: should key columns be coalesced (only makes sense for outer joins)
|
|
1108
1563
|
- maintain_order: which DataFrame row order to preserve, if any
|
|
1109
1564
|
"""
|
|
1565
|
+
config_options: ConfigOptions
|
|
1566
|
+
"""GPU-specific configuration options"""
|
|
1110
1567
|
|
|
1111
1568
|
def __init__(
|
|
1112
1569
|
self,
|
|
@@ -1114,6 +1571,7 @@ class Join(IR):
|
|
|
1114
1571
|
left_on: Sequence[expr.NamedExpr],
|
|
1115
1572
|
right_on: Sequence[expr.NamedExpr],
|
|
1116
1573
|
options: Any,
|
|
1574
|
+
config_options: ConfigOptions,
|
|
1117
1575
|
left: IR,
|
|
1118
1576
|
right: IR,
|
|
1119
1577
|
):
|
|
@@ -1121,6 +1579,7 @@ class Join(IR):
|
|
|
1121
1579
|
self.left_on = tuple(left_on)
|
|
1122
1580
|
self.right_on = tuple(right_on)
|
|
1123
1581
|
self.options = options
|
|
1582
|
+
self.config_options = config_options
|
|
1124
1583
|
self.children = (left, right)
|
|
1125
1584
|
self._non_child_args = (self.left_on, self.right_on, self.options)
|
|
1126
1585
|
# TODO: Implement maintain_order
|
|
@@ -1203,9 +1662,8 @@ class Join(IR):
|
|
|
1203
1662
|
left keys, and is stable wrt the right keys. For all other
|
|
1204
1663
|
joins, there is no order obligation.
|
|
1205
1664
|
"""
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
step = plc.interop.from_arrow(pa.scalar(1, type=dt))
|
|
1665
|
+
init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE)
|
|
1666
|
+
step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE)
|
|
1209
1667
|
left_order = plc.copying.gather(
|
|
1210
1668
|
plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
|
|
1211
1669
|
)
|
|
@@ -1227,7 +1685,7 @@ class Join(IR):
|
|
|
1227
1685
|
options: tuple[
|
|
1228
1686
|
Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
|
|
1229
1687
|
bool,
|
|
1230
|
-
|
|
1688
|
+
Zlice | None,
|
|
1231
1689
|
str,
|
|
1232
1690
|
bool,
|
|
1233
1691
|
Literal["none", "left", "right", "left_right", "right_left"],
|
|
@@ -1236,7 +1694,7 @@ class Join(IR):
|
|
|
1236
1694
|
right: DataFrame,
|
|
1237
1695
|
) -> DataFrame:
|
|
1238
1696
|
"""Evaluate and return a dataframe."""
|
|
1239
|
-
how,
|
|
1697
|
+
how, nulls_equal, zlice, suffix, coalesce, _ = options
|
|
1240
1698
|
if how == "Cross":
|
|
1241
1699
|
# Separate implementation, since cross_join returns the
|
|
1242
1700
|
# result, not the gather maps
|
|
@@ -1264,7 +1722,7 @@ class Join(IR):
|
|
|
1264
1722
|
right_on = DataFrame(broadcast(*(e.evaluate(right) for e in right_on_exprs)))
|
|
1265
1723
|
null_equality = (
|
|
1266
1724
|
plc.types.NullEquality.EQUAL
|
|
1267
|
-
if
|
|
1725
|
+
if nulls_equal
|
|
1268
1726
|
else plc.types.NullEquality.UNEQUAL
|
|
1269
1727
|
)
|
|
1270
1728
|
join_fn, left_policy, right_policy = cls._joiners(how)
|
|
@@ -1385,7 +1843,7 @@ class Distinct(IR):
|
|
|
1385
1843
|
subset: frozenset[str] | None
|
|
1386
1844
|
"""Which columns should be used to define distinctness. If None,
|
|
1387
1845
|
then all columns are used."""
|
|
1388
|
-
zlice:
|
|
1846
|
+
zlice: Zlice | None
|
|
1389
1847
|
"""Optional slice to apply to the result."""
|
|
1390
1848
|
stable: bool
|
|
1391
1849
|
"""Should the result maintain ordering."""
|
|
@@ -1395,7 +1853,7 @@ class Distinct(IR):
|
|
|
1395
1853
|
schema: Schema,
|
|
1396
1854
|
keep: plc.stream_compaction.DuplicateKeepOption,
|
|
1397
1855
|
subset: frozenset[str] | None,
|
|
1398
|
-
zlice:
|
|
1856
|
+
zlice: Zlice | None,
|
|
1399
1857
|
stable: bool, # noqa: FBT001
|
|
1400
1858
|
df: IR,
|
|
1401
1859
|
):
|
|
@@ -1419,10 +1877,10 @@ class Distinct(IR):
|
|
|
1419
1877
|
cls,
|
|
1420
1878
|
keep: plc.stream_compaction.DuplicateKeepOption,
|
|
1421
1879
|
subset: frozenset[str] | None,
|
|
1422
|
-
zlice:
|
|
1880
|
+
zlice: Zlice | None,
|
|
1423
1881
|
stable: bool, # noqa: FBT001
|
|
1424
1882
|
df: DataFrame,
|
|
1425
|
-
):
|
|
1883
|
+
) -> DataFrame:
|
|
1426
1884
|
"""Evaluate and return a dataframe."""
|
|
1427
1885
|
if subset is None:
|
|
1428
1886
|
indices = list(range(df.num_columns))
|
|
@@ -1475,7 +1933,7 @@ class Sort(IR):
|
|
|
1475
1933
|
"""Null sorting location for each sort key."""
|
|
1476
1934
|
stable: bool
|
|
1477
1935
|
"""Should the sort be stable?"""
|
|
1478
|
-
zlice:
|
|
1936
|
+
zlice: Zlice | None
|
|
1479
1937
|
"""Optional slice to apply to the result."""
|
|
1480
1938
|
|
|
1481
1939
|
def __init__(
|
|
@@ -1485,7 +1943,7 @@ class Sort(IR):
|
|
|
1485
1943
|
order: Sequence[plc.types.Order],
|
|
1486
1944
|
null_order: Sequence[plc.types.NullOrder],
|
|
1487
1945
|
stable: bool, # noqa: FBT001
|
|
1488
|
-
zlice:
|
|
1946
|
+
zlice: Zlice | None,
|
|
1489
1947
|
df: IR,
|
|
1490
1948
|
):
|
|
1491
1949
|
self.schema = schema
|
|
@@ -1510,17 +1968,11 @@ class Sort(IR):
|
|
|
1510
1968
|
order: Sequence[plc.types.Order],
|
|
1511
1969
|
null_order: Sequence[plc.types.NullOrder],
|
|
1512
1970
|
stable: bool, # noqa: FBT001
|
|
1513
|
-
zlice:
|
|
1971
|
+
zlice: Zlice | None,
|
|
1514
1972
|
df: DataFrame,
|
|
1515
1973
|
) -> DataFrame:
|
|
1516
1974
|
"""Evaluate and return a dataframe."""
|
|
1517
1975
|
sort_keys = broadcast(*(k.evaluate(df) for k in by), target_length=df.num_rows)
|
|
1518
|
-
# TODO: More robust identification here.
|
|
1519
|
-
keys_in_result = {
|
|
1520
|
-
k.name: i
|
|
1521
|
-
for i, k in enumerate(sort_keys)
|
|
1522
|
-
if k.name in df.column_map and k.obj is df.column_map[k.name].obj
|
|
1523
|
-
}
|
|
1524
1976
|
do_sort = plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
|
|
1525
1977
|
table = do_sort(
|
|
1526
1978
|
df.table,
|
|
@@ -1528,19 +1980,17 @@ class Sort(IR):
|
|
|
1528
1980
|
list(order),
|
|
1529
1981
|
list(null_order),
|
|
1530
1982
|
)
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
columns.append(column)
|
|
1543
|
-
return DataFrame(columns).slice(zlice)
|
|
1983
|
+
result = DataFrame.from_table(table, df.column_names)
|
|
1984
|
+
first_key = sort_keys[0]
|
|
1985
|
+
name = by[0].name
|
|
1986
|
+
first_key_in_result = (
|
|
1987
|
+
name in df.column_map and first_key.obj is df.column_map[name].obj
|
|
1988
|
+
)
|
|
1989
|
+
if first_key_in_result:
|
|
1990
|
+
result.column_map[name].set_sorted(
|
|
1991
|
+
is_sorted=plc.types.Sorted.YES, order=order[0], null_order=null_order[0]
|
|
1992
|
+
)
|
|
1993
|
+
return result.slice(zlice)
|
|
1544
1994
|
|
|
1545
1995
|
|
|
1546
1996
|
class Slice(IR):
|
|
@@ -1608,6 +2058,42 @@ class Projection(IR):
|
|
|
1608
2058
|
return DataFrame(columns)
|
|
1609
2059
|
|
|
1610
2060
|
|
|
2061
|
+
class MergeSorted(IR):
|
|
2062
|
+
"""Merge sorted operation."""
|
|
2063
|
+
|
|
2064
|
+
__slots__ = ("key",)
|
|
2065
|
+
_non_child = ("schema", "key")
|
|
2066
|
+
key: str
|
|
2067
|
+
"""Key that is sorted."""
|
|
2068
|
+
|
|
2069
|
+
def __init__(self, schema: Schema, key: str, left: IR, right: IR):
|
|
2070
|
+
assert isinstance(left, Sort)
|
|
2071
|
+
assert isinstance(right, Sort)
|
|
2072
|
+
assert left.order == right.order
|
|
2073
|
+
assert len(left.schema.keys()) <= len(right.schema.keys())
|
|
2074
|
+
self.schema = schema
|
|
2075
|
+
self.key = key
|
|
2076
|
+
self.children = (left, right)
|
|
2077
|
+
self._non_child_args = (key,)
|
|
2078
|
+
|
|
2079
|
+
@classmethod
|
|
2080
|
+
def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
|
|
2081
|
+
"""Evaluate and return a dataframe."""
|
|
2082
|
+
left, right = dfs
|
|
2083
|
+
right = right.discard_columns(right.column_names_set - left.column_names_set)
|
|
2084
|
+
on_col_left = left.select_columns({key})[0]
|
|
2085
|
+
on_col_right = right.select_columns({key})[0]
|
|
2086
|
+
return DataFrame.from_table(
|
|
2087
|
+
plc.merge.merge(
|
|
2088
|
+
[right.table, left.table],
|
|
2089
|
+
[left.column_names.index(key), right.column_names.index(key)],
|
|
2090
|
+
[on_col_left.order, on_col_right.order],
|
|
2091
|
+
[on_col_left.null_order, on_col_right.null_order],
|
|
2092
|
+
),
|
|
2093
|
+
left.column_names,
|
|
2094
|
+
)
|
|
2095
|
+
|
|
2096
|
+
|
|
1611
2097
|
class MapFunction(IR):
|
|
1612
2098
|
"""Apply some function to a dataframe."""
|
|
1613
2099
|
|
|
@@ -1621,13 +2107,10 @@ class MapFunction(IR):
|
|
|
1621
2107
|
_NAMES: ClassVar[frozenset[str]] = frozenset(
|
|
1622
2108
|
[
|
|
1623
2109
|
"rechunk",
|
|
1624
|
-
# libcudf merge is not stable wrt order of inputs, since
|
|
1625
|
-
# it uses a priority queue to manage the tables it produces.
|
|
1626
|
-
# See: https://github.com/rapidsai/cudf/issues/16010
|
|
1627
|
-
# "merge_sorted",
|
|
1628
2110
|
"rename",
|
|
1629
2111
|
"explode",
|
|
1630
2112
|
"unpivot",
|
|
2113
|
+
"row_index",
|
|
1631
2114
|
]
|
|
1632
2115
|
)
|
|
1633
2116
|
|
|
@@ -1636,8 +2119,12 @@ class MapFunction(IR):
|
|
|
1636
2119
|
self.name = name
|
|
1637
2120
|
self.options = options
|
|
1638
2121
|
self.children = (df,)
|
|
1639
|
-
if
|
|
1640
|
-
|
|
2122
|
+
if (
|
|
2123
|
+
self.name not in MapFunction._NAMES
|
|
2124
|
+
): # pragma: no cover; need more polars rust functions
|
|
2125
|
+
raise NotImplementedError(
|
|
2126
|
+
f"Unhandled map function {self.name}"
|
|
2127
|
+
) # pragma: no cover
|
|
1641
2128
|
if self.name == "explode":
|
|
1642
2129
|
(to_explode,) = self.options
|
|
1643
2130
|
if len(to_explode) > 1:
|
|
@@ -1674,6 +2161,9 @@ class MapFunction(IR):
|
|
|
1674
2161
|
variable_name,
|
|
1675
2162
|
value_name,
|
|
1676
2163
|
)
|
|
2164
|
+
elif self.name == "row_index":
|
|
2165
|
+
col_name, offset = options
|
|
2166
|
+
self.options = (col_name, offset)
|
|
1677
2167
|
self._non_child_args = (schema, name, self.options)
|
|
1678
2168
|
|
|
1679
2169
|
@classmethod
|
|
@@ -1739,6 +2229,19 @@ class MapFunction(IR):
|
|
|
1739
2229
|
Column(value_column, name=value_name),
|
|
1740
2230
|
]
|
|
1741
2231
|
)
|
|
2232
|
+
elif name == "row_index":
|
|
2233
|
+
col_name, offset = options
|
|
2234
|
+
dtype = schema[col_name]
|
|
2235
|
+
step = plc.Scalar.from_py(1, dtype)
|
|
2236
|
+
init = plc.Scalar.from_py(offset, dtype)
|
|
2237
|
+
index_col = Column(
|
|
2238
|
+
plc.filling.sequence(df.num_rows, init, step),
|
|
2239
|
+
is_sorted=plc.types.Sorted.YES,
|
|
2240
|
+
order=plc.types.Order.ASCENDING,
|
|
2241
|
+
null_order=plc.types.NullOrder.AFTER,
|
|
2242
|
+
name=col_name,
|
|
2243
|
+
)
|
|
2244
|
+
return DataFrame([index_col, *df.columns])
|
|
1742
2245
|
else:
|
|
1743
2246
|
raise AssertionError("Should never be reached") # pragma: no cover
|
|
1744
2247
|
|
|
@@ -1748,10 +2251,10 @@ class Union(IR):
|
|
|
1748
2251
|
|
|
1749
2252
|
__slots__ = ("zlice",)
|
|
1750
2253
|
_non_child = ("schema", "zlice")
|
|
1751
|
-
zlice:
|
|
2254
|
+
zlice: Zlice | None
|
|
1752
2255
|
"""Optional slice to apply to the result."""
|
|
1753
2256
|
|
|
1754
|
-
def __init__(self, schema: Schema, zlice:
|
|
2257
|
+
def __init__(self, schema: Schema, zlice: Zlice | None, *children: IR):
|
|
1755
2258
|
self.schema = schema
|
|
1756
2259
|
self.zlice = zlice
|
|
1757
2260
|
self._non_child_args = (zlice,)
|
|
@@ -1759,7 +2262,7 @@ class Union(IR):
|
|
|
1759
2262
|
schema = self.children[0].schema
|
|
1760
2263
|
|
|
1761
2264
|
@classmethod
|
|
1762
|
-
def do_evaluate(cls, zlice:
|
|
2265
|
+
def do_evaluate(cls, zlice: Zlice | None, *dfs: DataFrame) -> DataFrame:
|
|
1763
2266
|
"""Evaluate and return a dataframe."""
|
|
1764
2267
|
# TODO: only evaluate what we need if we have a slice?
|
|
1765
2268
|
return DataFrame.from_table(
|
|
@@ -1771,12 +2274,18 @@ class Union(IR):
|
|
|
1771
2274
|
class HConcat(IR):
|
|
1772
2275
|
"""Concatenate dataframes horizontally."""
|
|
1773
2276
|
|
|
1774
|
-
__slots__ = ()
|
|
1775
|
-
_non_child = ("schema",)
|
|
2277
|
+
__slots__ = ("should_broadcast",)
|
|
2278
|
+
_non_child = ("schema", "should_broadcast")
|
|
1776
2279
|
|
|
1777
|
-
def __init__(
|
|
2280
|
+
def __init__(
|
|
2281
|
+
self,
|
|
2282
|
+
schema: Schema,
|
|
2283
|
+
should_broadcast: bool, # noqa: FBT001
|
|
2284
|
+
*children: IR,
|
|
2285
|
+
):
|
|
1778
2286
|
self.schema = schema
|
|
1779
|
-
self.
|
|
2287
|
+
self.should_broadcast = should_broadcast
|
|
2288
|
+
self._non_child_args = (should_broadcast,)
|
|
1780
2289
|
self.children = children
|
|
1781
2290
|
|
|
1782
2291
|
@staticmethod
|
|
@@ -1808,8 +2317,19 @@ class HConcat(IR):
|
|
|
1808
2317
|
)
|
|
1809
2318
|
|
|
1810
2319
|
@classmethod
|
|
1811
|
-
def do_evaluate(
|
|
2320
|
+
def do_evaluate(
|
|
2321
|
+
cls,
|
|
2322
|
+
should_broadcast: bool, # noqa: FBT001
|
|
2323
|
+
*dfs: DataFrame,
|
|
2324
|
+
) -> DataFrame:
|
|
1812
2325
|
"""Evaluate and return a dataframe."""
|
|
2326
|
+
# Special should_broadcast case.
|
|
2327
|
+
# Used to recombine decomposed expressions
|
|
2328
|
+
if should_broadcast:
|
|
2329
|
+
return DataFrame(
|
|
2330
|
+
broadcast(*itertools.chain.from_iterable(df.columns for df in dfs))
|
|
2331
|
+
)
|
|
2332
|
+
|
|
1813
2333
|
max_rows = max(df.num_rows for df in dfs)
|
|
1814
2334
|
# Horizontal concatenation extends shorter tables with nulls
|
|
1815
2335
|
return DataFrame(
|
|
@@ -1826,3 +2346,20 @@ class HConcat(IR):
|
|
|
1826
2346
|
)
|
|
1827
2347
|
)
|
|
1828
2348
|
)
|
|
2349
|
+
|
|
2350
|
+
|
|
2351
|
+
class Empty(IR):
|
|
2352
|
+
"""Represents an empty DataFrame."""
|
|
2353
|
+
|
|
2354
|
+
__slots__ = ()
|
|
2355
|
+
_non_child = ()
|
|
2356
|
+
|
|
2357
|
+
def __init__(self) -> None:
|
|
2358
|
+
self.schema = {}
|
|
2359
|
+
self._non_child_args = ()
|
|
2360
|
+
self.children = ()
|
|
2361
|
+
|
|
2362
|
+
@classmethod
|
|
2363
|
+
def do_evaluate(cls) -> DataFrame: # pragma: no cover
|
|
2364
|
+
"""Evaluate and return a dataframe."""
|
|
2365
|
+
return DataFrame([])
|