cudf-polars-cu13 25.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -0
- cudf_polars/VERSION +1 -0
- cudf_polars/__init__.py +28 -0
- cudf_polars/_version.py +21 -0
- cudf_polars/callback.py +318 -0
- cudf_polars/containers/__init__.py +13 -0
- cudf_polars/containers/column.py +495 -0
- cudf_polars/containers/dataframe.py +361 -0
- cudf_polars/containers/datatype.py +137 -0
- cudf_polars/dsl/__init__.py +8 -0
- cudf_polars/dsl/expr.py +66 -0
- cudf_polars/dsl/expressions/__init__.py +8 -0
- cudf_polars/dsl/expressions/aggregation.py +226 -0
- cudf_polars/dsl/expressions/base.py +272 -0
- cudf_polars/dsl/expressions/binaryop.py +120 -0
- cudf_polars/dsl/expressions/boolean.py +326 -0
- cudf_polars/dsl/expressions/datetime.py +271 -0
- cudf_polars/dsl/expressions/literal.py +97 -0
- cudf_polars/dsl/expressions/rolling.py +643 -0
- cudf_polars/dsl/expressions/selection.py +74 -0
- cudf_polars/dsl/expressions/slicing.py +46 -0
- cudf_polars/dsl/expressions/sorting.py +85 -0
- cudf_polars/dsl/expressions/string.py +1002 -0
- cudf_polars/dsl/expressions/struct.py +137 -0
- cudf_polars/dsl/expressions/ternary.py +49 -0
- cudf_polars/dsl/expressions/unary.py +517 -0
- cudf_polars/dsl/ir.py +2607 -0
- cudf_polars/dsl/nodebase.py +164 -0
- cudf_polars/dsl/to_ast.py +359 -0
- cudf_polars/dsl/tracing.py +16 -0
- cudf_polars/dsl/translate.py +939 -0
- cudf_polars/dsl/traversal.py +224 -0
- cudf_polars/dsl/utils/__init__.py +8 -0
- cudf_polars/dsl/utils/aggregations.py +481 -0
- cudf_polars/dsl/utils/groupby.py +98 -0
- cudf_polars/dsl/utils/naming.py +34 -0
- cudf_polars/dsl/utils/replace.py +61 -0
- cudf_polars/dsl/utils/reshape.py +74 -0
- cudf_polars/dsl/utils/rolling.py +121 -0
- cudf_polars/dsl/utils/windows.py +192 -0
- cudf_polars/experimental/__init__.py +8 -0
- cudf_polars/experimental/base.py +386 -0
- cudf_polars/experimental/benchmarks/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds.py +220 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
- cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
- cudf_polars/experimental/benchmarks/pdsh.py +814 -0
- cudf_polars/experimental/benchmarks/utils.py +832 -0
- cudf_polars/experimental/dask_registers.py +200 -0
- cudf_polars/experimental/dispatch.py +156 -0
- cudf_polars/experimental/distinct.py +197 -0
- cudf_polars/experimental/explain.py +157 -0
- cudf_polars/experimental/expressions.py +590 -0
- cudf_polars/experimental/groupby.py +327 -0
- cudf_polars/experimental/io.py +943 -0
- cudf_polars/experimental/join.py +391 -0
- cudf_polars/experimental/parallel.py +423 -0
- cudf_polars/experimental/repartition.py +69 -0
- cudf_polars/experimental/scheduler.py +155 -0
- cudf_polars/experimental/select.py +188 -0
- cudf_polars/experimental/shuffle.py +354 -0
- cudf_polars/experimental/sort.py +609 -0
- cudf_polars/experimental/spilling.py +151 -0
- cudf_polars/experimental/statistics.py +795 -0
- cudf_polars/experimental/utils.py +169 -0
- cudf_polars/py.typed +0 -0
- cudf_polars/testing/__init__.py +8 -0
- cudf_polars/testing/asserts.py +448 -0
- cudf_polars/testing/io.py +122 -0
- cudf_polars/testing/plugin.py +236 -0
- cudf_polars/typing/__init__.py +219 -0
- cudf_polars/utils/__init__.py +8 -0
- cudf_polars/utils/config.py +741 -0
- cudf_polars/utils/conversion.py +40 -0
- cudf_polars/utils/dtypes.py +118 -0
- cudf_polars/utils/sorting.py +53 -0
- cudf_polars/utils/timer.py +39 -0
- cudf_polars/utils/versions.py +27 -0
- cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
- cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
- cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
- cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
- cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Multi-partition base classes."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import dataclasses
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from functools import cached_property
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import Generator, Iterator, MutableMapping
|
|
14
|
+
|
|
15
|
+
from cudf_polars.dsl.expr import NamedExpr
|
|
16
|
+
from cudf_polars.dsl.ir import IR
|
|
17
|
+
from cudf_polars.dsl.nodebase import Node
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PartitionInfo:
|
|
21
|
+
"""Partitioning information."""
|
|
22
|
+
|
|
23
|
+
__slots__ = ("count", "partitioned_on")
|
|
24
|
+
count: int
|
|
25
|
+
"""Partition count."""
|
|
26
|
+
partitioned_on: tuple[NamedExpr, ...]
|
|
27
|
+
"""Columns the data is hash-partitioned on."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
count: int,
|
|
32
|
+
partitioned_on: tuple[NamedExpr, ...] = (),
|
|
33
|
+
):
|
|
34
|
+
self.count = count
|
|
35
|
+
self.partitioned_on = partitioned_on
|
|
36
|
+
|
|
37
|
+
def keys(self, node: Node) -> Iterator[tuple[str, int]]:
|
|
38
|
+
"""Return the partitioned keys for a given node."""
|
|
39
|
+
name = get_key_name(node)
|
|
40
|
+
yield from ((name, i) for i in range(self.count))
|
|
41
|
+
|
|
42
|
+
def __rich_repr__(self) -> Generator[Any, None, None]:
|
|
43
|
+
"""Formatting for rich.pretty.pprint."""
|
|
44
|
+
yield "count", self.count
|
|
45
|
+
yield "partitioned_on", self.partitioned_on
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_key_name(node: Node) -> str:
|
|
49
|
+
"""Generate the key name for a Node."""
|
|
50
|
+
return f"{type(node).__name__.lower()}-{hash(node)}"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
T = TypeVar("T")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclasses.dataclass
|
|
57
|
+
class ColumnStat(Generic[T]):
|
|
58
|
+
"""
|
|
59
|
+
Generic column-statistic.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
value
|
|
64
|
+
Statistics value. Value will be None
|
|
65
|
+
if the statistics is unknown.
|
|
66
|
+
exact
|
|
67
|
+
Whether the statistics is known exactly.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
value: T | None = None
|
|
71
|
+
exact: bool = False
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclasses.dataclass
|
|
75
|
+
class UniqueStats:
|
|
76
|
+
"""
|
|
77
|
+
Sampled unique-value statistics.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
count
|
|
82
|
+
Unique-value count.
|
|
83
|
+
fraction
|
|
84
|
+
Unique-value fraction. This corresponds to the total
|
|
85
|
+
number of unique values (count) divided by the total
|
|
86
|
+
number of rows.
|
|
87
|
+
|
|
88
|
+
Notes
|
|
89
|
+
-----
|
|
90
|
+
This class is used to track unique-value column statistics
|
|
91
|
+
that have been sampled from a data source.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
count: ColumnStat[int] = dataclasses.field(default_factory=ColumnStat[int])
|
|
95
|
+
fraction: ColumnStat[float] = dataclasses.field(default_factory=ColumnStat[float])
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DataSourceInfo:
|
|
99
|
+
"""
|
|
100
|
+
Table data source information.
|
|
101
|
+
|
|
102
|
+
Notes
|
|
103
|
+
-----
|
|
104
|
+
This class should be sub-classed for specific
|
|
105
|
+
data source types (e.g. Parquet, DataFrame, etc.).
|
|
106
|
+
The required properties/methods enable lazy
|
|
107
|
+
sampling of the underlying datasource.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
_unique_stats_columns: set[str]
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def row_count(self) -> ColumnStat[int]: # pragma: no cover
|
|
114
|
+
"""Data source row-count estimate."""
|
|
115
|
+
raise NotImplementedError("Sub-class must implement row_count.")
|
|
116
|
+
|
|
117
|
+
def unique_stats(self, column: str) -> UniqueStats: # pragma: no cover
|
|
118
|
+
"""Return unique-value statistics for a column."""
|
|
119
|
+
raise NotImplementedError("Sub-class must implement unique_stats.")
|
|
120
|
+
|
|
121
|
+
def storage_size(self, column: str) -> ColumnStat[int]:
|
|
122
|
+
"""Return the average column size for a single file."""
|
|
123
|
+
return ColumnStat[int]()
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def unique_stats_columns(self) -> set[str]:
|
|
127
|
+
"""Return the set of columns needing unique-value information."""
|
|
128
|
+
return self._unique_stats_columns
|
|
129
|
+
|
|
130
|
+
def add_unique_stats_column(self, column: str) -> None:
|
|
131
|
+
"""Add a column needing unique-value information."""
|
|
132
|
+
self._unique_stats_columns.add(column)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class DataSourcePair(NamedTuple):
|
|
136
|
+
"""Pair of table-source and column-name information."""
|
|
137
|
+
|
|
138
|
+
table_source: DataSourceInfo
|
|
139
|
+
column_name: str
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class ColumnSourceInfo:
|
|
143
|
+
"""
|
|
144
|
+
Source column information.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
table_source_pairs
|
|
149
|
+
Sequence of DataSourcePair objects.
|
|
150
|
+
Union operations will result in multiple elements.
|
|
151
|
+
|
|
152
|
+
Notes
|
|
153
|
+
-----
|
|
154
|
+
This is a thin wrapper around DataSourceInfo that provides
|
|
155
|
+
direct access to column-specific information.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
__slots__ = (
|
|
159
|
+
"implied_unique_count",
|
|
160
|
+
"table_source_pairs",
|
|
161
|
+
)
|
|
162
|
+
table_source_pairs: list[DataSourcePair]
|
|
163
|
+
implied_unique_count: ColumnStat[int]
|
|
164
|
+
"""Unique-value count implied by join heuristics."""
|
|
165
|
+
|
|
166
|
+
def __init__(self, *table_source_pairs: DataSourcePair) -> None:
|
|
167
|
+
self.table_source_pairs = list(table_source_pairs)
|
|
168
|
+
self.implied_unique_count = ColumnStat[int](None)
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def is_unique_stats_column(self) -> bool:
|
|
172
|
+
"""Return whether this column requires unique-value information."""
|
|
173
|
+
return any(
|
|
174
|
+
pair.column_name in pair.table_source.unique_stats_columns
|
|
175
|
+
for pair in self.table_source_pairs
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def row_count(self) -> ColumnStat[int]:
|
|
180
|
+
"""Data source row-count estimate."""
|
|
181
|
+
return ColumnStat[int](
|
|
182
|
+
# Use sum of table-source row-count estimates.
|
|
183
|
+
value=sum(
|
|
184
|
+
value
|
|
185
|
+
for pair in self.table_source_pairs
|
|
186
|
+
if (value := pair.table_source.row_count.value) is not None
|
|
187
|
+
)
|
|
188
|
+
or None,
|
|
189
|
+
# Row-count may be exact if there is only one table source.
|
|
190
|
+
exact=len(self.table_source_pairs) == 1
|
|
191
|
+
and self.table_source_pairs[0].table_source.row_count.exact,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def unique_stats(self, *, force: bool = False) -> UniqueStats:
|
|
195
|
+
"""
|
|
196
|
+
Return unique-value statistics for a column.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
force
|
|
201
|
+
If True, return unique-value statistics even if the column
|
|
202
|
+
wasn't marked as needing unique-value information.
|
|
203
|
+
"""
|
|
204
|
+
if (force or self.is_unique_stats_column) and len(self.table_source_pairs) == 1:
|
|
205
|
+
# Single table source.
|
|
206
|
+
# TODO: Handle multiple tables sources if/when necessary.
|
|
207
|
+
# We may never need to do this if the source unique-value
|
|
208
|
+
# statistics are only "used" by the Scan/DataFrameScan nodes.
|
|
209
|
+
table_source, column_name = self.table_source_pairs[0]
|
|
210
|
+
return table_source.unique_stats(column_name)
|
|
211
|
+
else:
|
|
212
|
+
# Avoid sampling unique-stats if this column
|
|
213
|
+
# wasn't marked as "needing" unique-stats.
|
|
214
|
+
return UniqueStats()
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def storage_size(self) -> ColumnStat[int]:
|
|
218
|
+
"""Return the average column size for a single file."""
|
|
219
|
+
# We don't need to handle concatenated statistics for ``storage_size``.
|
|
220
|
+
# Just return the storage size of the first table source.
|
|
221
|
+
if self.table_source_pairs:
|
|
222
|
+
table_source, column_name = self.table_source_pairs[0]
|
|
223
|
+
return table_source.storage_size(column_name)
|
|
224
|
+
else: # pragma: no cover; We never call this for empty table sources.
|
|
225
|
+
return ColumnStat[int]()
|
|
226
|
+
|
|
227
|
+
def add_unique_stats_column(self, column: str | None = None) -> None:
|
|
228
|
+
"""Add a column needing unique-value information."""
|
|
229
|
+
# We must call add_unique_stats_column for ALL table sources.
|
|
230
|
+
for table_source, column_name in self.table_source_pairs:
|
|
231
|
+
table_source.add_unique_stats_column(column or column_name)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class ColumnStats:
|
|
235
|
+
"""
|
|
236
|
+
Column statistics.
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
name
|
|
241
|
+
Column name.
|
|
242
|
+
children
|
|
243
|
+
Child ColumnStats objects.
|
|
244
|
+
source_info
|
|
245
|
+
Column source information.
|
|
246
|
+
unique_count
|
|
247
|
+
Unique-value count.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
__slots__ = ("children", "name", "source_info", "unique_count")
|
|
251
|
+
|
|
252
|
+
name: str
|
|
253
|
+
children: tuple[ColumnStats, ...]
|
|
254
|
+
source_info: ColumnSourceInfo
|
|
255
|
+
unique_count: ColumnStat[int]
|
|
256
|
+
|
|
257
|
+
def __init__(
|
|
258
|
+
self,
|
|
259
|
+
name: str,
|
|
260
|
+
*,
|
|
261
|
+
children: tuple[ColumnStats, ...] = (),
|
|
262
|
+
source_info: ColumnSourceInfo | None = None,
|
|
263
|
+
unique_count: ColumnStat[int] | None = None,
|
|
264
|
+
) -> None:
|
|
265
|
+
self.name = name
|
|
266
|
+
self.children = children
|
|
267
|
+
self.source_info = source_info or ColumnSourceInfo()
|
|
268
|
+
self.unique_count = unique_count or ColumnStat[int](None)
|
|
269
|
+
|
|
270
|
+
def new_parent(
|
|
271
|
+
self,
|
|
272
|
+
*,
|
|
273
|
+
name: str | None = None,
|
|
274
|
+
) -> ColumnStats:
|
|
275
|
+
"""
|
|
276
|
+
Initialize a new parent ColumnStats object.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
name
|
|
281
|
+
The new column name.
|
|
282
|
+
|
|
283
|
+
Returns
|
|
284
|
+
-------
|
|
285
|
+
A new ColumnStats object.
|
|
286
|
+
|
|
287
|
+
Notes
|
|
288
|
+
-----
|
|
289
|
+
This API preserves the original DataSourceInfo reference.
|
|
290
|
+
"""
|
|
291
|
+
return ColumnStats(
|
|
292
|
+
name=name or self.name,
|
|
293
|
+
children=(self,),
|
|
294
|
+
# Want to reference the same DataSourceInfo
|
|
295
|
+
source_info=self.source_info,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class JoinKey:
|
|
300
|
+
"""
|
|
301
|
+
Join-key information.
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
column_stats
|
|
306
|
+
Column statistics for the join key.
|
|
307
|
+
|
|
308
|
+
Notes
|
|
309
|
+
-----
|
|
310
|
+
This class is used to track join-key information.
|
|
311
|
+
It is used to track the columns being joined on
|
|
312
|
+
and the estimated unique-value count for the join key.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
column_stats: tuple[ColumnStats, ...]
|
|
316
|
+
implied_unique_count: int | None
|
|
317
|
+
"""Estimated unique-value count from join heuristics."""
|
|
318
|
+
|
|
319
|
+
def __init__(self, *column_stats: ColumnStats) -> None:
|
|
320
|
+
self.column_stats = column_stats
|
|
321
|
+
self.implied_unique_count = None
|
|
322
|
+
|
|
323
|
+
@cached_property
|
|
324
|
+
def source_row_count(self) -> int | None:
|
|
325
|
+
"""
|
|
326
|
+
Return the estimated row-count of the source columns.
|
|
327
|
+
|
|
328
|
+
Notes
|
|
329
|
+
-----
|
|
330
|
+
This is the maximum row-count estimate of the source columns.
|
|
331
|
+
"""
|
|
332
|
+
return max(
|
|
333
|
+
(
|
|
334
|
+
cs.source_info.row_count.value
|
|
335
|
+
for cs in self.column_stats
|
|
336
|
+
if cs.source_info.row_count.value is not None
|
|
337
|
+
),
|
|
338
|
+
default=None,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class JoinInfo:
|
|
343
|
+
"""
|
|
344
|
+
Join information.
|
|
345
|
+
|
|
346
|
+
Notes
|
|
347
|
+
-----
|
|
348
|
+
This class is used to track mappings between joined-on
|
|
349
|
+
columns and joined-on keys (groups of columns). We need
|
|
350
|
+
these mappings to calculate equivalence sets and make
|
|
351
|
+
join-based unique-count and row-count estimates.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
__slots__ = ("column_map", "join_map", "key_map")
|
|
355
|
+
|
|
356
|
+
column_map: MutableMapping[ColumnStats, set[ColumnStats]]
|
|
357
|
+
"""Mapping between joined columns."""
|
|
358
|
+
key_map: MutableMapping[JoinKey, set[JoinKey]]
|
|
359
|
+
"""Mapping between joined keys (groups of columns)."""
|
|
360
|
+
join_map: dict[IR, list[JoinKey]]
|
|
361
|
+
"""Mapping between IR nodes and associated join keys."""
|
|
362
|
+
|
|
363
|
+
def __init__(self) -> None:
|
|
364
|
+
self.column_map: MutableMapping[ColumnStats, set[ColumnStats]] = defaultdict(
|
|
365
|
+
set[ColumnStats]
|
|
366
|
+
)
|
|
367
|
+
self.key_map: MutableMapping[JoinKey, set[JoinKey]] = defaultdict(set[JoinKey])
|
|
368
|
+
self.join_map: dict[IR, list[JoinKey]] = {}
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class StatsCollector:
|
|
372
|
+
"""Column statistics collector."""
|
|
373
|
+
|
|
374
|
+
__slots__ = ("column_stats", "join_info", "row_count")
|
|
375
|
+
|
|
376
|
+
row_count: dict[IR, ColumnStat[int]]
|
|
377
|
+
"""Estimated row count for each IR node."""
|
|
378
|
+
column_stats: dict[IR, dict[str, ColumnStats]]
|
|
379
|
+
"""Column statistics for each IR node."""
|
|
380
|
+
join_info: JoinInfo
|
|
381
|
+
"""Join information."""
|
|
382
|
+
|
|
383
|
+
def __init__(self) -> None:
|
|
384
|
+
self.row_count: dict[IR, ColumnStat[int]] = {}
|
|
385
|
+
self.column_stats: dict[IR, dict[str, ColumnStats]] = {}
|
|
386
|
+
self.join_info = JoinInfo()
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Experimental PDS-DS benchmarks.
|
|
6
|
+
|
|
7
|
+
Based on https://github.com/pola-rs/polars-benchmark.
|
|
8
|
+
|
|
9
|
+
WARNING: This is an experimental (and unofficial)
|
|
10
|
+
benchmark script. It is not intended for public use
|
|
11
|
+
and may be modified or removed at any time.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import contextlib
|
|
17
|
+
import importlib
|
|
18
|
+
import os
|
|
19
|
+
import time
|
|
20
|
+
from collections import defaultdict
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import TYPE_CHECKING
|
|
23
|
+
|
|
24
|
+
import polars as pl
|
|
25
|
+
|
|
26
|
+
with contextlib.suppress(ImportError):
|
|
27
|
+
from cudf_polars.experimental.benchmarks.utils import (
|
|
28
|
+
Record,
|
|
29
|
+
RunConfig,
|
|
30
|
+
get_executor_options,
|
|
31
|
+
parse_args,
|
|
32
|
+
run_polars,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from collections.abc import Sequence
|
|
37
|
+
from types import ModuleType
|
|
38
|
+
from typing import Any
|
|
39
|
+
|
|
40
|
+
# Without this setting, the first IO task to run
|
|
41
|
+
# on each worker takes ~15 sec extra
|
|
42
|
+
os.environ["KVIKIO_COMPAT_MODE"] = os.environ.get("KVIKIO_COMPAT_MODE", "on")
|
|
43
|
+
os.environ["KVIKIO_NTHREADS"] = os.environ.get("KVIKIO_NTHREADS", "8")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def valid_query(name: str) -> bool:
|
|
47
|
+
"""Return True for valid query names eg. 'q9', 'q65', etc."""
|
|
48
|
+
if not name.startswith("q"):
|
|
49
|
+
return False
|
|
50
|
+
try:
|
|
51
|
+
q_num = int(name[1:])
|
|
52
|
+
except ValueError:
|
|
53
|
+
return False
|
|
54
|
+
else:
|
|
55
|
+
return 1 <= q_num <= 99
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PDSDSQueriesMeta(type):
|
|
59
|
+
"""Metaclass used for query lookup."""
|
|
60
|
+
|
|
61
|
+
def __getattr__(cls, name: str): # type: ignore
|
|
62
|
+
"""Query lookup."""
|
|
63
|
+
if valid_query(name):
|
|
64
|
+
q_num = int(name[1:])
|
|
65
|
+
module: ModuleType = importlib.import_module(
|
|
66
|
+
f"cudf_polars.experimental.benchmarks.pdsds_queries.q{q_num}"
|
|
67
|
+
)
|
|
68
|
+
return getattr(module, cls.q_impl)
|
|
69
|
+
raise AttributeError(f"{name} is not a valid query name")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class PDSDSQueries(metaclass=PDSDSQueriesMeta):
|
|
73
|
+
"""Base class for query loading."""
|
|
74
|
+
|
|
75
|
+
q_impl: str
|
|
76
|
+
name: str = "pdsds"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class PDSDSPolarsQueries(PDSDSQueries):
|
|
80
|
+
"""Polars Queries."""
|
|
81
|
+
|
|
82
|
+
q_impl = "polars_impl"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class PDSDSDuckDBQueries(PDSDSQueries):
|
|
86
|
+
"""DuckDB Queries."""
|
|
87
|
+
|
|
88
|
+
q_impl = "duckdb_impl"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def execute_duckdb_query(query: str, dataset_path: Path) -> pl.DataFrame:
|
|
92
|
+
"""Execute a query with DuckDB."""
|
|
93
|
+
import duckdb
|
|
94
|
+
|
|
95
|
+
conn = duckdb.connect()
|
|
96
|
+
|
|
97
|
+
statements = [
|
|
98
|
+
f"CREATE VIEW {table.stem} as SELECT * FROM read_parquet('{table.absolute()}');"
|
|
99
|
+
for table in Path(dataset_path).glob("*.parquet")
|
|
100
|
+
]
|
|
101
|
+
statements.append(query)
|
|
102
|
+
return conn.execute("\n".join(statements)).pl()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def run_duckdb(benchmark: Any, options: Sequence[str] | None = None) -> None:
|
|
106
|
+
"""Run the benchmark with DuckDB."""
|
|
107
|
+
args = parse_args(options, num_queries=99)
|
|
108
|
+
vars(args).update({"query_set": benchmark.name})
|
|
109
|
+
run_config = RunConfig.from_args(args)
|
|
110
|
+
records: defaultdict[int, list[Record]] = defaultdict(list)
|
|
111
|
+
|
|
112
|
+
for q_id in run_config.queries:
|
|
113
|
+
try:
|
|
114
|
+
duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
|
|
115
|
+
except AttributeError as err:
|
|
116
|
+
raise NotImplementedError(f"Query {q_id} not implemented.") from err
|
|
117
|
+
|
|
118
|
+
print(f"DuckDB Executing: {q_id}")
|
|
119
|
+
records[q_id] = []
|
|
120
|
+
|
|
121
|
+
for i in range(args.iterations):
|
|
122
|
+
t0 = time.time()
|
|
123
|
+
|
|
124
|
+
result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
|
|
125
|
+
|
|
126
|
+
t1 = time.time()
|
|
127
|
+
record = Record(query=q_id, duration=t1 - t0)
|
|
128
|
+
if args.print_results:
|
|
129
|
+
print(result)
|
|
130
|
+
|
|
131
|
+
print(f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s")
|
|
132
|
+
records[q_id].append(record)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def run_validate(benchmark: Any, options: Sequence[str] | None = None) -> None:
|
|
136
|
+
"""Validate Polars CPU vs DuckDB or Polars GPU."""
|
|
137
|
+
from polars.testing import assert_frame_equal
|
|
138
|
+
|
|
139
|
+
args = parse_args(options, num_queries=99)
|
|
140
|
+
vars(args).update({"query_set": benchmark.name})
|
|
141
|
+
run_config = RunConfig.from_args(args)
|
|
142
|
+
|
|
143
|
+
baseline = args.baseline
|
|
144
|
+
if baseline not in {"duckdb", "cpu"}:
|
|
145
|
+
raise ValueError("Baseline must be one of: 'duckdb', 'cpu'")
|
|
146
|
+
|
|
147
|
+
failures: list[int] = []
|
|
148
|
+
|
|
149
|
+
engine: pl.GPUEngine | None = None
|
|
150
|
+
if run_config.executor != "cpu":
|
|
151
|
+
engine = pl.GPUEngine(
|
|
152
|
+
raise_on_fail=True,
|
|
153
|
+
executor=run_config.executor,
|
|
154
|
+
executor_options=get_executor_options(run_config, PDSDSPolarsQueries),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
for q_id in run_config.queries:
|
|
158
|
+
print(f"\nValidating Query {q_id}")
|
|
159
|
+
try:
|
|
160
|
+
polars_query = getattr(PDSDSPolarsQueries, f"q{q_id}")(run_config)
|
|
161
|
+
duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
|
|
162
|
+
except AttributeError as err:
|
|
163
|
+
raise NotImplementedError(f"Query {q_id} not implemented.") from err
|
|
164
|
+
|
|
165
|
+
if baseline == "duckdb":
|
|
166
|
+
base_result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
|
|
167
|
+
elif baseline == "cpu":
|
|
168
|
+
base_result = polars_query.collect(new_streaming=True)
|
|
169
|
+
|
|
170
|
+
if run_config.executor == "cpu":
|
|
171
|
+
test_result = polars_query.collect(new_streaming=True)
|
|
172
|
+
else:
|
|
173
|
+
try:
|
|
174
|
+
test_result = polars_query.collect(engine=engine)
|
|
175
|
+
except Exception as e:
|
|
176
|
+
failures.append(q_id)
|
|
177
|
+
print(f"❌ Query {q_id} failed validation: GPU execution failed.\n{e}")
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
assert_frame_equal(
|
|
182
|
+
base_result,
|
|
183
|
+
test_result,
|
|
184
|
+
check_dtypes=True,
|
|
185
|
+
check_column_order=False,
|
|
186
|
+
)
|
|
187
|
+
print(f"✅ Query {q_id} passed validation.")
|
|
188
|
+
except AssertionError as e:
|
|
189
|
+
failures.append(q_id)
|
|
190
|
+
print(f"❌ Query {q_id} failed validation:\n{e}")
|
|
191
|
+
if args.print_results:
|
|
192
|
+
print("Baseline Result:\n", base_result)
|
|
193
|
+
print("Test Result:\n", test_result)
|
|
194
|
+
|
|
195
|
+
if failures:
|
|
196
|
+
print("\nValidation Summary:")
|
|
197
|
+
print("===================")
|
|
198
|
+
print(f"{len(failures)} query(s) failed: {failures}")
|
|
199
|
+
else:
|
|
200
|
+
print("\nAll queries passed validation.")
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
if __name__ == "__main__":
|
|
204
|
+
import argparse
|
|
205
|
+
|
|
206
|
+
parser = argparse.ArgumentParser(description="Run PDS-DS benchmarks.")
|
|
207
|
+
parser.add_argument(
|
|
208
|
+
"--engine",
|
|
209
|
+
choices=["polars", "duckdb", "validate"],
|
|
210
|
+
default="polars",
|
|
211
|
+
help="Which engine to use for executing the benchmarks or to validate results.",
|
|
212
|
+
)
|
|
213
|
+
args, extra_args = parser.parse_known_args()
|
|
214
|
+
|
|
215
|
+
if args.engine == "polars":
|
|
216
|
+
run_polars(PDSDSPolarsQueries, extra_args, num_queries=99)
|
|
217
|
+
elif args.engine == "duckdb":
|
|
218
|
+
run_duckdb(PDSDSDuckDBQueries, extra_args)
|
|
219
|
+
elif args.engine == "validate":
|
|
220
|
+
run_validate(PDSDSQueries, extra_args)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Query 1."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import polars as pl
|
|
11
|
+
|
|
12
|
+
from cudf_polars.experimental.benchmarks.utils import get_data
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from cudf_polars.experimental.benchmarks.utils import RunConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def duckdb_impl(run_config: RunConfig) -> str:
|
|
19
|
+
"""Query 1."""
|
|
20
|
+
return """
|
|
21
|
+
WITH customer_total_return
|
|
22
|
+
AS (SELECT sr_customer_sk AS ctr_customer_sk,
|
|
23
|
+
sr_store_sk AS ctr_store_sk,
|
|
24
|
+
Sum(sr_return_amt) AS ctr_total_return
|
|
25
|
+
FROM store_returns,
|
|
26
|
+
date_dim
|
|
27
|
+
WHERE sr_returned_date_sk = d_date_sk
|
|
28
|
+
AND d_year = 2001
|
|
29
|
+
GROUP BY sr_customer_sk,
|
|
30
|
+
sr_store_sk)
|
|
31
|
+
SELECT c_customer_id
|
|
32
|
+
FROM customer_total_return ctr1,
|
|
33
|
+
store,
|
|
34
|
+
customer
|
|
35
|
+
WHERE ctr1.ctr_total_return > (SELECT Avg(ctr_total_return) * 1.2
|
|
36
|
+
FROM customer_total_return ctr2
|
|
37
|
+
WHERE ctr1.ctr_store_sk = ctr2.ctr_store_sk)
|
|
38
|
+
AND s_store_sk = ctr1.ctr_store_sk
|
|
39
|
+
AND s_state = 'TN'
|
|
40
|
+
AND ctr1.ctr_customer_sk = c_customer_sk
|
|
41
|
+
ORDER BY c_customer_id
|
|
42
|
+
LIMIT 100;
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def polars_impl(run_config: RunConfig) -> pl.LazyFrame:
|
|
47
|
+
"""Query 1."""
|
|
48
|
+
store_returns = get_data(
|
|
49
|
+
run_config.dataset_path, "store_returns", run_config.suffix
|
|
50
|
+
)
|
|
51
|
+
date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix)
|
|
52
|
+
store = get_data(run_config.dataset_path, "store", run_config.suffix)
|
|
53
|
+
customer = get_data(run_config.dataset_path, "customer", run_config.suffix)
|
|
54
|
+
|
|
55
|
+
# Step 1: Create customer_total_return CTE equivalent
|
|
56
|
+
customer_total_return = (
|
|
57
|
+
store_returns.join(
|
|
58
|
+
date_dim, left_on="sr_returned_date_sk", right_on="d_date_sk"
|
|
59
|
+
)
|
|
60
|
+
.filter(pl.col("d_year") == 2001)
|
|
61
|
+
.group_by(["sr_customer_sk", "sr_store_sk"])
|
|
62
|
+
.agg(pl.col("sr_return_amt").sum().alias("ctr_total_return"))
|
|
63
|
+
.rename(
|
|
64
|
+
{
|
|
65
|
+
"sr_customer_sk": "ctr_customer_sk",
|
|
66
|
+
"sr_store_sk": "ctr_store_sk",
|
|
67
|
+
}
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Step 2: Calculate average return per store for the subquery
|
|
72
|
+
store_avg_returns = customer_total_return.group_by("ctr_store_sk").agg(
|
|
73
|
+
[(pl.col("ctr_total_return").mean() * 1.2).alias("avg_return_threshold")]
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Step 3: Join everything together and apply filters
|
|
77
|
+
return (
|
|
78
|
+
customer_total_return.join(
|
|
79
|
+
store_avg_returns, left_on="ctr_store_sk", right_on="ctr_store_sk"
|
|
80
|
+
)
|
|
81
|
+
.filter(pl.col("ctr_total_return") > pl.col("avg_return_threshold"))
|
|
82
|
+
.join(store, left_on="ctr_store_sk", right_on="s_store_sk")
|
|
83
|
+
.filter(pl.col("s_state") == "TN")
|
|
84
|
+
.join(customer, left_on="ctr_customer_sk", right_on="c_customer_sk")
|
|
85
|
+
.select(["c_customer_id"])
|
|
86
|
+
.sort("c_customer_id")
|
|
87
|
+
.limit(100)
|
|
88
|
+
)
|