cudf-polars-cu13 25.10.0__py3-none-any.whl → 25.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +32 -8
- cudf_polars/containers/column.py +94 -59
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +235 -102
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +9 -3
- cudf_polars/dsl/expressions/unary.py +117 -58
- cudf_polars/dsl/ir.py +923 -290
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +294 -97
- cudf_polars/dsl/utils/aggregations.py +34 -26
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +45 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +791 -1
- cudf_polars/experimental/benchmarks/utils.py +515 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +22 -10
- cudf_polars/experimental/groupby.py +23 -4
- cudf_polars/experimental/io.py +93 -83
- cudf_polars/experimental/join.py +39 -22
- cudf_polars/experimental/parallel.py +60 -14
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/core.py +361 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +150 -0
- cudf_polars/experimental/rapidsmpf/io.py +604 -0
- cudf_polars/experimental/rapidsmpf/join.py +237 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +494 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +151 -0
- cudf_polars/experimental/rapidsmpf/shuffle.py +277 -0
- cudf_polars/experimental/rapidsmpf/union.py +96 -0
- cudf_polars/experimental/rapidsmpf/utils.py +162 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +28 -8
- cudf_polars/experimental/sort.py +92 -25
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +88 -15
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +406 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +3 -2
- cudf_polars_cu13-25.12.0.dist-info/METADATA +182 -0
- cudf_polars_cu13-25.12.0.dist-info/RECORD +104 -0
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-25.12.0.dist-info}/WHEEL +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-25.12.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-25.12.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/ir.py
CHANGED
|
@@ -17,27 +17,40 @@ import itertools
|
|
|
17
17
|
import json
|
|
18
18
|
import random
|
|
19
19
|
import time
|
|
20
|
+
from dataclasses import dataclass
|
|
20
21
|
from functools import cache
|
|
21
22
|
from pathlib import Path
|
|
22
|
-
from typing import TYPE_CHECKING, Any, ClassVar
|
|
23
|
+
from typing import TYPE_CHECKING, Any, ClassVar, overload
|
|
23
24
|
|
|
24
25
|
from typing_extensions import assert_never
|
|
25
26
|
|
|
26
27
|
import polars as pl
|
|
27
28
|
|
|
28
29
|
import pylibcudf as plc
|
|
30
|
+
from pylibcudf import expressions as plc_expr
|
|
29
31
|
|
|
30
32
|
import cudf_polars.dsl.expr as expr
|
|
31
33
|
from cudf_polars.containers import Column, DataFrame, DataType
|
|
34
|
+
from cudf_polars.containers.dataframe import NamedColumn
|
|
32
35
|
from cudf_polars.dsl.expressions import rolling, unary
|
|
33
36
|
from cudf_polars.dsl.expressions.base import ExecutionContext
|
|
34
37
|
from cudf_polars.dsl.nodebase import Node
|
|
35
38
|
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
|
|
36
|
-
from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
|
|
39
|
+
from cudf_polars.dsl.tracing import log_do_evaluate, nvtx_annotate_cudf_polars
|
|
37
40
|
from cudf_polars.dsl.utils.reshape import broadcast
|
|
38
|
-
from cudf_polars.dsl.utils.windows import
|
|
41
|
+
from cudf_polars.dsl.utils.windows import (
|
|
42
|
+
offsets_to_windows,
|
|
43
|
+
range_window_bounds,
|
|
44
|
+
)
|
|
39
45
|
from cudf_polars.utils import dtypes
|
|
40
|
-
from cudf_polars.utils.
|
|
46
|
+
from cudf_polars.utils.config import CUDAStreamPolicy
|
|
47
|
+
from cudf_polars.utils.cuda_stream import (
|
|
48
|
+
get_cuda_stream,
|
|
49
|
+
get_joined_cuda_stream,
|
|
50
|
+
get_new_cuda_stream,
|
|
51
|
+
join_cuda_streams,
|
|
52
|
+
)
|
|
53
|
+
from cudf_polars.utils.versions import POLARS_VERSION_LT_131, POLARS_VERSION_LT_134
|
|
41
54
|
|
|
42
55
|
if TYPE_CHECKING:
|
|
43
56
|
from collections.abc import Callable, Hashable, Iterable, Sequence
|
|
@@ -45,14 +58,15 @@ if TYPE_CHECKING:
|
|
|
45
58
|
|
|
46
59
|
from typing_extensions import Self
|
|
47
60
|
|
|
48
|
-
from polars
|
|
61
|
+
from polars import polars # type: ignore[attr-defined]
|
|
62
|
+
|
|
63
|
+
from rmm.pylibrmm.stream import Stream
|
|
49
64
|
|
|
50
65
|
from cudf_polars.containers.dataframe import NamedColumn
|
|
51
66
|
from cudf_polars.typing import CSECache, ClosedInterval, Schema, Slice as Zlice
|
|
52
|
-
from cudf_polars.utils.config import ParquetOptions
|
|
67
|
+
from cudf_polars.utils.config import ConfigOptions, ParquetOptions
|
|
53
68
|
from cudf_polars.utils.timer import Timer
|
|
54
69
|
|
|
55
|
-
|
|
56
70
|
__all__ = [
|
|
57
71
|
"IR",
|
|
58
72
|
"Cache",
|
|
@@ -65,6 +79,7 @@ __all__ = [
|
|
|
65
79
|
"GroupBy",
|
|
66
80
|
"HConcat",
|
|
67
81
|
"HStack",
|
|
82
|
+
"IRExecutionContext",
|
|
68
83
|
"Join",
|
|
69
84
|
"MapFunction",
|
|
70
85
|
"MergeSorted",
|
|
@@ -81,6 +96,53 @@ __all__ = [
|
|
|
81
96
|
]
|
|
82
97
|
|
|
83
98
|
|
|
99
|
+
@dataclass(frozen=True)
|
|
100
|
+
class IRExecutionContext:
|
|
101
|
+
"""
|
|
102
|
+
Runtime context for IR node execution.
|
|
103
|
+
|
|
104
|
+
This dataclass holds runtime information and configuration needed
|
|
105
|
+
during the evaluation of IR nodes.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
get_cuda_stream
|
|
110
|
+
A zero-argument callable that returns a CUDA stream.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
get_cuda_stream: Callable[[], Stream]
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def from_config_options(cls, config_options: ConfigOptions) -> IRExecutionContext:
|
|
117
|
+
"""Create an IRExecutionContext from ConfigOptions."""
|
|
118
|
+
match config_options.cuda_stream_policy:
|
|
119
|
+
case CUDAStreamPolicy.DEFAULT:
|
|
120
|
+
return cls(get_cuda_stream=get_cuda_stream)
|
|
121
|
+
case CUDAStreamPolicy.NEW:
|
|
122
|
+
return cls(get_cuda_stream=get_new_cuda_stream)
|
|
123
|
+
case _: # pragma: no cover
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"Invalid CUDA stream policy: {config_options.cuda_stream_policy}"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
_BINOPS = {
|
|
130
|
+
plc.binaryop.BinaryOperator.EQUAL,
|
|
131
|
+
plc.binaryop.BinaryOperator.NOT_EQUAL,
|
|
132
|
+
plc.binaryop.BinaryOperator.LESS,
|
|
133
|
+
plc.binaryop.BinaryOperator.LESS_EQUAL,
|
|
134
|
+
plc.binaryop.BinaryOperator.GREATER,
|
|
135
|
+
plc.binaryop.BinaryOperator.GREATER_EQUAL,
|
|
136
|
+
# TODO: Handle other binary operations as needed
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
_DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
_FLOAT_TYPES = {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}
|
|
144
|
+
|
|
145
|
+
|
|
84
146
|
class IR(Node["IR"]):
|
|
85
147
|
"""Abstract plan node, representing an unevaluated dataframe."""
|
|
86
148
|
|
|
@@ -134,7 +196,9 @@ class IR(Node["IR"]):
|
|
|
134
196
|
translation phase should fail earlier.
|
|
135
197
|
"""
|
|
136
198
|
|
|
137
|
-
def evaluate(
|
|
199
|
+
def evaluate(
|
|
200
|
+
self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
|
|
201
|
+
) -> DataFrame:
|
|
138
202
|
"""
|
|
139
203
|
Evaluate the node (recursively) and return a dataframe.
|
|
140
204
|
|
|
@@ -146,6 +210,8 @@ class IR(Node["IR"]):
|
|
|
146
210
|
timer
|
|
147
211
|
If not None, a Timer object to record timings for the
|
|
148
212
|
evaluation of the node.
|
|
213
|
+
context
|
|
214
|
+
The execution context for the node.
|
|
149
215
|
|
|
150
216
|
Notes
|
|
151
217
|
-----
|
|
@@ -164,16 +230,19 @@ class IR(Node["IR"]):
|
|
|
164
230
|
If evaluation fails. Ideally this should not occur, since the
|
|
165
231
|
translation phase should fail earlier.
|
|
166
232
|
"""
|
|
167
|
-
children = [
|
|
233
|
+
children = [
|
|
234
|
+
child.evaluate(cache=cache, timer=timer, context=context)
|
|
235
|
+
for child in self.children
|
|
236
|
+
]
|
|
168
237
|
if timer is not None:
|
|
169
238
|
start = time.monotonic_ns()
|
|
170
|
-
result = self.do_evaluate(*self._non_child_args, *children)
|
|
239
|
+
result = self.do_evaluate(*self._non_child_args, *children, context=context)
|
|
171
240
|
end = time.monotonic_ns()
|
|
172
241
|
# TODO: Set better names on each class object.
|
|
173
242
|
timer.store(start, end, type(self).__name__)
|
|
174
243
|
return result
|
|
175
244
|
else:
|
|
176
|
-
return self.do_evaluate(*self._non_child_args, *children)
|
|
245
|
+
return self.do_evaluate(*self._non_child_args, *children, context=context)
|
|
177
246
|
|
|
178
247
|
|
|
179
248
|
class ErrorNode(IR):
|
|
@@ -212,29 +281,93 @@ class PythonScan(IR):
|
|
|
212
281
|
raise NotImplementedError("PythonScan not implemented")
|
|
213
282
|
|
|
214
283
|
|
|
284
|
+
_DECIMAL_IDS = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
|
|
285
|
+
|
|
286
|
+
_COMPARISON_BINOPS = {
|
|
287
|
+
plc.binaryop.BinaryOperator.EQUAL,
|
|
288
|
+
plc.binaryop.BinaryOperator.NOT_EQUAL,
|
|
289
|
+
plc.binaryop.BinaryOperator.LESS,
|
|
290
|
+
plc.binaryop.BinaryOperator.LESS_EQUAL,
|
|
291
|
+
plc.binaryop.BinaryOperator.GREATER,
|
|
292
|
+
plc.binaryop.BinaryOperator.GREATER_EQUAL,
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _parquet_physical_types(
|
|
297
|
+
schema: Schema, paths: list[str], columns: list[str] | None, stream: Stream
|
|
298
|
+
) -> dict[str, plc.DataType]:
|
|
299
|
+
# TODO: Read the physical types as cudf::data_type's using
|
|
300
|
+
# read_parquet_metadata or another parquet API
|
|
301
|
+
options = plc.io.parquet.ParquetReaderOptions.builder(
|
|
302
|
+
plc.io.SourceInfo(paths)
|
|
303
|
+
).build()
|
|
304
|
+
if columns is not None:
|
|
305
|
+
options.set_columns(columns)
|
|
306
|
+
options.set_num_rows(0)
|
|
307
|
+
df = plc.io.parquet.read_parquet(options, stream=stream)
|
|
308
|
+
return dict(zip(schema.keys(), [c.type() for c in df.tbl.columns()], strict=True))
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _cast_literal_to_decimal(
|
|
312
|
+
side: expr.Expr, lit: expr.Literal, phys_type_map: dict[str, plc.DataType]
|
|
313
|
+
) -> expr.Expr:
|
|
314
|
+
if isinstance(side, expr.Cast):
|
|
315
|
+
col = side.children[0]
|
|
316
|
+
assert isinstance(col, expr.Col)
|
|
317
|
+
name = col.name
|
|
318
|
+
else:
|
|
319
|
+
assert isinstance(side, expr.Col)
|
|
320
|
+
name = side.name
|
|
321
|
+
if (type_ := phys_type_map[name]).id() in _DECIMAL_IDS:
|
|
322
|
+
scale = abs(type_.scale())
|
|
323
|
+
return expr.Cast(side.dtype, expr.Cast(DataType(pl.Decimal(38, scale)), lit))
|
|
324
|
+
return lit
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _cast_literals_to_physical_types(
|
|
328
|
+
node: expr.Expr, phys_type_map: dict[str, plc.DataType]
|
|
329
|
+
) -> expr.Expr:
|
|
330
|
+
if isinstance(node, expr.BinOp):
|
|
331
|
+
left, right = node.children
|
|
332
|
+
left = _cast_literals_to_physical_types(left, phys_type_map)
|
|
333
|
+
right = _cast_literals_to_physical_types(right, phys_type_map)
|
|
334
|
+
if node.op in _COMPARISON_BINOPS:
|
|
335
|
+
if isinstance(left, (expr.Col, expr.Cast)) and isinstance(
|
|
336
|
+
right, expr.Literal
|
|
337
|
+
):
|
|
338
|
+
right = _cast_literal_to_decimal(left, right, phys_type_map)
|
|
339
|
+
elif isinstance(right, (expr.Col, expr.Cast)) and isinstance(
|
|
340
|
+
left, expr.Literal
|
|
341
|
+
):
|
|
342
|
+
left = _cast_literal_to_decimal(right, left, phys_type_map)
|
|
343
|
+
|
|
344
|
+
return node.reconstruct([left, right])
|
|
345
|
+
return node
|
|
346
|
+
|
|
347
|
+
|
|
215
348
|
def _align_parquet_schema(df: DataFrame, schema: Schema) -> DataFrame:
|
|
216
349
|
# TODO: Alternatively set the schema of the parquet reader to decimal128
|
|
217
|
-
plc_decimals_ids = {
|
|
218
|
-
plc.TypeId.DECIMAL32,
|
|
219
|
-
plc.TypeId.DECIMAL64,
|
|
220
|
-
plc.TypeId.DECIMAL128,
|
|
221
|
-
}
|
|
222
350
|
cast_list = []
|
|
223
351
|
|
|
224
352
|
for name, col in df.column_map.items():
|
|
225
353
|
src = col.obj.type()
|
|
226
|
-
dst = schema[name].
|
|
354
|
+
dst = schema[name].plc_type
|
|
355
|
+
|
|
227
356
|
if (
|
|
228
|
-
|
|
229
|
-
and
|
|
230
|
-
and ((src.id() != dst.id()) or (src.scale != dst.scale))
|
|
357
|
+
plc.traits.is_fixed_point(src)
|
|
358
|
+
and plc.traits.is_fixed_point(dst)
|
|
359
|
+
and ((src.id() != dst.id()) or (src.scale() != dst.scale()))
|
|
231
360
|
):
|
|
232
361
|
cast_list.append(
|
|
233
|
-
Column(
|
|
362
|
+
Column(
|
|
363
|
+
plc.unary.cast(col.obj, dst, stream=df.stream),
|
|
364
|
+
name=name,
|
|
365
|
+
dtype=schema[name],
|
|
366
|
+
)
|
|
234
367
|
)
|
|
235
368
|
|
|
236
369
|
if cast_list:
|
|
237
|
-
df = df.with_columns(cast_list)
|
|
370
|
+
df = df.with_columns(cast_list, stream=df.stream)
|
|
238
371
|
|
|
239
372
|
return df
|
|
240
373
|
|
|
@@ -460,13 +593,24 @@ class Scan(IR):
|
|
|
460
593
|
Each path is repeated according to the number of rows read from it.
|
|
461
594
|
"""
|
|
462
595
|
(filepaths,) = plc.filling.repeat(
|
|
463
|
-
plc.Table(
|
|
596
|
+
plc.Table(
|
|
597
|
+
[
|
|
598
|
+
plc.Column.from_arrow(
|
|
599
|
+
pl.Series(values=map(str, paths)),
|
|
600
|
+
stream=df.stream,
|
|
601
|
+
)
|
|
602
|
+
]
|
|
603
|
+
),
|
|
464
604
|
plc.Column.from_arrow(
|
|
465
|
-
pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32())
|
|
605
|
+
pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32()),
|
|
606
|
+
stream=df.stream,
|
|
466
607
|
),
|
|
608
|
+
stream=df.stream,
|
|
467
609
|
).columns()
|
|
468
610
|
dtype = DataType(pl.String())
|
|
469
|
-
return df.with_columns(
|
|
611
|
+
return df.with_columns(
|
|
612
|
+
[Column(filepaths, name=name, dtype=dtype)], stream=df.stream
|
|
613
|
+
)
|
|
470
614
|
|
|
471
615
|
def fast_count(self) -> int: # pragma: no cover
|
|
472
616
|
"""Get the number of rows in a Parquet Scan."""
|
|
@@ -479,6 +623,7 @@ class Scan(IR):
|
|
|
479
623
|
return max(total_rows, 0)
|
|
480
624
|
|
|
481
625
|
@classmethod
|
|
626
|
+
@log_do_evaluate
|
|
482
627
|
@nvtx_annotate_cudf_polars(message="Scan")
|
|
483
628
|
def do_evaluate(
|
|
484
629
|
cls,
|
|
@@ -493,8 +638,11 @@ class Scan(IR):
|
|
|
493
638
|
include_file_paths: str | None,
|
|
494
639
|
predicate: expr.NamedExpr | None,
|
|
495
640
|
parquet_options: ParquetOptions,
|
|
641
|
+
*,
|
|
642
|
+
context: IRExecutionContext,
|
|
496
643
|
) -> DataFrame:
|
|
497
644
|
"""Evaluate and return a dataframe."""
|
|
645
|
+
stream = context.get_cuda_stream()
|
|
498
646
|
if typ == "csv":
|
|
499
647
|
|
|
500
648
|
def read_csv_header(
|
|
@@ -551,6 +699,7 @@ class Scan(IR):
|
|
|
551
699
|
plc.io.csv.CsvReaderOptions.builder(plc.io.SourceInfo([path]))
|
|
552
700
|
.nrows(n_rows)
|
|
553
701
|
.skiprows(skiprows + skip_rows)
|
|
702
|
+
.skip_blank_lines(skip_blank_lines=False)
|
|
554
703
|
.lineterminator(str(eol))
|
|
555
704
|
.quotechar(str(quote))
|
|
556
705
|
.decimal(decimal)
|
|
@@ -567,13 +716,15 @@ class Scan(IR):
|
|
|
567
716
|
column_names = read_csv_header(path, str(sep))
|
|
568
717
|
options.set_names(column_names)
|
|
569
718
|
options.set_header(header)
|
|
570
|
-
options.set_dtypes(
|
|
719
|
+
options.set_dtypes(
|
|
720
|
+
{name: dtype.plc_type for name, dtype in schema.items()}
|
|
721
|
+
)
|
|
571
722
|
if usecols is not None:
|
|
572
723
|
options.set_use_cols_names([str(name) for name in usecols])
|
|
573
724
|
options.set_na_values(null_values)
|
|
574
725
|
if comment is not None:
|
|
575
726
|
options.set_comment(comment)
|
|
576
|
-
tbl_w_meta = plc.io.csv.read_csv(options)
|
|
727
|
+
tbl_w_meta = plc.io.csv.read_csv(options, stream=stream)
|
|
577
728
|
pieces.append(tbl_w_meta)
|
|
578
729
|
if include_file_paths is not None:
|
|
579
730
|
seen_paths.append(p)
|
|
@@ -589,9 +740,10 @@ class Scan(IR):
|
|
|
589
740
|
strict=True,
|
|
590
741
|
)
|
|
591
742
|
df = DataFrame.from_table(
|
|
592
|
-
plc.concatenate.concatenate(list(tables)),
|
|
743
|
+
plc.concatenate.concatenate(list(tables), stream=stream),
|
|
593
744
|
colnames,
|
|
594
745
|
[schema[colname] for colname in colnames],
|
|
746
|
+
stream=stream,
|
|
595
747
|
)
|
|
596
748
|
if include_file_paths is not None:
|
|
597
749
|
df = Scan.add_file_paths(
|
|
@@ -604,42 +756,50 @@ class Scan(IR):
|
|
|
604
756
|
filters = None
|
|
605
757
|
if predicate is not None and row_index is None:
|
|
606
758
|
# Can't apply filters during read if we have a row index.
|
|
607
|
-
filters = to_parquet_filter(
|
|
608
|
-
|
|
759
|
+
filters = to_parquet_filter(
|
|
760
|
+
_cast_literals_to_physical_types(
|
|
761
|
+
predicate.value,
|
|
762
|
+
_parquet_physical_types(
|
|
763
|
+
schema, paths, with_columns or list(schema.keys()), stream
|
|
764
|
+
),
|
|
765
|
+
),
|
|
766
|
+
stream=stream,
|
|
767
|
+
)
|
|
768
|
+
parquet_reader_options = plc.io.parquet.ParquetReaderOptions.builder(
|
|
609
769
|
plc.io.SourceInfo(paths)
|
|
610
770
|
).build()
|
|
611
771
|
if with_columns is not None:
|
|
612
|
-
|
|
772
|
+
parquet_reader_options.set_columns(with_columns)
|
|
613
773
|
if filters is not None:
|
|
614
|
-
|
|
774
|
+
parquet_reader_options.set_filter(filters)
|
|
615
775
|
if n_rows != -1:
|
|
616
|
-
|
|
776
|
+
parquet_reader_options.set_num_rows(n_rows)
|
|
617
777
|
if skip_rows != 0:
|
|
618
|
-
|
|
778
|
+
parquet_reader_options.set_skip_rows(skip_rows)
|
|
619
779
|
if parquet_options.chunked:
|
|
620
780
|
reader = plc.io.parquet.ChunkedParquetReader(
|
|
621
|
-
|
|
781
|
+
parquet_reader_options,
|
|
622
782
|
chunk_read_limit=parquet_options.chunk_read_limit,
|
|
623
783
|
pass_read_limit=parquet_options.pass_read_limit,
|
|
784
|
+
stream=stream,
|
|
624
785
|
)
|
|
625
786
|
chunk = reader.read_chunk()
|
|
626
|
-
tbl = chunk.tbl
|
|
627
787
|
# TODO: Nested column names
|
|
628
788
|
names = chunk.column_names(include_children=False)
|
|
629
|
-
concatenated_columns = tbl.columns()
|
|
789
|
+
concatenated_columns = chunk.tbl.columns()
|
|
630
790
|
while reader.has_next():
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
791
|
+
columns = reader.read_chunk().tbl.columns()
|
|
792
|
+
# Discard columns while concatenating to reduce memory footprint.
|
|
793
|
+
# Reverse order to avoid O(n^2) list popping cost.
|
|
794
|
+
for i in range(len(concatenated_columns) - 1, -1, -1):
|
|
634
795
|
concatenated_columns[i] = plc.concatenate.concatenate(
|
|
635
|
-
[concatenated_columns[i],
|
|
796
|
+
[concatenated_columns[i], columns.pop()], stream=stream
|
|
636
797
|
)
|
|
637
|
-
# Drop residual columns to save memory
|
|
638
|
-
tbl._columns[i] = None
|
|
639
798
|
df = DataFrame.from_table(
|
|
640
799
|
plc.Table(concatenated_columns),
|
|
641
800
|
names=names,
|
|
642
801
|
dtypes=[schema[name] for name in names],
|
|
802
|
+
stream=stream,
|
|
643
803
|
)
|
|
644
804
|
df = _align_parquet_schema(df, schema)
|
|
645
805
|
if include_file_paths is not None:
|
|
@@ -647,13 +807,16 @@ class Scan(IR):
|
|
|
647
807
|
include_file_paths, paths, chunk.num_rows_per_source, df
|
|
648
808
|
)
|
|
649
809
|
else:
|
|
650
|
-
tbl_w_meta = plc.io.parquet.read_parquet(
|
|
810
|
+
tbl_w_meta = plc.io.parquet.read_parquet(
|
|
811
|
+
parquet_reader_options, stream=stream
|
|
812
|
+
)
|
|
651
813
|
# TODO: consider nested column names?
|
|
652
814
|
col_names = tbl_w_meta.column_names(include_children=False)
|
|
653
815
|
df = DataFrame.from_table(
|
|
654
816
|
tbl_w_meta.tbl,
|
|
655
817
|
col_names,
|
|
656
818
|
[schema[name] for name in col_names],
|
|
819
|
+
stream=stream,
|
|
657
820
|
)
|
|
658
821
|
df = _align_parquet_schema(df, schema)
|
|
659
822
|
if include_file_paths is not None:
|
|
@@ -665,16 +828,16 @@ class Scan(IR):
|
|
|
665
828
|
return df
|
|
666
829
|
elif typ == "ndjson":
|
|
667
830
|
json_schema: list[plc.io.json.NameAndType] = [
|
|
668
|
-
(name, typ.
|
|
831
|
+
(name, typ.plc_type, []) for name, typ in schema.items()
|
|
669
832
|
]
|
|
670
|
-
|
|
671
|
-
plc.io.json.
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
)
|
|
833
|
+
json_reader_options = (
|
|
834
|
+
plc.io.json.JsonReaderOptions.builder(plc.io.SourceInfo(paths))
|
|
835
|
+
.lines(val=True)
|
|
836
|
+
.dtypes(json_schema)
|
|
837
|
+
.prune_columns(val=True)
|
|
838
|
+
.build()
|
|
677
839
|
)
|
|
840
|
+
plc_tbl_w_meta = plc.io.json.read_json(json_reader_options, stream=stream)
|
|
678
841
|
# TODO: I don't think cudf-polars supports nested types in general right now
|
|
679
842
|
# (but when it does, we should pass child column names from nested columns in)
|
|
680
843
|
col_names = plc_tbl_w_meta.column_names(include_children=False)
|
|
@@ -682,6 +845,7 @@ class Scan(IR):
|
|
|
682
845
|
plc_tbl_w_meta.tbl,
|
|
683
846
|
col_names,
|
|
684
847
|
[schema[name] for name in col_names],
|
|
848
|
+
stream=stream,
|
|
685
849
|
)
|
|
686
850
|
col_order = list(schema.keys())
|
|
687
851
|
if row_index is not None:
|
|
@@ -695,26 +859,28 @@ class Scan(IR):
|
|
|
695
859
|
name, offset = row_index
|
|
696
860
|
offset += skip_rows
|
|
697
861
|
dtype = schema[name]
|
|
698
|
-
step = plc.Scalar.from_py(1, dtype.
|
|
699
|
-
init = plc.Scalar.from_py(offset, dtype.
|
|
862
|
+
step = plc.Scalar.from_py(1, dtype.plc_type, stream=stream)
|
|
863
|
+
init = plc.Scalar.from_py(offset, dtype.plc_type, stream=stream)
|
|
700
864
|
index_col = Column(
|
|
701
|
-
plc.filling.sequence(df.num_rows, init, step),
|
|
865
|
+
plc.filling.sequence(df.num_rows, init, step, stream=stream),
|
|
702
866
|
is_sorted=plc.types.Sorted.YES,
|
|
703
867
|
order=plc.types.Order.ASCENDING,
|
|
704
868
|
null_order=plc.types.NullOrder.AFTER,
|
|
705
869
|
name=name,
|
|
706
870
|
dtype=dtype,
|
|
707
871
|
)
|
|
708
|
-
df = DataFrame([index_col, *df.columns])
|
|
872
|
+
df = DataFrame([index_col, *df.columns], stream=df.stream)
|
|
709
873
|
if next(iter(schema)) != name:
|
|
710
874
|
df = df.select(schema)
|
|
711
875
|
assert all(
|
|
712
|
-
c.obj.type() == schema[name].
|
|
876
|
+
c.obj.type() == schema[name].plc_type for name, c in df.column_map.items()
|
|
713
877
|
)
|
|
714
878
|
if predicate is None:
|
|
715
879
|
return df
|
|
716
880
|
else:
|
|
717
|
-
(mask,) = broadcast(
|
|
881
|
+
(mask,) = broadcast(
|
|
882
|
+
predicate.evaluate(df), target_length=df.num_rows, stream=df.stream
|
|
883
|
+
)
|
|
718
884
|
return df.filter(mask)
|
|
719
885
|
|
|
720
886
|
|
|
@@ -775,7 +941,8 @@ class Sink(IR):
|
|
|
775
941
|
child_schema = df.schema.values()
|
|
776
942
|
if kind == "Csv":
|
|
777
943
|
if not all(
|
|
778
|
-
plc.io.csv.is_supported_write_csv(dtype.
|
|
944
|
+
plc.io.csv.is_supported_write_csv(dtype.plc_type)
|
|
945
|
+
for dtype in child_schema
|
|
779
946
|
):
|
|
780
947
|
# Nested types are unsupported in polars and libcudf
|
|
781
948
|
raise NotImplementedError(
|
|
@@ -826,7 +993,8 @@ class Sink(IR):
|
|
|
826
993
|
kind == "Json"
|
|
827
994
|
): # pragma: no cover; options are validated on the polars side
|
|
828
995
|
if not all(
|
|
829
|
-
plc.io.json.is_supported_write_json(dtype.
|
|
996
|
+
plc.io.json.is_supported_write_json(dtype.plc_type)
|
|
997
|
+
for dtype in child_schema
|
|
830
998
|
):
|
|
831
999
|
# Nested types are unsupported in polars and libcudf
|
|
832
1000
|
raise NotImplementedError(
|
|
@@ -863,7 +1031,7 @@ class Sink(IR):
|
|
|
863
1031
|
) -> None:
|
|
864
1032
|
"""Write CSV data to a sink."""
|
|
865
1033
|
serialize = options["serialize_options"]
|
|
866
|
-
|
|
1034
|
+
csv_writer_options = (
|
|
867
1035
|
plc.io.csv.CsvWriterOptions.builder(target, df.table)
|
|
868
1036
|
.include_header(options["include_header"])
|
|
869
1037
|
.names(df.column_names if options["include_header"] else [])
|
|
@@ -872,7 +1040,7 @@ class Sink(IR):
|
|
|
872
1040
|
.inter_column_delimiter(chr(serialize["separator"]))
|
|
873
1041
|
.build()
|
|
874
1042
|
)
|
|
875
|
-
plc.io.csv.write_csv(
|
|
1043
|
+
plc.io.csv.write_csv(csv_writer_options, stream=df.stream)
|
|
876
1044
|
|
|
877
1045
|
@classmethod
|
|
878
1046
|
def _write_json(cls, target: plc.io.SinkInfo, df: DataFrame) -> None:
|
|
@@ -889,7 +1057,7 @@ class Sink(IR):
|
|
|
889
1057
|
.utf8_escaped(val=False)
|
|
890
1058
|
.build()
|
|
891
1059
|
)
|
|
892
|
-
plc.io.json.write_json(options)
|
|
1060
|
+
plc.io.json.write_json(options, stream=df.stream)
|
|
893
1061
|
|
|
894
1062
|
@staticmethod
|
|
895
1063
|
def _make_parquet_metadata(df: DataFrame) -> plc.io.types.TableInputMetadata:
|
|
@@ -899,6 +1067,20 @@ class Sink(IR):
|
|
|
899
1067
|
metadata.column_metadata[i].set_name(name)
|
|
900
1068
|
return metadata
|
|
901
1069
|
|
|
1070
|
+
@overload
|
|
1071
|
+
@staticmethod
|
|
1072
|
+
def _apply_parquet_writer_options(
|
|
1073
|
+
builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder,
|
|
1074
|
+
options: dict[str, Any],
|
|
1075
|
+
) -> plc.io.parquet.ChunkedParquetWriterOptionsBuilder: ...
|
|
1076
|
+
|
|
1077
|
+
@overload
|
|
1078
|
+
@staticmethod
|
|
1079
|
+
def _apply_parquet_writer_options(
|
|
1080
|
+
builder: plc.io.parquet.ParquetWriterOptionsBuilder,
|
|
1081
|
+
options: dict[str, Any],
|
|
1082
|
+
) -> plc.io.parquet.ParquetWriterOptionsBuilder: ...
|
|
1083
|
+
|
|
902
1084
|
@staticmethod
|
|
903
1085
|
def _apply_parquet_writer_options(
|
|
904
1086
|
builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder
|
|
@@ -944,12 +1126,16 @@ class Sink(IR):
|
|
|
944
1126
|
and parquet_options.n_output_chunks != 1
|
|
945
1127
|
and df.table.num_rows() != 0
|
|
946
1128
|
):
|
|
947
|
-
|
|
1129
|
+
chunked_builder = plc.io.parquet.ChunkedParquetWriterOptions.builder(
|
|
948
1130
|
target
|
|
949
1131
|
).metadata(metadata)
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
1132
|
+
chunked_builder = cls._apply_parquet_writer_options(
|
|
1133
|
+
chunked_builder, options
|
|
1134
|
+
)
|
|
1135
|
+
chunked_writer_options = chunked_builder.build()
|
|
1136
|
+
writer = plc.io.parquet.ChunkedParquetWriter.from_options(
|
|
1137
|
+
chunked_writer_options, stream=df.stream
|
|
1138
|
+
)
|
|
953
1139
|
|
|
954
1140
|
# TODO: Can be based on a heuristic that estimates chunk size
|
|
955
1141
|
# from the input table size and available GPU memory.
|
|
@@ -957,6 +1143,7 @@ class Sink(IR):
|
|
|
957
1143
|
table_chunks = plc.copying.split(
|
|
958
1144
|
df.table,
|
|
959
1145
|
[i * df.table.num_rows() // num_chunks for i in range(1, num_chunks)],
|
|
1146
|
+
stream=df.stream,
|
|
960
1147
|
)
|
|
961
1148
|
for chunk in table_chunks:
|
|
962
1149
|
writer.write(chunk)
|
|
@@ -968,9 +1155,10 @@ class Sink(IR):
|
|
|
968
1155
|
).metadata(metadata)
|
|
969
1156
|
builder = cls._apply_parquet_writer_options(builder, options)
|
|
970
1157
|
writer_options = builder.build()
|
|
971
|
-
plc.io.parquet.write_parquet(writer_options)
|
|
1158
|
+
plc.io.parquet.write_parquet(writer_options, stream=df.stream)
|
|
972
1159
|
|
|
973
1160
|
@classmethod
|
|
1161
|
+
@log_do_evaluate
|
|
974
1162
|
@nvtx_annotate_cudf_polars(message="Sink")
|
|
975
1163
|
def do_evaluate(
|
|
976
1164
|
cls,
|
|
@@ -980,6 +1168,8 @@ class Sink(IR):
|
|
|
980
1168
|
parquet_options: ParquetOptions,
|
|
981
1169
|
options: dict[str, Any],
|
|
982
1170
|
df: DataFrame,
|
|
1171
|
+
*,
|
|
1172
|
+
context: IRExecutionContext,
|
|
983
1173
|
) -> DataFrame:
|
|
984
1174
|
"""Write the dataframe to a file."""
|
|
985
1175
|
target = plc.io.SinkInfo([path])
|
|
@@ -993,7 +1183,7 @@ class Sink(IR):
|
|
|
993
1183
|
elif kind == "Json":
|
|
994
1184
|
cls._write_json(target, df)
|
|
995
1185
|
|
|
996
|
-
return DataFrame([])
|
|
1186
|
+
return DataFrame([], stream=df.stream)
|
|
997
1187
|
|
|
998
1188
|
|
|
999
1189
|
class Cache(IR):
|
|
@@ -1030,16 +1220,24 @@ class Cache(IR):
|
|
|
1030
1220
|
return False
|
|
1031
1221
|
|
|
1032
1222
|
@classmethod
|
|
1223
|
+
@log_do_evaluate
|
|
1033
1224
|
@nvtx_annotate_cudf_polars(message="Cache")
|
|
1034
1225
|
def do_evaluate(
|
|
1035
|
-
cls,
|
|
1226
|
+
cls,
|
|
1227
|
+
key: int,
|
|
1228
|
+
refcount: int | None,
|
|
1229
|
+
df: DataFrame,
|
|
1230
|
+
*,
|
|
1231
|
+
context: IRExecutionContext,
|
|
1036
1232
|
) -> DataFrame: # pragma: no cover; basic evaluation never calls this
|
|
1037
1233
|
"""Evaluate and return a dataframe."""
|
|
1038
1234
|
# Our value has already been computed for us, so let's just
|
|
1039
1235
|
# return it.
|
|
1040
1236
|
return df
|
|
1041
1237
|
|
|
1042
|
-
def evaluate(
|
|
1238
|
+
def evaluate(
|
|
1239
|
+
self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
|
|
1240
|
+
) -> DataFrame:
|
|
1043
1241
|
"""Evaluate and return a dataframe."""
|
|
1044
1242
|
# We must override the recursion scheme because we don't want
|
|
1045
1243
|
# to recurse if we're in the cache.
|
|
@@ -1047,7 +1245,7 @@ class Cache(IR):
|
|
|
1047
1245
|
(result, hits) = cache[self.key]
|
|
1048
1246
|
except KeyError:
|
|
1049
1247
|
(value,) = self.children
|
|
1050
|
-
result = value.evaluate(cache=cache, timer=timer)
|
|
1248
|
+
result = value.evaluate(cache=cache, timer=timer, context=context)
|
|
1051
1249
|
cache[self.key] = (result, 0)
|
|
1052
1250
|
return result
|
|
1053
1251
|
else:
|
|
@@ -1110,19 +1308,22 @@ class DataFrameScan(IR):
|
|
|
1110
1308
|
)
|
|
1111
1309
|
|
|
1112
1310
|
@classmethod
|
|
1311
|
+
@log_do_evaluate
|
|
1113
1312
|
@nvtx_annotate_cudf_polars(message="DataFrameScan")
|
|
1114
1313
|
def do_evaluate(
|
|
1115
1314
|
cls,
|
|
1116
1315
|
schema: Schema,
|
|
1117
1316
|
df: Any,
|
|
1118
1317
|
projection: tuple[str, ...] | None,
|
|
1318
|
+
*,
|
|
1319
|
+
context: IRExecutionContext,
|
|
1119
1320
|
) -> DataFrame:
|
|
1120
1321
|
"""Evaluate and return a dataframe."""
|
|
1121
1322
|
if projection is not None:
|
|
1122
1323
|
df = df.select(projection)
|
|
1123
|
-
df = DataFrame.from_polars(df)
|
|
1324
|
+
df = DataFrame.from_polars(df, stream=context.get_cuda_stream())
|
|
1124
1325
|
assert all(
|
|
1125
|
-
c.obj.type() == dtype.
|
|
1326
|
+
c.obj.type() == dtype.plc_type
|
|
1126
1327
|
for c, dtype in zip(df.columns, schema.values(), strict=True)
|
|
1127
1328
|
)
|
|
1128
1329
|
return df
|
|
@@ -1169,21 +1370,26 @@ class Select(IR):
|
|
|
1169
1370
|
return False
|
|
1170
1371
|
|
|
1171
1372
|
@classmethod
|
|
1373
|
+
@log_do_evaluate
|
|
1172
1374
|
@nvtx_annotate_cudf_polars(message="Select")
|
|
1173
1375
|
def do_evaluate(
|
|
1174
1376
|
cls,
|
|
1175
1377
|
exprs: tuple[expr.NamedExpr, ...],
|
|
1176
1378
|
should_broadcast: bool, # noqa: FBT001
|
|
1177
1379
|
df: DataFrame,
|
|
1380
|
+
*,
|
|
1381
|
+
context: IRExecutionContext,
|
|
1178
1382
|
) -> DataFrame:
|
|
1179
1383
|
"""Evaluate and return a dataframe."""
|
|
1180
1384
|
# Handle any broadcasting
|
|
1181
1385
|
columns = [e.evaluate(df) for e in exprs]
|
|
1182
1386
|
if should_broadcast:
|
|
1183
|
-
columns = broadcast(*columns)
|
|
1184
|
-
return DataFrame(columns)
|
|
1387
|
+
columns = broadcast(*columns, stream=df.stream)
|
|
1388
|
+
return DataFrame(columns, stream=df.stream)
|
|
1185
1389
|
|
|
1186
|
-
def evaluate(
|
|
1390
|
+
def evaluate(
|
|
1391
|
+
self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
|
|
1392
|
+
) -> DataFrame:
|
|
1187
1393
|
"""
|
|
1188
1394
|
Evaluate the Select node with special handling for fast count queries.
|
|
1189
1395
|
|
|
@@ -1195,6 +1401,8 @@ class Select(IR):
|
|
|
1195
1401
|
timer
|
|
1196
1402
|
If not None, a Timer object to record timings for the
|
|
1197
1403
|
evaluation of the node.
|
|
1404
|
+
context
|
|
1405
|
+
The execution context for the node.
|
|
1198
1406
|
|
|
1199
1407
|
Returns
|
|
1200
1408
|
-------
|
|
@@ -1214,21 +1422,23 @@ class Select(IR):
|
|
|
1214
1422
|
and Select._is_len_expr(self.exprs)
|
|
1215
1423
|
and self.children[0].typ == "parquet"
|
|
1216
1424
|
and self.children[0].predicate is None
|
|
1217
|
-
):
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1425
|
+
): # pragma: no cover
|
|
1426
|
+
stream = context.get_cuda_stream()
|
|
1427
|
+
scan = self.children[0]
|
|
1428
|
+
effective_rows = scan.fast_count()
|
|
1429
|
+
dtype = DataType(pl.UInt32())
|
|
1221
1430
|
col = Column(
|
|
1222
1431
|
plc.Column.from_scalar(
|
|
1223
|
-
plc.Scalar.from_py(effective_rows, dtype.
|
|
1432
|
+
plc.Scalar.from_py(effective_rows, dtype.plc_type, stream=stream),
|
|
1224
1433
|
1,
|
|
1434
|
+
stream=stream,
|
|
1225
1435
|
),
|
|
1226
1436
|
name=self.exprs[0].name or "len",
|
|
1227
1437
|
dtype=dtype,
|
|
1228
|
-
)
|
|
1229
|
-
return DataFrame([col])
|
|
1438
|
+
)
|
|
1439
|
+
return DataFrame([col], stream=stream)
|
|
1230
1440
|
|
|
1231
|
-
return super().evaluate(cache=cache, timer=timer)
|
|
1441
|
+
return super().evaluate(cache=cache, timer=timer, context=context)
|
|
1232
1442
|
|
|
1233
1443
|
|
|
1234
1444
|
class Reduce(IR):
|
|
@@ -1252,16 +1462,19 @@ class Reduce(IR):
|
|
|
1252
1462
|
self._non_child_args = (self.exprs,)
|
|
1253
1463
|
|
|
1254
1464
|
@classmethod
|
|
1465
|
+
@log_do_evaluate
|
|
1255
1466
|
@nvtx_annotate_cudf_polars(message="Reduce")
|
|
1256
1467
|
def do_evaluate(
|
|
1257
1468
|
cls,
|
|
1258
1469
|
exprs: tuple[expr.NamedExpr, ...],
|
|
1259
1470
|
df: DataFrame,
|
|
1471
|
+
*,
|
|
1472
|
+
context: IRExecutionContext,
|
|
1260
1473
|
) -> DataFrame: # pragma: no cover; not exposed by polars yet
|
|
1261
1474
|
"""Evaluate and return a dataframe."""
|
|
1262
|
-
columns = broadcast(*(e.evaluate(df) for e in exprs))
|
|
1475
|
+
columns = broadcast(*(e.evaluate(df) for e in exprs), stream=df.stream)
|
|
1263
1476
|
assert all(column.size == 1 for column in columns)
|
|
1264
|
-
return DataFrame(columns)
|
|
1477
|
+
return DataFrame(columns, stream=df.stream)
|
|
1265
1478
|
|
|
1266
1479
|
|
|
1267
1480
|
class Rolling(IR):
|
|
@@ -1270,17 +1483,19 @@ class Rolling(IR):
|
|
|
1270
1483
|
__slots__ = (
|
|
1271
1484
|
"agg_requests",
|
|
1272
1485
|
"closed_window",
|
|
1273
|
-
"
|
|
1486
|
+
"following_ordinal",
|
|
1274
1487
|
"index",
|
|
1488
|
+
"index_dtype",
|
|
1275
1489
|
"keys",
|
|
1276
|
-
"
|
|
1490
|
+
"preceding_ordinal",
|
|
1277
1491
|
"zlice",
|
|
1278
1492
|
)
|
|
1279
1493
|
_non_child = (
|
|
1280
1494
|
"schema",
|
|
1281
1495
|
"index",
|
|
1282
|
-
"
|
|
1283
|
-
"
|
|
1496
|
+
"index_dtype",
|
|
1497
|
+
"preceding_ordinal",
|
|
1498
|
+
"following_ordinal",
|
|
1284
1499
|
"closed_window",
|
|
1285
1500
|
"keys",
|
|
1286
1501
|
"agg_requests",
|
|
@@ -1288,10 +1503,12 @@ class Rolling(IR):
|
|
|
1288
1503
|
)
|
|
1289
1504
|
index: expr.NamedExpr
|
|
1290
1505
|
"""Column being rolled over."""
|
|
1291
|
-
|
|
1292
|
-
"""
|
|
1293
|
-
|
|
1294
|
-
"""
|
|
1506
|
+
index_dtype: plc.DataType
|
|
1507
|
+
"""Datatype of the index column."""
|
|
1508
|
+
preceding_ordinal: int
|
|
1509
|
+
"""Preceding window extent defining start of window as a host integer."""
|
|
1510
|
+
following_ordinal: int
|
|
1511
|
+
"""Following window extent defining end of window as a host integer."""
|
|
1295
1512
|
closed_window: ClosedInterval
|
|
1296
1513
|
"""Treatment of window endpoints."""
|
|
1297
1514
|
keys: tuple[expr.NamedExpr, ...]
|
|
@@ -1305,8 +1522,9 @@ class Rolling(IR):
|
|
|
1305
1522
|
self,
|
|
1306
1523
|
schema: Schema,
|
|
1307
1524
|
index: expr.NamedExpr,
|
|
1308
|
-
|
|
1309
|
-
|
|
1525
|
+
index_dtype: plc.DataType,
|
|
1526
|
+
preceding_ordinal: int,
|
|
1527
|
+
following_ordinal: int,
|
|
1310
1528
|
closed_window: ClosedInterval,
|
|
1311
1529
|
keys: Sequence[expr.NamedExpr],
|
|
1312
1530
|
agg_requests: Sequence[expr.NamedExpr],
|
|
@@ -1315,14 +1533,15 @@ class Rolling(IR):
|
|
|
1315
1533
|
):
|
|
1316
1534
|
self.schema = schema
|
|
1317
1535
|
self.index = index
|
|
1318
|
-
self.
|
|
1319
|
-
self.
|
|
1536
|
+
self.index_dtype = index_dtype
|
|
1537
|
+
self.preceding_ordinal = preceding_ordinal
|
|
1538
|
+
self.following_ordinal = following_ordinal
|
|
1320
1539
|
self.closed_window = closed_window
|
|
1321
1540
|
self.keys = tuple(keys)
|
|
1322
1541
|
self.agg_requests = tuple(agg_requests)
|
|
1323
1542
|
if not all(
|
|
1324
1543
|
plc.rolling.is_valid_rolling_aggregation(
|
|
1325
|
-
agg.value.dtype.
|
|
1544
|
+
agg.value.dtype.plc_type, agg.value.agg_request
|
|
1326
1545
|
)
|
|
1327
1546
|
for agg in self.agg_requests
|
|
1328
1547
|
):
|
|
@@ -1339,8 +1558,9 @@ class Rolling(IR):
|
|
|
1339
1558
|
self.children = (df,)
|
|
1340
1559
|
self._non_child_args = (
|
|
1341
1560
|
index,
|
|
1342
|
-
|
|
1343
|
-
|
|
1561
|
+
index_dtype,
|
|
1562
|
+
preceding_ordinal,
|
|
1563
|
+
following_ordinal,
|
|
1344
1564
|
closed_window,
|
|
1345
1565
|
keys,
|
|
1346
1566
|
agg_requests,
|
|
@@ -1348,31 +1568,46 @@ class Rolling(IR):
|
|
|
1348
1568
|
)
|
|
1349
1569
|
|
|
1350
1570
|
@classmethod
|
|
1571
|
+
@log_do_evaluate
|
|
1351
1572
|
@nvtx_annotate_cudf_polars(message="Rolling")
|
|
1352
1573
|
def do_evaluate(
|
|
1353
1574
|
cls,
|
|
1354
1575
|
index: expr.NamedExpr,
|
|
1355
|
-
|
|
1356
|
-
|
|
1576
|
+
index_dtype: plc.DataType,
|
|
1577
|
+
preceding_ordinal: int,
|
|
1578
|
+
following_ordinal: int,
|
|
1357
1579
|
closed_window: ClosedInterval,
|
|
1358
1580
|
keys_in: Sequence[expr.NamedExpr],
|
|
1359
1581
|
aggs: Sequence[expr.NamedExpr],
|
|
1360
1582
|
zlice: Zlice | None,
|
|
1361
1583
|
df: DataFrame,
|
|
1584
|
+
*,
|
|
1585
|
+
context: IRExecutionContext,
|
|
1362
1586
|
) -> DataFrame:
|
|
1363
1587
|
"""Evaluate and return a dataframe."""
|
|
1364
|
-
keys = broadcast(
|
|
1588
|
+
keys = broadcast(
|
|
1589
|
+
*(k.evaluate(df) for k in keys_in),
|
|
1590
|
+
target_length=df.num_rows,
|
|
1591
|
+
stream=df.stream,
|
|
1592
|
+
)
|
|
1365
1593
|
orderby = index.evaluate(df)
|
|
1366
1594
|
# Polars casts integral orderby to int64, but only for calculating window bounds
|
|
1367
1595
|
if (
|
|
1368
1596
|
plc.traits.is_integral(orderby.obj.type())
|
|
1369
1597
|
and orderby.obj.type().id() != plc.TypeId.INT64
|
|
1370
1598
|
):
|
|
1371
|
-
orderby_obj = plc.unary.cast(
|
|
1599
|
+
orderby_obj = plc.unary.cast(
|
|
1600
|
+
orderby.obj, plc.DataType(plc.TypeId.INT64), stream=df.stream
|
|
1601
|
+
)
|
|
1372
1602
|
else:
|
|
1373
1603
|
orderby_obj = orderby.obj
|
|
1604
|
+
|
|
1605
|
+
preceding_scalar, following_scalar = offsets_to_windows(
|
|
1606
|
+
index_dtype, preceding_ordinal, following_ordinal, stream=df.stream
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1374
1609
|
preceding_window, following_window = range_window_bounds(
|
|
1375
|
-
|
|
1610
|
+
preceding_scalar, following_scalar, closed_window
|
|
1376
1611
|
)
|
|
1377
1612
|
if orderby.obj.null_count() != 0:
|
|
1378
1613
|
raise RuntimeError(
|
|
@@ -1383,12 +1618,17 @@ class Rolling(IR):
|
|
|
1383
1618
|
table = plc.Table([*(k.obj for k in keys), orderby_obj])
|
|
1384
1619
|
n = table.num_columns()
|
|
1385
1620
|
if not plc.sorting.is_sorted(
|
|
1386
|
-
table,
|
|
1621
|
+
table,
|
|
1622
|
+
[plc.types.Order.ASCENDING] * n,
|
|
1623
|
+
[plc.types.NullOrder.BEFORE] * n,
|
|
1624
|
+
stream=df.stream,
|
|
1387
1625
|
):
|
|
1388
1626
|
raise RuntimeError("Input for grouped rolling is not sorted")
|
|
1389
1627
|
else:
|
|
1390
1628
|
if not orderby.check_sorted(
|
|
1391
|
-
order=plc.types.Order.ASCENDING,
|
|
1629
|
+
order=plc.types.Order.ASCENDING,
|
|
1630
|
+
null_order=plc.types.NullOrder.BEFORE,
|
|
1631
|
+
stream=df.stream,
|
|
1392
1632
|
):
|
|
1393
1633
|
raise RuntimeError(
|
|
1394
1634
|
f"Index column '{index.name}' in rolling is not sorted, please sort first"
|
|
@@ -1401,6 +1641,7 @@ class Rolling(IR):
|
|
|
1401
1641
|
preceding_window,
|
|
1402
1642
|
following_window,
|
|
1403
1643
|
[rolling.to_request(request.value, orderby, df) for request in aggs],
|
|
1644
|
+
stream=df.stream,
|
|
1404
1645
|
)
|
|
1405
1646
|
return DataFrame(
|
|
1406
1647
|
itertools.chain(
|
|
@@ -1410,7 +1651,8 @@ class Rolling(IR):
|
|
|
1410
1651
|
Column(col, name=request.name, dtype=request.value.dtype)
|
|
1411
1652
|
for col, request in zip(values.columns(), aggs, strict=True)
|
|
1412
1653
|
),
|
|
1413
|
-
)
|
|
1654
|
+
),
|
|
1655
|
+
stream=df.stream,
|
|
1414
1656
|
).slice(zlice)
|
|
1415
1657
|
|
|
1416
1658
|
|
|
@@ -1472,6 +1714,7 @@ class GroupBy(IR):
|
|
|
1472
1714
|
)
|
|
1473
1715
|
|
|
1474
1716
|
@classmethod
|
|
1717
|
+
@log_do_evaluate
|
|
1475
1718
|
@nvtx_annotate_cudf_polars(message="GroupBy")
|
|
1476
1719
|
def do_evaluate(
|
|
1477
1720
|
cls,
|
|
@@ -1481,9 +1724,15 @@ class GroupBy(IR):
|
|
|
1481
1724
|
maintain_order: bool, # noqa: FBT001
|
|
1482
1725
|
zlice: Zlice | None,
|
|
1483
1726
|
df: DataFrame,
|
|
1727
|
+
*,
|
|
1728
|
+
context: IRExecutionContext,
|
|
1484
1729
|
) -> DataFrame:
|
|
1485
1730
|
"""Evaluate and return a dataframe."""
|
|
1486
|
-
keys = broadcast(
|
|
1731
|
+
keys = broadcast(
|
|
1732
|
+
*(k.evaluate(df) for k in keys_in),
|
|
1733
|
+
target_length=df.num_rows,
|
|
1734
|
+
stream=df.stream,
|
|
1735
|
+
)
|
|
1487
1736
|
sorted = (
|
|
1488
1737
|
plc.types.Sorted.YES
|
|
1489
1738
|
if all(k.is_sorted for k in keys)
|
|
@@ -1515,7 +1764,7 @@ class GroupBy(IR):
|
|
|
1515
1764
|
col = value.evaluate(df, context=ExecutionContext.GROUPBY).obj
|
|
1516
1765
|
requests.append(plc.groupby.GroupByRequest(col, [value.agg_request]))
|
|
1517
1766
|
names.append(name)
|
|
1518
|
-
group_keys, raw_tables = grouper.aggregate(requests)
|
|
1767
|
+
group_keys, raw_tables = grouper.aggregate(requests, stream=df.stream)
|
|
1519
1768
|
results = [
|
|
1520
1769
|
Column(column, name=name, dtype=schema[name])
|
|
1521
1770
|
for name, column, request in zip(
|
|
@@ -1529,7 +1778,7 @@ class GroupBy(IR):
|
|
|
1529
1778
|
Column(grouped_key, name=key.name, dtype=key.dtype)
|
|
1530
1779
|
for key, grouped_key in zip(keys, group_keys.columns(), strict=True)
|
|
1531
1780
|
]
|
|
1532
|
-
broadcasted = broadcast(*result_keys, *results)
|
|
1781
|
+
broadcasted = broadcast(*result_keys, *results, stream=df.stream)
|
|
1533
1782
|
# Handle order preservation of groups
|
|
1534
1783
|
if maintain_order and not sorted:
|
|
1535
1784
|
# The order we want
|
|
@@ -1539,6 +1788,7 @@ class GroupBy(IR):
|
|
|
1539
1788
|
plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
|
|
1540
1789
|
plc.types.NullEquality.EQUAL,
|
|
1541
1790
|
plc.types.NanEquality.ALL_EQUAL,
|
|
1791
|
+
stream=df.stream,
|
|
1542
1792
|
)
|
|
1543
1793
|
# The order we have
|
|
1544
1794
|
have = plc.Table([key.obj for key in broadcasted[: len(keys)]])
|
|
@@ -1546,7 +1796,7 @@ class GroupBy(IR):
|
|
|
1546
1796
|
# We know an inner join is OK because by construction
|
|
1547
1797
|
# want and have are permutations of each other.
|
|
1548
1798
|
left_order, right_order = plc.join.inner_join(
|
|
1549
|
-
want, have, plc.types.NullEquality.EQUAL
|
|
1799
|
+
want, have, plc.types.NullEquality.EQUAL, stream=df.stream
|
|
1550
1800
|
)
|
|
1551
1801
|
# Now left_order is an arbitrary permutation of the ordering we
|
|
1552
1802
|
# want, and right_order is a matching permutation of the ordering
|
|
@@ -1559,11 +1809,13 @@ class GroupBy(IR):
|
|
|
1559
1809
|
plc.Table([left_order]),
|
|
1560
1810
|
[plc.types.Order.ASCENDING],
|
|
1561
1811
|
[plc.types.NullOrder.AFTER],
|
|
1812
|
+
stream=df.stream,
|
|
1562
1813
|
).columns()
|
|
1563
1814
|
ordered_table = plc.copying.gather(
|
|
1564
1815
|
plc.Table([col.obj for col in broadcasted]),
|
|
1565
1816
|
right_order,
|
|
1566
1817
|
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
1818
|
+
stream=df.stream,
|
|
1567
1819
|
)
|
|
1568
1820
|
broadcasted = [
|
|
1569
1821
|
Column(reordered, name=old.name, dtype=old.dtype)
|
|
@@ -1571,7 +1823,126 @@ class GroupBy(IR):
|
|
|
1571
1823
|
ordered_table.columns(), broadcasted, strict=True
|
|
1572
1824
|
)
|
|
1573
1825
|
]
|
|
1574
|
-
return DataFrame(broadcasted).slice(zlice)
|
|
1826
|
+
return DataFrame(broadcasted, stream=df.stream).slice(zlice)
|
|
1827
|
+
|
|
1828
|
+
|
|
1829
|
+
def _strip_predicate_casts(node: expr.Expr) -> expr.Expr:
|
|
1830
|
+
if isinstance(node, expr.Cast):
|
|
1831
|
+
(child,) = node.children
|
|
1832
|
+
child = _strip_predicate_casts(child)
|
|
1833
|
+
|
|
1834
|
+
src = child.dtype
|
|
1835
|
+
dst = node.dtype
|
|
1836
|
+
|
|
1837
|
+
if plc.traits.is_fixed_point(src.plc_type) or plc.traits.is_fixed_point(
|
|
1838
|
+
dst.plc_type
|
|
1839
|
+
):
|
|
1840
|
+
return child
|
|
1841
|
+
|
|
1842
|
+
if (
|
|
1843
|
+
not POLARS_VERSION_LT_134
|
|
1844
|
+
and isinstance(child, expr.ColRef)
|
|
1845
|
+
and (
|
|
1846
|
+
(
|
|
1847
|
+
plc.traits.is_floating_point(src.plc_type)
|
|
1848
|
+
and plc.traits.is_floating_point(dst.plc_type)
|
|
1849
|
+
)
|
|
1850
|
+
or (
|
|
1851
|
+
plc.traits.is_integral(src.plc_type)
|
|
1852
|
+
and plc.traits.is_integral(dst.plc_type)
|
|
1853
|
+
and src.plc_type.id() == dst.plc_type.id()
|
|
1854
|
+
)
|
|
1855
|
+
)
|
|
1856
|
+
):
|
|
1857
|
+
return child
|
|
1858
|
+
|
|
1859
|
+
if not node.children:
|
|
1860
|
+
return node
|
|
1861
|
+
return node.reconstruct([_strip_predicate_casts(child) for child in node.children])
|
|
1862
|
+
|
|
1863
|
+
|
|
1864
|
+
def _add_cast(
|
|
1865
|
+
target: DataType,
|
|
1866
|
+
side: expr.ColRef,
|
|
1867
|
+
left_casts: dict[str, DataType],
|
|
1868
|
+
right_casts: dict[str, DataType],
|
|
1869
|
+
) -> None:
|
|
1870
|
+
(col,) = side.children
|
|
1871
|
+
assert isinstance(col, expr.Col)
|
|
1872
|
+
casts = (
|
|
1873
|
+
left_casts if side.table_ref == plc_expr.TableReference.LEFT else right_casts
|
|
1874
|
+
)
|
|
1875
|
+
casts[col.name] = target
|
|
1876
|
+
|
|
1877
|
+
|
|
1878
|
+
def _align_decimal_binop_types(
|
|
1879
|
+
left_expr: expr.ColRef,
|
|
1880
|
+
right_expr: expr.ColRef,
|
|
1881
|
+
left_casts: dict[str, DataType],
|
|
1882
|
+
right_casts: dict[str, DataType],
|
|
1883
|
+
) -> None:
|
|
1884
|
+
left_type, right_type = left_expr.dtype, right_expr.dtype
|
|
1885
|
+
|
|
1886
|
+
if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point(
|
|
1887
|
+
right_type.plc_type
|
|
1888
|
+
):
|
|
1889
|
+
target = DataType.common_decimal_dtype(left_type, right_type)
|
|
1890
|
+
|
|
1891
|
+
if left_type.id() != target.id() or left_type.scale() != target.scale():
|
|
1892
|
+
_add_cast(target, left_expr, left_casts, right_casts)
|
|
1893
|
+
|
|
1894
|
+
if right_type.id() != target.id() or right_type.scale() != target.scale():
|
|
1895
|
+
_add_cast(target, right_expr, left_casts, right_casts)
|
|
1896
|
+
|
|
1897
|
+
elif (
|
|
1898
|
+
plc.traits.is_fixed_point(left_type.plc_type)
|
|
1899
|
+
and plc.traits.is_floating_point(right_type.plc_type)
|
|
1900
|
+
) or (
|
|
1901
|
+
plc.traits.is_fixed_point(right_type.plc_type)
|
|
1902
|
+
and plc.traits.is_floating_point(left_type.plc_type)
|
|
1903
|
+
):
|
|
1904
|
+
is_decimal_left = plc.traits.is_fixed_point(left_type.plc_type)
|
|
1905
|
+
decimal_expr, float_expr = (
|
|
1906
|
+
(left_expr, right_expr) if is_decimal_left else (right_expr, left_expr)
|
|
1907
|
+
)
|
|
1908
|
+
_add_cast(decimal_expr.dtype, float_expr, left_casts, right_casts)
|
|
1909
|
+
|
|
1910
|
+
|
|
1911
|
+
def _collect_decimal_binop_casts(
|
|
1912
|
+
predicate: expr.Expr,
|
|
1913
|
+
) -> tuple[dict[str, DataType], dict[str, DataType]]:
|
|
1914
|
+
left_casts: dict[str, DataType] = {}
|
|
1915
|
+
right_casts: dict[str, DataType] = {}
|
|
1916
|
+
|
|
1917
|
+
def _walk(node: expr.Expr) -> None:
|
|
1918
|
+
if isinstance(node, expr.BinOp) and node.op in _BINOPS:
|
|
1919
|
+
left_expr, right_expr = node.children
|
|
1920
|
+
if isinstance(left_expr, expr.ColRef) and isinstance(
|
|
1921
|
+
right_expr, expr.ColRef
|
|
1922
|
+
):
|
|
1923
|
+
_align_decimal_binop_types(
|
|
1924
|
+
left_expr, right_expr, left_casts, right_casts
|
|
1925
|
+
)
|
|
1926
|
+
for child in node.children:
|
|
1927
|
+
_walk(child)
|
|
1928
|
+
|
|
1929
|
+
_walk(predicate)
|
|
1930
|
+
return left_casts, right_casts
|
|
1931
|
+
|
|
1932
|
+
|
|
1933
|
+
def _apply_casts(df: DataFrame, casts: dict[str, DataType]) -> DataFrame:
|
|
1934
|
+
if not casts:
|
|
1935
|
+
return df
|
|
1936
|
+
|
|
1937
|
+
columns = []
|
|
1938
|
+
for col in df.columns:
|
|
1939
|
+
target = casts.get(col.name)
|
|
1940
|
+
if target is None:
|
|
1941
|
+
columns.append(Column(col.obj, dtype=col.dtype, name=col.name))
|
|
1942
|
+
else:
|
|
1943
|
+
casted = col.astype(target, stream=df.stream)
|
|
1944
|
+
columns.append(Column(casted.obj, dtype=casted.dtype, name=col.name))
|
|
1945
|
+
return DataFrame(columns, stream=df.stream)
|
|
1575
1946
|
|
|
1576
1947
|
|
|
1577
1948
|
class ConditionalJoin(IR):
|
|
@@ -1585,7 +1956,14 @@ class ConditionalJoin(IR):
|
|
|
1585
1956
|
|
|
1586
1957
|
def __init__(self, predicate: expr.Expr):
|
|
1587
1958
|
self.predicate = predicate
|
|
1588
|
-
|
|
1959
|
+
stream = get_cuda_stream()
|
|
1960
|
+
ast_result = to_ast(predicate, stream=stream)
|
|
1961
|
+
stream.synchronize()
|
|
1962
|
+
if ast_result is None:
|
|
1963
|
+
raise NotImplementedError(
|
|
1964
|
+
f"Conditional join with predicate {predicate}"
|
|
1965
|
+
) # pragma: no cover; polars never delivers expressions we can't handle
|
|
1966
|
+
self.ast = ast_result
|
|
1589
1967
|
|
|
1590
1968
|
def __reduce__(self) -> tuple[Any, ...]:
|
|
1591
1969
|
"""Pickle a Predicate object."""
|
|
@@ -1598,8 +1976,9 @@ class ConditionalJoin(IR):
|
|
|
1598
1976
|
options: tuple[
|
|
1599
1977
|
tuple[
|
|
1600
1978
|
str,
|
|
1601
|
-
|
|
1602
|
-
]
|
|
1979
|
+
polars._expr_nodes.Operator | Iterable[polars._expr_nodes.Operator],
|
|
1980
|
+
]
|
|
1981
|
+
| None,
|
|
1603
1982
|
bool,
|
|
1604
1983
|
Zlice | None,
|
|
1605
1984
|
str,
|
|
@@ -1620,7 +1999,14 @@ class ConditionalJoin(IR):
|
|
|
1620
1999
|
self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
|
|
1621
2000
|
) -> None:
|
|
1622
2001
|
self.schema = schema
|
|
2002
|
+
predicate = _strip_predicate_casts(predicate)
|
|
1623
2003
|
self.predicate = predicate
|
|
2004
|
+
# options[0] is a tuple[str, Operator, ...]
|
|
2005
|
+
# The Operator class can't be pickled, but we don't use it anyway so
|
|
2006
|
+
# just throw that away
|
|
2007
|
+
if options[0] is not None:
|
|
2008
|
+
options = (None, *options[1:])
|
|
2009
|
+
|
|
1624
2010
|
self.options = options
|
|
1625
2011
|
self.children = (left, right)
|
|
1626
2012
|
predicate_wrapper = self.Predicate(predicate)
|
|
@@ -1629,51 +2015,70 @@ class ConditionalJoin(IR):
|
|
|
1629
2015
|
assert not nulls_equal
|
|
1630
2016
|
assert not coalesce
|
|
1631
2017
|
assert maintain_order == "none"
|
|
1632
|
-
|
|
1633
|
-
raise NotImplementedError(
|
|
1634
|
-
f"Conditional join with predicate {predicate}"
|
|
1635
|
-
) # pragma: no cover; polars never delivers expressions we can't handle
|
|
1636
|
-
self._non_child_args = (predicate_wrapper, zlice, suffix, maintain_order)
|
|
2018
|
+
self._non_child_args = (predicate_wrapper, options)
|
|
1637
2019
|
|
|
1638
2020
|
@classmethod
|
|
2021
|
+
@log_do_evaluate
|
|
1639
2022
|
@nvtx_annotate_cudf_polars(message="ConditionalJoin")
|
|
1640
2023
|
def do_evaluate(
|
|
1641
2024
|
cls,
|
|
1642
2025
|
predicate_wrapper: Predicate,
|
|
1643
|
-
|
|
1644
|
-
suffix: str,
|
|
1645
|
-
maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
|
|
2026
|
+
options: tuple,
|
|
1646
2027
|
left: DataFrame,
|
|
1647
2028
|
right: DataFrame,
|
|
2029
|
+
*,
|
|
2030
|
+
context: IRExecutionContext,
|
|
1648
2031
|
) -> DataFrame:
|
|
1649
2032
|
"""Evaluate and return a dataframe."""
|
|
2033
|
+
stream = get_joined_cuda_stream(
|
|
2034
|
+
context.get_cuda_stream,
|
|
2035
|
+
upstreams=(
|
|
2036
|
+
left.stream,
|
|
2037
|
+
right.stream,
|
|
2038
|
+
),
|
|
2039
|
+
)
|
|
2040
|
+
left_casts, right_casts = _collect_decimal_binop_casts(
|
|
2041
|
+
predicate_wrapper.predicate
|
|
2042
|
+
)
|
|
2043
|
+
_, _, zlice, suffix, _, _ = options
|
|
2044
|
+
|
|
1650
2045
|
lg, rg = plc.join.conditional_inner_join(
|
|
1651
|
-
left.table,
|
|
1652
|
-
right.table,
|
|
2046
|
+
_apply_casts(left, left_casts).table,
|
|
2047
|
+
_apply_casts(right, right_casts).table,
|
|
1653
2048
|
predicate_wrapper.ast,
|
|
2049
|
+
stream=stream,
|
|
1654
2050
|
)
|
|
1655
|
-
|
|
2051
|
+
left_result = DataFrame.from_table(
|
|
1656
2052
|
plc.copying.gather(
|
|
1657
|
-
left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
|
|
2053
|
+
left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK, stream=stream
|
|
1658
2054
|
),
|
|
1659
2055
|
left.column_names,
|
|
1660
2056
|
left.dtypes,
|
|
2057
|
+
stream=stream,
|
|
1661
2058
|
)
|
|
1662
|
-
|
|
2059
|
+
right_result = DataFrame.from_table(
|
|
1663
2060
|
plc.copying.gather(
|
|
1664
|
-
right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
|
|
2061
|
+
right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK, stream=stream
|
|
1665
2062
|
),
|
|
1666
2063
|
right.column_names,
|
|
1667
2064
|
right.dtypes,
|
|
2065
|
+
stream=stream,
|
|
1668
2066
|
)
|
|
1669
|
-
|
|
2067
|
+
right_result = right_result.rename_columns(
|
|
1670
2068
|
{
|
|
1671
2069
|
name: f"{name}{suffix}"
|
|
1672
2070
|
for name in right.column_names
|
|
1673
2071
|
if name in left.column_names_set
|
|
1674
2072
|
}
|
|
1675
2073
|
)
|
|
1676
|
-
result =
|
|
2074
|
+
result = left_result.with_columns(right_result.columns, stream=stream)
|
|
2075
|
+
|
|
2076
|
+
# Join the original streams back into the result stream to ensure that the
|
|
2077
|
+
# deallocations (on the original streams) happen after the result is ready
|
|
2078
|
+
join_cuda_streams(
|
|
2079
|
+
downstreams=(left.stream, right.stream), upstreams=(result.stream,)
|
|
2080
|
+
)
|
|
2081
|
+
|
|
1677
2082
|
return result.slice(zlice)
|
|
1678
2083
|
|
|
1679
2084
|
|
|
@@ -1704,6 +2109,19 @@ class Join(IR):
|
|
|
1704
2109
|
- maintain_order: which DataFrame row order to preserve, if any
|
|
1705
2110
|
"""
|
|
1706
2111
|
|
|
2112
|
+
SWAPPED_ORDER: ClassVar[
|
|
2113
|
+
dict[
|
|
2114
|
+
Literal["none", "left", "right", "left_right", "right_left"],
|
|
2115
|
+
Literal["none", "left", "right", "left_right", "right_left"],
|
|
2116
|
+
]
|
|
2117
|
+
] = {
|
|
2118
|
+
"none": "none",
|
|
2119
|
+
"left": "right",
|
|
2120
|
+
"right": "left",
|
|
2121
|
+
"left_right": "right_left",
|
|
2122
|
+
"right_left": "left_right",
|
|
2123
|
+
}
|
|
2124
|
+
|
|
1707
2125
|
def __init__(
|
|
1708
2126
|
self,
|
|
1709
2127
|
schema: Schema,
|
|
@@ -1719,9 +2137,6 @@ class Join(IR):
|
|
|
1719
2137
|
self.options = options
|
|
1720
2138
|
self.children = (left, right)
|
|
1721
2139
|
self._non_child_args = (self.left_on, self.right_on, self.options)
|
|
1722
|
-
# TODO: Implement maintain_order
|
|
1723
|
-
if options[5] != "none":
|
|
1724
|
-
raise NotImplementedError("maintain_order not implemented yet")
|
|
1725
2140
|
|
|
1726
2141
|
@staticmethod
|
|
1727
2142
|
@cache
|
|
@@ -1770,6 +2185,9 @@ class Join(IR):
|
|
|
1770
2185
|
right_rows: int,
|
|
1771
2186
|
rg: plc.Column,
|
|
1772
2187
|
right_policy: plc.copying.OutOfBoundsPolicy,
|
|
2188
|
+
*,
|
|
2189
|
+
left_primary: bool = True,
|
|
2190
|
+
stream: Stream,
|
|
1773
2191
|
) -> list[plc.Column]:
|
|
1774
2192
|
"""
|
|
1775
2193
|
Reorder gather maps to satisfy polars join order restrictions.
|
|
@@ -1788,30 +2206,70 @@ class Join(IR):
|
|
|
1788
2206
|
Right gather map
|
|
1789
2207
|
right_policy
|
|
1790
2208
|
Nullify policy for right map
|
|
2209
|
+
left_primary
|
|
2210
|
+
Whether to preserve the left input row order first, and which
|
|
2211
|
+
input stream to use for the primary sort.
|
|
2212
|
+
Defaults to True.
|
|
2213
|
+
stream
|
|
2214
|
+
CUDA stream used for device memory operations and kernel launches.
|
|
1791
2215
|
|
|
1792
2216
|
Returns
|
|
1793
2217
|
-------
|
|
1794
|
-
list
|
|
2218
|
+
list[plc.Column]
|
|
2219
|
+
Reordered left and right gather maps.
|
|
1795
2220
|
|
|
1796
2221
|
Notes
|
|
1797
2222
|
-----
|
|
1798
|
-
|
|
1799
|
-
left
|
|
1800
|
-
|
|
2223
|
+
When ``left_primary`` is True, the pair of gather maps is stably sorted by
|
|
2224
|
+
the original row order of the left side, breaking ties by the right side.
|
|
2225
|
+
And vice versa when ``left_primary`` is False.
|
|
1801
2226
|
"""
|
|
1802
|
-
init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE)
|
|
1803
|
-
step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE)
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
2227
|
+
init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE, stream=stream)
|
|
2228
|
+
step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE, stream=stream)
|
|
2229
|
+
|
|
2230
|
+
(left_order_col,) = plc.copying.gather(
|
|
2231
|
+
plc.Table(
|
|
2232
|
+
[
|
|
2233
|
+
plc.filling.sequence(
|
|
2234
|
+
left_rows,
|
|
2235
|
+
init,
|
|
2236
|
+
step,
|
|
2237
|
+
stream=stream,
|
|
2238
|
+
)
|
|
2239
|
+
]
|
|
2240
|
+
),
|
|
2241
|
+
lg,
|
|
2242
|
+
left_policy,
|
|
2243
|
+
stream=stream,
|
|
2244
|
+
).columns()
|
|
2245
|
+
(right_order_col,) = plc.copying.gather(
|
|
2246
|
+
plc.Table(
|
|
2247
|
+
[
|
|
2248
|
+
plc.filling.sequence(
|
|
2249
|
+
right_rows,
|
|
2250
|
+
init,
|
|
2251
|
+
step,
|
|
2252
|
+
stream=stream,
|
|
2253
|
+
)
|
|
2254
|
+
]
|
|
2255
|
+
),
|
|
2256
|
+
rg,
|
|
2257
|
+
right_policy,
|
|
2258
|
+
stream=stream,
|
|
2259
|
+
).columns()
|
|
2260
|
+
|
|
2261
|
+
keys = (
|
|
2262
|
+
plc.Table([left_order_col, right_order_col])
|
|
2263
|
+
if left_primary
|
|
2264
|
+
else plc.Table([right_order_col, left_order_col])
|
|
1809
2265
|
)
|
|
2266
|
+
|
|
1810
2267
|
return plc.sorting.stable_sort_by_key(
|
|
1811
2268
|
plc.Table([lg, rg]),
|
|
1812
|
-
|
|
2269
|
+
keys,
|
|
1813
2270
|
[plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
|
|
1814
2271
|
[plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
|
|
2272
|
+
stream=stream,
|
|
1815
2273
|
).columns()
|
|
1816
2274
|
|
|
1817
2275
|
@staticmethod
|
|
@@ -1822,31 +2280,35 @@ class Join(IR):
|
|
|
1822
2280
|
left: bool = True,
|
|
1823
2281
|
empty: bool = False,
|
|
1824
2282
|
rename: Callable[[str], str] = lambda name: name,
|
|
2283
|
+
stream: Stream,
|
|
1825
2284
|
) -> list[Column]:
|
|
1826
2285
|
if empty:
|
|
1827
2286
|
return [
|
|
1828
2287
|
Column(
|
|
1829
|
-
plc.column_factories.make_empty_column(
|
|
2288
|
+
plc.column_factories.make_empty_column(
|
|
2289
|
+
col.dtype.plc_type, stream=stream
|
|
2290
|
+
),
|
|
1830
2291
|
col.dtype,
|
|
1831
2292
|
name=rename(col.name),
|
|
1832
2293
|
)
|
|
1833
2294
|
for col in template
|
|
1834
2295
|
]
|
|
1835
2296
|
|
|
1836
|
-
|
|
2297
|
+
result = [
|
|
1837
2298
|
Column(new, col.dtype, name=rename(col.name))
|
|
1838
2299
|
for new, col in zip(columns, template, strict=True)
|
|
1839
2300
|
]
|
|
1840
2301
|
|
|
1841
2302
|
if left:
|
|
1842
|
-
|
|
2303
|
+
result = [
|
|
1843
2304
|
col.sorted_like(orig)
|
|
1844
|
-
for col, orig in zip(
|
|
2305
|
+
for col, orig in zip(result, template, strict=True)
|
|
1845
2306
|
]
|
|
1846
2307
|
|
|
1847
|
-
return
|
|
2308
|
+
return result
|
|
1848
2309
|
|
|
1849
2310
|
@classmethod
|
|
2311
|
+
@log_do_evaluate
|
|
1850
2312
|
@nvtx_annotate_cudf_polars(message="Join")
|
|
1851
2313
|
def do_evaluate(
|
|
1852
2314
|
cls,
|
|
@@ -1862,14 +2324,21 @@ class Join(IR):
|
|
|
1862
2324
|
],
|
|
1863
2325
|
left: DataFrame,
|
|
1864
2326
|
right: DataFrame,
|
|
2327
|
+
*,
|
|
2328
|
+
context: IRExecutionContext,
|
|
1865
2329
|
) -> DataFrame:
|
|
1866
2330
|
"""Evaluate and return a dataframe."""
|
|
1867
|
-
|
|
2331
|
+
stream = get_joined_cuda_stream(
|
|
2332
|
+
context.get_cuda_stream, upstreams=(left.stream, right.stream)
|
|
2333
|
+
)
|
|
2334
|
+
how, nulls_equal, zlice, suffix, coalesce, maintain_order = options
|
|
1868
2335
|
if how == "Cross":
|
|
1869
2336
|
# Separate implementation, since cross_join returns the
|
|
1870
2337
|
# result, not the gather maps
|
|
1871
2338
|
if right.num_rows == 0:
|
|
1872
|
-
left_cols = Join._build_columns(
|
|
2339
|
+
left_cols = Join._build_columns(
|
|
2340
|
+
[], left.columns, empty=True, stream=stream
|
|
2341
|
+
)
|
|
1873
2342
|
right_cols = Join._build_columns(
|
|
1874
2343
|
[],
|
|
1875
2344
|
right.columns,
|
|
@@ -1878,96 +2347,145 @@ class Join(IR):
|
|
|
1878
2347
|
rename=lambda name: name
|
|
1879
2348
|
if name not in left.column_names_set
|
|
1880
2349
|
else f"{name}{suffix}",
|
|
2350
|
+
stream=stream,
|
|
2351
|
+
)
|
|
2352
|
+
result = DataFrame([*left_cols, *right_cols], stream=stream)
|
|
2353
|
+
else:
|
|
2354
|
+
columns = plc.join.cross_join(
|
|
2355
|
+
left.table, right.table, stream=stream
|
|
2356
|
+
).columns()
|
|
2357
|
+
left_cols = Join._build_columns(
|
|
2358
|
+
columns[: left.num_columns], left.columns, stream=stream
|
|
2359
|
+
)
|
|
2360
|
+
right_cols = Join._build_columns(
|
|
2361
|
+
columns[left.num_columns :],
|
|
2362
|
+
right.columns,
|
|
2363
|
+
rename=lambda name: name
|
|
2364
|
+
if name not in left.column_names_set
|
|
2365
|
+
else f"{name}{suffix}",
|
|
2366
|
+
left=False,
|
|
2367
|
+
stream=stream,
|
|
2368
|
+
)
|
|
2369
|
+
result = DataFrame([*left_cols, *right_cols], stream=stream).slice(
|
|
2370
|
+
zlice
|
|
1881
2371
|
)
|
|
1882
|
-
return DataFrame([*left_cols, *right_cols])
|
|
1883
2372
|
|
|
1884
|
-
columns = plc.join.cross_join(left.table, right.table).columns()
|
|
1885
|
-
left_cols = Join._build_columns(
|
|
1886
|
-
columns[: left.num_columns],
|
|
1887
|
-
left.columns,
|
|
1888
|
-
)
|
|
1889
|
-
right_cols = Join._build_columns(
|
|
1890
|
-
columns[left.num_columns :],
|
|
1891
|
-
right.columns,
|
|
1892
|
-
rename=lambda name: name
|
|
1893
|
-
if name not in left.column_names_set
|
|
1894
|
-
else f"{name}{suffix}",
|
|
1895
|
-
left=False,
|
|
1896
|
-
)
|
|
1897
|
-
return DataFrame([*left_cols, *right_cols]).slice(zlice)
|
|
1898
|
-
# TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
|
|
1899
|
-
left_on = DataFrame(broadcast(*(e.evaluate(left) for e in left_on_exprs)))
|
|
1900
|
-
right_on = DataFrame(broadcast(*(e.evaluate(right) for e in right_on_exprs)))
|
|
1901
|
-
null_equality = (
|
|
1902
|
-
plc.types.NullEquality.EQUAL
|
|
1903
|
-
if nulls_equal
|
|
1904
|
-
else plc.types.NullEquality.UNEQUAL
|
|
1905
|
-
)
|
|
1906
|
-
join_fn, left_policy, right_policy = cls._joiners(how)
|
|
1907
|
-
if right_policy is None:
|
|
1908
|
-
# Semi join
|
|
1909
|
-
lg = join_fn(left_on.table, right_on.table, null_equality)
|
|
1910
|
-
table = plc.copying.gather(left.table, lg, left_policy)
|
|
1911
|
-
result = DataFrame.from_table(table, left.column_names, left.dtypes)
|
|
1912
2373
|
else:
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
if how == "Left" or how == "Right":
|
|
1919
|
-
# Order of left table is preserved
|
|
1920
|
-
lg, rg = cls._reorder_maps(
|
|
1921
|
-
left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
|
|
1922
|
-
)
|
|
1923
|
-
if coalesce:
|
|
1924
|
-
if how == "Full":
|
|
1925
|
-
# In this case, keys must be column references,
|
|
1926
|
-
# possibly with dtype casting. We should use them in
|
|
1927
|
-
# preference to the columns from the original tables.
|
|
1928
|
-
left = left.with_columns(left_on.columns, replace_only=True)
|
|
1929
|
-
right = right.with_columns(right_on.columns, replace_only=True)
|
|
1930
|
-
else:
|
|
1931
|
-
right = right.discard_columns(right_on.column_names_set)
|
|
1932
|
-
left = DataFrame.from_table(
|
|
1933
|
-
plc.copying.gather(left.table, lg, left_policy),
|
|
1934
|
-
left.column_names,
|
|
1935
|
-
left.dtypes,
|
|
2374
|
+
# how != "Cross"
|
|
2375
|
+
# TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
|
|
2376
|
+
left_on = DataFrame(
|
|
2377
|
+
broadcast(*(e.evaluate(left) for e in left_on_exprs), stream=stream),
|
|
2378
|
+
stream=stream,
|
|
1936
2379
|
)
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
right.dtypes,
|
|
2380
|
+
right_on = DataFrame(
|
|
2381
|
+
broadcast(*(e.evaluate(right) for e in right_on_exprs), stream=stream),
|
|
2382
|
+
stream=stream,
|
|
1941
2383
|
)
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
2384
|
+
null_equality = (
|
|
2385
|
+
plc.types.NullEquality.EQUAL
|
|
2386
|
+
if nulls_equal
|
|
2387
|
+
else plc.types.NullEquality.UNEQUAL
|
|
2388
|
+
)
|
|
2389
|
+
join_fn, left_policy, right_policy = cls._joiners(how)
|
|
2390
|
+
if right_policy is None:
|
|
2391
|
+
# Semi join
|
|
2392
|
+
lg = join_fn(left_on.table, right_on.table, null_equality, stream)
|
|
2393
|
+
table = plc.copying.gather(left.table, lg, left_policy, stream=stream)
|
|
2394
|
+
result = DataFrame.from_table(
|
|
2395
|
+
table, left.column_names, left.dtypes, stream=stream
|
|
2396
|
+
)
|
|
2397
|
+
else:
|
|
2398
|
+
if how == "Right":
|
|
2399
|
+
# Right join is a left join with the tables swapped
|
|
2400
|
+
left, right = right, left
|
|
2401
|
+
left_on, right_on = right_on, left_on
|
|
2402
|
+
maintain_order = Join.SWAPPED_ORDER[maintain_order]
|
|
2403
|
+
|
|
2404
|
+
lg, rg = join_fn(
|
|
2405
|
+
left_on.table, right_on.table, null_equality, stream=stream
|
|
2406
|
+
)
|
|
2407
|
+
if (
|
|
2408
|
+
how in ("Inner", "Left", "Right", "Full")
|
|
2409
|
+
and maintain_order != "none"
|
|
2410
|
+
):
|
|
2411
|
+
lg, rg = cls._reorder_maps(
|
|
2412
|
+
left.num_rows,
|
|
2413
|
+
lg,
|
|
2414
|
+
left_policy,
|
|
2415
|
+
right.num_rows,
|
|
2416
|
+
rg,
|
|
2417
|
+
right_policy,
|
|
2418
|
+
left_primary=maintain_order.startswith("left"),
|
|
2419
|
+
stream=stream,
|
|
2420
|
+
)
|
|
2421
|
+
if coalesce:
|
|
2422
|
+
if how == "Full":
|
|
2423
|
+
# In this case, keys must be column references,
|
|
2424
|
+
# possibly with dtype casting. We should use them in
|
|
2425
|
+
# preference to the columns from the original tables.
|
|
2426
|
+
|
|
2427
|
+
# We need to specify `stream` here. We know that `{left,right}_on`
|
|
2428
|
+
# is valid on `stream`, which is ordered after `{left,right}.stream`.
|
|
2429
|
+
left = left.with_columns(
|
|
2430
|
+
left_on.columns, replace_only=True, stream=stream
|
|
1949
2431
|
)
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
right.select_columns(right_on.column_names_set),
|
|
1953
|
-
strict=True,
|
|
2432
|
+
right = right.with_columns(
|
|
2433
|
+
right_on.columns, replace_only=True, stream=stream
|
|
1954
2434
|
)
|
|
1955
|
-
|
|
1956
|
-
|
|
2435
|
+
else:
|
|
2436
|
+
right = right.discard_columns(right_on.column_names_set)
|
|
2437
|
+
left = DataFrame.from_table(
|
|
2438
|
+
plc.copying.gather(left.table, lg, left_policy, stream=stream),
|
|
2439
|
+
left.column_names,
|
|
2440
|
+
left.dtypes,
|
|
2441
|
+
stream=stream,
|
|
1957
2442
|
)
|
|
1958
|
-
right =
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
2443
|
+
right = DataFrame.from_table(
|
|
2444
|
+
plc.copying.gather(right.table, rg, right_policy, stream=stream),
|
|
2445
|
+
right.column_names,
|
|
2446
|
+
right.dtypes,
|
|
2447
|
+
stream=stream,
|
|
2448
|
+
)
|
|
2449
|
+
if coalesce and how == "Full":
|
|
2450
|
+
left = left.with_columns(
|
|
2451
|
+
(
|
|
2452
|
+
Column(
|
|
2453
|
+
plc.replace.replace_nulls(
|
|
2454
|
+
left_col.obj, right_col.obj, stream=stream
|
|
2455
|
+
),
|
|
2456
|
+
name=left_col.name,
|
|
2457
|
+
dtype=left_col.dtype,
|
|
2458
|
+
)
|
|
2459
|
+
for left_col, right_col in zip(
|
|
2460
|
+
left.select_columns(left_on.column_names_set),
|
|
2461
|
+
right.select_columns(right_on.column_names_set),
|
|
2462
|
+
strict=True,
|
|
2463
|
+
)
|
|
2464
|
+
),
|
|
2465
|
+
replace_only=True,
|
|
2466
|
+
stream=stream,
|
|
2467
|
+
)
|
|
2468
|
+
right = right.discard_columns(right_on.column_names_set)
|
|
2469
|
+
if how == "Right":
|
|
2470
|
+
# Undo the swap for right join before gluing together.
|
|
2471
|
+
left, right = right, left
|
|
2472
|
+
right = right.rename_columns(
|
|
2473
|
+
{
|
|
2474
|
+
name: f"{name}{suffix}"
|
|
2475
|
+
for name in right.column_names
|
|
2476
|
+
if name in left.column_names_set
|
|
2477
|
+
}
|
|
2478
|
+
)
|
|
2479
|
+
result = left.with_columns(right.columns, stream=stream)
|
|
2480
|
+
result = result.slice(zlice)
|
|
2481
|
+
|
|
2482
|
+
# Join the original streams back into the result stream to ensure that the
|
|
2483
|
+
# deallocations (on the original streams) happen after the result is ready
|
|
2484
|
+
join_cuda_streams(
|
|
2485
|
+
downstreams=(left.stream, right.stream), upstreams=(result.stream,)
|
|
2486
|
+
)
|
|
2487
|
+
|
|
2488
|
+
return result
|
|
1971
2489
|
|
|
1972
2490
|
|
|
1973
2491
|
class HStack(IR):
|
|
@@ -1992,18 +2510,23 @@ class HStack(IR):
|
|
|
1992
2510
|
self.children = (df,)
|
|
1993
2511
|
|
|
1994
2512
|
@classmethod
|
|
2513
|
+
@log_do_evaluate
|
|
1995
2514
|
@nvtx_annotate_cudf_polars(message="HStack")
|
|
1996
2515
|
def do_evaluate(
|
|
1997
2516
|
cls,
|
|
1998
2517
|
exprs: Sequence[expr.NamedExpr],
|
|
1999
2518
|
should_broadcast: bool, # noqa: FBT001
|
|
2000
2519
|
df: DataFrame,
|
|
2520
|
+
*,
|
|
2521
|
+
context: IRExecutionContext,
|
|
2001
2522
|
) -> DataFrame:
|
|
2002
2523
|
"""Evaluate and return a dataframe."""
|
|
2003
2524
|
columns = [c.evaluate(df) for c in exprs]
|
|
2004
2525
|
if should_broadcast:
|
|
2005
2526
|
columns = broadcast(
|
|
2006
|
-
*columns,
|
|
2527
|
+
*columns,
|
|
2528
|
+
target_length=df.num_rows if df.num_columns != 0 else None,
|
|
2529
|
+
stream=df.stream,
|
|
2007
2530
|
)
|
|
2008
2531
|
else:
|
|
2009
2532
|
# Polars ensures this is true, but let's make sure nothing
|
|
@@ -2014,7 +2537,7 @@ class HStack(IR):
|
|
|
2014
2537
|
# never be turned into a pylibcudf Table with all columns
|
|
2015
2538
|
# by the Select, which is why this is safe.
|
|
2016
2539
|
assert all(e.name.startswith("__POLARS_CSER_0x") for e in exprs)
|
|
2017
|
-
return df.with_columns(columns)
|
|
2540
|
+
return df.with_columns(columns, stream=df.stream)
|
|
2018
2541
|
|
|
2019
2542
|
|
|
2020
2543
|
class Distinct(IR):
|
|
@@ -2057,6 +2580,7 @@ class Distinct(IR):
|
|
|
2057
2580
|
}
|
|
2058
2581
|
|
|
2059
2582
|
@classmethod
|
|
2583
|
+
@log_do_evaluate
|
|
2060
2584
|
@nvtx_annotate_cudf_polars(message="Distinct")
|
|
2061
2585
|
def do_evaluate(
|
|
2062
2586
|
cls,
|
|
@@ -2065,6 +2589,8 @@ class Distinct(IR):
|
|
|
2065
2589
|
zlice: Zlice | None,
|
|
2066
2590
|
stable: bool, # noqa: FBT001
|
|
2067
2591
|
df: DataFrame,
|
|
2592
|
+
*,
|
|
2593
|
+
context: IRExecutionContext,
|
|
2068
2594
|
) -> DataFrame:
|
|
2069
2595
|
"""Evaluate and return a dataframe."""
|
|
2070
2596
|
if subset is None:
|
|
@@ -2079,6 +2605,7 @@ class Distinct(IR):
|
|
|
2079
2605
|
indices,
|
|
2080
2606
|
keep,
|
|
2081
2607
|
plc.types.NullEquality.EQUAL,
|
|
2608
|
+
stream=df.stream,
|
|
2082
2609
|
)
|
|
2083
2610
|
else:
|
|
2084
2611
|
distinct = (
|
|
@@ -2092,13 +2619,15 @@ class Distinct(IR):
|
|
|
2092
2619
|
keep,
|
|
2093
2620
|
plc.types.NullEquality.EQUAL,
|
|
2094
2621
|
plc.types.NanEquality.ALL_EQUAL,
|
|
2622
|
+
df.stream,
|
|
2095
2623
|
)
|
|
2096
2624
|
# TODO: Is this sortedness setting correct
|
|
2097
2625
|
result = DataFrame(
|
|
2098
2626
|
[
|
|
2099
2627
|
Column(new, name=old.name, dtype=old.dtype).sorted_like(old)
|
|
2100
2628
|
for new, old in zip(table.columns(), df.columns, strict=True)
|
|
2101
|
-
]
|
|
2629
|
+
],
|
|
2630
|
+
stream=df.stream,
|
|
2102
2631
|
)
|
|
2103
2632
|
if keys_sorted or stable:
|
|
2104
2633
|
result = result.sorted_like(df)
|
|
@@ -2147,6 +2676,7 @@ class Sort(IR):
|
|
|
2147
2676
|
self.children = (df,)
|
|
2148
2677
|
|
|
2149
2678
|
@classmethod
|
|
2679
|
+
@log_do_evaluate
|
|
2150
2680
|
@nvtx_annotate_cudf_polars(message="Sort")
|
|
2151
2681
|
def do_evaluate(
|
|
2152
2682
|
cls,
|
|
@@ -2156,17 +2686,24 @@ class Sort(IR):
|
|
|
2156
2686
|
stable: bool, # noqa: FBT001
|
|
2157
2687
|
zlice: Zlice | None,
|
|
2158
2688
|
df: DataFrame,
|
|
2689
|
+
*,
|
|
2690
|
+
context: IRExecutionContext,
|
|
2159
2691
|
) -> DataFrame:
|
|
2160
2692
|
"""Evaluate and return a dataframe."""
|
|
2161
|
-
sort_keys = broadcast(
|
|
2693
|
+
sort_keys = broadcast(
|
|
2694
|
+
*(k.evaluate(df) for k in by), target_length=df.num_rows, stream=df.stream
|
|
2695
|
+
)
|
|
2162
2696
|
do_sort = plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
|
|
2163
2697
|
table = do_sort(
|
|
2164
2698
|
df.table,
|
|
2165
2699
|
plc.Table([k.obj for k in sort_keys]),
|
|
2166
2700
|
list(order),
|
|
2167
2701
|
list(null_order),
|
|
2702
|
+
stream=df.stream,
|
|
2703
|
+
)
|
|
2704
|
+
result = DataFrame.from_table(
|
|
2705
|
+
table, df.column_names, df.dtypes, stream=df.stream
|
|
2168
2706
|
)
|
|
2169
|
-
result = DataFrame.from_table(table, df.column_names, df.dtypes)
|
|
2170
2707
|
first_key = sort_keys[0]
|
|
2171
2708
|
name = by[0].name
|
|
2172
2709
|
first_key_in_result = (
|
|
@@ -2197,8 +2734,11 @@ class Slice(IR):
|
|
|
2197
2734
|
self.children = (df,)
|
|
2198
2735
|
|
|
2199
2736
|
@classmethod
|
|
2737
|
+
@log_do_evaluate
|
|
2200
2738
|
@nvtx_annotate_cudf_polars(message="Slice")
|
|
2201
|
-
def do_evaluate(
|
|
2739
|
+
def do_evaluate(
|
|
2740
|
+
cls, offset: int, length: int, df: DataFrame, *, context: IRExecutionContext
|
|
2741
|
+
) -> DataFrame:
|
|
2202
2742
|
"""Evaluate and return a dataframe."""
|
|
2203
2743
|
return df.slice((offset, length))
|
|
2204
2744
|
|
|
@@ -2218,10 +2758,15 @@ class Filter(IR):
|
|
|
2218
2758
|
self.children = (df,)
|
|
2219
2759
|
|
|
2220
2760
|
@classmethod
|
|
2761
|
+
@log_do_evaluate
|
|
2221
2762
|
@nvtx_annotate_cudf_polars(message="Filter")
|
|
2222
|
-
def do_evaluate(
|
|
2763
|
+
def do_evaluate(
|
|
2764
|
+
cls, mask_expr: expr.NamedExpr, df: DataFrame, *, context: IRExecutionContext
|
|
2765
|
+
) -> DataFrame:
|
|
2223
2766
|
"""Evaluate and return a dataframe."""
|
|
2224
|
-
(mask,) = broadcast(
|
|
2767
|
+
(mask,) = broadcast(
|
|
2768
|
+
mask_expr.evaluate(df), target_length=df.num_rows, stream=df.stream
|
|
2769
|
+
)
|
|
2225
2770
|
return df.filter(mask)
|
|
2226
2771
|
|
|
2227
2772
|
|
|
@@ -2237,14 +2782,19 @@ class Projection(IR):
|
|
|
2237
2782
|
self.children = (df,)
|
|
2238
2783
|
|
|
2239
2784
|
@classmethod
|
|
2785
|
+
@log_do_evaluate
|
|
2240
2786
|
@nvtx_annotate_cudf_polars(message="Projection")
|
|
2241
|
-
def do_evaluate(
|
|
2787
|
+
def do_evaluate(
|
|
2788
|
+
cls, schema: Schema, df: DataFrame, *, context: IRExecutionContext
|
|
2789
|
+
) -> DataFrame:
|
|
2242
2790
|
"""Evaluate and return a dataframe."""
|
|
2243
2791
|
# This can reorder things.
|
|
2244
2792
|
columns = broadcast(
|
|
2245
|
-
*(df.column_map[name] for name in schema),
|
|
2793
|
+
*(df.column_map[name] for name in schema),
|
|
2794
|
+
target_length=df.num_rows,
|
|
2795
|
+
stream=df.stream,
|
|
2246
2796
|
)
|
|
2247
|
-
return DataFrame(columns)
|
|
2797
|
+
return DataFrame(columns, stream=df.stream)
|
|
2248
2798
|
|
|
2249
2799
|
|
|
2250
2800
|
class MergeSorted(IR):
|
|
@@ -2270,24 +2820,40 @@ class MergeSorted(IR):
|
|
|
2270
2820
|
self._non_child_args = (key,)
|
|
2271
2821
|
|
|
2272
2822
|
@classmethod
|
|
2823
|
+
@log_do_evaluate
|
|
2273
2824
|
@nvtx_annotate_cudf_polars(message="MergeSorted")
|
|
2274
|
-
def do_evaluate(
|
|
2825
|
+
def do_evaluate(
|
|
2826
|
+
cls, key: str, *dfs: DataFrame, context: IRExecutionContext
|
|
2827
|
+
) -> DataFrame:
|
|
2275
2828
|
"""Evaluate and return a dataframe."""
|
|
2829
|
+
stream = get_joined_cuda_stream(
|
|
2830
|
+
context.get_cuda_stream, upstreams=[df.stream for df in dfs]
|
|
2831
|
+
)
|
|
2276
2832
|
left, right = dfs
|
|
2277
2833
|
right = right.discard_columns(right.column_names_set - left.column_names_set)
|
|
2278
2834
|
on_col_left = left.select_columns({key})[0]
|
|
2279
2835
|
on_col_right = right.select_columns({key})[0]
|
|
2280
|
-
|
|
2836
|
+
result = DataFrame.from_table(
|
|
2281
2837
|
plc.merge.merge(
|
|
2282
2838
|
[right.table, left.table],
|
|
2283
2839
|
[left.column_names.index(key), right.column_names.index(key)],
|
|
2284
2840
|
[on_col_left.order, on_col_right.order],
|
|
2285
2841
|
[on_col_left.null_order, on_col_right.null_order],
|
|
2842
|
+
stream=stream,
|
|
2286
2843
|
),
|
|
2287
2844
|
left.column_names,
|
|
2288
2845
|
left.dtypes,
|
|
2846
|
+
stream=stream,
|
|
2847
|
+
)
|
|
2848
|
+
|
|
2849
|
+
# Join the original streams back into the result stream to ensure that the
|
|
2850
|
+
# deallocations (on the original streams) happen after the result is ready
|
|
2851
|
+
join_cuda_streams(
|
|
2852
|
+
downstreams=[df.stream for df in dfs], upstreams=(result.stream,)
|
|
2289
2853
|
)
|
|
2290
2854
|
|
|
2855
|
+
return result
|
|
2856
|
+
|
|
2291
2857
|
|
|
2292
2858
|
class MapFunction(IR):
|
|
2293
2859
|
"""Apply some function to a dataframe."""
|
|
@@ -2347,7 +2913,7 @@ class MapFunction(IR):
|
|
|
2347
2913
|
index = frozenset(indices)
|
|
2348
2914
|
pivotees = [name for name in df.schema if name not in index]
|
|
2349
2915
|
if not all(
|
|
2350
|
-
dtypes.can_cast(df.schema[p].
|
|
2916
|
+
dtypes.can_cast(df.schema[p].plc_type, self.schema[value_name].plc_type)
|
|
2351
2917
|
for p in pivotees
|
|
2352
2918
|
):
|
|
2353
2919
|
raise NotImplementedError(
|
|
@@ -2390,9 +2956,16 @@ class MapFunction(IR):
|
|
|
2390
2956
|
)
|
|
2391
2957
|
|
|
2392
2958
|
@classmethod
|
|
2959
|
+
@log_do_evaluate
|
|
2393
2960
|
@nvtx_annotate_cudf_polars(message="MapFunction")
|
|
2394
2961
|
def do_evaluate(
|
|
2395
|
-
cls,
|
|
2962
|
+
cls,
|
|
2963
|
+
schema: Schema,
|
|
2964
|
+
name: str,
|
|
2965
|
+
options: Any,
|
|
2966
|
+
df: DataFrame,
|
|
2967
|
+
*,
|
|
2968
|
+
context: IRExecutionContext,
|
|
2396
2969
|
) -> DataFrame:
|
|
2397
2970
|
"""Evaluate and return a dataframe."""
|
|
2398
2971
|
if name == "rechunk":
|
|
@@ -2409,7 +2982,10 @@ class MapFunction(IR):
|
|
|
2409
2982
|
index = df.column_names.index(to_explode)
|
|
2410
2983
|
subset = df.column_names_set - {to_explode}
|
|
2411
2984
|
return DataFrame.from_table(
|
|
2412
|
-
plc.lists.explode_outer(df.table, index
|
|
2985
|
+
plc.lists.explode_outer(df.table, index, stream=df.stream),
|
|
2986
|
+
df.column_names,
|
|
2987
|
+
df.dtypes,
|
|
2988
|
+
stream=df.stream,
|
|
2413
2989
|
).sorted_like(df, subset=subset)
|
|
2414
2990
|
elif name == "unpivot":
|
|
2415
2991
|
(
|
|
@@ -2423,7 +2999,7 @@ class MapFunction(IR):
|
|
|
2423
2999
|
index_columns = [
|
|
2424
3000
|
Column(tiled, name=name, dtype=old.dtype)
|
|
2425
3001
|
for tiled, name, old in zip(
|
|
2426
|
-
plc.reshape.tile(selected.table, npiv).columns(),
|
|
3002
|
+
plc.reshape.tile(selected.table, npiv, stream=df.stream).columns(),
|
|
2427
3003
|
indices,
|
|
2428
3004
|
selected.columns,
|
|
2429
3005
|
strict=True,
|
|
@@ -2434,18 +3010,23 @@ class MapFunction(IR):
|
|
|
2434
3010
|
[
|
|
2435
3011
|
plc.Column.from_arrow(
|
|
2436
3012
|
pl.Series(
|
|
2437
|
-
values=pivotees, dtype=schema[variable_name].
|
|
2438
|
-
)
|
|
3013
|
+
values=pivotees, dtype=schema[variable_name].polars_type
|
|
3014
|
+
),
|
|
3015
|
+
stream=df.stream,
|
|
2439
3016
|
)
|
|
2440
3017
|
]
|
|
2441
3018
|
),
|
|
2442
3019
|
df.num_rows,
|
|
3020
|
+
stream=df.stream,
|
|
2443
3021
|
).columns()
|
|
2444
3022
|
value_column = plc.concatenate.concatenate(
|
|
2445
3023
|
[
|
|
2446
|
-
df.column_map[pivotee]
|
|
3024
|
+
df.column_map[pivotee]
|
|
3025
|
+
.astype(schema[value_name], stream=df.stream)
|
|
3026
|
+
.obj
|
|
2447
3027
|
for pivotee in pivotees
|
|
2448
|
-
]
|
|
3028
|
+
],
|
|
3029
|
+
stream=df.stream,
|
|
2449
3030
|
)
|
|
2450
3031
|
return DataFrame(
|
|
2451
3032
|
[
|
|
@@ -2454,22 +3035,23 @@ class MapFunction(IR):
|
|
|
2454
3035
|
variable_column, name=variable_name, dtype=schema[variable_name]
|
|
2455
3036
|
),
|
|
2456
3037
|
Column(value_column, name=value_name, dtype=schema[value_name]),
|
|
2457
|
-
]
|
|
3038
|
+
],
|
|
3039
|
+
stream=df.stream,
|
|
2458
3040
|
)
|
|
2459
3041
|
elif name == "row_index":
|
|
2460
3042
|
col_name, offset = options
|
|
2461
3043
|
dtype = schema[col_name]
|
|
2462
|
-
step = plc.Scalar.from_py(1, dtype.
|
|
2463
|
-
init = plc.Scalar.from_py(offset, dtype.
|
|
3044
|
+
step = plc.Scalar.from_py(1, dtype.plc_type, stream=df.stream)
|
|
3045
|
+
init = plc.Scalar.from_py(offset, dtype.plc_type, stream=df.stream)
|
|
2464
3046
|
index_col = Column(
|
|
2465
|
-
plc.filling.sequence(df.num_rows, init, step),
|
|
3047
|
+
plc.filling.sequence(df.num_rows, init, step, stream=df.stream),
|
|
2466
3048
|
is_sorted=plc.types.Sorted.YES,
|
|
2467
3049
|
order=plc.types.Order.ASCENDING,
|
|
2468
3050
|
null_order=plc.types.NullOrder.AFTER,
|
|
2469
3051
|
name=col_name,
|
|
2470
3052
|
dtype=dtype,
|
|
2471
3053
|
)
|
|
2472
|
-
return DataFrame([index_col, *df.columns])
|
|
3054
|
+
return DataFrame([index_col, *df.columns], stream=df.stream)
|
|
2473
3055
|
else:
|
|
2474
3056
|
raise AssertionError("Should never be reached") # pragma: no cover
|
|
2475
3057
|
|
|
@@ -2490,16 +3072,33 @@ class Union(IR):
|
|
|
2490
3072
|
schema = self.children[0].schema
|
|
2491
3073
|
|
|
2492
3074
|
@classmethod
|
|
3075
|
+
@log_do_evaluate
|
|
2493
3076
|
@nvtx_annotate_cudf_polars(message="Union")
|
|
2494
|
-
def do_evaluate(
|
|
3077
|
+
def do_evaluate(
|
|
3078
|
+
cls, zlice: Zlice | None, *dfs: DataFrame, context: IRExecutionContext
|
|
3079
|
+
) -> DataFrame:
|
|
2495
3080
|
"""Evaluate and return a dataframe."""
|
|
3081
|
+
stream = get_joined_cuda_stream(
|
|
3082
|
+
context.get_cuda_stream, upstreams=[df.stream for df in dfs]
|
|
3083
|
+
)
|
|
3084
|
+
|
|
2496
3085
|
# TODO: only evaluate what we need if we have a slice?
|
|
2497
|
-
|
|
2498
|
-
plc.concatenate.concatenate([df.table for df in dfs]),
|
|
3086
|
+
result = DataFrame.from_table(
|
|
3087
|
+
plc.concatenate.concatenate([df.table for df in dfs], stream=stream),
|
|
2499
3088
|
dfs[0].column_names,
|
|
2500
3089
|
dfs[0].dtypes,
|
|
3090
|
+
stream=stream,
|
|
2501
3091
|
).slice(zlice)
|
|
2502
3092
|
|
|
3093
|
+
# now join the original streams *back* to the new result stream
|
|
3094
|
+
# to ensure that the deallocations (on the original streams)
|
|
3095
|
+
# happen after the result is ready
|
|
3096
|
+
join_cuda_streams(
|
|
3097
|
+
downstreams=[df.stream for df in dfs], upstreams=(result.stream,)
|
|
3098
|
+
)
|
|
3099
|
+
|
|
3100
|
+
return result
|
|
3101
|
+
|
|
2503
3102
|
|
|
2504
3103
|
class HConcat(IR):
|
|
2505
3104
|
"""Concatenate dataframes horizontally."""
|
|
@@ -2519,7 +3118,9 @@ class HConcat(IR):
|
|
|
2519
3118
|
self.children = children
|
|
2520
3119
|
|
|
2521
3120
|
@staticmethod
|
|
2522
|
-
def _extend_with_nulls(
|
|
3121
|
+
def _extend_with_nulls(
|
|
3122
|
+
table: plc.Table, *, nrows: int, stream: Stream
|
|
3123
|
+
) -> plc.Table:
|
|
2523
3124
|
"""
|
|
2524
3125
|
Extend a table with nulls.
|
|
2525
3126
|
|
|
@@ -2529,6 +3130,8 @@ class HConcat(IR):
|
|
|
2529
3130
|
Table to extend
|
|
2530
3131
|
nrows
|
|
2531
3132
|
Number of additional rows
|
|
3133
|
+
stream
|
|
3134
|
+
CUDA stream used for device memory operations and kernel launches
|
|
2532
3135
|
|
|
2533
3136
|
Returns
|
|
2534
3137
|
-------
|
|
@@ -2539,46 +3142,69 @@ class HConcat(IR):
|
|
|
2539
3142
|
table,
|
|
2540
3143
|
plc.Table(
|
|
2541
3144
|
[
|
|
2542
|
-
plc.Column.all_null_like(column, nrows)
|
|
3145
|
+
plc.Column.all_null_like(column, nrows, stream=stream)
|
|
2543
3146
|
for column in table.columns()
|
|
2544
3147
|
]
|
|
2545
3148
|
),
|
|
2546
|
-
]
|
|
3149
|
+
],
|
|
3150
|
+
stream=stream,
|
|
2547
3151
|
)
|
|
2548
3152
|
|
|
2549
3153
|
@classmethod
|
|
3154
|
+
@log_do_evaluate
|
|
2550
3155
|
@nvtx_annotate_cudf_polars(message="HConcat")
|
|
2551
3156
|
def do_evaluate(
|
|
2552
3157
|
cls,
|
|
2553
3158
|
should_broadcast: bool, # noqa: FBT001
|
|
2554
3159
|
*dfs: DataFrame,
|
|
3160
|
+
context: IRExecutionContext,
|
|
2555
3161
|
) -> DataFrame:
|
|
2556
3162
|
"""Evaluate and return a dataframe."""
|
|
3163
|
+
stream = get_joined_cuda_stream(
|
|
3164
|
+
context.get_cuda_stream, upstreams=[df.stream for df in dfs]
|
|
3165
|
+
)
|
|
3166
|
+
|
|
2557
3167
|
# Special should_broadcast case.
|
|
2558
3168
|
# Used to recombine decomposed expressions
|
|
2559
3169
|
if should_broadcast:
|
|
2560
|
-
|
|
2561
|
-
broadcast(
|
|
3170
|
+
result = DataFrame(
|
|
3171
|
+
broadcast(
|
|
3172
|
+
*itertools.chain.from_iterable(df.columns for df in dfs),
|
|
3173
|
+
stream=stream,
|
|
3174
|
+
),
|
|
3175
|
+
stream=stream,
|
|
2562
3176
|
)
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
|
|
2571
|
-
|
|
2572
|
-
|
|
2573
|
-
|
|
2574
|
-
|
|
2575
|
-
|
|
3177
|
+
else:
|
|
3178
|
+
max_rows = max(df.num_rows for df in dfs)
|
|
3179
|
+
# Horizontal concatenation extends shorter tables with nulls
|
|
3180
|
+
result = DataFrame(
|
|
3181
|
+
itertools.chain.from_iterable(
|
|
3182
|
+
df.columns
|
|
3183
|
+
for df in (
|
|
3184
|
+
df
|
|
3185
|
+
if df.num_rows == max_rows
|
|
3186
|
+
else DataFrame.from_table(
|
|
3187
|
+
cls._extend_with_nulls(
|
|
3188
|
+
df.table, nrows=max_rows - df.num_rows, stream=stream
|
|
3189
|
+
),
|
|
3190
|
+
df.column_names,
|
|
3191
|
+
df.dtypes,
|
|
3192
|
+
stream=stream,
|
|
3193
|
+
)
|
|
3194
|
+
for df in dfs
|
|
2576
3195
|
)
|
|
2577
|
-
|
|
2578
|
-
|
|
3196
|
+
),
|
|
3197
|
+
stream=stream,
|
|
2579
3198
|
)
|
|
3199
|
+
|
|
3200
|
+
# Join the original streams back into the result stream to ensure that the
|
|
3201
|
+
# deallocations (on the original streams) happen after the result is ready
|
|
3202
|
+
join_cuda_streams(
|
|
3203
|
+
downstreams=[df.stream for df in dfs], upstreams=(result.stream,)
|
|
2580
3204
|
)
|
|
2581
3205
|
|
|
3206
|
+
return result
|
|
3207
|
+
|
|
2582
3208
|
|
|
2583
3209
|
class Empty(IR):
|
|
2584
3210
|
"""Represents an empty DataFrame with a known schema."""
|
|
@@ -2592,16 +3218,23 @@ class Empty(IR):
|
|
|
2592
3218
|
self.children = ()
|
|
2593
3219
|
|
|
2594
3220
|
@classmethod
|
|
3221
|
+
@log_do_evaluate
|
|
2595
3222
|
@nvtx_annotate_cudf_polars(message="Empty")
|
|
2596
|
-
def do_evaluate(
|
|
3223
|
+
def do_evaluate(
|
|
3224
|
+
cls, schema: Schema, *, context: IRExecutionContext
|
|
3225
|
+
) -> DataFrame: # pragma: no cover
|
|
2597
3226
|
"""Evaluate and return a dataframe."""
|
|
3227
|
+
stream = context.get_cuda_stream()
|
|
2598
3228
|
return DataFrame(
|
|
2599
3229
|
[
|
|
2600
3230
|
Column(
|
|
2601
|
-
plc.column_factories.make_empty_column(
|
|
3231
|
+
plc.column_factories.make_empty_column(
|
|
3232
|
+
dtype.plc_type, stream=stream
|
|
3233
|
+
),
|
|
2602
3234
|
dtype=dtype,
|
|
2603
3235
|
name=name,
|
|
2604
3236
|
)
|
|
2605
3237
|
for name, dtype in schema.items()
|
|
2606
|
-
]
|
|
3238
|
+
],
|
|
3239
|
+
stream=stream,
|
|
2607
3240
|
)
|