cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +60 -15
- cudf_polars/containers/column.py +137 -77
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +256 -114
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +33 -3
- cudf_polars/dsl/expressions/unary.py +126 -64
- cudf_polars/dsl/ir.py +1053 -350
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +307 -107
- cudf_polars/dsl/utils/aggregations.py +43 -30
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +55 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +792 -2
- cudf_polars/experimental/benchmarks/utils.py +596 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +30 -15
- cudf_polars/experimental/groupby.py +25 -4
- cudf_polars/experimental/io.py +156 -124
- cudf_polars/experimental/join.py +53 -23
- cudf_polars/experimental/parallel.py +68 -19
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
- cudf_polars/experimental/rapidsmpf/core.py +488 -0
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
- cudf_polars/experimental/rapidsmpf/io.py +696 -0
- cudf_polars/experimental/rapidsmpf/join.py +322 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
- cudf_polars/experimental/rapidsmpf/union.py +115 -0
- cudf_polars/experimental/rapidsmpf/utils.py +374 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +46 -12
- cudf_polars/experimental/sort.py +100 -26
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +93 -17
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +473 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +5 -4
- cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
- cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/ir.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024-
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""
|
|
4
4
|
DSL nodes for the LogicalPlan of polars.
|
|
@@ -13,46 +13,61 @@ can be considered as functions:
|
|
|
13
13
|
|
|
14
14
|
from __future__ import annotations
|
|
15
15
|
|
|
16
|
+
import contextlib
|
|
16
17
|
import itertools
|
|
17
18
|
import json
|
|
18
19
|
import random
|
|
19
20
|
import time
|
|
21
|
+
from dataclasses import dataclass
|
|
20
22
|
from functools import cache
|
|
21
23
|
from pathlib import Path
|
|
22
|
-
from typing import TYPE_CHECKING, Any, ClassVar
|
|
24
|
+
from typing import TYPE_CHECKING, Any, ClassVar, overload
|
|
23
25
|
|
|
24
26
|
from typing_extensions import assert_never
|
|
25
27
|
|
|
26
28
|
import polars as pl
|
|
27
29
|
|
|
28
30
|
import pylibcudf as plc
|
|
31
|
+
from pylibcudf import expressions as plc_expr
|
|
29
32
|
|
|
30
33
|
import cudf_polars.dsl.expr as expr
|
|
31
34
|
from cudf_polars.containers import Column, DataFrame, DataType
|
|
35
|
+
from cudf_polars.containers.dataframe import NamedColumn
|
|
32
36
|
from cudf_polars.dsl.expressions import rolling, unary
|
|
33
37
|
from cudf_polars.dsl.expressions.base import ExecutionContext
|
|
34
38
|
from cudf_polars.dsl.nodebase import Node
|
|
35
39
|
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
|
|
36
|
-
from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
|
|
40
|
+
from cudf_polars.dsl.tracing import log_do_evaluate, nvtx_annotate_cudf_polars
|
|
37
41
|
from cudf_polars.dsl.utils.reshape import broadcast
|
|
38
|
-
from cudf_polars.dsl.utils.windows import
|
|
42
|
+
from cudf_polars.dsl.utils.windows import (
|
|
43
|
+
offsets_to_windows,
|
|
44
|
+
range_window_bounds,
|
|
45
|
+
)
|
|
39
46
|
from cudf_polars.utils import dtypes
|
|
40
|
-
from cudf_polars.utils.
|
|
47
|
+
from cudf_polars.utils.config import CUDAStreamPolicy
|
|
48
|
+
from cudf_polars.utils.cuda_stream import (
|
|
49
|
+
get_cuda_stream,
|
|
50
|
+
get_joined_cuda_stream,
|
|
51
|
+
get_new_cuda_stream,
|
|
52
|
+
join_cuda_streams,
|
|
53
|
+
)
|
|
54
|
+
from cudf_polars.utils.versions import POLARS_VERSION_LT_131, POLARS_VERSION_LT_134
|
|
41
55
|
|
|
42
56
|
if TYPE_CHECKING:
|
|
43
|
-
from collections.abc import Callable, Hashable, Iterable, Sequence
|
|
57
|
+
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
|
|
44
58
|
from typing import Literal
|
|
45
59
|
|
|
46
60
|
from typing_extensions import Self
|
|
47
61
|
|
|
48
|
-
from polars
|
|
62
|
+
from polars import polars # type: ignore[attr-defined]
|
|
63
|
+
|
|
64
|
+
from rmm.pylibrmm.stream import Stream
|
|
49
65
|
|
|
50
66
|
from cudf_polars.containers.dataframe import NamedColumn
|
|
51
67
|
from cudf_polars.typing import CSECache, ClosedInterval, Schema, Slice as Zlice
|
|
52
|
-
from cudf_polars.utils.config import ParquetOptions
|
|
68
|
+
from cudf_polars.utils.config import ConfigOptions, ParquetOptions
|
|
53
69
|
from cudf_polars.utils.timer import Timer
|
|
54
70
|
|
|
55
|
-
|
|
56
71
|
__all__ = [
|
|
57
72
|
"IR",
|
|
58
73
|
"Cache",
|
|
@@ -65,6 +80,7 @@ __all__ = [
|
|
|
65
80
|
"GroupBy",
|
|
66
81
|
"HConcat",
|
|
67
82
|
"HStack",
|
|
83
|
+
"IRExecutionContext",
|
|
68
84
|
"Join",
|
|
69
85
|
"MapFunction",
|
|
70
86
|
"MergeSorted",
|
|
@@ -81,6 +97,97 @@ __all__ = [
|
|
|
81
97
|
]
|
|
82
98
|
|
|
83
99
|
|
|
100
|
+
@dataclass(frozen=True)
|
|
101
|
+
class IRExecutionContext:
|
|
102
|
+
"""
|
|
103
|
+
Runtime context for IR node execution.
|
|
104
|
+
|
|
105
|
+
This dataclass holds runtime information and configuration needed
|
|
106
|
+
during the evaluation of IR nodes.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
get_cuda_stream
|
|
111
|
+
A zero-argument callable that returns a CUDA stream.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
get_cuda_stream: Callable[[], Stream]
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def from_config_options(cls, config_options: ConfigOptions) -> IRExecutionContext:
|
|
118
|
+
"""Create an IRExecutionContext from ConfigOptions."""
|
|
119
|
+
match config_options.cuda_stream_policy:
|
|
120
|
+
case CUDAStreamPolicy.DEFAULT:
|
|
121
|
+
return cls(get_cuda_stream=get_cuda_stream)
|
|
122
|
+
case CUDAStreamPolicy.NEW:
|
|
123
|
+
return cls(get_cuda_stream=get_new_cuda_stream)
|
|
124
|
+
case _: # pragma: no cover
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Invalid CUDA stream policy: {config_options.cuda_stream_policy}"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@contextlib.contextmanager
|
|
130
|
+
def stream_ordered_after(self, *dfs: DataFrame) -> Generator[Stream, None, None]:
|
|
131
|
+
"""
|
|
132
|
+
Get a joined CUDA stream with safe stream ordering for deallocation of inputs.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
dfs
|
|
137
|
+
The dataframes being provided to stream-ordered operations.
|
|
138
|
+
|
|
139
|
+
Yields
|
|
140
|
+
------
|
|
141
|
+
A CUDA stream that is downstream of the given dataframes.
|
|
142
|
+
|
|
143
|
+
Notes
|
|
144
|
+
-----
|
|
145
|
+
This context manager provides two useful guarantees when working with
|
|
146
|
+
objects holding references to stream-ordered objects:
|
|
147
|
+
|
|
148
|
+
1. The stream yield upon entering the context manager is *downstream* of
|
|
149
|
+
all the input dataframes. This ensures that you can safely perform
|
|
150
|
+
stream-ordered operations on any input using the yielded stream.
|
|
151
|
+
2. The stream-ordered CUDA deallocation of the inputs happens *after* the
|
|
152
|
+
context manager exits. This ensures that all stream-ordered operations
|
|
153
|
+
submitted inside the context manager can complete before the memory
|
|
154
|
+
referenced by the inputs is deallocated.
|
|
155
|
+
|
|
156
|
+
Note that this does (deliberately) disconnect the dropping of the Python
|
|
157
|
+
object (by its refcount dropping to 0) from the actual stream-ordered
|
|
158
|
+
deallocation of the CUDA memory. This is precisely what we need to ensure
|
|
159
|
+
that the inputs are valid long enough for the stream-ordered operations to
|
|
160
|
+
complete.
|
|
161
|
+
"""
|
|
162
|
+
result_stream = get_joined_cuda_stream(
|
|
163
|
+
self.get_cuda_stream, upstreams=[df.stream for df in dfs]
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
yield result_stream
|
|
167
|
+
|
|
168
|
+
# ensure that the inputs are downstream of result_stream (so that deallocation happens after the result is ready)
|
|
169
|
+
join_cuda_streams(
|
|
170
|
+
downstreams=[df.stream for df in dfs], upstreams=[result_stream]
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
_BINOPS = {
|
|
175
|
+
plc.binaryop.BinaryOperator.EQUAL,
|
|
176
|
+
plc.binaryop.BinaryOperator.NOT_EQUAL,
|
|
177
|
+
plc.binaryop.BinaryOperator.LESS,
|
|
178
|
+
plc.binaryop.BinaryOperator.LESS_EQUAL,
|
|
179
|
+
plc.binaryop.BinaryOperator.GREATER,
|
|
180
|
+
plc.binaryop.BinaryOperator.GREATER_EQUAL,
|
|
181
|
+
# TODO: Handle other binary operations as needed
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
_DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
_FLOAT_TYPES = {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}
|
|
189
|
+
|
|
190
|
+
|
|
84
191
|
class IR(Node["IR"]):
|
|
85
192
|
"""Abstract plan node, representing an unevaluated dataframe."""
|
|
86
193
|
|
|
@@ -134,7 +241,9 @@ class IR(Node["IR"]):
|
|
|
134
241
|
translation phase should fail earlier.
|
|
135
242
|
"""
|
|
136
243
|
|
|
137
|
-
def evaluate(
|
|
244
|
+
def evaluate(
|
|
245
|
+
self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
|
|
246
|
+
) -> DataFrame:
|
|
138
247
|
"""
|
|
139
248
|
Evaluate the node (recursively) and return a dataframe.
|
|
140
249
|
|
|
@@ -146,6 +255,8 @@ class IR(Node["IR"]):
|
|
|
146
255
|
timer
|
|
147
256
|
If not None, a Timer object to record timings for the
|
|
148
257
|
evaluation of the node.
|
|
258
|
+
context
|
|
259
|
+
The execution context for the node.
|
|
149
260
|
|
|
150
261
|
Notes
|
|
151
262
|
-----
|
|
@@ -164,16 +275,19 @@ class IR(Node["IR"]):
|
|
|
164
275
|
If evaluation fails. Ideally this should not occur, since the
|
|
165
276
|
translation phase should fail earlier.
|
|
166
277
|
"""
|
|
167
|
-
children = [
|
|
278
|
+
children = [
|
|
279
|
+
child.evaluate(cache=cache, timer=timer, context=context)
|
|
280
|
+
for child in self.children
|
|
281
|
+
]
|
|
168
282
|
if timer is not None:
|
|
169
283
|
start = time.monotonic_ns()
|
|
170
|
-
result = self.do_evaluate(*self._non_child_args, *children)
|
|
284
|
+
result = self.do_evaluate(*self._non_child_args, *children, context=context)
|
|
171
285
|
end = time.monotonic_ns()
|
|
172
286
|
# TODO: Set better names on each class object.
|
|
173
287
|
timer.store(start, end, type(self).__name__)
|
|
174
288
|
return result
|
|
175
289
|
else:
|
|
176
|
-
return self.do_evaluate(*self._non_child_args, *children)
|
|
290
|
+
return self.do_evaluate(*self._non_child_args, *children, context=context)
|
|
177
291
|
|
|
178
292
|
|
|
179
293
|
class ErrorNode(IR):
|
|
@@ -212,29 +326,97 @@ class PythonScan(IR):
|
|
|
212
326
|
raise NotImplementedError("PythonScan not implemented")
|
|
213
327
|
|
|
214
328
|
|
|
329
|
+
_DECIMAL_IDS = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
|
|
330
|
+
|
|
331
|
+
_COMPARISON_BINOPS = {
|
|
332
|
+
plc.binaryop.BinaryOperator.EQUAL,
|
|
333
|
+
plc.binaryop.BinaryOperator.NOT_EQUAL,
|
|
334
|
+
plc.binaryop.BinaryOperator.LESS,
|
|
335
|
+
plc.binaryop.BinaryOperator.LESS_EQUAL,
|
|
336
|
+
plc.binaryop.BinaryOperator.GREATER,
|
|
337
|
+
plc.binaryop.BinaryOperator.GREATER_EQUAL,
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def _parquet_physical_types(
|
|
342
|
+
schema: Schema, paths: list[str], columns: list[str] | None, stream: Stream
|
|
343
|
+
) -> dict[str, plc.DataType]:
|
|
344
|
+
# TODO: Read the physical types as cudf::data_type's using
|
|
345
|
+
# read_parquet_metadata or another parquet API
|
|
346
|
+
options = plc.io.parquet.ParquetReaderOptions.builder(
|
|
347
|
+
plc.io.SourceInfo(paths)
|
|
348
|
+
).build()
|
|
349
|
+
if columns is not None:
|
|
350
|
+
options.set_columns(columns)
|
|
351
|
+
options.set_num_rows(0)
|
|
352
|
+
df = plc.io.parquet.read_parquet(options, stream=stream)
|
|
353
|
+
return dict(zip(schema.keys(), [c.type() for c in df.tbl.columns()], strict=True))
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _cast_literal_to_decimal(
|
|
357
|
+
side: expr.Expr, lit: expr.Literal, phys_type_map: dict[str, plc.DataType]
|
|
358
|
+
) -> expr.Expr:
|
|
359
|
+
if isinstance(side, expr.Cast):
|
|
360
|
+
col = side.children[0]
|
|
361
|
+
assert isinstance(col, expr.Col)
|
|
362
|
+
name = col.name
|
|
363
|
+
else:
|
|
364
|
+
assert isinstance(side, expr.Col)
|
|
365
|
+
name = side.name
|
|
366
|
+
if (type_ := phys_type_map[name]).id() in _DECIMAL_IDS:
|
|
367
|
+
scale = abs(type_.scale())
|
|
368
|
+
return expr.Cast(
|
|
369
|
+
side.dtype,
|
|
370
|
+
True, # noqa: FBT003
|
|
371
|
+
expr.Cast(DataType(pl.Decimal(38, scale)), True, lit), # noqa: FBT003
|
|
372
|
+
)
|
|
373
|
+
return lit
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def _cast_literals_to_physical_types(
|
|
377
|
+
node: expr.Expr, phys_type_map: dict[str, plc.DataType]
|
|
378
|
+
) -> expr.Expr:
|
|
379
|
+
if isinstance(node, expr.BinOp):
|
|
380
|
+
left, right = node.children
|
|
381
|
+
left = _cast_literals_to_physical_types(left, phys_type_map)
|
|
382
|
+
right = _cast_literals_to_physical_types(right, phys_type_map)
|
|
383
|
+
if node.op in _COMPARISON_BINOPS:
|
|
384
|
+
if isinstance(left, (expr.Col, expr.Cast)) and isinstance(
|
|
385
|
+
right, expr.Literal
|
|
386
|
+
):
|
|
387
|
+
right = _cast_literal_to_decimal(left, right, phys_type_map)
|
|
388
|
+
elif isinstance(right, (expr.Col, expr.Cast)) and isinstance(
|
|
389
|
+
left, expr.Literal
|
|
390
|
+
):
|
|
391
|
+
left = _cast_literal_to_decimal(right, left, phys_type_map)
|
|
392
|
+
|
|
393
|
+
return node.reconstruct([left, right])
|
|
394
|
+
return node
|
|
395
|
+
|
|
396
|
+
|
|
215
397
|
def _align_parquet_schema(df: DataFrame, schema: Schema) -> DataFrame:
|
|
216
398
|
# TODO: Alternatively set the schema of the parquet reader to decimal128
|
|
217
|
-
plc_decimals_ids = {
|
|
218
|
-
plc.TypeId.DECIMAL32,
|
|
219
|
-
plc.TypeId.DECIMAL64,
|
|
220
|
-
plc.TypeId.DECIMAL128,
|
|
221
|
-
}
|
|
222
399
|
cast_list = []
|
|
223
400
|
|
|
224
401
|
for name, col in df.column_map.items():
|
|
225
402
|
src = col.obj.type()
|
|
226
|
-
dst = schema[name].
|
|
403
|
+
dst = schema[name].plc_type
|
|
404
|
+
|
|
227
405
|
if (
|
|
228
|
-
|
|
229
|
-
and
|
|
230
|
-
and ((src.id() != dst.id()) or (src.scale != dst.scale))
|
|
406
|
+
plc.traits.is_fixed_point(src)
|
|
407
|
+
and plc.traits.is_fixed_point(dst)
|
|
408
|
+
and ((src.id() != dst.id()) or (src.scale() != dst.scale()))
|
|
231
409
|
):
|
|
232
410
|
cast_list.append(
|
|
233
|
-
Column(
|
|
411
|
+
Column(
|
|
412
|
+
plc.unary.cast(col.obj, dst, stream=df.stream),
|
|
413
|
+
name=name,
|
|
414
|
+
dtype=schema[name],
|
|
415
|
+
)
|
|
234
416
|
)
|
|
235
417
|
|
|
236
418
|
if cast_list:
|
|
237
|
-
df = df.with_columns(cast_list)
|
|
419
|
+
df = df.with_columns(cast_list, stream=df.stream)
|
|
238
420
|
|
|
239
421
|
return df
|
|
240
422
|
|
|
@@ -460,13 +642,24 @@ class Scan(IR):
|
|
|
460
642
|
Each path is repeated according to the number of rows read from it.
|
|
461
643
|
"""
|
|
462
644
|
(filepaths,) = plc.filling.repeat(
|
|
463
|
-
plc.Table(
|
|
645
|
+
plc.Table(
|
|
646
|
+
[
|
|
647
|
+
plc.Column.from_arrow(
|
|
648
|
+
pl.Series(values=map(str, paths)),
|
|
649
|
+
stream=df.stream,
|
|
650
|
+
)
|
|
651
|
+
]
|
|
652
|
+
),
|
|
464
653
|
plc.Column.from_arrow(
|
|
465
|
-
pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32())
|
|
654
|
+
pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32()),
|
|
655
|
+
stream=df.stream,
|
|
466
656
|
),
|
|
657
|
+
stream=df.stream,
|
|
467
658
|
).columns()
|
|
468
659
|
dtype = DataType(pl.String())
|
|
469
|
-
return df.with_columns(
|
|
660
|
+
return df.with_columns(
|
|
661
|
+
[Column(filepaths, name=name, dtype=dtype)], stream=df.stream
|
|
662
|
+
)
|
|
470
663
|
|
|
471
664
|
def fast_count(self) -> int: # pragma: no cover
|
|
472
665
|
"""Get the number of rows in a Parquet Scan."""
|
|
@@ -479,6 +672,7 @@ class Scan(IR):
|
|
|
479
672
|
return max(total_rows, 0)
|
|
480
673
|
|
|
481
674
|
@classmethod
|
|
675
|
+
@log_do_evaluate
|
|
482
676
|
@nvtx_annotate_cudf_polars(message="Scan")
|
|
483
677
|
def do_evaluate(
|
|
484
678
|
cls,
|
|
@@ -493,8 +687,11 @@ class Scan(IR):
|
|
|
493
687
|
include_file_paths: str | None,
|
|
494
688
|
predicate: expr.NamedExpr | None,
|
|
495
689
|
parquet_options: ParquetOptions,
|
|
690
|
+
*,
|
|
691
|
+
context: IRExecutionContext,
|
|
496
692
|
) -> DataFrame:
|
|
497
693
|
"""Evaluate and return a dataframe."""
|
|
694
|
+
stream = context.get_cuda_stream()
|
|
498
695
|
if typ == "csv":
|
|
499
696
|
|
|
500
697
|
def read_csv_header(
|
|
@@ -551,6 +748,7 @@ class Scan(IR):
|
|
|
551
748
|
plc.io.csv.CsvReaderOptions.builder(plc.io.SourceInfo([path]))
|
|
552
749
|
.nrows(n_rows)
|
|
553
750
|
.skiprows(skiprows + skip_rows)
|
|
751
|
+
.skip_blank_lines(skip_blank_lines=False)
|
|
554
752
|
.lineterminator(str(eol))
|
|
555
753
|
.quotechar(str(quote))
|
|
556
754
|
.decimal(decimal)
|
|
@@ -567,13 +765,15 @@ class Scan(IR):
|
|
|
567
765
|
column_names = read_csv_header(path, str(sep))
|
|
568
766
|
options.set_names(column_names)
|
|
569
767
|
options.set_header(header)
|
|
570
|
-
options.set_dtypes(
|
|
768
|
+
options.set_dtypes(
|
|
769
|
+
{name: dtype.plc_type for name, dtype in schema.items()}
|
|
770
|
+
)
|
|
571
771
|
if usecols is not None:
|
|
572
772
|
options.set_use_cols_names([str(name) for name in usecols])
|
|
573
773
|
options.set_na_values(null_values)
|
|
574
774
|
if comment is not None:
|
|
575
775
|
options.set_comment(comment)
|
|
576
|
-
tbl_w_meta = plc.io.csv.read_csv(options)
|
|
776
|
+
tbl_w_meta = plc.io.csv.read_csv(options, stream=stream)
|
|
577
777
|
pieces.append(tbl_w_meta)
|
|
578
778
|
if include_file_paths is not None:
|
|
579
779
|
seen_paths.append(p)
|
|
@@ -589,9 +789,10 @@ class Scan(IR):
|
|
|
589
789
|
strict=True,
|
|
590
790
|
)
|
|
591
791
|
df = DataFrame.from_table(
|
|
592
|
-
plc.concatenate.concatenate(list(tables)),
|
|
792
|
+
plc.concatenate.concatenate(list(tables), stream=stream),
|
|
593
793
|
colnames,
|
|
594
794
|
[schema[colname] for colname in colnames],
|
|
795
|
+
stream=stream,
|
|
595
796
|
)
|
|
596
797
|
if include_file_paths is not None:
|
|
597
798
|
df = Scan.add_file_paths(
|
|
@@ -604,42 +805,50 @@ class Scan(IR):
|
|
|
604
805
|
filters = None
|
|
605
806
|
if predicate is not None and row_index is None:
|
|
606
807
|
# Can't apply filters during read if we have a row index.
|
|
607
|
-
filters = to_parquet_filter(
|
|
608
|
-
|
|
808
|
+
filters = to_parquet_filter(
|
|
809
|
+
_cast_literals_to_physical_types(
|
|
810
|
+
predicate.value,
|
|
811
|
+
_parquet_physical_types(
|
|
812
|
+
schema, paths, with_columns or list(schema.keys()), stream
|
|
813
|
+
),
|
|
814
|
+
),
|
|
815
|
+
stream=stream,
|
|
816
|
+
)
|
|
817
|
+
parquet_reader_options = plc.io.parquet.ParquetReaderOptions.builder(
|
|
609
818
|
plc.io.SourceInfo(paths)
|
|
610
819
|
).build()
|
|
611
820
|
if with_columns is not None:
|
|
612
|
-
|
|
821
|
+
parquet_reader_options.set_columns(with_columns)
|
|
613
822
|
if filters is not None:
|
|
614
|
-
|
|
823
|
+
parquet_reader_options.set_filter(filters)
|
|
615
824
|
if n_rows != -1:
|
|
616
|
-
|
|
825
|
+
parquet_reader_options.set_num_rows(n_rows)
|
|
617
826
|
if skip_rows != 0:
|
|
618
|
-
|
|
827
|
+
parquet_reader_options.set_skip_rows(skip_rows)
|
|
619
828
|
if parquet_options.chunked:
|
|
620
829
|
reader = plc.io.parquet.ChunkedParquetReader(
|
|
621
|
-
|
|
830
|
+
parquet_reader_options,
|
|
622
831
|
chunk_read_limit=parquet_options.chunk_read_limit,
|
|
623
832
|
pass_read_limit=parquet_options.pass_read_limit,
|
|
833
|
+
stream=stream,
|
|
624
834
|
)
|
|
625
835
|
chunk = reader.read_chunk()
|
|
626
|
-
tbl = chunk.tbl
|
|
627
836
|
# TODO: Nested column names
|
|
628
837
|
names = chunk.column_names(include_children=False)
|
|
629
|
-
concatenated_columns = tbl.columns()
|
|
838
|
+
concatenated_columns = chunk.tbl.columns()
|
|
630
839
|
while reader.has_next():
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
840
|
+
columns = reader.read_chunk().tbl.columns()
|
|
841
|
+
# Discard columns while concatenating to reduce memory footprint.
|
|
842
|
+
# Reverse order to avoid O(n^2) list popping cost.
|
|
843
|
+
for i in range(len(concatenated_columns) - 1, -1, -1):
|
|
634
844
|
concatenated_columns[i] = plc.concatenate.concatenate(
|
|
635
|
-
[concatenated_columns[i],
|
|
845
|
+
[concatenated_columns[i], columns.pop()], stream=stream
|
|
636
846
|
)
|
|
637
|
-
# Drop residual columns to save memory
|
|
638
|
-
tbl._columns[i] = None
|
|
639
847
|
df = DataFrame.from_table(
|
|
640
848
|
plc.Table(concatenated_columns),
|
|
641
849
|
names=names,
|
|
642
850
|
dtypes=[schema[name] for name in names],
|
|
851
|
+
stream=stream,
|
|
643
852
|
)
|
|
644
853
|
df = _align_parquet_schema(df, schema)
|
|
645
854
|
if include_file_paths is not None:
|
|
@@ -647,13 +856,16 @@ class Scan(IR):
|
|
|
647
856
|
include_file_paths, paths, chunk.num_rows_per_source, df
|
|
648
857
|
)
|
|
649
858
|
else:
|
|
650
|
-
tbl_w_meta = plc.io.parquet.read_parquet(
|
|
859
|
+
tbl_w_meta = plc.io.parquet.read_parquet(
|
|
860
|
+
parquet_reader_options, stream=stream
|
|
861
|
+
)
|
|
651
862
|
# TODO: consider nested column names?
|
|
652
863
|
col_names = tbl_w_meta.column_names(include_children=False)
|
|
653
864
|
df = DataFrame.from_table(
|
|
654
865
|
tbl_w_meta.tbl,
|
|
655
866
|
col_names,
|
|
656
867
|
[schema[name] for name in col_names],
|
|
868
|
+
stream=stream,
|
|
657
869
|
)
|
|
658
870
|
df = _align_parquet_schema(df, schema)
|
|
659
871
|
if include_file_paths is not None:
|
|
@@ -665,16 +877,16 @@ class Scan(IR):
|
|
|
665
877
|
return df
|
|
666
878
|
elif typ == "ndjson":
|
|
667
879
|
json_schema: list[plc.io.json.NameAndType] = [
|
|
668
|
-
(name, typ.
|
|
880
|
+
(name, typ.plc_type, []) for name, typ in schema.items()
|
|
669
881
|
]
|
|
670
|
-
|
|
671
|
-
plc.io.json.
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
)
|
|
882
|
+
json_reader_options = (
|
|
883
|
+
plc.io.json.JsonReaderOptions.builder(plc.io.SourceInfo(paths))
|
|
884
|
+
.lines(val=True)
|
|
885
|
+
.dtypes(json_schema)
|
|
886
|
+
.prune_columns(val=True)
|
|
887
|
+
.build()
|
|
677
888
|
)
|
|
889
|
+
plc_tbl_w_meta = plc.io.json.read_json(json_reader_options, stream=stream)
|
|
678
890
|
# TODO: I don't think cudf-polars supports nested types in general right now
|
|
679
891
|
# (but when it does, we should pass child column names from nested columns in)
|
|
680
892
|
col_names = plc_tbl_w_meta.column_names(include_children=False)
|
|
@@ -682,6 +894,7 @@ class Scan(IR):
|
|
|
682
894
|
plc_tbl_w_meta.tbl,
|
|
683
895
|
col_names,
|
|
684
896
|
[schema[name] for name in col_names],
|
|
897
|
+
stream=stream,
|
|
685
898
|
)
|
|
686
899
|
col_order = list(schema.keys())
|
|
687
900
|
if row_index is not None:
|
|
@@ -695,26 +908,28 @@ class Scan(IR):
|
|
|
695
908
|
name, offset = row_index
|
|
696
909
|
offset += skip_rows
|
|
697
910
|
dtype = schema[name]
|
|
698
|
-
step = plc.Scalar.from_py(1, dtype.
|
|
699
|
-
init = plc.Scalar.from_py(offset, dtype.
|
|
911
|
+
step = plc.Scalar.from_py(1, dtype.plc_type, stream=stream)
|
|
912
|
+
init = plc.Scalar.from_py(offset, dtype.plc_type, stream=stream)
|
|
700
913
|
index_col = Column(
|
|
701
|
-
plc.filling.sequence(df.num_rows, init, step),
|
|
914
|
+
plc.filling.sequence(df.num_rows, init, step, stream=stream),
|
|
702
915
|
is_sorted=plc.types.Sorted.YES,
|
|
703
916
|
order=plc.types.Order.ASCENDING,
|
|
704
917
|
null_order=plc.types.NullOrder.AFTER,
|
|
705
918
|
name=name,
|
|
706
919
|
dtype=dtype,
|
|
707
920
|
)
|
|
708
|
-
df = DataFrame([index_col, *df.columns])
|
|
921
|
+
df = DataFrame([index_col, *df.columns], stream=df.stream)
|
|
709
922
|
if next(iter(schema)) != name:
|
|
710
923
|
df = df.select(schema)
|
|
711
924
|
assert all(
|
|
712
|
-
c.obj.type() == schema[name].
|
|
925
|
+
c.obj.type() == schema[name].plc_type for name, c in df.column_map.items()
|
|
713
926
|
)
|
|
714
927
|
if predicate is None:
|
|
715
928
|
return df
|
|
716
929
|
else:
|
|
717
|
-
(mask,) = broadcast(
|
|
930
|
+
(mask,) = broadcast(
|
|
931
|
+
predicate.evaluate(df), target_length=df.num_rows, stream=df.stream
|
|
932
|
+
)
|
|
718
933
|
return df.filter(mask)
|
|
719
934
|
|
|
720
935
|
|
|
@@ -775,7 +990,8 @@ class Sink(IR):
|
|
|
775
990
|
child_schema = df.schema.values()
|
|
776
991
|
if kind == "Csv":
|
|
777
992
|
if not all(
|
|
778
|
-
plc.io.csv.is_supported_write_csv(dtype.
|
|
993
|
+
plc.io.csv.is_supported_write_csv(dtype.plc_type)
|
|
994
|
+
for dtype in child_schema
|
|
779
995
|
):
|
|
780
996
|
# Nested types are unsupported in polars and libcudf
|
|
781
997
|
raise NotImplementedError(
|
|
@@ -826,7 +1042,8 @@ class Sink(IR):
|
|
|
826
1042
|
kind == "Json"
|
|
827
1043
|
): # pragma: no cover; options are validated on the polars side
|
|
828
1044
|
if not all(
|
|
829
|
-
plc.io.json.is_supported_write_json(dtype.
|
|
1045
|
+
plc.io.json.is_supported_write_json(dtype.plc_type)
|
|
1046
|
+
for dtype in child_schema
|
|
830
1047
|
):
|
|
831
1048
|
# Nested types are unsupported in polars and libcudf
|
|
832
1049
|
raise NotImplementedError(
|
|
@@ -863,7 +1080,7 @@ class Sink(IR):
|
|
|
863
1080
|
) -> None:
|
|
864
1081
|
"""Write CSV data to a sink."""
|
|
865
1082
|
serialize = options["serialize_options"]
|
|
866
|
-
|
|
1083
|
+
csv_writer_options = (
|
|
867
1084
|
plc.io.csv.CsvWriterOptions.builder(target, df.table)
|
|
868
1085
|
.include_header(options["include_header"])
|
|
869
1086
|
.names(df.column_names if options["include_header"] else [])
|
|
@@ -872,7 +1089,7 @@ class Sink(IR):
|
|
|
872
1089
|
.inter_column_delimiter(chr(serialize["separator"]))
|
|
873
1090
|
.build()
|
|
874
1091
|
)
|
|
875
|
-
plc.io.csv.write_csv(
|
|
1092
|
+
plc.io.csv.write_csv(csv_writer_options, stream=df.stream)
|
|
876
1093
|
|
|
877
1094
|
@classmethod
|
|
878
1095
|
def _write_json(cls, target: plc.io.SinkInfo, df: DataFrame) -> None:
|
|
@@ -889,7 +1106,7 @@ class Sink(IR):
|
|
|
889
1106
|
.utf8_escaped(val=False)
|
|
890
1107
|
.build()
|
|
891
1108
|
)
|
|
892
|
-
plc.io.json.write_json(options)
|
|
1109
|
+
plc.io.json.write_json(options, stream=df.stream)
|
|
893
1110
|
|
|
894
1111
|
@staticmethod
|
|
895
1112
|
def _make_parquet_metadata(df: DataFrame) -> plc.io.types.TableInputMetadata:
|
|
@@ -899,6 +1116,20 @@ class Sink(IR):
|
|
|
899
1116
|
metadata.column_metadata[i].set_name(name)
|
|
900
1117
|
return metadata
|
|
901
1118
|
|
|
1119
|
+
@overload
|
|
1120
|
+
@staticmethod
|
|
1121
|
+
def _apply_parquet_writer_options(
|
|
1122
|
+
builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder,
|
|
1123
|
+
options: dict[str, Any],
|
|
1124
|
+
) -> plc.io.parquet.ChunkedParquetWriterOptionsBuilder: ...
|
|
1125
|
+
|
|
1126
|
+
@overload
|
|
1127
|
+
@staticmethod
|
|
1128
|
+
def _apply_parquet_writer_options(
|
|
1129
|
+
builder: plc.io.parquet.ParquetWriterOptionsBuilder,
|
|
1130
|
+
options: dict[str, Any],
|
|
1131
|
+
) -> plc.io.parquet.ParquetWriterOptionsBuilder: ...
|
|
1132
|
+
|
|
902
1133
|
@staticmethod
|
|
903
1134
|
def _apply_parquet_writer_options(
|
|
904
1135
|
builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder
|
|
@@ -944,12 +1175,16 @@ class Sink(IR):
|
|
|
944
1175
|
and parquet_options.n_output_chunks != 1
|
|
945
1176
|
and df.table.num_rows() != 0
|
|
946
1177
|
):
|
|
947
|
-
|
|
1178
|
+
chunked_builder = plc.io.parquet.ChunkedParquetWriterOptions.builder(
|
|
948
1179
|
target
|
|
949
1180
|
).metadata(metadata)
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
1181
|
+
chunked_builder = cls._apply_parquet_writer_options(
|
|
1182
|
+
chunked_builder, options
|
|
1183
|
+
)
|
|
1184
|
+
chunked_writer_options = chunked_builder.build()
|
|
1185
|
+
writer = plc.io.parquet.ChunkedParquetWriter.from_options(
|
|
1186
|
+
chunked_writer_options, stream=df.stream
|
|
1187
|
+
)
|
|
953
1188
|
|
|
954
1189
|
# TODO: Can be based on a heuristic that estimates chunk size
|
|
955
1190
|
# from the input table size and available GPU memory.
|
|
@@ -957,6 +1192,7 @@ class Sink(IR):
|
|
|
957
1192
|
table_chunks = plc.copying.split(
|
|
958
1193
|
df.table,
|
|
959
1194
|
[i * df.table.num_rows() // num_chunks for i in range(1, num_chunks)],
|
|
1195
|
+
stream=df.stream,
|
|
960
1196
|
)
|
|
961
1197
|
for chunk in table_chunks:
|
|
962
1198
|
writer.write(chunk)
|
|
@@ -968,9 +1204,10 @@ class Sink(IR):
|
|
|
968
1204
|
).metadata(metadata)
|
|
969
1205
|
builder = cls._apply_parquet_writer_options(builder, options)
|
|
970
1206
|
writer_options = builder.build()
|
|
971
|
-
plc.io.parquet.write_parquet(writer_options)
|
|
1207
|
+
plc.io.parquet.write_parquet(writer_options, stream=df.stream)
|
|
972
1208
|
|
|
973
1209
|
@classmethod
|
|
1210
|
+
@log_do_evaluate
|
|
974
1211
|
@nvtx_annotate_cudf_polars(message="Sink")
|
|
975
1212
|
def do_evaluate(
|
|
976
1213
|
cls,
|
|
@@ -980,6 +1217,8 @@ class Sink(IR):
|
|
|
980
1217
|
parquet_options: ParquetOptions,
|
|
981
1218
|
options: dict[str, Any],
|
|
982
1219
|
df: DataFrame,
|
|
1220
|
+
*,
|
|
1221
|
+
context: IRExecutionContext,
|
|
983
1222
|
) -> DataFrame:
|
|
984
1223
|
"""Write the dataframe to a file."""
|
|
985
1224
|
target = plc.io.SinkInfo([path])
|
|
@@ -993,7 +1232,7 @@ class Sink(IR):
|
|
|
993
1232
|
elif kind == "Json":
|
|
994
1233
|
cls._write_json(target, df)
|
|
995
1234
|
|
|
996
|
-
return DataFrame([])
|
|
1235
|
+
return DataFrame([], stream=df.stream)
|
|
997
1236
|
|
|
998
1237
|
|
|
999
1238
|
class Cache(IR):
|
|
@@ -1030,16 +1269,24 @@ class Cache(IR):
|
|
|
1030
1269
|
return False
|
|
1031
1270
|
|
|
1032
1271
|
@classmethod
|
|
1272
|
+
@log_do_evaluate
|
|
1033
1273
|
@nvtx_annotate_cudf_polars(message="Cache")
|
|
1034
1274
|
def do_evaluate(
|
|
1035
|
-
cls,
|
|
1275
|
+
cls,
|
|
1276
|
+
key: int,
|
|
1277
|
+
refcount: int | None,
|
|
1278
|
+
df: DataFrame,
|
|
1279
|
+
*,
|
|
1280
|
+
context: IRExecutionContext,
|
|
1036
1281
|
) -> DataFrame: # pragma: no cover; basic evaluation never calls this
|
|
1037
1282
|
"""Evaluate and return a dataframe."""
|
|
1038
1283
|
# Our value has already been computed for us, so let's just
|
|
1039
1284
|
# return it.
|
|
1040
1285
|
return df
|
|
1041
1286
|
|
|
1042
|
-
def evaluate(
|
|
1287
|
+
def evaluate(
|
|
1288
|
+
self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
|
|
1289
|
+
) -> DataFrame:
|
|
1043
1290
|
"""Evaluate and return a dataframe."""
|
|
1044
1291
|
# We must override the recursion scheme because we don't want
|
|
1045
1292
|
# to recurse if we're in the cache.
|
|
@@ -1047,7 +1294,7 @@ class Cache(IR):
|
|
|
1047
1294
|
(result, hits) = cache[self.key]
|
|
1048
1295
|
except KeyError:
|
|
1049
1296
|
(value,) = self.children
|
|
1050
|
-
result = value.evaluate(cache=cache, timer=timer)
|
|
1297
|
+
result = value.evaluate(cache=cache, timer=timer, context=context)
|
|
1051
1298
|
cache[self.key] = (result, 0)
|
|
1052
1299
|
return result
|
|
1053
1300
|
else:
|
|
@@ -1093,6 +1340,42 @@ class DataFrameScan(IR):
|
|
|
1093
1340
|
self.children = ()
|
|
1094
1341
|
self._id_for_hash = random.randint(0, 2**64 - 1)
|
|
1095
1342
|
|
|
1343
|
+
@staticmethod
|
|
1344
|
+
def _reconstruct(
|
|
1345
|
+
schema: Schema,
|
|
1346
|
+
pl_df: pl.DataFrame,
|
|
1347
|
+
projection: Sequence[str] | None,
|
|
1348
|
+
id_for_hash: int,
|
|
1349
|
+
) -> DataFrameScan: # pragma: no cover
|
|
1350
|
+
"""
|
|
1351
|
+
Reconstruct a DataFrameScan from pickled data.
|
|
1352
|
+
|
|
1353
|
+
Parameters
|
|
1354
|
+
----------
|
|
1355
|
+
schema: Schema
|
|
1356
|
+
The schema of the DataFrameScan.
|
|
1357
|
+
pl_df: pl.DataFrame
|
|
1358
|
+
The underlying polars DataFrame.
|
|
1359
|
+
projection: Sequence[str] | None
|
|
1360
|
+
The projection of the DataFrameScan.
|
|
1361
|
+
id_for_hash: int
|
|
1362
|
+
The id for hash of the DataFrameScan.
|
|
1363
|
+
|
|
1364
|
+
Returns
|
|
1365
|
+
-------
|
|
1366
|
+
The reconstructed DataFrameScan.
|
|
1367
|
+
"""
|
|
1368
|
+
node = DataFrameScan(schema, pl_df._df, projection)
|
|
1369
|
+
node._id_for_hash = id_for_hash
|
|
1370
|
+
return node
|
|
1371
|
+
|
|
1372
|
+
def __reduce__(self) -> tuple[Any, ...]: # pragma: no cover
|
|
1373
|
+
"""Pickle a DataFrameScan object."""
|
|
1374
|
+
return (
|
|
1375
|
+
self._reconstruct,
|
|
1376
|
+
(*self._non_child_args, self._id_for_hash),
|
|
1377
|
+
)
|
|
1378
|
+
|
|
1096
1379
|
def get_hashable(self) -> Hashable:
|
|
1097
1380
|
"""
|
|
1098
1381
|
Hashable representation of the node.
|
|
@@ -1109,20 +1392,34 @@ class DataFrameScan(IR):
|
|
|
1109
1392
|
self.projection,
|
|
1110
1393
|
)
|
|
1111
1394
|
|
|
1395
|
+
def is_equal(self, other: Self) -> bool:
|
|
1396
|
+
"""Equality of DataFrameScan nodes."""
|
|
1397
|
+
return self is other or (
|
|
1398
|
+
self._id_for_hash == other._id_for_hash
|
|
1399
|
+
and self.schema == other.schema
|
|
1400
|
+
and self.projection == other.projection
|
|
1401
|
+
and pl.DataFrame._from_pydf(self.df).equals(
|
|
1402
|
+
pl.DataFrame._from_pydf(other.df)
|
|
1403
|
+
)
|
|
1404
|
+
)
|
|
1405
|
+
|
|
1112
1406
|
@classmethod
|
|
1407
|
+
@log_do_evaluate
|
|
1113
1408
|
@nvtx_annotate_cudf_polars(message="DataFrameScan")
|
|
1114
1409
|
def do_evaluate(
|
|
1115
1410
|
cls,
|
|
1116
1411
|
schema: Schema,
|
|
1117
1412
|
df: Any,
|
|
1118
1413
|
projection: tuple[str, ...] | None,
|
|
1414
|
+
*,
|
|
1415
|
+
context: IRExecutionContext,
|
|
1119
1416
|
) -> DataFrame:
|
|
1120
1417
|
"""Evaluate and return a dataframe."""
|
|
1121
1418
|
if projection is not None:
|
|
1122
1419
|
df = df.select(projection)
|
|
1123
|
-
df = DataFrame.from_polars(df)
|
|
1420
|
+
df = DataFrame.from_polars(df, stream=context.get_cuda_stream())
|
|
1124
1421
|
assert all(
|
|
1125
|
-
c.obj.type() == dtype.
|
|
1422
|
+
c.obj.type() == dtype.plc_type
|
|
1126
1423
|
for c, dtype in zip(df.columns, schema.values(), strict=True)
|
|
1127
1424
|
)
|
|
1128
1425
|
return df
|
|
@@ -1169,21 +1466,26 @@ class Select(IR):
|
|
|
1169
1466
|
return False
|
|
1170
1467
|
|
|
1171
1468
|
@classmethod
|
|
1469
|
+
@log_do_evaluate
|
|
1172
1470
|
@nvtx_annotate_cudf_polars(message="Select")
|
|
1173
1471
|
def do_evaluate(
|
|
1174
1472
|
cls,
|
|
1175
1473
|
exprs: tuple[expr.NamedExpr, ...],
|
|
1176
1474
|
should_broadcast: bool, # noqa: FBT001
|
|
1177
1475
|
df: DataFrame,
|
|
1476
|
+
*,
|
|
1477
|
+
context: IRExecutionContext,
|
|
1178
1478
|
) -> DataFrame:
|
|
1179
1479
|
"""Evaluate and return a dataframe."""
|
|
1180
1480
|
# Handle any broadcasting
|
|
1181
1481
|
columns = [e.evaluate(df) for e in exprs]
|
|
1182
1482
|
if should_broadcast:
|
|
1183
|
-
columns = broadcast(*columns)
|
|
1184
|
-
return DataFrame(columns)
|
|
1483
|
+
columns = broadcast(*columns, stream=df.stream)
|
|
1484
|
+
return DataFrame(columns, stream=df.stream)
|
|
1185
1485
|
|
|
1186
|
-
def evaluate(
|
|
1486
|
+
def evaluate(
|
|
1487
|
+
self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
|
|
1488
|
+
) -> DataFrame:
|
|
1187
1489
|
"""
|
|
1188
1490
|
Evaluate the Select node with special handling for fast count queries.
|
|
1189
1491
|
|
|
@@ -1195,6 +1497,8 @@ class Select(IR):
|
|
|
1195
1497
|
timer
|
|
1196
1498
|
If not None, a Timer object to record timings for the
|
|
1197
1499
|
evaluation of the node.
|
|
1500
|
+
context
|
|
1501
|
+
The execution context for the node.
|
|
1198
1502
|
|
|
1199
1503
|
Returns
|
|
1200
1504
|
-------
|
|
@@ -1214,21 +1518,23 @@ class Select(IR):
|
|
|
1214
1518
|
and Select._is_len_expr(self.exprs)
|
|
1215
1519
|
and self.children[0].typ == "parquet"
|
|
1216
1520
|
and self.children[0].predicate is None
|
|
1217
|
-
):
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1521
|
+
): # pragma: no cover
|
|
1522
|
+
stream = context.get_cuda_stream()
|
|
1523
|
+
scan = self.children[0]
|
|
1524
|
+
effective_rows = scan.fast_count()
|
|
1525
|
+
dtype = DataType(pl.UInt32())
|
|
1221
1526
|
col = Column(
|
|
1222
1527
|
plc.Column.from_scalar(
|
|
1223
|
-
plc.Scalar.from_py(effective_rows, dtype.
|
|
1528
|
+
plc.Scalar.from_py(effective_rows, dtype.plc_type, stream=stream),
|
|
1224
1529
|
1,
|
|
1530
|
+
stream=stream,
|
|
1225
1531
|
),
|
|
1226
1532
|
name=self.exprs[0].name or "len",
|
|
1227
1533
|
dtype=dtype,
|
|
1228
|
-
)
|
|
1229
|
-
return DataFrame([col])
|
|
1534
|
+
)
|
|
1535
|
+
return DataFrame([col], stream=stream)
|
|
1230
1536
|
|
|
1231
|
-
return super().evaluate(cache=cache, timer=timer)
|
|
1537
|
+
return super().evaluate(cache=cache, timer=timer, context=context)
|
|
1232
1538
|
|
|
1233
1539
|
|
|
1234
1540
|
class Reduce(IR):
|
|
@@ -1252,16 +1558,19 @@ class Reduce(IR):
|
|
|
1252
1558
|
self._non_child_args = (self.exprs,)
|
|
1253
1559
|
|
|
1254
1560
|
@classmethod
|
|
1561
|
+
@log_do_evaluate
|
|
1255
1562
|
@nvtx_annotate_cudf_polars(message="Reduce")
|
|
1256
1563
|
def do_evaluate(
|
|
1257
1564
|
cls,
|
|
1258
1565
|
exprs: tuple[expr.NamedExpr, ...],
|
|
1259
1566
|
df: DataFrame,
|
|
1567
|
+
*,
|
|
1568
|
+
context: IRExecutionContext,
|
|
1260
1569
|
) -> DataFrame: # pragma: no cover; not exposed by polars yet
|
|
1261
1570
|
"""Evaluate and return a dataframe."""
|
|
1262
|
-
columns = broadcast(*(e.evaluate(df) for e in exprs))
|
|
1571
|
+
columns = broadcast(*(e.evaluate(df) for e in exprs), stream=df.stream)
|
|
1263
1572
|
assert all(column.size == 1 for column in columns)
|
|
1264
|
-
return DataFrame(columns)
|
|
1573
|
+
return DataFrame(columns, stream=df.stream)
|
|
1265
1574
|
|
|
1266
1575
|
|
|
1267
1576
|
class Rolling(IR):
|
|
@@ -1270,17 +1579,19 @@ class Rolling(IR):
|
|
|
1270
1579
|
__slots__ = (
|
|
1271
1580
|
"agg_requests",
|
|
1272
1581
|
"closed_window",
|
|
1273
|
-
"
|
|
1582
|
+
"following_ordinal",
|
|
1274
1583
|
"index",
|
|
1584
|
+
"index_dtype",
|
|
1275
1585
|
"keys",
|
|
1276
|
-
"
|
|
1586
|
+
"preceding_ordinal",
|
|
1277
1587
|
"zlice",
|
|
1278
1588
|
)
|
|
1279
1589
|
_non_child = (
|
|
1280
1590
|
"schema",
|
|
1281
1591
|
"index",
|
|
1282
|
-
"
|
|
1283
|
-
"
|
|
1592
|
+
"index_dtype",
|
|
1593
|
+
"preceding_ordinal",
|
|
1594
|
+
"following_ordinal",
|
|
1284
1595
|
"closed_window",
|
|
1285
1596
|
"keys",
|
|
1286
1597
|
"agg_requests",
|
|
@@ -1288,10 +1599,12 @@ class Rolling(IR):
|
|
|
1288
1599
|
)
|
|
1289
1600
|
index: expr.NamedExpr
|
|
1290
1601
|
"""Column being rolled over."""
|
|
1291
|
-
|
|
1292
|
-
"""
|
|
1293
|
-
|
|
1294
|
-
"""
|
|
1602
|
+
index_dtype: plc.DataType
|
|
1603
|
+
"""Datatype of the index column."""
|
|
1604
|
+
preceding_ordinal: int
|
|
1605
|
+
"""Preceding window extent defining start of window as a host integer."""
|
|
1606
|
+
following_ordinal: int
|
|
1607
|
+
"""Following window extent defining end of window as a host integer."""
|
|
1295
1608
|
closed_window: ClosedInterval
|
|
1296
1609
|
"""Treatment of window endpoints."""
|
|
1297
1610
|
keys: tuple[expr.NamedExpr, ...]
|
|
@@ -1305,8 +1618,9 @@ class Rolling(IR):
|
|
|
1305
1618
|
self,
|
|
1306
1619
|
schema: Schema,
|
|
1307
1620
|
index: expr.NamedExpr,
|
|
1308
|
-
|
|
1309
|
-
|
|
1621
|
+
index_dtype: plc.DataType,
|
|
1622
|
+
preceding_ordinal: int,
|
|
1623
|
+
following_ordinal: int,
|
|
1310
1624
|
closed_window: ClosedInterval,
|
|
1311
1625
|
keys: Sequence[expr.NamedExpr],
|
|
1312
1626
|
agg_requests: Sequence[expr.NamedExpr],
|
|
@@ -1315,14 +1629,15 @@ class Rolling(IR):
|
|
|
1315
1629
|
):
|
|
1316
1630
|
self.schema = schema
|
|
1317
1631
|
self.index = index
|
|
1318
|
-
self.
|
|
1319
|
-
self.
|
|
1632
|
+
self.index_dtype = index_dtype
|
|
1633
|
+
self.preceding_ordinal = preceding_ordinal
|
|
1634
|
+
self.following_ordinal = following_ordinal
|
|
1320
1635
|
self.closed_window = closed_window
|
|
1321
1636
|
self.keys = tuple(keys)
|
|
1322
1637
|
self.agg_requests = tuple(agg_requests)
|
|
1323
1638
|
if not all(
|
|
1324
1639
|
plc.rolling.is_valid_rolling_aggregation(
|
|
1325
|
-
agg.value.dtype.
|
|
1640
|
+
agg.value.dtype.plc_type, agg.value.agg_request
|
|
1326
1641
|
)
|
|
1327
1642
|
for agg in self.agg_requests
|
|
1328
1643
|
):
|
|
@@ -1339,8 +1654,9 @@ class Rolling(IR):
|
|
|
1339
1654
|
self.children = (df,)
|
|
1340
1655
|
self._non_child_args = (
|
|
1341
1656
|
index,
|
|
1342
|
-
|
|
1343
|
-
|
|
1657
|
+
index_dtype,
|
|
1658
|
+
preceding_ordinal,
|
|
1659
|
+
following_ordinal,
|
|
1344
1660
|
closed_window,
|
|
1345
1661
|
keys,
|
|
1346
1662
|
agg_requests,
|
|
@@ -1348,31 +1664,46 @@ class Rolling(IR):
|
|
|
1348
1664
|
)
|
|
1349
1665
|
|
|
1350
1666
|
@classmethod
|
|
1667
|
+
@log_do_evaluate
|
|
1351
1668
|
@nvtx_annotate_cudf_polars(message="Rolling")
|
|
1352
1669
|
def do_evaluate(
|
|
1353
1670
|
cls,
|
|
1354
1671
|
index: expr.NamedExpr,
|
|
1355
|
-
|
|
1356
|
-
|
|
1672
|
+
index_dtype: plc.DataType,
|
|
1673
|
+
preceding_ordinal: int,
|
|
1674
|
+
following_ordinal: int,
|
|
1357
1675
|
closed_window: ClosedInterval,
|
|
1358
1676
|
keys_in: Sequence[expr.NamedExpr],
|
|
1359
1677
|
aggs: Sequence[expr.NamedExpr],
|
|
1360
1678
|
zlice: Zlice | None,
|
|
1361
1679
|
df: DataFrame,
|
|
1680
|
+
*,
|
|
1681
|
+
context: IRExecutionContext,
|
|
1362
1682
|
) -> DataFrame:
|
|
1363
1683
|
"""Evaluate and return a dataframe."""
|
|
1364
|
-
keys = broadcast(
|
|
1684
|
+
keys = broadcast(
|
|
1685
|
+
*(k.evaluate(df) for k in keys_in),
|
|
1686
|
+
target_length=df.num_rows,
|
|
1687
|
+
stream=df.stream,
|
|
1688
|
+
)
|
|
1365
1689
|
orderby = index.evaluate(df)
|
|
1366
1690
|
# Polars casts integral orderby to int64, but only for calculating window bounds
|
|
1367
1691
|
if (
|
|
1368
1692
|
plc.traits.is_integral(orderby.obj.type())
|
|
1369
1693
|
and orderby.obj.type().id() != plc.TypeId.INT64
|
|
1370
1694
|
):
|
|
1371
|
-
orderby_obj = plc.unary.cast(
|
|
1695
|
+
orderby_obj = plc.unary.cast(
|
|
1696
|
+
orderby.obj, plc.DataType(plc.TypeId.INT64), stream=df.stream
|
|
1697
|
+
)
|
|
1372
1698
|
else:
|
|
1373
1699
|
orderby_obj = orderby.obj
|
|
1700
|
+
|
|
1701
|
+
preceding_scalar, following_scalar = offsets_to_windows(
|
|
1702
|
+
index_dtype, preceding_ordinal, following_ordinal, stream=df.stream
|
|
1703
|
+
)
|
|
1704
|
+
|
|
1374
1705
|
preceding_window, following_window = range_window_bounds(
|
|
1375
|
-
|
|
1706
|
+
preceding_scalar, following_scalar, closed_window
|
|
1376
1707
|
)
|
|
1377
1708
|
if orderby.obj.null_count() != 0:
|
|
1378
1709
|
raise RuntimeError(
|
|
@@ -1383,12 +1714,17 @@ class Rolling(IR):
|
|
|
1383
1714
|
table = plc.Table([*(k.obj for k in keys), orderby_obj])
|
|
1384
1715
|
n = table.num_columns()
|
|
1385
1716
|
if not plc.sorting.is_sorted(
|
|
1386
|
-
table,
|
|
1717
|
+
table,
|
|
1718
|
+
[plc.types.Order.ASCENDING] * n,
|
|
1719
|
+
[plc.types.NullOrder.BEFORE] * n,
|
|
1720
|
+
stream=df.stream,
|
|
1387
1721
|
):
|
|
1388
1722
|
raise RuntimeError("Input for grouped rolling is not sorted")
|
|
1389
1723
|
else:
|
|
1390
1724
|
if not orderby.check_sorted(
|
|
1391
|
-
order=plc.types.Order.ASCENDING,
|
|
1725
|
+
order=plc.types.Order.ASCENDING,
|
|
1726
|
+
null_order=plc.types.NullOrder.BEFORE,
|
|
1727
|
+
stream=df.stream,
|
|
1392
1728
|
):
|
|
1393
1729
|
raise RuntimeError(
|
|
1394
1730
|
f"Index column '{index.name}' in rolling is not sorted, please sort first"
|
|
@@ -1401,6 +1737,7 @@ class Rolling(IR):
|
|
|
1401
1737
|
preceding_window,
|
|
1402
1738
|
following_window,
|
|
1403
1739
|
[rolling.to_request(request.value, orderby, df) for request in aggs],
|
|
1740
|
+
stream=df.stream,
|
|
1404
1741
|
)
|
|
1405
1742
|
return DataFrame(
|
|
1406
1743
|
itertools.chain(
|
|
@@ -1410,7 +1747,8 @@ class Rolling(IR):
|
|
|
1410
1747
|
Column(col, name=request.name, dtype=request.value.dtype)
|
|
1411
1748
|
for col, request in zip(values.columns(), aggs, strict=True)
|
|
1412
1749
|
),
|
|
1413
|
-
)
|
|
1750
|
+
),
|
|
1751
|
+
stream=df.stream,
|
|
1414
1752
|
).slice(zlice)
|
|
1415
1753
|
|
|
1416
1754
|
|
|
@@ -1472,6 +1810,7 @@ class GroupBy(IR):
|
|
|
1472
1810
|
)
|
|
1473
1811
|
|
|
1474
1812
|
@classmethod
|
|
1813
|
+
@log_do_evaluate
|
|
1475
1814
|
@nvtx_annotate_cudf_polars(message="GroupBy")
|
|
1476
1815
|
def do_evaluate(
|
|
1477
1816
|
cls,
|
|
@@ -1481,9 +1820,15 @@ class GroupBy(IR):
|
|
|
1481
1820
|
maintain_order: bool, # noqa: FBT001
|
|
1482
1821
|
zlice: Zlice | None,
|
|
1483
1822
|
df: DataFrame,
|
|
1823
|
+
*,
|
|
1824
|
+
context: IRExecutionContext,
|
|
1484
1825
|
) -> DataFrame:
|
|
1485
1826
|
"""Evaluate and return a dataframe."""
|
|
1486
|
-
keys = broadcast(
|
|
1827
|
+
keys = broadcast(
|
|
1828
|
+
*(k.evaluate(df) for k in keys_in),
|
|
1829
|
+
target_length=df.num_rows,
|
|
1830
|
+
stream=df.stream,
|
|
1831
|
+
)
|
|
1487
1832
|
sorted = (
|
|
1488
1833
|
plc.types.Sorted.YES
|
|
1489
1834
|
if all(k.is_sorted for k in keys)
|
|
@@ -1512,10 +1857,15 @@ class GroupBy(IR):
|
|
|
1512
1857
|
col = child.evaluate(df, context=ExecutionContext.GROUPBY).obj
|
|
1513
1858
|
else:
|
|
1514
1859
|
# Anything else, we pre-evaluate
|
|
1515
|
-
|
|
1860
|
+
column = value.evaluate(df, context=ExecutionContext.GROUPBY)
|
|
1861
|
+
if column.size != keys[0].size:
|
|
1862
|
+
column = broadcast(
|
|
1863
|
+
column, target_length=keys[0].size, stream=df.stream
|
|
1864
|
+
)[0]
|
|
1865
|
+
col = column.obj
|
|
1516
1866
|
requests.append(plc.groupby.GroupByRequest(col, [value.agg_request]))
|
|
1517
1867
|
names.append(name)
|
|
1518
|
-
group_keys, raw_tables = grouper.aggregate(requests)
|
|
1868
|
+
group_keys, raw_tables = grouper.aggregate(requests, stream=df.stream)
|
|
1519
1869
|
results = [
|
|
1520
1870
|
Column(column, name=name, dtype=schema[name])
|
|
1521
1871
|
for name, column, request in zip(
|
|
@@ -1529,7 +1879,7 @@ class GroupBy(IR):
|
|
|
1529
1879
|
Column(grouped_key, name=key.name, dtype=key.dtype)
|
|
1530
1880
|
for key, grouped_key in zip(keys, group_keys.columns(), strict=True)
|
|
1531
1881
|
]
|
|
1532
|
-
broadcasted = broadcast(*result_keys, *results)
|
|
1882
|
+
broadcasted = broadcast(*result_keys, *results, stream=df.stream)
|
|
1533
1883
|
# Handle order preservation of groups
|
|
1534
1884
|
if maintain_order and not sorted:
|
|
1535
1885
|
# The order we want
|
|
@@ -1539,6 +1889,7 @@ class GroupBy(IR):
|
|
|
1539
1889
|
plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
|
|
1540
1890
|
plc.types.NullEquality.EQUAL,
|
|
1541
1891
|
plc.types.NanEquality.ALL_EQUAL,
|
|
1892
|
+
stream=df.stream,
|
|
1542
1893
|
)
|
|
1543
1894
|
# The order we have
|
|
1544
1895
|
have = plc.Table([key.obj for key in broadcasted[: len(keys)]])
|
|
@@ -1546,7 +1897,7 @@ class GroupBy(IR):
|
|
|
1546
1897
|
# We know an inner join is OK because by construction
|
|
1547
1898
|
# want and have are permutations of each other.
|
|
1548
1899
|
left_order, right_order = plc.join.inner_join(
|
|
1549
|
-
want, have, plc.types.NullEquality.EQUAL
|
|
1900
|
+
want, have, plc.types.NullEquality.EQUAL, stream=df.stream
|
|
1550
1901
|
)
|
|
1551
1902
|
# Now left_order is an arbitrary permutation of the ordering we
|
|
1552
1903
|
# want, and right_order is a matching permutation of the ordering
|
|
@@ -1559,11 +1910,13 @@ class GroupBy(IR):
|
|
|
1559
1910
|
plc.Table([left_order]),
|
|
1560
1911
|
[plc.types.Order.ASCENDING],
|
|
1561
1912
|
[plc.types.NullOrder.AFTER],
|
|
1913
|
+
stream=df.stream,
|
|
1562
1914
|
).columns()
|
|
1563
1915
|
ordered_table = plc.copying.gather(
|
|
1564
1916
|
plc.Table([col.obj for col in broadcasted]),
|
|
1565
1917
|
right_order,
|
|
1566
1918
|
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
1919
|
+
stream=df.stream,
|
|
1567
1920
|
)
|
|
1568
1921
|
broadcasted = [
|
|
1569
1922
|
Column(reordered, name=old.name, dtype=old.dtype)
|
|
@@ -1571,7 +1924,126 @@ class GroupBy(IR):
|
|
|
1571
1924
|
ordered_table.columns(), broadcasted, strict=True
|
|
1572
1925
|
)
|
|
1573
1926
|
]
|
|
1574
|
-
return DataFrame(broadcasted).slice(zlice)
|
|
1927
|
+
return DataFrame(broadcasted, stream=df.stream).slice(zlice)
|
|
1928
|
+
|
|
1929
|
+
|
|
1930
|
+
def _strip_predicate_casts(node: expr.Expr) -> expr.Expr:
|
|
1931
|
+
if isinstance(node, expr.Cast):
|
|
1932
|
+
(child,) = node.children
|
|
1933
|
+
child = _strip_predicate_casts(child)
|
|
1934
|
+
|
|
1935
|
+
src = child.dtype
|
|
1936
|
+
dst = node.dtype
|
|
1937
|
+
|
|
1938
|
+
if plc.traits.is_fixed_point(src.plc_type) or plc.traits.is_fixed_point(
|
|
1939
|
+
dst.plc_type
|
|
1940
|
+
):
|
|
1941
|
+
return child
|
|
1942
|
+
|
|
1943
|
+
if (
|
|
1944
|
+
not POLARS_VERSION_LT_134
|
|
1945
|
+
and isinstance(child, expr.ColRef)
|
|
1946
|
+
and (
|
|
1947
|
+
(
|
|
1948
|
+
plc.traits.is_floating_point(src.plc_type)
|
|
1949
|
+
and plc.traits.is_floating_point(dst.plc_type)
|
|
1950
|
+
)
|
|
1951
|
+
or (
|
|
1952
|
+
plc.traits.is_integral(src.plc_type)
|
|
1953
|
+
and plc.traits.is_integral(dst.plc_type)
|
|
1954
|
+
and src.plc_type.id() == dst.plc_type.id()
|
|
1955
|
+
)
|
|
1956
|
+
)
|
|
1957
|
+
):
|
|
1958
|
+
return child
|
|
1959
|
+
|
|
1960
|
+
if not node.children:
|
|
1961
|
+
return node
|
|
1962
|
+
return node.reconstruct([_strip_predicate_casts(child) for child in node.children])
|
|
1963
|
+
|
|
1964
|
+
|
|
1965
|
+
def _add_cast(
|
|
1966
|
+
target: DataType,
|
|
1967
|
+
side: expr.ColRef,
|
|
1968
|
+
left_casts: dict[str, DataType],
|
|
1969
|
+
right_casts: dict[str, DataType],
|
|
1970
|
+
) -> None:
|
|
1971
|
+
(col,) = side.children
|
|
1972
|
+
assert isinstance(col, expr.Col)
|
|
1973
|
+
casts = (
|
|
1974
|
+
left_casts if side.table_ref == plc_expr.TableReference.LEFT else right_casts
|
|
1975
|
+
)
|
|
1976
|
+
casts[col.name] = target
|
|
1977
|
+
|
|
1978
|
+
|
|
1979
|
+
def _align_decimal_binop_types(
|
|
1980
|
+
left_expr: expr.ColRef,
|
|
1981
|
+
right_expr: expr.ColRef,
|
|
1982
|
+
left_casts: dict[str, DataType],
|
|
1983
|
+
right_casts: dict[str, DataType],
|
|
1984
|
+
) -> None:
|
|
1985
|
+
left_type, right_type = left_expr.dtype, right_expr.dtype
|
|
1986
|
+
|
|
1987
|
+
if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point(
|
|
1988
|
+
right_type.plc_type
|
|
1989
|
+
):
|
|
1990
|
+
target = DataType.common_decimal_dtype(left_type, right_type)
|
|
1991
|
+
|
|
1992
|
+
if left_type.id() != target.id() or left_type.scale() != target.scale():
|
|
1993
|
+
_add_cast(target, left_expr, left_casts, right_casts)
|
|
1994
|
+
|
|
1995
|
+
if right_type.id() != target.id() or right_type.scale() != target.scale():
|
|
1996
|
+
_add_cast(target, right_expr, left_casts, right_casts)
|
|
1997
|
+
|
|
1998
|
+
elif (
|
|
1999
|
+
plc.traits.is_fixed_point(left_type.plc_type)
|
|
2000
|
+
and plc.traits.is_floating_point(right_type.plc_type)
|
|
2001
|
+
) or (
|
|
2002
|
+
plc.traits.is_fixed_point(right_type.plc_type)
|
|
2003
|
+
and plc.traits.is_floating_point(left_type.plc_type)
|
|
2004
|
+
):
|
|
2005
|
+
is_decimal_left = plc.traits.is_fixed_point(left_type.plc_type)
|
|
2006
|
+
decimal_expr, float_expr = (
|
|
2007
|
+
(left_expr, right_expr) if is_decimal_left else (right_expr, left_expr)
|
|
2008
|
+
)
|
|
2009
|
+
_add_cast(decimal_expr.dtype, float_expr, left_casts, right_casts)
|
|
2010
|
+
|
|
2011
|
+
|
|
2012
|
+
def _collect_decimal_binop_casts(
|
|
2013
|
+
predicate: expr.Expr,
|
|
2014
|
+
) -> tuple[dict[str, DataType], dict[str, DataType]]:
|
|
2015
|
+
left_casts: dict[str, DataType] = {}
|
|
2016
|
+
right_casts: dict[str, DataType] = {}
|
|
2017
|
+
|
|
2018
|
+
def _walk(node: expr.Expr) -> None:
|
|
2019
|
+
if isinstance(node, expr.BinOp) and node.op in _BINOPS:
|
|
2020
|
+
left_expr, right_expr = node.children
|
|
2021
|
+
if isinstance(left_expr, expr.ColRef) and isinstance(
|
|
2022
|
+
right_expr, expr.ColRef
|
|
2023
|
+
):
|
|
2024
|
+
_align_decimal_binop_types(
|
|
2025
|
+
left_expr, right_expr, left_casts, right_casts
|
|
2026
|
+
)
|
|
2027
|
+
for child in node.children:
|
|
2028
|
+
_walk(child)
|
|
2029
|
+
|
|
2030
|
+
_walk(predicate)
|
|
2031
|
+
return left_casts, right_casts
|
|
2032
|
+
|
|
2033
|
+
|
|
2034
|
+
def _apply_casts(df: DataFrame, casts: dict[str, DataType]) -> DataFrame:
|
|
2035
|
+
if not casts:
|
|
2036
|
+
return df
|
|
2037
|
+
|
|
2038
|
+
columns = []
|
|
2039
|
+
for col in df.columns:
|
|
2040
|
+
target = casts.get(col.name)
|
|
2041
|
+
if target is None:
|
|
2042
|
+
columns.append(Column(col.obj, dtype=col.dtype, name=col.name))
|
|
2043
|
+
else:
|
|
2044
|
+
casted = col.astype(target, stream=df.stream)
|
|
2045
|
+
columns.append(Column(casted.obj, dtype=casted.dtype, name=col.name))
|
|
2046
|
+
return DataFrame(columns, stream=df.stream)
|
|
1575
2047
|
|
|
1576
2048
|
|
|
1577
2049
|
class ConditionalJoin(IR):
|
|
@@ -1585,7 +2057,14 @@ class ConditionalJoin(IR):
|
|
|
1585
2057
|
|
|
1586
2058
|
def __init__(self, predicate: expr.Expr):
|
|
1587
2059
|
self.predicate = predicate
|
|
1588
|
-
|
|
2060
|
+
stream = get_cuda_stream()
|
|
2061
|
+
ast_result = to_ast(predicate, stream=stream)
|
|
2062
|
+
stream.synchronize()
|
|
2063
|
+
if ast_result is None:
|
|
2064
|
+
raise NotImplementedError(
|
|
2065
|
+
f"Conditional join with predicate {predicate}"
|
|
2066
|
+
) # pragma: no cover; polars never delivers expressions we can't handle
|
|
2067
|
+
self.ast = ast_result
|
|
1589
2068
|
|
|
1590
2069
|
def __reduce__(self) -> tuple[Any, ...]:
|
|
1591
2070
|
"""Pickle a Predicate object."""
|
|
@@ -1598,8 +2077,9 @@ class ConditionalJoin(IR):
|
|
|
1598
2077
|
options: tuple[
|
|
1599
2078
|
tuple[
|
|
1600
2079
|
str,
|
|
1601
|
-
|
|
1602
|
-
]
|
|
2080
|
+
polars._expr_nodes.Operator | Iterable[polars._expr_nodes.Operator],
|
|
2081
|
+
]
|
|
2082
|
+
| None,
|
|
1603
2083
|
bool,
|
|
1604
2084
|
Zlice | None,
|
|
1605
2085
|
str,
|
|
@@ -1620,7 +2100,14 @@ class ConditionalJoin(IR):
|
|
|
1620
2100
|
self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
|
|
1621
2101
|
) -> None:
|
|
1622
2102
|
self.schema = schema
|
|
2103
|
+
predicate = _strip_predicate_casts(predicate)
|
|
1623
2104
|
self.predicate = predicate
|
|
2105
|
+
# options[0] is a tuple[str, Operator, ...]
|
|
2106
|
+
# The Operator class can't be pickled, but we don't use it anyway so
|
|
2107
|
+
# just throw that away
|
|
2108
|
+
if options[0] is not None:
|
|
2109
|
+
options = (None, *options[1:])
|
|
2110
|
+
|
|
1624
2111
|
self.options = options
|
|
1625
2112
|
self.children = (left, right)
|
|
1626
2113
|
predicate_wrapper = self.Predicate(predicate)
|
|
@@ -1629,51 +2116,64 @@ class ConditionalJoin(IR):
|
|
|
1629
2116
|
assert not nulls_equal
|
|
1630
2117
|
assert not coalesce
|
|
1631
2118
|
assert maintain_order == "none"
|
|
1632
|
-
|
|
1633
|
-
raise NotImplementedError(
|
|
1634
|
-
f"Conditional join with predicate {predicate}"
|
|
1635
|
-
) # pragma: no cover; polars never delivers expressions we can't handle
|
|
1636
|
-
self._non_child_args = (predicate_wrapper, zlice, suffix, maintain_order)
|
|
2119
|
+
self._non_child_args = (predicate_wrapper, options)
|
|
1637
2120
|
|
|
1638
2121
|
@classmethod
|
|
2122
|
+
@log_do_evaluate
|
|
1639
2123
|
@nvtx_annotate_cudf_polars(message="ConditionalJoin")
|
|
1640
2124
|
def do_evaluate(
|
|
1641
2125
|
cls,
|
|
1642
2126
|
predicate_wrapper: Predicate,
|
|
1643
|
-
|
|
1644
|
-
suffix: str,
|
|
1645
|
-
maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
|
|
2127
|
+
options: tuple,
|
|
1646
2128
|
left: DataFrame,
|
|
1647
2129
|
right: DataFrame,
|
|
2130
|
+
*,
|
|
2131
|
+
context: IRExecutionContext,
|
|
1648
2132
|
) -> DataFrame:
|
|
1649
2133
|
"""Evaluate and return a dataframe."""
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
|
|
1656
|
-
plc.
|
|
1657
|
-
left.table,
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
|
|
1663
|
-
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
2134
|
+
with context.stream_ordered_after(left, right) as stream:
|
|
2135
|
+
left_casts, right_casts = _collect_decimal_binop_casts(
|
|
2136
|
+
predicate_wrapper.predicate
|
|
2137
|
+
)
|
|
2138
|
+
_, _, zlice, suffix, _, _ = options
|
|
2139
|
+
|
|
2140
|
+
lg, rg = plc.join.conditional_inner_join(
|
|
2141
|
+
_apply_casts(left, left_casts).table,
|
|
2142
|
+
_apply_casts(right, right_casts).table,
|
|
2143
|
+
predicate_wrapper.ast,
|
|
2144
|
+
stream=stream,
|
|
2145
|
+
)
|
|
2146
|
+
left_result = DataFrame.from_table(
|
|
2147
|
+
plc.copying.gather(
|
|
2148
|
+
left.table,
|
|
2149
|
+
lg,
|
|
2150
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
2151
|
+
stream=stream,
|
|
2152
|
+
),
|
|
2153
|
+
left.column_names,
|
|
2154
|
+
left.dtypes,
|
|
2155
|
+
stream=stream,
|
|
2156
|
+
)
|
|
2157
|
+
right_result = DataFrame.from_table(
|
|
2158
|
+
plc.copying.gather(
|
|
2159
|
+
right.table,
|
|
2160
|
+
rg,
|
|
2161
|
+
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
|
|
2162
|
+
stream=stream,
|
|
2163
|
+
),
|
|
2164
|
+
right.column_names,
|
|
2165
|
+
right.dtypes,
|
|
2166
|
+
stream=stream,
|
|
2167
|
+
)
|
|
2168
|
+
right_result = right_result.rename_columns(
|
|
2169
|
+
{
|
|
2170
|
+
name: f"{name}{suffix}"
|
|
2171
|
+
for name in right.column_names
|
|
2172
|
+
if name in left.column_names_set
|
|
2173
|
+
}
|
|
2174
|
+
)
|
|
2175
|
+
result = left_result.with_columns(right_result.columns, stream=stream)
|
|
2176
|
+
|
|
1677
2177
|
return result.slice(zlice)
|
|
1678
2178
|
|
|
1679
2179
|
|
|
@@ -1704,6 +2204,19 @@ class Join(IR):
|
|
|
1704
2204
|
- maintain_order: which DataFrame row order to preserve, if any
|
|
1705
2205
|
"""
|
|
1706
2206
|
|
|
2207
|
+
SWAPPED_ORDER: ClassVar[
|
|
2208
|
+
dict[
|
|
2209
|
+
Literal["none", "left", "right", "left_right", "right_left"],
|
|
2210
|
+
Literal["none", "left", "right", "left_right", "right_left"],
|
|
2211
|
+
]
|
|
2212
|
+
] = {
|
|
2213
|
+
"none": "none",
|
|
2214
|
+
"left": "right",
|
|
2215
|
+
"right": "left",
|
|
2216
|
+
"left_right": "right_left",
|
|
2217
|
+
"right_left": "left_right",
|
|
2218
|
+
}
|
|
2219
|
+
|
|
1707
2220
|
def __init__(
|
|
1708
2221
|
self,
|
|
1709
2222
|
schema: Schema,
|
|
@@ -1719,9 +2232,6 @@ class Join(IR):
|
|
|
1719
2232
|
self.options = options
|
|
1720
2233
|
self.children = (left, right)
|
|
1721
2234
|
self._non_child_args = (self.left_on, self.right_on, self.options)
|
|
1722
|
-
# TODO: Implement maintain_order
|
|
1723
|
-
if options[5] != "none":
|
|
1724
|
-
raise NotImplementedError("maintain_order not implemented yet")
|
|
1725
2235
|
|
|
1726
2236
|
@staticmethod
|
|
1727
2237
|
@cache
|
|
@@ -1770,6 +2280,9 @@ class Join(IR):
|
|
|
1770
2280
|
right_rows: int,
|
|
1771
2281
|
rg: plc.Column,
|
|
1772
2282
|
right_policy: plc.copying.OutOfBoundsPolicy,
|
|
2283
|
+
*,
|
|
2284
|
+
left_primary: bool = True,
|
|
2285
|
+
stream: Stream,
|
|
1773
2286
|
) -> list[plc.Column]:
|
|
1774
2287
|
"""
|
|
1775
2288
|
Reorder gather maps to satisfy polars join order restrictions.
|
|
@@ -1788,30 +2301,70 @@ class Join(IR):
|
|
|
1788
2301
|
Right gather map
|
|
1789
2302
|
right_policy
|
|
1790
2303
|
Nullify policy for right map
|
|
2304
|
+
left_primary
|
|
2305
|
+
Whether to preserve the left input row order first, and which
|
|
2306
|
+
input stream to use for the primary sort.
|
|
2307
|
+
Defaults to True.
|
|
2308
|
+
stream
|
|
2309
|
+
CUDA stream used for device memory operations and kernel launches.
|
|
1791
2310
|
|
|
1792
2311
|
Returns
|
|
1793
2312
|
-------
|
|
1794
|
-
list
|
|
2313
|
+
list[plc.Column]
|
|
2314
|
+
Reordered left and right gather maps.
|
|
1795
2315
|
|
|
1796
2316
|
Notes
|
|
1797
2317
|
-----
|
|
1798
|
-
|
|
1799
|
-
left
|
|
1800
|
-
|
|
2318
|
+
When ``left_primary`` is True, the pair of gather maps is stably sorted by
|
|
2319
|
+
the original row order of the left side, breaking ties by the right side.
|
|
2320
|
+
And vice versa when ``left_primary`` is False.
|
|
1801
2321
|
"""
|
|
1802
|
-
init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE)
|
|
1803
|
-
step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE)
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
2322
|
+
init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE, stream=stream)
|
|
2323
|
+
step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE, stream=stream)
|
|
2324
|
+
|
|
2325
|
+
(left_order_col,) = plc.copying.gather(
|
|
2326
|
+
plc.Table(
|
|
2327
|
+
[
|
|
2328
|
+
plc.filling.sequence(
|
|
2329
|
+
left_rows,
|
|
2330
|
+
init,
|
|
2331
|
+
step,
|
|
2332
|
+
stream=stream,
|
|
2333
|
+
)
|
|
2334
|
+
]
|
|
2335
|
+
),
|
|
2336
|
+
lg,
|
|
2337
|
+
left_policy,
|
|
2338
|
+
stream=stream,
|
|
2339
|
+
).columns()
|
|
2340
|
+
(right_order_col,) = plc.copying.gather(
|
|
2341
|
+
plc.Table(
|
|
2342
|
+
[
|
|
2343
|
+
plc.filling.sequence(
|
|
2344
|
+
right_rows,
|
|
2345
|
+
init,
|
|
2346
|
+
step,
|
|
2347
|
+
stream=stream,
|
|
2348
|
+
)
|
|
2349
|
+
]
|
|
2350
|
+
),
|
|
2351
|
+
rg,
|
|
2352
|
+
right_policy,
|
|
2353
|
+
stream=stream,
|
|
2354
|
+
).columns()
|
|
2355
|
+
|
|
2356
|
+
keys = (
|
|
2357
|
+
plc.Table([left_order_col, right_order_col])
|
|
2358
|
+
if left_primary
|
|
2359
|
+
else plc.Table([right_order_col, left_order_col])
|
|
1809
2360
|
)
|
|
2361
|
+
|
|
1810
2362
|
return plc.sorting.stable_sort_by_key(
|
|
1811
2363
|
plc.Table([lg, rg]),
|
|
1812
|
-
|
|
2364
|
+
keys,
|
|
1813
2365
|
[plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
|
|
1814
2366
|
[plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
|
|
2367
|
+
stream=stream,
|
|
1815
2368
|
).columns()
|
|
1816
2369
|
|
|
1817
2370
|
@staticmethod
|
|
@@ -1822,31 +2375,35 @@ class Join(IR):
|
|
|
1822
2375
|
left: bool = True,
|
|
1823
2376
|
empty: bool = False,
|
|
1824
2377
|
rename: Callable[[str], str] = lambda name: name,
|
|
2378
|
+
stream: Stream,
|
|
1825
2379
|
) -> list[Column]:
|
|
1826
2380
|
if empty:
|
|
1827
2381
|
return [
|
|
1828
2382
|
Column(
|
|
1829
|
-
plc.column_factories.make_empty_column(
|
|
2383
|
+
plc.column_factories.make_empty_column(
|
|
2384
|
+
col.dtype.plc_type, stream=stream
|
|
2385
|
+
),
|
|
1830
2386
|
col.dtype,
|
|
1831
2387
|
name=rename(col.name),
|
|
1832
2388
|
)
|
|
1833
2389
|
for col in template
|
|
1834
2390
|
]
|
|
1835
2391
|
|
|
1836
|
-
|
|
2392
|
+
result = [
|
|
1837
2393
|
Column(new, col.dtype, name=rename(col.name))
|
|
1838
2394
|
for new, col in zip(columns, template, strict=True)
|
|
1839
2395
|
]
|
|
1840
2396
|
|
|
1841
2397
|
if left:
|
|
1842
|
-
|
|
2398
|
+
result = [
|
|
1843
2399
|
col.sorted_like(orig)
|
|
1844
|
-
for col, orig in zip(
|
|
2400
|
+
for col, orig in zip(result, template, strict=True)
|
|
1845
2401
|
]
|
|
1846
2402
|
|
|
1847
|
-
return
|
|
2403
|
+
return result
|
|
1848
2404
|
|
|
1849
2405
|
@classmethod
|
|
2406
|
+
@log_do_evaluate
|
|
1850
2407
|
@nvtx_annotate_cudf_polars(message="Join")
|
|
1851
2408
|
def do_evaluate(
|
|
1852
2409
|
cls,
|
|
@@ -1862,112 +2419,168 @@ class Join(IR):
|
|
|
1862
2419
|
],
|
|
1863
2420
|
left: DataFrame,
|
|
1864
2421
|
right: DataFrame,
|
|
2422
|
+
*,
|
|
2423
|
+
context: IRExecutionContext,
|
|
1865
2424
|
) -> DataFrame:
|
|
1866
2425
|
"""Evaluate and return a dataframe."""
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
1870
|
-
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
2426
|
+
with context.stream_ordered_after(left, right) as stream:
|
|
2427
|
+
how, nulls_equal, zlice, suffix, coalesce, maintain_order = options
|
|
2428
|
+
if how == "Cross":
|
|
2429
|
+
# Separate implementation, since cross_join returns the
|
|
2430
|
+
# result, not the gather maps
|
|
2431
|
+
if right.num_rows == 0:
|
|
2432
|
+
left_cols = Join._build_columns(
|
|
2433
|
+
[], left.columns, empty=True, stream=stream
|
|
2434
|
+
)
|
|
2435
|
+
right_cols = Join._build_columns(
|
|
2436
|
+
[],
|
|
2437
|
+
right.columns,
|
|
2438
|
+
left=False,
|
|
2439
|
+
empty=True,
|
|
2440
|
+
rename=lambda name: name
|
|
2441
|
+
if name not in left.column_names_set
|
|
2442
|
+
else f"{name}{suffix}",
|
|
2443
|
+
stream=stream,
|
|
2444
|
+
)
|
|
2445
|
+
result = DataFrame([*left_cols, *right_cols], stream=stream)
|
|
2446
|
+
else:
|
|
2447
|
+
columns = plc.join.cross_join(
|
|
2448
|
+
left.table, right.table, stream=stream
|
|
2449
|
+
).columns()
|
|
2450
|
+
left_cols = Join._build_columns(
|
|
2451
|
+
columns[: left.num_columns], left.columns, stream=stream
|
|
2452
|
+
)
|
|
2453
|
+
right_cols = Join._build_columns(
|
|
2454
|
+
columns[left.num_columns :],
|
|
2455
|
+
right.columns,
|
|
2456
|
+
rename=lambda name: name
|
|
2457
|
+
if name not in left.column_names_set
|
|
2458
|
+
else f"{name}{suffix}",
|
|
2459
|
+
left=False,
|
|
2460
|
+
stream=stream,
|
|
2461
|
+
)
|
|
2462
|
+
result = DataFrame([*left_cols, *right_cols], stream=stream).slice(
|
|
2463
|
+
zlice
|
|
2464
|
+
)
|
|
1883
2465
|
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
|
|
1894
|
-
|
|
1895
|
-
|
|
1896
|
-
|
|
1897
|
-
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
if nulls_equal
|
|
1904
|
-
else plc.types.NullEquality.UNEQUAL
|
|
1905
|
-
)
|
|
1906
|
-
join_fn, left_policy, right_policy = cls._joiners(how)
|
|
1907
|
-
if right_policy is None:
|
|
1908
|
-
# Semi join
|
|
1909
|
-
lg = join_fn(left_on.table, right_on.table, null_equality)
|
|
1910
|
-
table = plc.copying.gather(left.table, lg, left_policy)
|
|
1911
|
-
result = DataFrame.from_table(table, left.column_names, left.dtypes)
|
|
1912
|
-
else:
|
|
1913
|
-
if how == "Right":
|
|
1914
|
-
# Right join is a left join with the tables swapped
|
|
1915
|
-
left, right = right, left
|
|
1916
|
-
left_on, right_on = right_on, left_on
|
|
1917
|
-
lg, rg = join_fn(left_on.table, right_on.table, null_equality)
|
|
1918
|
-
if how == "Left" or how == "Right":
|
|
1919
|
-
# Order of left table is preserved
|
|
1920
|
-
lg, rg = cls._reorder_maps(
|
|
1921
|
-
left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
|
|
2466
|
+
else:
|
|
2467
|
+
# how != "Cross"
|
|
2468
|
+
# TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
|
|
2469
|
+
left_on = DataFrame(
|
|
2470
|
+
broadcast(
|
|
2471
|
+
*(e.evaluate(left) for e in left_on_exprs), stream=stream
|
|
2472
|
+
),
|
|
2473
|
+
stream=stream,
|
|
2474
|
+
)
|
|
2475
|
+
right_on = DataFrame(
|
|
2476
|
+
broadcast(
|
|
2477
|
+
*(e.evaluate(right) for e in right_on_exprs), stream=stream
|
|
2478
|
+
),
|
|
2479
|
+
stream=stream,
|
|
2480
|
+
)
|
|
2481
|
+
null_equality = (
|
|
2482
|
+
plc.types.NullEquality.EQUAL
|
|
2483
|
+
if nulls_equal
|
|
2484
|
+
else plc.types.NullEquality.UNEQUAL
|
|
1922
2485
|
)
|
|
1923
|
-
|
|
1924
|
-
if
|
|
1925
|
-
#
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
|
|
2486
|
+
join_fn, left_policy, right_policy = cls._joiners(how)
|
|
2487
|
+
if right_policy is None:
|
|
2488
|
+
# Semi join
|
|
2489
|
+
lg = join_fn(left_on.table, right_on.table, null_equality, stream)
|
|
2490
|
+
table = plc.copying.gather(
|
|
2491
|
+
left.table, lg, left_policy, stream=stream
|
|
2492
|
+
)
|
|
2493
|
+
result = DataFrame.from_table(
|
|
2494
|
+
table, left.column_names, left.dtypes, stream=stream
|
|
2495
|
+
)
|
|
1930
2496
|
else:
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
2497
|
+
if how == "Right":
|
|
2498
|
+
# Right join is a left join with the tables swapped
|
|
2499
|
+
left, right = right, left
|
|
2500
|
+
left_on, right_on = right_on, left_on
|
|
2501
|
+
maintain_order = Join.SWAPPED_ORDER[maintain_order]
|
|
2502
|
+
|
|
2503
|
+
lg, rg = join_fn(
|
|
2504
|
+
left_on.table, right_on.table, null_equality, stream=stream
|
|
2505
|
+
)
|
|
2506
|
+
if (
|
|
2507
|
+
how in ("Inner", "Left", "Right", "Full")
|
|
2508
|
+
and maintain_order != "none"
|
|
2509
|
+
):
|
|
2510
|
+
lg, rg = cls._reorder_maps(
|
|
2511
|
+
left.num_rows,
|
|
2512
|
+
lg,
|
|
2513
|
+
left_policy,
|
|
2514
|
+
right.num_rows,
|
|
2515
|
+
rg,
|
|
2516
|
+
right_policy,
|
|
2517
|
+
left_primary=maintain_order.startswith("left"),
|
|
2518
|
+
stream=stream,
|
|
1949
2519
|
)
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
2520
|
+
if coalesce:
|
|
2521
|
+
if how == "Full":
|
|
2522
|
+
# In this case, keys must be column references,
|
|
2523
|
+
# possibly with dtype casting. We should use them in
|
|
2524
|
+
# preference to the columns from the original tables.
|
|
2525
|
+
|
|
2526
|
+
# We need to specify `stream` here. We know that `{left,right}_on`
|
|
2527
|
+
# is valid on `stream`, which is ordered after `{left,right}.stream`.
|
|
2528
|
+
left = left.with_columns(
|
|
2529
|
+
left_on.columns, replace_only=True, stream=stream
|
|
2530
|
+
)
|
|
2531
|
+
right = right.with_columns(
|
|
2532
|
+
right_on.columns, replace_only=True, stream=stream
|
|
2533
|
+
)
|
|
2534
|
+
else:
|
|
2535
|
+
right = right.discard_columns(right_on.column_names_set)
|
|
2536
|
+
left = DataFrame.from_table(
|
|
2537
|
+
plc.copying.gather(left.table, lg, left_policy, stream=stream),
|
|
2538
|
+
left.column_names,
|
|
2539
|
+
left.dtypes,
|
|
2540
|
+
stream=stream,
|
|
2541
|
+
)
|
|
2542
|
+
right = DataFrame.from_table(
|
|
2543
|
+
plc.copying.gather(
|
|
2544
|
+
right.table, rg, right_policy, stream=stream
|
|
2545
|
+
),
|
|
2546
|
+
right.column_names,
|
|
2547
|
+
right.dtypes,
|
|
2548
|
+
stream=stream,
|
|
2549
|
+
)
|
|
2550
|
+
if coalesce and how == "Full":
|
|
2551
|
+
left = left.with_columns(
|
|
2552
|
+
(
|
|
2553
|
+
Column(
|
|
2554
|
+
plc.replace.replace_nulls(
|
|
2555
|
+
left_col.obj, right_col.obj, stream=stream
|
|
2556
|
+
),
|
|
2557
|
+
name=left_col.name,
|
|
2558
|
+
dtype=left_col.dtype,
|
|
2559
|
+
)
|
|
2560
|
+
for left_col, right_col in zip(
|
|
2561
|
+
left.select_columns(left_on.column_names_set),
|
|
2562
|
+
right.select_columns(right_on.column_names_set),
|
|
2563
|
+
strict=True,
|
|
2564
|
+
)
|
|
2565
|
+
),
|
|
2566
|
+
replace_only=True,
|
|
2567
|
+
stream=stream,
|
|
1954
2568
|
)
|
|
1955
|
-
|
|
1956
|
-
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
return result.slice(zlice)
|
|
2569
|
+
right = right.discard_columns(right_on.column_names_set)
|
|
2570
|
+
if how == "Right":
|
|
2571
|
+
# Undo the swap for right join before gluing together.
|
|
2572
|
+
left, right = right, left
|
|
2573
|
+
right = right.rename_columns(
|
|
2574
|
+
{
|
|
2575
|
+
name: f"{name}{suffix}"
|
|
2576
|
+
for name in right.column_names
|
|
2577
|
+
if name in left.column_names_set
|
|
2578
|
+
}
|
|
2579
|
+
)
|
|
2580
|
+
result = left.with_columns(right.columns, stream=stream)
|
|
2581
|
+
result = result.slice(zlice)
|
|
2582
|
+
|
|
2583
|
+
return result
|
|
1971
2584
|
|
|
1972
2585
|
|
|
1973
2586
|
class HStack(IR):
|
|
@@ -1992,18 +2605,23 @@ class HStack(IR):
|
|
|
1992
2605
|
self.children = (df,)
|
|
1993
2606
|
|
|
1994
2607
|
@classmethod
|
|
2608
|
+
@log_do_evaluate
|
|
1995
2609
|
@nvtx_annotate_cudf_polars(message="HStack")
|
|
1996
2610
|
def do_evaluate(
|
|
1997
2611
|
cls,
|
|
1998
2612
|
exprs: Sequence[expr.NamedExpr],
|
|
1999
2613
|
should_broadcast: bool, # noqa: FBT001
|
|
2000
2614
|
df: DataFrame,
|
|
2615
|
+
*,
|
|
2616
|
+
context: IRExecutionContext,
|
|
2001
2617
|
) -> DataFrame:
|
|
2002
2618
|
"""Evaluate and return a dataframe."""
|
|
2003
2619
|
columns = [c.evaluate(df) for c in exprs]
|
|
2004
2620
|
if should_broadcast:
|
|
2005
2621
|
columns = broadcast(
|
|
2006
|
-
*columns,
|
|
2622
|
+
*columns,
|
|
2623
|
+
target_length=df.num_rows if df.num_columns != 0 else None,
|
|
2624
|
+
stream=df.stream,
|
|
2007
2625
|
)
|
|
2008
2626
|
else:
|
|
2009
2627
|
# Polars ensures this is true, but let's make sure nothing
|
|
@@ -2014,7 +2632,7 @@ class HStack(IR):
|
|
|
2014
2632
|
# never be turned into a pylibcudf Table with all columns
|
|
2015
2633
|
# by the Select, which is why this is safe.
|
|
2016
2634
|
assert all(e.name.startswith("__POLARS_CSER_0x") for e in exprs)
|
|
2017
|
-
return df.with_columns(columns)
|
|
2635
|
+
return df.with_columns(columns, stream=df.stream)
|
|
2018
2636
|
|
|
2019
2637
|
|
|
2020
2638
|
class Distinct(IR):
|
|
@@ -2057,6 +2675,7 @@ class Distinct(IR):
|
|
|
2057
2675
|
}
|
|
2058
2676
|
|
|
2059
2677
|
@classmethod
|
|
2678
|
+
@log_do_evaluate
|
|
2060
2679
|
@nvtx_annotate_cudf_polars(message="Distinct")
|
|
2061
2680
|
def do_evaluate(
|
|
2062
2681
|
cls,
|
|
@@ -2065,6 +2684,8 @@ class Distinct(IR):
|
|
|
2065
2684
|
zlice: Zlice | None,
|
|
2066
2685
|
stable: bool, # noqa: FBT001
|
|
2067
2686
|
df: DataFrame,
|
|
2687
|
+
*,
|
|
2688
|
+
context: IRExecutionContext,
|
|
2068
2689
|
) -> DataFrame:
|
|
2069
2690
|
"""Evaluate and return a dataframe."""
|
|
2070
2691
|
if subset is None:
|
|
@@ -2079,6 +2700,7 @@ class Distinct(IR):
|
|
|
2079
2700
|
indices,
|
|
2080
2701
|
keep,
|
|
2081
2702
|
plc.types.NullEquality.EQUAL,
|
|
2703
|
+
stream=df.stream,
|
|
2082
2704
|
)
|
|
2083
2705
|
else:
|
|
2084
2706
|
distinct = (
|
|
@@ -2092,13 +2714,15 @@ class Distinct(IR):
|
|
|
2092
2714
|
keep,
|
|
2093
2715
|
plc.types.NullEquality.EQUAL,
|
|
2094
2716
|
plc.types.NanEquality.ALL_EQUAL,
|
|
2717
|
+
df.stream,
|
|
2095
2718
|
)
|
|
2096
2719
|
# TODO: Is this sortedness setting correct
|
|
2097
2720
|
result = DataFrame(
|
|
2098
2721
|
[
|
|
2099
2722
|
Column(new, name=old.name, dtype=old.dtype).sorted_like(old)
|
|
2100
2723
|
for new, old in zip(table.columns(), df.columns, strict=True)
|
|
2101
|
-
]
|
|
2724
|
+
],
|
|
2725
|
+
stream=df.stream,
|
|
2102
2726
|
)
|
|
2103
2727
|
if keys_sorted or stable:
|
|
2104
2728
|
result = result.sorted_like(df)
|
|
@@ -2147,6 +2771,7 @@ class Sort(IR):
|
|
|
2147
2771
|
self.children = (df,)
|
|
2148
2772
|
|
|
2149
2773
|
@classmethod
|
|
2774
|
+
@log_do_evaluate
|
|
2150
2775
|
@nvtx_annotate_cudf_polars(message="Sort")
|
|
2151
2776
|
def do_evaluate(
|
|
2152
2777
|
cls,
|
|
@@ -2156,17 +2781,24 @@ class Sort(IR):
|
|
|
2156
2781
|
stable: bool, # noqa: FBT001
|
|
2157
2782
|
zlice: Zlice | None,
|
|
2158
2783
|
df: DataFrame,
|
|
2784
|
+
*,
|
|
2785
|
+
context: IRExecutionContext,
|
|
2159
2786
|
) -> DataFrame:
|
|
2160
2787
|
"""Evaluate and return a dataframe."""
|
|
2161
|
-
sort_keys = broadcast(
|
|
2788
|
+
sort_keys = broadcast(
|
|
2789
|
+
*(k.evaluate(df) for k in by), target_length=df.num_rows, stream=df.stream
|
|
2790
|
+
)
|
|
2162
2791
|
do_sort = plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
|
|
2163
2792
|
table = do_sort(
|
|
2164
2793
|
df.table,
|
|
2165
2794
|
plc.Table([k.obj for k in sort_keys]),
|
|
2166
2795
|
list(order),
|
|
2167
2796
|
list(null_order),
|
|
2797
|
+
stream=df.stream,
|
|
2798
|
+
)
|
|
2799
|
+
result = DataFrame.from_table(
|
|
2800
|
+
table, df.column_names, df.dtypes, stream=df.stream
|
|
2168
2801
|
)
|
|
2169
|
-
result = DataFrame.from_table(table, df.column_names, df.dtypes)
|
|
2170
2802
|
first_key = sort_keys[0]
|
|
2171
2803
|
name = by[0].name
|
|
2172
2804
|
first_key_in_result = (
|
|
@@ -2197,8 +2829,11 @@ class Slice(IR):
|
|
|
2197
2829
|
self.children = (df,)
|
|
2198
2830
|
|
|
2199
2831
|
@classmethod
|
|
2832
|
+
@log_do_evaluate
|
|
2200
2833
|
@nvtx_annotate_cudf_polars(message="Slice")
|
|
2201
|
-
def do_evaluate(
|
|
2834
|
+
def do_evaluate(
|
|
2835
|
+
cls, offset: int, length: int, df: DataFrame, *, context: IRExecutionContext
|
|
2836
|
+
) -> DataFrame:
|
|
2202
2837
|
"""Evaluate and return a dataframe."""
|
|
2203
2838
|
return df.slice((offset, length))
|
|
2204
2839
|
|
|
@@ -2218,10 +2853,15 @@ class Filter(IR):
|
|
|
2218
2853
|
self.children = (df,)
|
|
2219
2854
|
|
|
2220
2855
|
@classmethod
|
|
2856
|
+
@log_do_evaluate
|
|
2221
2857
|
@nvtx_annotate_cudf_polars(message="Filter")
|
|
2222
|
-
def do_evaluate(
|
|
2858
|
+
def do_evaluate(
|
|
2859
|
+
cls, mask_expr: expr.NamedExpr, df: DataFrame, *, context: IRExecutionContext
|
|
2860
|
+
) -> DataFrame:
|
|
2223
2861
|
"""Evaluate and return a dataframe."""
|
|
2224
|
-
(mask,) = broadcast(
|
|
2862
|
+
(mask,) = broadcast(
|
|
2863
|
+
mask_expr.evaluate(df), target_length=df.num_rows, stream=df.stream
|
|
2864
|
+
)
|
|
2225
2865
|
return df.filter(mask)
|
|
2226
2866
|
|
|
2227
2867
|
|
|
@@ -2237,14 +2877,19 @@ class Projection(IR):
|
|
|
2237
2877
|
self.children = (df,)
|
|
2238
2878
|
|
|
2239
2879
|
@classmethod
|
|
2880
|
+
@log_do_evaluate
|
|
2240
2881
|
@nvtx_annotate_cudf_polars(message="Projection")
|
|
2241
|
-
def do_evaluate(
|
|
2882
|
+
def do_evaluate(
|
|
2883
|
+
cls, schema: Schema, df: DataFrame, *, context: IRExecutionContext
|
|
2884
|
+
) -> DataFrame:
|
|
2242
2885
|
"""Evaluate and return a dataframe."""
|
|
2243
2886
|
# This can reorder things.
|
|
2244
2887
|
columns = broadcast(
|
|
2245
|
-
*(df.column_map[name] for name in schema),
|
|
2888
|
+
*(df.column_map[name] for name in schema),
|
|
2889
|
+
target_length=df.num_rows,
|
|
2890
|
+
stream=df.stream,
|
|
2246
2891
|
)
|
|
2247
|
-
return DataFrame(columns)
|
|
2892
|
+
return DataFrame(columns, stream=df.stream)
|
|
2248
2893
|
|
|
2249
2894
|
|
|
2250
2895
|
class MergeSorted(IR):
|
|
@@ -2270,23 +2915,31 @@ class MergeSorted(IR):
|
|
|
2270
2915
|
self._non_child_args = (key,)
|
|
2271
2916
|
|
|
2272
2917
|
@classmethod
|
|
2918
|
+
@log_do_evaluate
|
|
2273
2919
|
@nvtx_annotate_cudf_polars(message="MergeSorted")
|
|
2274
|
-
def do_evaluate(
|
|
2920
|
+
def do_evaluate(
|
|
2921
|
+
cls, key: str, *dfs: DataFrame, context: IRExecutionContext
|
|
2922
|
+
) -> DataFrame:
|
|
2275
2923
|
"""Evaluate and return a dataframe."""
|
|
2276
|
-
|
|
2277
|
-
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
|
|
2924
|
+
with context.stream_ordered_after(*dfs) as stream:
|
|
2925
|
+
left, right = dfs
|
|
2926
|
+
right = right.discard_columns(
|
|
2927
|
+
right.column_names_set - left.column_names_set
|
|
2928
|
+
)
|
|
2929
|
+
on_col_left = left.select_columns({key})[0]
|
|
2930
|
+
on_col_right = right.select_columns({key})[0]
|
|
2931
|
+
return DataFrame.from_table(
|
|
2932
|
+
plc.merge.merge(
|
|
2933
|
+
[right.table, left.table],
|
|
2934
|
+
[left.column_names.index(key), right.column_names.index(key)],
|
|
2935
|
+
[on_col_left.order, on_col_right.order],
|
|
2936
|
+
[on_col_left.null_order, on_col_right.null_order],
|
|
2937
|
+
stream=stream,
|
|
2938
|
+
),
|
|
2939
|
+
left.column_names,
|
|
2940
|
+
left.dtypes,
|
|
2941
|
+
stream=stream,
|
|
2942
|
+
)
|
|
2290
2943
|
|
|
2291
2944
|
|
|
2292
2945
|
class MapFunction(IR):
|
|
@@ -2347,7 +3000,7 @@ class MapFunction(IR):
|
|
|
2347
3000
|
index = frozenset(indices)
|
|
2348
3001
|
pivotees = [name for name in df.schema if name not in index]
|
|
2349
3002
|
if not all(
|
|
2350
|
-
dtypes.can_cast(df.schema[p].
|
|
3003
|
+
dtypes.can_cast(df.schema[p].plc_type, self.schema[value_name].plc_type)
|
|
2351
3004
|
for p in pivotees
|
|
2352
3005
|
):
|
|
2353
3006
|
raise NotImplementedError(
|
|
@@ -2362,6 +3015,8 @@ class MapFunction(IR):
|
|
|
2362
3015
|
)
|
|
2363
3016
|
elif self.name == "row_index":
|
|
2364
3017
|
col_name, offset = options
|
|
3018
|
+
if col_name in df.schema:
|
|
3019
|
+
raise NotImplementedError("Duplicate row index name")
|
|
2365
3020
|
self.options = (col_name, offset)
|
|
2366
3021
|
elif self.name == "fast_count":
|
|
2367
3022
|
# TODO: Remove this once all scan types support projections
|
|
@@ -2390,9 +3045,16 @@ class MapFunction(IR):
|
|
|
2390
3045
|
)
|
|
2391
3046
|
|
|
2392
3047
|
@classmethod
|
|
3048
|
+
@log_do_evaluate
|
|
2393
3049
|
@nvtx_annotate_cudf_polars(message="MapFunction")
|
|
2394
3050
|
def do_evaluate(
|
|
2395
|
-
cls,
|
|
3051
|
+
cls,
|
|
3052
|
+
schema: Schema,
|
|
3053
|
+
name: str,
|
|
3054
|
+
options: Any,
|
|
3055
|
+
df: DataFrame,
|
|
3056
|
+
*,
|
|
3057
|
+
context: IRExecutionContext,
|
|
2396
3058
|
) -> DataFrame:
|
|
2397
3059
|
"""Evaluate and return a dataframe."""
|
|
2398
3060
|
if name == "rechunk":
|
|
@@ -2409,7 +3071,10 @@ class MapFunction(IR):
|
|
|
2409
3071
|
index = df.column_names.index(to_explode)
|
|
2410
3072
|
subset = df.column_names_set - {to_explode}
|
|
2411
3073
|
return DataFrame.from_table(
|
|
2412
|
-
plc.lists.explode_outer(df.table, index
|
|
3074
|
+
plc.lists.explode_outer(df.table, index, stream=df.stream),
|
|
3075
|
+
df.column_names,
|
|
3076
|
+
df.dtypes,
|
|
3077
|
+
stream=df.stream,
|
|
2413
3078
|
).sorted_like(df, subset=subset)
|
|
2414
3079
|
elif name == "unpivot":
|
|
2415
3080
|
(
|
|
@@ -2423,7 +3088,7 @@ class MapFunction(IR):
|
|
|
2423
3088
|
index_columns = [
|
|
2424
3089
|
Column(tiled, name=name, dtype=old.dtype)
|
|
2425
3090
|
for tiled, name, old in zip(
|
|
2426
|
-
plc.reshape.tile(selected.table, npiv).columns(),
|
|
3091
|
+
plc.reshape.tile(selected.table, npiv, stream=df.stream).columns(),
|
|
2427
3092
|
indices,
|
|
2428
3093
|
selected.columns,
|
|
2429
3094
|
strict=True,
|
|
@@ -2434,18 +3099,23 @@ class MapFunction(IR):
|
|
|
2434
3099
|
[
|
|
2435
3100
|
plc.Column.from_arrow(
|
|
2436
3101
|
pl.Series(
|
|
2437
|
-
values=pivotees, dtype=schema[variable_name].
|
|
2438
|
-
)
|
|
3102
|
+
values=pivotees, dtype=schema[variable_name].polars_type
|
|
3103
|
+
),
|
|
3104
|
+
stream=df.stream,
|
|
2439
3105
|
)
|
|
2440
3106
|
]
|
|
2441
3107
|
),
|
|
2442
3108
|
df.num_rows,
|
|
3109
|
+
stream=df.stream,
|
|
2443
3110
|
).columns()
|
|
2444
3111
|
value_column = plc.concatenate.concatenate(
|
|
2445
3112
|
[
|
|
2446
|
-
df.column_map[pivotee]
|
|
3113
|
+
df.column_map[pivotee]
|
|
3114
|
+
.astype(schema[value_name], stream=df.stream)
|
|
3115
|
+
.obj
|
|
2447
3116
|
for pivotee in pivotees
|
|
2448
|
-
]
|
|
3117
|
+
],
|
|
3118
|
+
stream=df.stream,
|
|
2449
3119
|
)
|
|
2450
3120
|
return DataFrame(
|
|
2451
3121
|
[
|
|
@@ -2454,22 +3124,23 @@ class MapFunction(IR):
|
|
|
2454
3124
|
variable_column, name=variable_name, dtype=schema[variable_name]
|
|
2455
3125
|
),
|
|
2456
3126
|
Column(value_column, name=value_name, dtype=schema[value_name]),
|
|
2457
|
-
]
|
|
3127
|
+
],
|
|
3128
|
+
stream=df.stream,
|
|
2458
3129
|
)
|
|
2459
3130
|
elif name == "row_index":
|
|
2460
3131
|
col_name, offset = options
|
|
2461
3132
|
dtype = schema[col_name]
|
|
2462
|
-
step = plc.Scalar.from_py(1, dtype.
|
|
2463
|
-
init = plc.Scalar.from_py(offset, dtype.
|
|
3133
|
+
step = plc.Scalar.from_py(1, dtype.plc_type, stream=df.stream)
|
|
3134
|
+
init = plc.Scalar.from_py(offset, dtype.plc_type, stream=df.stream)
|
|
2464
3135
|
index_col = Column(
|
|
2465
|
-
plc.filling.sequence(df.num_rows, init, step),
|
|
3136
|
+
plc.filling.sequence(df.num_rows, init, step, stream=df.stream),
|
|
2466
3137
|
is_sorted=plc.types.Sorted.YES,
|
|
2467
3138
|
order=plc.types.Order.ASCENDING,
|
|
2468
3139
|
null_order=plc.types.NullOrder.AFTER,
|
|
2469
3140
|
name=col_name,
|
|
2470
3141
|
dtype=dtype,
|
|
2471
3142
|
)
|
|
2472
|
-
return DataFrame([index_col, *df.columns])
|
|
3143
|
+
return DataFrame([index_col, *df.columns], stream=df.stream)
|
|
2473
3144
|
else:
|
|
2474
3145
|
raise AssertionError("Should never be reached") # pragma: no cover
|
|
2475
3146
|
|
|
@@ -2490,15 +3161,20 @@ class Union(IR):
|
|
|
2490
3161
|
schema = self.children[0].schema
|
|
2491
3162
|
|
|
2492
3163
|
@classmethod
|
|
3164
|
+
@log_do_evaluate
|
|
2493
3165
|
@nvtx_annotate_cudf_polars(message="Union")
|
|
2494
|
-
def do_evaluate(
|
|
3166
|
+
def do_evaluate(
|
|
3167
|
+
cls, zlice: Zlice | None, *dfs: DataFrame, context: IRExecutionContext
|
|
3168
|
+
) -> DataFrame:
|
|
2495
3169
|
"""Evaluate and return a dataframe."""
|
|
2496
|
-
|
|
2497
|
-
|
|
2498
|
-
|
|
2499
|
-
|
|
2500
|
-
|
|
2501
|
-
|
|
3170
|
+
with context.stream_ordered_after(*dfs) as stream:
|
|
3171
|
+
# TODO: only evaluate what we need if we have a slice?
|
|
3172
|
+
return DataFrame.from_table(
|
|
3173
|
+
plc.concatenate.concatenate([df.table for df in dfs], stream=stream),
|
|
3174
|
+
dfs[0].column_names,
|
|
3175
|
+
dfs[0].dtypes,
|
|
3176
|
+
stream=stream,
|
|
3177
|
+
).slice(zlice)
|
|
2502
3178
|
|
|
2503
3179
|
|
|
2504
3180
|
class HConcat(IR):
|
|
@@ -2519,7 +3195,9 @@ class HConcat(IR):
|
|
|
2519
3195
|
self.children = children
|
|
2520
3196
|
|
|
2521
3197
|
@staticmethod
|
|
2522
|
-
def _extend_with_nulls(
|
|
3198
|
+
def _extend_with_nulls(
|
|
3199
|
+
table: plc.Table, *, nrows: int, stream: Stream
|
|
3200
|
+
) -> plc.Table:
|
|
2523
3201
|
"""
|
|
2524
3202
|
Extend a table with nulls.
|
|
2525
3203
|
|
|
@@ -2529,6 +3207,8 @@ class HConcat(IR):
|
|
|
2529
3207
|
Table to extend
|
|
2530
3208
|
nrows
|
|
2531
3209
|
Number of additional rows
|
|
3210
|
+
stream
|
|
3211
|
+
CUDA stream used for device memory operations and kernel launches
|
|
2532
3212
|
|
|
2533
3213
|
Returns
|
|
2534
3214
|
-------
|
|
@@ -2539,45 +3219,61 @@ class HConcat(IR):
|
|
|
2539
3219
|
table,
|
|
2540
3220
|
plc.Table(
|
|
2541
3221
|
[
|
|
2542
|
-
plc.Column.all_null_like(column, nrows)
|
|
3222
|
+
plc.Column.all_null_like(column, nrows, stream=stream)
|
|
2543
3223
|
for column in table.columns()
|
|
2544
3224
|
]
|
|
2545
3225
|
),
|
|
2546
|
-
]
|
|
3226
|
+
],
|
|
3227
|
+
stream=stream,
|
|
2547
3228
|
)
|
|
2548
3229
|
|
|
2549
3230
|
@classmethod
|
|
3231
|
+
@log_do_evaluate
|
|
2550
3232
|
@nvtx_annotate_cudf_polars(message="HConcat")
|
|
2551
3233
|
def do_evaluate(
|
|
2552
3234
|
cls,
|
|
2553
3235
|
should_broadcast: bool, # noqa: FBT001
|
|
2554
3236
|
*dfs: DataFrame,
|
|
3237
|
+
context: IRExecutionContext,
|
|
2555
3238
|
) -> DataFrame:
|
|
2556
3239
|
"""Evaluate and return a dataframe."""
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
for df in
|
|
2570
|
-
|
|
2571
|
-
|
|
2572
|
-
|
|
2573
|
-
|
|
2574
|
-
df
|
|
2575
|
-
|
|
2576
|
-
|
|
2577
|
-
|
|
3240
|
+
with context.stream_ordered_after(*dfs) as stream:
|
|
3241
|
+
# Special should_broadcast case.
|
|
3242
|
+
# Used to recombine decomposed expressions
|
|
3243
|
+
if should_broadcast:
|
|
3244
|
+
result = DataFrame(
|
|
3245
|
+
broadcast(
|
|
3246
|
+
*itertools.chain.from_iterable(df.columns for df in dfs),
|
|
3247
|
+
stream=stream,
|
|
3248
|
+
),
|
|
3249
|
+
stream=stream,
|
|
3250
|
+
)
|
|
3251
|
+
else:
|
|
3252
|
+
max_rows = max(df.num_rows for df in dfs)
|
|
3253
|
+
# Horizontal concatenation extends shorter tables with nulls
|
|
3254
|
+
result = DataFrame(
|
|
3255
|
+
itertools.chain.from_iterable(
|
|
3256
|
+
df.columns
|
|
3257
|
+
for df in (
|
|
3258
|
+
df
|
|
3259
|
+
if df.num_rows == max_rows
|
|
3260
|
+
else DataFrame.from_table(
|
|
3261
|
+
cls._extend_with_nulls(
|
|
3262
|
+
df.table,
|
|
3263
|
+
nrows=max_rows - df.num_rows,
|
|
3264
|
+
stream=stream,
|
|
3265
|
+
),
|
|
3266
|
+
df.column_names,
|
|
3267
|
+
df.dtypes,
|
|
3268
|
+
stream=stream,
|
|
3269
|
+
)
|
|
3270
|
+
for df in dfs
|
|
3271
|
+
)
|
|
3272
|
+
),
|
|
3273
|
+
stream=stream,
|
|
2578
3274
|
)
|
|
2579
|
-
|
|
2580
|
-
|
|
3275
|
+
|
|
3276
|
+
return result
|
|
2581
3277
|
|
|
2582
3278
|
|
|
2583
3279
|
class Empty(IR):
|
|
@@ -2592,16 +3288,23 @@ class Empty(IR):
|
|
|
2592
3288
|
self.children = ()
|
|
2593
3289
|
|
|
2594
3290
|
@classmethod
|
|
3291
|
+
@log_do_evaluate
|
|
2595
3292
|
@nvtx_annotate_cudf_polars(message="Empty")
|
|
2596
|
-
def do_evaluate(
|
|
3293
|
+
def do_evaluate(
|
|
3294
|
+
cls, schema: Schema, *, context: IRExecutionContext
|
|
3295
|
+
) -> DataFrame: # pragma: no cover
|
|
2597
3296
|
"""Evaluate and return a dataframe."""
|
|
3297
|
+
stream = context.get_cuda_stream()
|
|
2598
3298
|
return DataFrame(
|
|
2599
3299
|
[
|
|
2600
3300
|
Column(
|
|
2601
|
-
plc.column_factories.make_empty_column(
|
|
3301
|
+
plc.column_factories.make_empty_column(
|
|
3302
|
+
dtype.plc_type, stream=stream
|
|
3303
|
+
),
|
|
2602
3304
|
dtype=dtype,
|
|
2603
3305
|
name=name,
|
|
2604
3306
|
)
|
|
2605
3307
|
for name, dtype in schema.items()
|
|
2606
|
-
]
|
|
3308
|
+
],
|
|
3309
|
+
stream=stream,
|
|
2607
3310
|
)
|