cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,386 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Multi-partition base classes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+ from collections import defaultdict
9
+ from functools import cached_property
10
+ from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar
11
+
12
+ if TYPE_CHECKING:
13
+ from collections.abc import Generator, Iterator, MutableMapping
14
+
15
+ from cudf_polars.dsl.expr import NamedExpr
16
+ from cudf_polars.dsl.ir import IR
17
+ from cudf_polars.dsl.nodebase import Node
18
+
19
+
20
+ class PartitionInfo:
21
+ """Partitioning information."""
22
+
23
+ __slots__ = ("count", "partitioned_on")
24
+ count: int
25
+ """Partition count."""
26
+ partitioned_on: tuple[NamedExpr, ...]
27
+ """Columns the data is hash-partitioned on."""
28
+
29
+ def __init__(
30
+ self,
31
+ count: int,
32
+ partitioned_on: tuple[NamedExpr, ...] = (),
33
+ ):
34
+ self.count = count
35
+ self.partitioned_on = partitioned_on
36
+
37
+ def keys(self, node: Node) -> Iterator[tuple[str, int]]:
38
+ """Return the partitioned keys for a given node."""
39
+ name = get_key_name(node)
40
+ yield from ((name, i) for i in range(self.count))
41
+
42
+ def __rich_repr__(self) -> Generator[Any, None, None]:
43
+ """Formatting for rich.pretty.pprint."""
44
+ yield "count", self.count
45
+ yield "partitioned_on", self.partitioned_on
46
+
47
+
48
+ def get_key_name(node: Node) -> str:
49
+ """Generate the key name for a Node."""
50
+ return f"{type(node).__name__.lower()}-{hash(node)}"
51
+
52
+
53
+ T = TypeVar("T")
54
+
55
+
56
+ @dataclasses.dataclass
57
+ class ColumnStat(Generic[T]):
58
+ """
59
+ Generic column-statistic.
60
+
61
+ Parameters
62
+ ----------
63
+ value
64
+ Statistics value. Value will be None
65
+ if the statistics is unknown.
66
+ exact
67
+ Whether the statistics is known exactly.
68
+ """
69
+
70
+ value: T | None = None
71
+ exact: bool = False
72
+
73
+
74
+ @dataclasses.dataclass
75
+ class UniqueStats:
76
+ """
77
+ Sampled unique-value statistics.
78
+
79
+ Parameters
80
+ ----------
81
+ count
82
+ Unique-value count.
83
+ fraction
84
+ Unique-value fraction. This corresponds to the total
85
+ number of unique values (count) divided by the total
86
+ number of rows.
87
+
88
+ Notes
89
+ -----
90
+ This class is used to track unique-value column statistics
91
+ that have been sampled from a data source.
92
+ """
93
+
94
+ count: ColumnStat[int] = dataclasses.field(default_factory=ColumnStat[int])
95
+ fraction: ColumnStat[float] = dataclasses.field(default_factory=ColumnStat[float])
96
+
97
+
98
+ class DataSourceInfo:
99
+ """
100
+ Table data source information.
101
+
102
+ Notes
103
+ -----
104
+ This class should be sub-classed for specific
105
+ data source types (e.g. Parquet, DataFrame, etc.).
106
+ The required properties/methods enable lazy
107
+ sampling of the underlying datasource.
108
+ """
109
+
110
+ _unique_stats_columns: set[str]
111
+
112
+ @property
113
+ def row_count(self) -> ColumnStat[int]: # pragma: no cover
114
+ """Data source row-count estimate."""
115
+ raise NotImplementedError("Sub-class must implement row_count.")
116
+
117
+ def unique_stats(self, column: str) -> UniqueStats: # pragma: no cover
118
+ """Return unique-value statistics for a column."""
119
+ raise NotImplementedError("Sub-class must implement unique_stats.")
120
+
121
+ def storage_size(self, column: str) -> ColumnStat[int]:
122
+ """Return the average column size for a single file."""
123
+ return ColumnStat[int]()
124
+
125
+ @property
126
+ def unique_stats_columns(self) -> set[str]:
127
+ """Return the set of columns needing unique-value information."""
128
+ return self._unique_stats_columns
129
+
130
+ def add_unique_stats_column(self, column: str) -> None:
131
+ """Add a column needing unique-value information."""
132
+ self._unique_stats_columns.add(column)
133
+
134
+
135
+ class DataSourcePair(NamedTuple):
136
+ """Pair of table-source and column-name information."""
137
+
138
+ table_source: DataSourceInfo
139
+ column_name: str
140
+
141
+
142
+ class ColumnSourceInfo:
143
+ """
144
+ Source column information.
145
+
146
+ Parameters
147
+ ----------
148
+ table_source_pairs
149
+ Sequence of DataSourcePair objects.
150
+ Union operations will result in multiple elements.
151
+
152
+ Notes
153
+ -----
154
+ This is a thin wrapper around DataSourceInfo that provides
155
+ direct access to column-specific information.
156
+ """
157
+
158
+ __slots__ = (
159
+ "implied_unique_count",
160
+ "table_source_pairs",
161
+ )
162
+ table_source_pairs: list[DataSourcePair]
163
+ implied_unique_count: ColumnStat[int]
164
+ """Unique-value count implied by join heuristics."""
165
+
166
+ def __init__(self, *table_source_pairs: DataSourcePair) -> None:
167
+ self.table_source_pairs = list(table_source_pairs)
168
+ self.implied_unique_count = ColumnStat[int](None)
169
+
170
+ @property
171
+ def is_unique_stats_column(self) -> bool:
172
+ """Return whether this column requires unique-value information."""
173
+ return any(
174
+ pair.column_name in pair.table_source.unique_stats_columns
175
+ for pair in self.table_source_pairs
176
+ )
177
+
178
+ @property
179
+ def row_count(self) -> ColumnStat[int]:
180
+ """Data source row-count estimate."""
181
+ return ColumnStat[int](
182
+ # Use sum of table-source row-count estimates.
183
+ value=sum(
184
+ value
185
+ for pair in self.table_source_pairs
186
+ if (value := pair.table_source.row_count.value) is not None
187
+ )
188
+ or None,
189
+ # Row-count may be exact if there is only one table source.
190
+ exact=len(self.table_source_pairs) == 1
191
+ and self.table_source_pairs[0].table_source.row_count.exact,
192
+ )
193
+
194
+ def unique_stats(self, *, force: bool = False) -> UniqueStats:
195
+ """
196
+ Return unique-value statistics for a column.
197
+
198
+ Parameters
199
+ ----------
200
+ force
201
+ If True, return unique-value statistics even if the column
202
+ wasn't marked as needing unique-value information.
203
+ """
204
+ if (force or self.is_unique_stats_column) and len(self.table_source_pairs) == 1:
205
+ # Single table source.
206
+ # TODO: Handle multiple tables sources if/when necessary.
207
+ # We may never need to do this if the source unique-value
208
+ # statistics are only "used" by the Scan/DataFrameScan nodes.
209
+ table_source, column_name = self.table_source_pairs[0]
210
+ return table_source.unique_stats(column_name)
211
+ else:
212
+ # Avoid sampling unique-stats if this column
213
+ # wasn't marked as "needing" unique-stats.
214
+ return UniqueStats()
215
+
216
+ @property
217
+ def storage_size(self) -> ColumnStat[int]:
218
+ """Return the average column size for a single file."""
219
+ # We don't need to handle concatenated statistics for ``storage_size``.
220
+ # Just return the storage size of the first table source.
221
+ if self.table_source_pairs:
222
+ table_source, column_name = self.table_source_pairs[0]
223
+ return table_source.storage_size(column_name)
224
+ else: # pragma: no cover; We never call this for empty table sources.
225
+ return ColumnStat[int]()
226
+
227
+ def add_unique_stats_column(self, column: str | None = None) -> None:
228
+ """Add a column needing unique-value information."""
229
+ # We must call add_unique_stats_column for ALL table sources.
230
+ for table_source, column_name in self.table_source_pairs:
231
+ table_source.add_unique_stats_column(column or column_name)
232
+
233
+
234
+ class ColumnStats:
235
+ """
236
+ Column statistics.
237
+
238
+ Parameters
239
+ ----------
240
+ name
241
+ Column name.
242
+ children
243
+ Child ColumnStats objects.
244
+ source_info
245
+ Column source information.
246
+ unique_count
247
+ Unique-value count.
248
+ """
249
+
250
+ __slots__ = ("children", "name", "source_info", "unique_count")
251
+
252
+ name: str
253
+ children: tuple[ColumnStats, ...]
254
+ source_info: ColumnSourceInfo
255
+ unique_count: ColumnStat[int]
256
+
257
+ def __init__(
258
+ self,
259
+ name: str,
260
+ *,
261
+ children: tuple[ColumnStats, ...] = (),
262
+ source_info: ColumnSourceInfo | None = None,
263
+ unique_count: ColumnStat[int] | None = None,
264
+ ) -> None:
265
+ self.name = name
266
+ self.children = children
267
+ self.source_info = source_info or ColumnSourceInfo()
268
+ self.unique_count = unique_count or ColumnStat[int](None)
269
+
270
+ def new_parent(
271
+ self,
272
+ *,
273
+ name: str | None = None,
274
+ ) -> ColumnStats:
275
+ """
276
+ Initialize a new parent ColumnStats object.
277
+
278
+ Parameters
279
+ ----------
280
+ name
281
+ The new column name.
282
+
283
+ Returns
284
+ -------
285
+ A new ColumnStats object.
286
+
287
+ Notes
288
+ -----
289
+ This API preserves the original DataSourceInfo reference.
290
+ """
291
+ return ColumnStats(
292
+ name=name or self.name,
293
+ children=(self,),
294
+ # Want to reference the same DataSourceInfo
295
+ source_info=self.source_info,
296
+ )
297
+
298
+
299
+ class JoinKey:
300
+ """
301
+ Join-key information.
302
+
303
+ Parameters
304
+ ----------
305
+ column_stats
306
+ Column statistics for the join key.
307
+
308
+ Notes
309
+ -----
310
+ This class is used to track join-key information.
311
+ It is used to track the columns being joined on
312
+ and the estimated unique-value count for the join key.
313
+ """
314
+
315
+ column_stats: tuple[ColumnStats, ...]
316
+ implied_unique_count: int | None
317
+ """Estimated unique-value count from join heuristics."""
318
+
319
+ def __init__(self, *column_stats: ColumnStats) -> None:
320
+ self.column_stats = column_stats
321
+ self.implied_unique_count = None
322
+
323
+ @cached_property
324
+ def source_row_count(self) -> int | None:
325
+ """
326
+ Return the estimated row-count of the source columns.
327
+
328
+ Notes
329
+ -----
330
+ This is the maximum row-count estimate of the source columns.
331
+ """
332
+ return max(
333
+ (
334
+ cs.source_info.row_count.value
335
+ for cs in self.column_stats
336
+ if cs.source_info.row_count.value is not None
337
+ ),
338
+ default=None,
339
+ )
340
+
341
+
342
+ class JoinInfo:
343
+ """
344
+ Join information.
345
+
346
+ Notes
347
+ -----
348
+ This class is used to track mappings between joined-on
349
+ columns and joined-on keys (groups of columns). We need
350
+ these mappings to calculate equivalence sets and make
351
+ join-based unique-count and row-count estimates.
352
+ """
353
+
354
+ __slots__ = ("column_map", "join_map", "key_map")
355
+
356
+ column_map: MutableMapping[ColumnStats, set[ColumnStats]]
357
+ """Mapping between joined columns."""
358
+ key_map: MutableMapping[JoinKey, set[JoinKey]]
359
+ """Mapping between joined keys (groups of columns)."""
360
+ join_map: dict[IR, list[JoinKey]]
361
+ """Mapping between IR nodes and associated join keys."""
362
+
363
+ def __init__(self) -> None:
364
+ self.column_map: MutableMapping[ColumnStats, set[ColumnStats]] = defaultdict(
365
+ set[ColumnStats]
366
+ )
367
+ self.key_map: MutableMapping[JoinKey, set[JoinKey]] = defaultdict(set[JoinKey])
368
+ self.join_map: dict[IR, list[JoinKey]] = {}
369
+
370
+
371
+ class StatsCollector:
372
+ """Column statistics collector."""
373
+
374
+ __slots__ = ("column_stats", "join_info", "row_count")
375
+
376
+ row_count: dict[IR, ColumnStat[int]]
377
+ """Estimated row count for each IR node."""
378
+ column_stats: dict[IR, dict[str, ColumnStats]]
379
+ """Column statistics for each IR node."""
380
+ join_info: JoinInfo
381
+ """Join information."""
382
+
383
+ def __init__(self) -> None:
384
+ self.row_count: dict[IR, ColumnStat[int]] = {}
385
+ self.column_stats: dict[IR, dict[str, ColumnStats]] = {}
386
+ self.join_info = JoinInfo()
@@ -0,0 +1,4 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Experimental benchmarks."""
@@ -0,0 +1,220 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Experimental PDS-DS benchmarks.
6
+
7
+ Based on https://github.com/pola-rs/polars-benchmark.
8
+
9
+ WARNING: This is an experimental (and unofficial)
10
+ benchmark script. It is not intended for public use
11
+ and may be modified or removed at any time.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import contextlib
17
+ import importlib
18
+ import os
19
+ import time
20
+ from collections import defaultdict
21
+ from pathlib import Path
22
+ from typing import TYPE_CHECKING
23
+
24
+ import polars as pl
25
+
26
+ with contextlib.suppress(ImportError):
27
+ from cudf_polars.experimental.benchmarks.utils import (
28
+ Record,
29
+ RunConfig,
30
+ get_executor_options,
31
+ parse_args,
32
+ run_polars,
33
+ )
34
+
35
+ if TYPE_CHECKING:
36
+ from collections.abc import Sequence
37
+ from types import ModuleType
38
+ from typing import Any
39
+
40
+ # Without this setting, the first IO task to run
41
+ # on each worker takes ~15 sec extra
42
+ os.environ["KVIKIO_COMPAT_MODE"] = os.environ.get("KVIKIO_COMPAT_MODE", "on")
43
+ os.environ["KVIKIO_NTHREADS"] = os.environ.get("KVIKIO_NTHREADS", "8")
44
+
45
+
46
+ def valid_query(name: str) -> bool:
47
+ """Return True for valid query names eg. 'q9', 'q65', etc."""
48
+ if not name.startswith("q"):
49
+ return False
50
+ try:
51
+ q_num = int(name[1:])
52
+ except ValueError:
53
+ return False
54
+ else:
55
+ return 1 <= q_num <= 99
56
+
57
+
58
+ class PDSDSQueriesMeta(type):
59
+ """Metaclass used for query lookup."""
60
+
61
+ def __getattr__(cls, name: str): # type: ignore
62
+ """Query lookup."""
63
+ if valid_query(name):
64
+ q_num = int(name[1:])
65
+ module: ModuleType = importlib.import_module(
66
+ f"cudf_polars.experimental.benchmarks.pdsds_queries.q{q_num}"
67
+ )
68
+ return getattr(module, cls.q_impl)
69
+ raise AttributeError(f"{name} is not a valid query name")
70
+
71
+
72
+ class PDSDSQueries(metaclass=PDSDSQueriesMeta):
73
+ """Base class for query loading."""
74
+
75
+ q_impl: str
76
+ name: str = "pdsds"
77
+
78
+
79
+ class PDSDSPolarsQueries(PDSDSQueries):
80
+ """Polars Queries."""
81
+
82
+ q_impl = "polars_impl"
83
+
84
+
85
+ class PDSDSDuckDBQueries(PDSDSQueries):
86
+ """DuckDB Queries."""
87
+
88
+ q_impl = "duckdb_impl"
89
+
90
+
91
+ def execute_duckdb_query(query: str, dataset_path: Path) -> pl.DataFrame:
92
+ """Execute a query with DuckDB."""
93
+ import duckdb
94
+
95
+ conn = duckdb.connect()
96
+
97
+ statements = [
98
+ f"CREATE VIEW {table.stem} as SELECT * FROM read_parquet('{table.absolute()}');"
99
+ for table in Path(dataset_path).glob("*.parquet")
100
+ ]
101
+ statements.append(query)
102
+ return conn.execute("\n".join(statements)).pl()
103
+
104
+
105
+ def run_duckdb(benchmark: Any, options: Sequence[str] | None = None) -> None:
106
+ """Run the benchmark with DuckDB."""
107
+ args = parse_args(options, num_queries=99)
108
+ vars(args).update({"query_set": benchmark.name})
109
+ run_config = RunConfig.from_args(args)
110
+ records: defaultdict[int, list[Record]] = defaultdict(list)
111
+
112
+ for q_id in run_config.queries:
113
+ try:
114
+ duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
115
+ except AttributeError as err:
116
+ raise NotImplementedError(f"Query {q_id} not implemented.") from err
117
+
118
+ print(f"DuckDB Executing: {q_id}")
119
+ records[q_id] = []
120
+
121
+ for i in range(args.iterations):
122
+ t0 = time.time()
123
+
124
+ result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
125
+
126
+ t1 = time.time()
127
+ record = Record(query=q_id, duration=t1 - t0)
128
+ if args.print_results:
129
+ print(result)
130
+
131
+ print(f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s")
132
+ records[q_id].append(record)
133
+
134
+
135
+ def run_validate(benchmark: Any, options: Sequence[str] | None = None) -> None:
136
+ """Validate Polars CPU vs DuckDB or Polars GPU."""
137
+ from polars.testing import assert_frame_equal
138
+
139
+ args = parse_args(options, num_queries=99)
140
+ vars(args).update({"query_set": benchmark.name})
141
+ run_config = RunConfig.from_args(args)
142
+
143
+ baseline = args.baseline
144
+ if baseline not in {"duckdb", "cpu"}:
145
+ raise ValueError("Baseline must be one of: 'duckdb', 'cpu'")
146
+
147
+ failures: list[int] = []
148
+
149
+ engine: pl.GPUEngine | None = None
150
+ if run_config.executor != "cpu":
151
+ engine = pl.GPUEngine(
152
+ raise_on_fail=True,
153
+ executor=run_config.executor,
154
+ executor_options=get_executor_options(run_config, PDSDSPolarsQueries),
155
+ )
156
+
157
+ for q_id in run_config.queries:
158
+ print(f"\nValidating Query {q_id}")
159
+ try:
160
+ polars_query = getattr(PDSDSPolarsQueries, f"q{q_id}")(run_config)
161
+ duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
162
+ except AttributeError as err:
163
+ raise NotImplementedError(f"Query {q_id} not implemented.") from err
164
+
165
+ if baseline == "duckdb":
166
+ base_result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
167
+ elif baseline == "cpu":
168
+ base_result = polars_query.collect(new_streaming=True)
169
+
170
+ if run_config.executor == "cpu":
171
+ test_result = polars_query.collect(new_streaming=True)
172
+ else:
173
+ try:
174
+ test_result = polars_query.collect(engine=engine)
175
+ except Exception as e:
176
+ failures.append(q_id)
177
+ print(f"❌ Query {q_id} failed validation: GPU execution failed.\n{e}")
178
+ continue
179
+
180
+ try:
181
+ assert_frame_equal(
182
+ base_result,
183
+ test_result,
184
+ check_dtypes=True,
185
+ check_column_order=False,
186
+ )
187
+ print(f"✅ Query {q_id} passed validation.")
188
+ except AssertionError as e:
189
+ failures.append(q_id)
190
+ print(f"❌ Query {q_id} failed validation:\n{e}")
191
+ if args.print_results:
192
+ print("Baseline Result:\n", base_result)
193
+ print("Test Result:\n", test_result)
194
+
195
+ if failures:
196
+ print("\nValidation Summary:")
197
+ print("===================")
198
+ print(f"{len(failures)} query(s) failed: {failures}")
199
+ else:
200
+ print("\nAll queries passed validation.")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ import argparse
205
+
206
+ parser = argparse.ArgumentParser(description="Run PDS-DS benchmarks.")
207
+ parser.add_argument(
208
+ "--engine",
209
+ choices=["polars", "duckdb", "validate"],
210
+ default="polars",
211
+ help="Which engine to use for executing the benchmarks or to validate results.",
212
+ )
213
+ args, extra_args = parser.parse_known_args()
214
+
215
+ if args.engine == "polars":
216
+ run_polars(PDSDSPolarsQueries, extra_args, num_queries=99)
217
+ elif args.engine == "duckdb":
218
+ run_duckdb(PDSDSDuckDBQueries, extra_args)
219
+ elif args.engine == "validate":
220
+ run_validate(PDSDSQueries, extra_args)
@@ -0,0 +1,4 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """DuckDB and Polars queries."""
@@ -0,0 +1,88 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Query 1."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING
9
+
10
+ import polars as pl
11
+
12
+ from cudf_polars.experimental.benchmarks.utils import get_data
13
+
14
+ if TYPE_CHECKING:
15
+ from cudf_polars.experimental.benchmarks.utils import RunConfig
16
+
17
+
18
+ def duckdb_impl(run_config: RunConfig) -> str:
19
+ """Query 1."""
20
+ return """
21
+ WITH customer_total_return
22
+ AS (SELECT sr_customer_sk AS ctr_customer_sk,
23
+ sr_store_sk AS ctr_store_sk,
24
+ Sum(sr_return_amt) AS ctr_total_return
25
+ FROM store_returns,
26
+ date_dim
27
+ WHERE sr_returned_date_sk = d_date_sk
28
+ AND d_year = 2001
29
+ GROUP BY sr_customer_sk,
30
+ sr_store_sk)
31
+ SELECT c_customer_id
32
+ FROM customer_total_return ctr1,
33
+ store,
34
+ customer
35
+ WHERE ctr1.ctr_total_return > (SELECT Avg(ctr_total_return) * 1.2
36
+ FROM customer_total_return ctr2
37
+ WHERE ctr1.ctr_store_sk = ctr2.ctr_store_sk)
38
+ AND s_store_sk = ctr1.ctr_store_sk
39
+ AND s_state = 'TN'
40
+ AND ctr1.ctr_customer_sk = c_customer_sk
41
+ ORDER BY c_customer_id
42
+ LIMIT 100;
43
+ """
44
+
45
+
46
+ def polars_impl(run_config: RunConfig) -> pl.LazyFrame:
47
+ """Query 1."""
48
+ store_returns = get_data(
49
+ run_config.dataset_path, "store_returns", run_config.suffix
50
+ )
51
+ date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix)
52
+ store = get_data(run_config.dataset_path, "store", run_config.suffix)
53
+ customer = get_data(run_config.dataset_path, "customer", run_config.suffix)
54
+
55
+ # Step 1: Create customer_total_return CTE equivalent
56
+ customer_total_return = (
57
+ store_returns.join(
58
+ date_dim, left_on="sr_returned_date_sk", right_on="d_date_sk"
59
+ )
60
+ .filter(pl.col("d_year") == 2001)
61
+ .group_by(["sr_customer_sk", "sr_store_sk"])
62
+ .agg(pl.col("sr_return_amt").sum().alias("ctr_total_return"))
63
+ .rename(
64
+ {
65
+ "sr_customer_sk": "ctr_customer_sk",
66
+ "sr_store_sk": "ctr_store_sk",
67
+ }
68
+ )
69
+ )
70
+
71
+ # Step 2: Calculate average return per store for the subquery
72
+ store_avg_returns = customer_total_return.group_by("ctr_store_sk").agg(
73
+ [(pl.col("ctr_total_return").mean() * 1.2).alias("avg_return_threshold")]
74
+ )
75
+
76
+ # Step 3: Join everything together and apply filters
77
+ return (
78
+ customer_total_return.join(
79
+ store_avg_returns, left_on="ctr_store_sk", right_on="ctr_store_sk"
80
+ )
81
+ .filter(pl.col("ctr_total_return") > pl.col("avg_return_threshold"))
82
+ .join(store, left_on="ctr_store_sk", right_on="s_store_sk")
83
+ .filter(pl.col("s_state") == "TN")
84
+ .join(customer, left_on="ctr_customer_sk", right_on="c_customer_sk")
85
+ .select(["c_customer_id"])
86
+ .sort("c_customer_id")
87
+ .limit(100)
88
+ )