cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +60 -15
- cudf_polars/containers/column.py +137 -77
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +256 -114
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +33 -3
- cudf_polars/dsl/expressions/unary.py +126 -64
- cudf_polars/dsl/ir.py +1053 -350
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +307 -107
- cudf_polars/dsl/utils/aggregations.py +43 -30
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +55 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +792 -2
- cudf_polars/experimental/benchmarks/utils.py +596 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +30 -15
- cudf_polars/experimental/groupby.py +25 -4
- cudf_polars/experimental/io.py +156 -124
- cudf_polars/experimental/join.py +53 -23
- cudf_polars/experimental/parallel.py +68 -19
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
- cudf_polars/experimental/rapidsmpf/core.py +488 -0
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
- cudf_polars/experimental/rapidsmpf/io.py +696 -0
- cudf_polars/experimental/rapidsmpf/join.py +322 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
- cudf_polars/experimental/rapidsmpf/union.py +115 -0
- cudf_polars/experimental/rapidsmpf/utils.py +374 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +46 -12
- cudf_polars/experimental/sort.py +100 -26
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +93 -17
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +473 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +5 -4
- cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
- cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
8
|
import itertools
|
|
9
|
+
from decimal import Decimal
|
|
9
10
|
from functools import partial
|
|
10
11
|
from typing import TYPE_CHECKING, Any
|
|
11
12
|
|
|
@@ -16,7 +17,7 @@ import pylibcudf as plc
|
|
|
16
17
|
from cudf_polars.containers import DataType
|
|
17
18
|
from cudf_polars.dsl import expr, ir
|
|
18
19
|
from cudf_polars.dsl.expressions.base import ExecutionContext
|
|
19
|
-
from cudf_polars.utils.versions import POLARS_VERSION_LT_1323
|
|
20
|
+
from cudf_polars.utils.versions import POLARS_VERSION_LT_134, POLARS_VERSION_LT_1323
|
|
20
21
|
|
|
21
22
|
if TYPE_CHECKING:
|
|
22
23
|
from collections.abc import Callable, Generator, Iterable, Sequence
|
|
@@ -45,6 +46,11 @@ def replace_nulls(col: expr.Expr, value: Any, *, is_top: bool) -> expr.Expr:
|
|
|
45
46
|
"""
|
|
46
47
|
if not is_top:
|
|
47
48
|
return col
|
|
49
|
+
if isinstance(value, int) and value == 0:
|
|
50
|
+
dtype = col.dtype.plc_type
|
|
51
|
+
value = (
|
|
52
|
+
Decimal(0).scaleb(dtype.scale()) if plc.traits.is_fixed_point(dtype) else 0
|
|
53
|
+
)
|
|
48
54
|
return expr.UnaryFunction(
|
|
49
55
|
col.dtype, "fill_null", (), col, expr.Literal(col.dtype, value)
|
|
50
56
|
)
|
|
@@ -91,17 +97,25 @@ def decompose_single_agg(
|
|
|
91
97
|
name = named_expr.name
|
|
92
98
|
if isinstance(agg, expr.UnaryFunction) and agg.name in {
|
|
93
99
|
"rank",
|
|
100
|
+
"fill_null_with_strategy",
|
|
101
|
+
"cum_sum",
|
|
94
102
|
}:
|
|
95
103
|
if context != ExecutionContext.WINDOW:
|
|
96
104
|
raise NotImplementedError(
|
|
97
105
|
f"{agg.name} is not supported in groupby or rolling context"
|
|
98
106
|
)
|
|
107
|
+
if agg.name == "fill_null_with_strategy" and (
|
|
108
|
+
strategy := agg.options[0]
|
|
109
|
+
) not in {"forward", "backward"}:
|
|
110
|
+
raise NotImplementedError(
|
|
111
|
+
f"fill_null({strategy=}) not supported in a groupy or rolling context"
|
|
112
|
+
)
|
|
99
113
|
# Ensure Polars semantics for dtype:
|
|
100
114
|
# - average -> Float64
|
|
101
115
|
# - min/max/dense/ordinal -> IDX_DTYPE (UInt32/UInt64)
|
|
102
116
|
post_col: expr.Expr = expr.Col(agg.dtype, name)
|
|
103
117
|
if agg.name == "rank":
|
|
104
|
-
post_col = expr.Cast(agg.dtype, post_col)
|
|
118
|
+
post_col = expr.Cast(agg.dtype, True, post_col) # noqa: FBT003
|
|
105
119
|
|
|
106
120
|
return [(named_expr, True)], named_expr.reconstruct(post_col)
|
|
107
121
|
if isinstance(agg, expr.UnaryFunction) and agg.name == "null_count":
|
|
@@ -117,10 +131,10 @@ def decompose_single_agg(
|
|
|
117
131
|
sum_name = next(name_generator)
|
|
118
132
|
sum_agg = expr.NamedExpr(
|
|
119
133
|
sum_name,
|
|
120
|
-
expr.Agg(u32, "sum", (), expr.Cast(u32, is_null_bool)),
|
|
134
|
+
expr.Agg(u32, "sum", (), context, expr.Cast(u32, True, is_null_bool)), # noqa: FBT003
|
|
121
135
|
)
|
|
122
136
|
return [(sum_agg, True)], named_expr.reconstruct(
|
|
123
|
-
expr.Cast(u32, expr.Col(u32, sum_name))
|
|
137
|
+
expr.Cast(u32, True, expr.Col(u32, sum_name)) # noqa: FBT003
|
|
124
138
|
)
|
|
125
139
|
if isinstance(agg, expr.Col):
|
|
126
140
|
# TODO: collect_list produces null for empty group in libcudf, empty list in polars.
|
|
@@ -146,15 +160,6 @@ def decompose_single_agg(
|
|
|
146
160
|
return [(named_expr, True)], named_expr.reconstruct(expr.Col(agg.dtype, name))
|
|
147
161
|
if isinstance(agg, (expr.Literal, expr.LiteralColumn)):
|
|
148
162
|
return [], named_expr
|
|
149
|
-
if (
|
|
150
|
-
is_top
|
|
151
|
-
and isinstance(agg, expr.UnaryFunction)
|
|
152
|
-
and agg.name == "fill_null_with_strategy"
|
|
153
|
-
):
|
|
154
|
-
strategy, _ = agg.options
|
|
155
|
-
raise NotImplementedError(
|
|
156
|
-
f"fill_null_with_strategy({strategy!r}) is not supported in groupby aggregations"
|
|
157
|
-
)
|
|
158
163
|
if isinstance(agg, expr.Agg):
|
|
159
164
|
if agg.name == "quantile":
|
|
160
165
|
# Second child the requested quantile (which is asserted
|
|
@@ -163,7 +168,7 @@ def decompose_single_agg(
|
|
|
163
168
|
else:
|
|
164
169
|
(child,) = agg.children
|
|
165
170
|
needs_masking = agg.name in {"min", "max"} and plc.traits.is_floating_point(
|
|
166
|
-
child.dtype.
|
|
171
|
+
child.dtype.plc_type
|
|
167
172
|
)
|
|
168
173
|
if needs_masking and agg.options:
|
|
169
174
|
# pl.col("a").nan_max or nan_min
|
|
@@ -177,7 +182,7 @@ def decompose_single_agg(
|
|
|
177
182
|
if any(has_agg for _, has_agg in aggs):
|
|
178
183
|
raise NotImplementedError("Nested aggs in groupby not supported")
|
|
179
184
|
|
|
180
|
-
child_dtype = child.dtype.
|
|
185
|
+
child_dtype = child.dtype.plc_type
|
|
181
186
|
req = agg.agg_request
|
|
182
187
|
is_median = agg.name == "median"
|
|
183
188
|
is_quantile = agg.name == "quantile"
|
|
@@ -186,18 +191,22 @@ def decompose_single_agg(
|
|
|
186
191
|
# mean/median on decimal: Polars returns float -> pre-cast
|
|
187
192
|
decimal_unsupported = False
|
|
188
193
|
if plc.traits.is_fixed_point(child_dtype):
|
|
189
|
-
|
|
194
|
+
cast_for_quantile = is_quantile and not POLARS_VERSION_LT_134
|
|
195
|
+
cast_for_mean_or_median = (
|
|
196
|
+
agg.name in {"mean", "median"}
|
|
197
|
+
) and plc.traits.is_floating_point(agg.dtype.plc_type)
|
|
198
|
+
|
|
199
|
+
if cast_for_quantile or cast_for_mean_or_median:
|
|
200
|
+
child = expr.Cast(
|
|
201
|
+
agg.dtype
|
|
202
|
+
if plc.traits.is_floating_point(agg.dtype.plc_type)
|
|
203
|
+
else DataType(pl.Float64()),
|
|
204
|
+
True, # noqa: FBT003
|
|
205
|
+
child,
|
|
206
|
+
)
|
|
207
|
+
child_dtype = child.dtype.plc_type
|
|
208
|
+
elif is_quantile and POLARS_VERSION_LT_134: # pragma: no cover
|
|
190
209
|
decimal_unsupported = True
|
|
191
|
-
elif agg.name in {"mean", "median"}:
|
|
192
|
-
tid = agg.dtype.plc.id()
|
|
193
|
-
if tid in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}:
|
|
194
|
-
cast_to = (
|
|
195
|
-
DataType(pl.Float64)
|
|
196
|
-
if tid == plc.TypeId.FLOAT64
|
|
197
|
-
else DataType(pl.Float32)
|
|
198
|
-
)
|
|
199
|
-
child = expr.Cast(cast_to, child)
|
|
200
|
-
child_dtype = child.dtype.plc
|
|
201
210
|
|
|
202
211
|
is_group_quantile_supported = plc.traits.is_integral(
|
|
203
212
|
child_dtype
|
|
@@ -221,9 +230,13 @@ def decompose_single_agg(
|
|
|
221
230
|
|
|
222
231
|
if agg.name == "sum":
|
|
223
232
|
col = (
|
|
224
|
-
expr.Cast(
|
|
233
|
+
expr.Cast(
|
|
234
|
+
agg.dtype,
|
|
235
|
+
True, # noqa: FBT003
|
|
236
|
+
expr.Col(DataType(pl.datatypes.Int64()), name),
|
|
237
|
+
)
|
|
225
238
|
if (
|
|
226
|
-
plc.traits.is_integral(agg.dtype.
|
|
239
|
+
plc.traits.is_integral(agg.dtype.plc_type)
|
|
227
240
|
and agg.dtype.id() != plc.TypeId.INT64
|
|
228
241
|
)
|
|
229
242
|
else expr.Col(agg.dtype, name)
|
|
@@ -272,9 +285,9 @@ def decompose_single_agg(
|
|
|
272
285
|
post_agg_col: expr.Expr = expr.Col(
|
|
273
286
|
DataType(pl.Float64()), name
|
|
274
287
|
) # libcudf promotes to float64
|
|
275
|
-
if agg.dtype.
|
|
288
|
+
if agg.dtype.plc_type.id() == plc.TypeId.FLOAT32:
|
|
276
289
|
# Cast back to float32 to match Polars
|
|
277
|
-
post_agg_col = expr.Cast(agg.dtype, post_agg_col)
|
|
290
|
+
post_agg_col = expr.Cast(agg.dtype, True, post_agg_col) # noqa: FBT003
|
|
278
291
|
return [(named_expr, True)], named_expr.reconstruct(post_agg_col)
|
|
279
292
|
else:
|
|
280
293
|
return [(named_expr, True)], named_expr.reconstruct(
|
cudf_polars/dsl/utils/reshape.py
CHANGED
|
@@ -4,12 +4,19 @@
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
7
9
|
import pylibcudf as plc
|
|
8
10
|
|
|
9
11
|
from cudf_polars.containers import Column
|
|
10
12
|
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from rmm.pylibrmm.stream import Stream
|
|
15
|
+
|
|
11
16
|
|
|
12
|
-
def broadcast(
|
|
17
|
+
def broadcast(
|
|
18
|
+
*columns: Column, target_length: int | None = None, stream: Stream
|
|
19
|
+
) -> list[Column]:
|
|
13
20
|
"""
|
|
14
21
|
Broadcast a sequence of columns to a common length.
|
|
15
22
|
|
|
@@ -20,6 +27,9 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
|
|
|
20
27
|
target_length
|
|
21
28
|
Optional length to broadcast to. If not provided, uses the
|
|
22
29
|
non-unit length of existing columns.
|
|
30
|
+
stream
|
|
31
|
+
CUDA stream used for device memory operations and kernel launches
|
|
32
|
+
on this dataframe.
|
|
23
33
|
|
|
24
34
|
Returns
|
|
25
35
|
-------
|
|
@@ -63,7 +73,9 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
|
|
|
63
73
|
column
|
|
64
74
|
if column.size != 1
|
|
65
75
|
else Column(
|
|
66
|
-
plc.Column.from_scalar(
|
|
76
|
+
plc.Column.from_scalar(
|
|
77
|
+
column.obj_scalar(stream=stream), nrows, stream=stream
|
|
78
|
+
),
|
|
67
79
|
is_sorted=plc.types.Sorted.YES,
|
|
68
80
|
order=plc.types.Order.ASCENDING,
|
|
69
81
|
null_order=plc.types.NullOrder.BEFORE,
|
cudf_polars/dsl/utils/rolling.py
CHANGED
|
@@ -13,7 +13,7 @@ from cudf_polars.dsl import expr, ir
|
|
|
13
13
|
from cudf_polars.dsl.expressions.base import ExecutionContext
|
|
14
14
|
from cudf_polars.dsl.utils.aggregations import apply_pre_evaluation
|
|
15
15
|
from cudf_polars.dsl.utils.naming import unique_names
|
|
16
|
-
from cudf_polars.dsl.utils.windows import
|
|
16
|
+
from cudf_polars.dsl.utils.windows import duration_to_int
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
19
|
from collections.abc import Sequence
|
|
@@ -74,10 +74,13 @@ def rewrite_rolling(
|
|
|
74
74
|
index_name = options.rolling.index_column
|
|
75
75
|
index_dtype = schema[index_name]
|
|
76
76
|
index_col = expr.Col(index_dtype, index_name)
|
|
77
|
-
if
|
|
77
|
+
if (
|
|
78
|
+
plc.traits.is_integral(index_dtype.plc_type)
|
|
79
|
+
and index_dtype.id() != plc.TypeId.INT64
|
|
80
|
+
):
|
|
78
81
|
plc_index_dtype = plc.DataType(plc.TypeId.INT64)
|
|
79
82
|
else:
|
|
80
|
-
plc_index_dtype = index_dtype.
|
|
83
|
+
plc_index_dtype = index_dtype.plc_type
|
|
81
84
|
index = expr.NamedExpr(index_name, index_col)
|
|
82
85
|
temp_prefix = "_" * max(map(len, schema))
|
|
83
86
|
if len(aggs) > 0:
|
|
@@ -92,9 +95,9 @@ def rewrite_rolling(
|
|
|
92
95
|
else:
|
|
93
96
|
rolling_schema = schema
|
|
94
97
|
apply_post_evaluation = lambda inp: inp # noqa: E731
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
+
preceding_ordinal = duration_to_int(plc_index_dtype, *options.rolling.offset)
|
|
99
|
+
following_ordinal = duration_to_int(plc_index_dtype, *options.rolling.period)
|
|
100
|
+
|
|
98
101
|
if (n := len(keys)) > 0:
|
|
99
102
|
# Grouped rolling in polars sorts the output by the groups.
|
|
100
103
|
inp = ir.Sort(
|
|
@@ -110,8 +113,9 @@ def rewrite_rolling(
|
|
|
110
113
|
ir.Rolling(
|
|
111
114
|
rolling_schema,
|
|
112
115
|
index,
|
|
113
|
-
|
|
114
|
-
|
|
116
|
+
plc_index_dtype,
|
|
117
|
+
preceding_ordinal,
|
|
118
|
+
following_ordinal,
|
|
115
119
|
options.rolling.closed_window,
|
|
116
120
|
keys,
|
|
117
121
|
aggs,
|
cudf_polars/dsl/utils/windows.py
CHANGED
|
@@ -12,7 +12,9 @@ import polars as pl
|
|
|
12
12
|
import pylibcudf as plc
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
|
-
from
|
|
15
|
+
from rmm.pylibrmm.stream import Stream
|
|
16
|
+
|
|
17
|
+
from cudf_polars.typing import ClosedInterval
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
__all__ = [
|
|
@@ -75,7 +77,7 @@ def duration_to_int(
|
|
|
75
77
|
return -value if negative else value
|
|
76
78
|
|
|
77
79
|
|
|
78
|
-
def duration_to_scalar(dtype: plc.DataType, value: int) -> plc.Scalar:
|
|
80
|
+
def duration_to_scalar(dtype: plc.DataType, value: int, stream: Stream) -> plc.Scalar:
|
|
79
81
|
"""
|
|
80
82
|
Convert a raw polars duration value to a pylibcudf scalar.
|
|
81
83
|
|
|
@@ -86,6 +88,9 @@ def duration_to_scalar(dtype: plc.DataType, value: int) -> plc.Scalar:
|
|
|
86
88
|
value
|
|
87
89
|
The raw value as in integer. If `dtype` represents a timestamp
|
|
88
90
|
type, this should be in nanoseconds.
|
|
91
|
+
stream
|
|
92
|
+
CUDA stream used for device memory operations and kernel launches
|
|
93
|
+
on this dataframe. The returned scalar will be valid on this stream.
|
|
89
94
|
|
|
90
95
|
Returns
|
|
91
96
|
-------
|
|
@@ -99,20 +104,28 @@ def duration_to_scalar(dtype: plc.DataType, value: int) -> plc.Scalar:
|
|
|
99
104
|
"""
|
|
100
105
|
tid = dtype.id()
|
|
101
106
|
if tid == plc.TypeId.INT64:
|
|
102
|
-
return plc.Scalar.from_py(value, dtype)
|
|
107
|
+
return plc.Scalar.from_py(value, dtype, stream=stream)
|
|
103
108
|
elif tid == plc.TypeId.TIMESTAMP_NANOSECONDS:
|
|
104
|
-
return plc.Scalar.from_py(
|
|
109
|
+
return plc.Scalar.from_py(
|
|
110
|
+
value, plc.DataType(plc.TypeId.DURATION_NANOSECONDS), stream=stream
|
|
111
|
+
)
|
|
105
112
|
elif tid == plc.TypeId.TIMESTAMP_MICROSECONDS:
|
|
106
113
|
return plc.Scalar.from_py(
|
|
107
|
-
value // 1000,
|
|
114
|
+
value // 1000,
|
|
115
|
+
plc.DataType(plc.TypeId.DURATION_MICROSECONDS),
|
|
116
|
+
stream=stream,
|
|
108
117
|
)
|
|
109
118
|
elif tid == plc.TypeId.TIMESTAMP_MILLISECONDS:
|
|
110
119
|
return plc.Scalar.from_py(
|
|
111
|
-
value // 1_000_000,
|
|
120
|
+
value // 1_000_000,
|
|
121
|
+
plc.DataType(plc.TypeId.DURATION_MILLISECONDS),
|
|
122
|
+
stream=stream,
|
|
112
123
|
)
|
|
113
124
|
elif tid == plc.TypeId.TIMESTAMP_DAYS:
|
|
114
125
|
return plc.Scalar.from_py(
|
|
115
|
-
value // 86_400_000_000_000,
|
|
126
|
+
value // 86_400_000_000_000,
|
|
127
|
+
plc.DataType(plc.TypeId.DURATION_DAYS),
|
|
128
|
+
stream=stream,
|
|
116
129
|
)
|
|
117
130
|
else:
|
|
118
131
|
raise NotImplementedError(
|
|
@@ -122,8 +135,9 @@ def duration_to_scalar(dtype: plc.DataType, value: int) -> plc.Scalar:
|
|
|
122
135
|
|
|
123
136
|
def offsets_to_windows(
|
|
124
137
|
dtype: plc.DataType,
|
|
125
|
-
|
|
126
|
-
|
|
138
|
+
offset_i: int,
|
|
139
|
+
period_i: int,
|
|
140
|
+
stream: Stream,
|
|
127
141
|
) -> tuple[plc.Scalar, plc.Scalar]:
|
|
128
142
|
"""
|
|
129
143
|
Convert polars offset/period pair to preceding/following windows.
|
|
@@ -132,21 +146,22 @@ def offsets_to_windows(
|
|
|
132
146
|
----------
|
|
133
147
|
dtype
|
|
134
148
|
Datatype of column defining windows
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
149
|
+
offset_i
|
|
150
|
+
Integer ordinal representing the offset of the window.
|
|
151
|
+
See :func:`duration_to_int` for more details.
|
|
152
|
+
period_i
|
|
153
|
+
Integer ordinal representing the period of the window.
|
|
154
|
+
See :func:`duration_to_int` for more details.
|
|
155
|
+
stream
|
|
156
|
+
CUDA stream used for device memory operations and kernel launches
|
|
139
157
|
|
|
140
158
|
Returns
|
|
141
159
|
-------
|
|
142
|
-
tuple of preceding and following windows as
|
|
160
|
+
tuple of preceding and following windows as host integers.
|
|
143
161
|
"""
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
# Libcudf uses current_row - preceding, ..., current_row + following
|
|
148
|
-
return duration_to_scalar(dtype, -offset_i), duration_to_scalar(
|
|
149
|
-
dtype, offset_i + period_i
|
|
162
|
+
return (
|
|
163
|
+
duration_to_scalar(dtype, -offset_i, stream=stream),
|
|
164
|
+
duration_to_scalar(dtype, offset_i + period_i, stream=stream),
|
|
150
165
|
)
|
|
151
166
|
|
|
152
167
|
|
cudf_polars/experimental/base.py
CHANGED
|
@@ -5,7 +5,9 @@
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
7
|
import dataclasses
|
|
8
|
+
import enum
|
|
8
9
|
from collections import defaultdict
|
|
10
|
+
from enum import IntEnum
|
|
9
11
|
from functools import cached_property
|
|
10
12
|
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar
|
|
11
13
|
|
|
@@ -20,19 +22,24 @@ if TYPE_CHECKING:
|
|
|
20
22
|
class PartitionInfo:
|
|
21
23
|
"""Partitioning information."""
|
|
22
24
|
|
|
23
|
-
__slots__ = ("count", "partitioned_on")
|
|
25
|
+
__slots__ = ("count", "io_plan", "partitioned_on")
|
|
24
26
|
count: int
|
|
25
27
|
"""Partition count."""
|
|
26
28
|
partitioned_on: tuple[NamedExpr, ...]
|
|
27
29
|
"""Columns the data is hash-partitioned on."""
|
|
30
|
+
io_plan: IOPartitionPlan | None
|
|
31
|
+
"""IO partitioning plan (Scan nodes only)."""
|
|
28
32
|
|
|
29
33
|
def __init__(
|
|
30
34
|
self,
|
|
31
35
|
count: int,
|
|
36
|
+
*,
|
|
32
37
|
partitioned_on: tuple[NamedExpr, ...] = (),
|
|
38
|
+
io_plan: IOPartitionPlan | None = None,
|
|
33
39
|
):
|
|
34
40
|
self.count = count
|
|
35
41
|
self.partitioned_on = partitioned_on
|
|
42
|
+
self.io_plan = io_plan
|
|
36
43
|
|
|
37
44
|
def keys(self, node: Node) -> Iterator[tuple[str, int]]:
|
|
38
45
|
"""Return the partitioned keys for a given node."""
|
|
@@ -108,13 +115,17 @@ class DataSourceInfo:
|
|
|
108
115
|
"""
|
|
109
116
|
|
|
110
117
|
_unique_stats_columns: set[str]
|
|
118
|
+
_read_columns: set[str]
|
|
111
119
|
|
|
112
120
|
@property
|
|
113
121
|
def row_count(self) -> ColumnStat[int]: # pragma: no cover
|
|
114
122
|
"""Data source row-count estimate."""
|
|
115
123
|
raise NotImplementedError("Sub-class must implement row_count.")
|
|
116
124
|
|
|
117
|
-
def unique_stats(
|
|
125
|
+
def unique_stats(
|
|
126
|
+
self,
|
|
127
|
+
column: str,
|
|
128
|
+
) -> UniqueStats: # pragma: no cover
|
|
118
129
|
"""Return unique-value statistics for a column."""
|
|
119
130
|
raise NotImplementedError("Sub-class must implement unique_stats.")
|
|
120
131
|
|
|
@@ -131,6 +142,10 @@ class DataSourceInfo:
|
|
|
131
142
|
"""Add a column needing unique-value information."""
|
|
132
143
|
self._unique_stats_columns.add(column)
|
|
133
144
|
|
|
145
|
+
def add_read_column(self, column: str) -> None:
|
|
146
|
+
"""Add a column needing to be read."""
|
|
147
|
+
self._read_columns.add(column)
|
|
148
|
+
|
|
134
149
|
|
|
135
150
|
class DataSourcePair(NamedTuple):
|
|
136
151
|
"""Pair of table-source and column-name information."""
|
|
@@ -230,6 +245,11 @@ class ColumnSourceInfo:
|
|
|
230
245
|
for table_source, column_name in self.table_source_pairs:
|
|
231
246
|
table_source.add_unique_stats_column(column or column_name)
|
|
232
247
|
|
|
248
|
+
def add_read_column(self, column: str | None = None) -> None:
|
|
249
|
+
"""Add a column needing to be read."""
|
|
250
|
+
for table_source, column_name in self.table_source_pairs:
|
|
251
|
+
table_source.add_read_column(column or column_name)
|
|
252
|
+
|
|
233
253
|
|
|
234
254
|
class ColumnStats:
|
|
235
255
|
"""
|
|
@@ -384,3 +404,36 @@ class StatsCollector:
|
|
|
384
404
|
self.row_count: dict[IR, ColumnStat[int]] = {}
|
|
385
405
|
self.column_stats: dict[IR, dict[str, ColumnStats]] = {}
|
|
386
406
|
self.join_info = JoinInfo()
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
class IOPartitionFlavor(IntEnum):
|
|
410
|
+
"""Flavor of IO partitioning."""
|
|
411
|
+
|
|
412
|
+
SINGLE_FILE = enum.auto() # 1:1 mapping between files and partitions
|
|
413
|
+
SPLIT_FILES = enum.auto() # Split each file into >1 partition
|
|
414
|
+
FUSED_FILES = enum.auto() # Fuse multiple files into each partition
|
|
415
|
+
SINGLE_READ = enum.auto() # One worker/task reads everything
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
class IOPartitionPlan:
|
|
419
|
+
"""
|
|
420
|
+
IO partitioning plan.
|
|
421
|
+
|
|
422
|
+
Notes
|
|
423
|
+
-----
|
|
424
|
+
The meaning of `factor` depends on the value of `flavor`:
|
|
425
|
+
- SINGLE_FILE: `factor` must be `1`.
|
|
426
|
+
- SPLIT_FILES: `factor` is the number of partitions per file.
|
|
427
|
+
- FUSED_FILES: `factor` is the number of files per partition.
|
|
428
|
+
- SINGLE_READ: `factor` is the total number of files.
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
__slots__ = ("factor", "flavor")
|
|
432
|
+
factor: int
|
|
433
|
+
flavor: IOPartitionFlavor
|
|
434
|
+
|
|
435
|
+
def __init__(self, factor: int, flavor: IOPartitionFlavor) -> None:
|
|
436
|
+
if flavor == IOPartitionFlavor.SINGLE_FILE and factor != 1: # pragma: no cover
|
|
437
|
+
raise ValueError(f"Expected factor == 1 for {flavor}, got: {factor}")
|
|
438
|
+
self.factor = factor
|
|
439
|
+
self.flavor = flavor
|
|
@@ -16,26 +16,17 @@ from __future__ import annotations
|
|
|
16
16
|
import contextlib
|
|
17
17
|
import importlib
|
|
18
18
|
import os
|
|
19
|
-
import time
|
|
20
|
-
from collections import defaultdict
|
|
21
|
-
from pathlib import Path
|
|
22
19
|
from typing import TYPE_CHECKING
|
|
23
20
|
|
|
24
|
-
import polars as pl
|
|
25
|
-
|
|
26
21
|
with contextlib.suppress(ImportError):
|
|
27
22
|
from cudf_polars.experimental.benchmarks.utils import (
|
|
28
|
-
|
|
29
|
-
RunConfig,
|
|
30
|
-
get_executor_options,
|
|
31
|
-
parse_args,
|
|
23
|
+
run_duckdb,
|
|
32
24
|
run_polars,
|
|
25
|
+
run_validate,
|
|
33
26
|
)
|
|
34
27
|
|
|
35
28
|
if TYPE_CHECKING:
|
|
36
|
-
from collections.abc import Sequence
|
|
37
29
|
from types import ModuleType
|
|
38
|
-
from typing import Any
|
|
39
30
|
|
|
40
31
|
# Without this setting, the first IO task to run
|
|
41
32
|
# on each worker takes ~15 sec extra
|
|
@@ -58,7 +49,7 @@ def valid_query(name: str) -> bool:
|
|
|
58
49
|
class PDSDSQueriesMeta(type):
|
|
59
50
|
"""Metaclass used for query lookup."""
|
|
60
51
|
|
|
61
|
-
def __getattr__(cls, name: str): # type: ignore
|
|
52
|
+
def __getattr__(cls, name: str): # type: ignore[no-untyped-def]
|
|
62
53
|
"""Query lookup."""
|
|
63
54
|
if valid_query(name):
|
|
64
55
|
q_num = int(name[1:])
|
|
@@ -88,118 +79,6 @@ class PDSDSDuckDBQueries(PDSDSQueries):
|
|
|
88
79
|
q_impl = "duckdb_impl"
|
|
89
80
|
|
|
90
81
|
|
|
91
|
-
def execute_duckdb_query(query: str, dataset_path: Path) -> pl.DataFrame:
|
|
92
|
-
"""Execute a query with DuckDB."""
|
|
93
|
-
import duckdb
|
|
94
|
-
|
|
95
|
-
conn = duckdb.connect()
|
|
96
|
-
|
|
97
|
-
statements = [
|
|
98
|
-
f"CREATE VIEW {table.stem} as SELECT * FROM read_parquet('{table.absolute()}');"
|
|
99
|
-
for table in Path(dataset_path).glob("*.parquet")
|
|
100
|
-
]
|
|
101
|
-
statements.append(query)
|
|
102
|
-
return conn.execute("\n".join(statements)).pl()
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def run_duckdb(benchmark: Any, options: Sequence[str] | None = None) -> None:
|
|
106
|
-
"""Run the benchmark with DuckDB."""
|
|
107
|
-
args = parse_args(options, num_queries=99)
|
|
108
|
-
vars(args).update({"query_set": benchmark.name})
|
|
109
|
-
run_config = RunConfig.from_args(args)
|
|
110
|
-
records: defaultdict[int, list[Record]] = defaultdict(list)
|
|
111
|
-
|
|
112
|
-
for q_id in run_config.queries:
|
|
113
|
-
try:
|
|
114
|
-
duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
|
|
115
|
-
except AttributeError as err:
|
|
116
|
-
raise NotImplementedError(f"Query {q_id} not implemented.") from err
|
|
117
|
-
|
|
118
|
-
print(f"DuckDB Executing: {q_id}")
|
|
119
|
-
records[q_id] = []
|
|
120
|
-
|
|
121
|
-
for i in range(args.iterations):
|
|
122
|
-
t0 = time.time()
|
|
123
|
-
|
|
124
|
-
result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
|
|
125
|
-
|
|
126
|
-
t1 = time.time()
|
|
127
|
-
record = Record(query=q_id, duration=t1 - t0)
|
|
128
|
-
if args.print_results:
|
|
129
|
-
print(result)
|
|
130
|
-
|
|
131
|
-
print(f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s")
|
|
132
|
-
records[q_id].append(record)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
def run_validate(benchmark: Any, options: Sequence[str] | None = None) -> None:
|
|
136
|
-
"""Validate Polars CPU vs DuckDB or Polars GPU."""
|
|
137
|
-
from polars.testing import assert_frame_equal
|
|
138
|
-
|
|
139
|
-
args = parse_args(options, num_queries=99)
|
|
140
|
-
vars(args).update({"query_set": benchmark.name})
|
|
141
|
-
run_config = RunConfig.from_args(args)
|
|
142
|
-
|
|
143
|
-
baseline = args.baseline
|
|
144
|
-
if baseline not in {"duckdb", "cpu"}:
|
|
145
|
-
raise ValueError("Baseline must be one of: 'duckdb', 'cpu'")
|
|
146
|
-
|
|
147
|
-
failures: list[int] = []
|
|
148
|
-
|
|
149
|
-
engine: pl.GPUEngine | None = None
|
|
150
|
-
if run_config.executor != "cpu":
|
|
151
|
-
engine = pl.GPUEngine(
|
|
152
|
-
raise_on_fail=True,
|
|
153
|
-
executor=run_config.executor,
|
|
154
|
-
executor_options=get_executor_options(run_config, PDSDSPolarsQueries),
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
for q_id in run_config.queries:
|
|
158
|
-
print(f"\nValidating Query {q_id}")
|
|
159
|
-
try:
|
|
160
|
-
polars_query = getattr(PDSDSPolarsQueries, f"q{q_id}")(run_config)
|
|
161
|
-
duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
|
|
162
|
-
except AttributeError as err:
|
|
163
|
-
raise NotImplementedError(f"Query {q_id} not implemented.") from err
|
|
164
|
-
|
|
165
|
-
if baseline == "duckdb":
|
|
166
|
-
base_result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
|
|
167
|
-
elif baseline == "cpu":
|
|
168
|
-
base_result = polars_query.collect(new_streaming=True)
|
|
169
|
-
|
|
170
|
-
if run_config.executor == "cpu":
|
|
171
|
-
test_result = polars_query.collect(new_streaming=True)
|
|
172
|
-
else:
|
|
173
|
-
try:
|
|
174
|
-
test_result = polars_query.collect(engine=engine)
|
|
175
|
-
except Exception as e:
|
|
176
|
-
failures.append(q_id)
|
|
177
|
-
print(f"❌ Query {q_id} failed validation: GPU execution failed.\n{e}")
|
|
178
|
-
continue
|
|
179
|
-
|
|
180
|
-
try:
|
|
181
|
-
assert_frame_equal(
|
|
182
|
-
base_result,
|
|
183
|
-
test_result,
|
|
184
|
-
check_dtypes=True,
|
|
185
|
-
check_column_order=False,
|
|
186
|
-
)
|
|
187
|
-
print(f"✅ Query {q_id} passed validation.")
|
|
188
|
-
except AssertionError as e:
|
|
189
|
-
failures.append(q_id)
|
|
190
|
-
print(f"❌ Query {q_id} failed validation:\n{e}")
|
|
191
|
-
if args.print_results:
|
|
192
|
-
print("Baseline Result:\n", base_result)
|
|
193
|
-
print("Test Result:\n", test_result)
|
|
194
|
-
|
|
195
|
-
if failures:
|
|
196
|
-
print("\nValidation Summary:")
|
|
197
|
-
print("===================")
|
|
198
|
-
print(f"{len(failures)} query(s) failed: {failures}")
|
|
199
|
-
else:
|
|
200
|
-
print("\nAll queries passed validation.")
|
|
201
|
-
|
|
202
|
-
|
|
203
82
|
if __name__ == "__main__":
|
|
204
83
|
import argparse
|
|
205
84
|
|
|
@@ -215,6 +94,13 @@ if __name__ == "__main__":
|
|
|
215
94
|
if args.engine == "polars":
|
|
216
95
|
run_polars(PDSDSPolarsQueries, extra_args, num_queries=99)
|
|
217
96
|
elif args.engine == "duckdb":
|
|
218
|
-
run_duckdb(PDSDSDuckDBQueries, extra_args)
|
|
97
|
+
run_duckdb(PDSDSDuckDBQueries, extra_args, num_queries=99)
|
|
219
98
|
elif args.engine == "validate":
|
|
220
|
-
run_validate(
|
|
99
|
+
run_validate(
|
|
100
|
+
PDSDSPolarsQueries,
|
|
101
|
+
PDSDSDuckDBQueries,
|
|
102
|
+
extra_args,
|
|
103
|
+
num_queries=99,
|
|
104
|
+
check_dtypes=True,
|
|
105
|
+
check_column_order=True,
|
|
106
|
+
)
|