cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,832 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Utility functions/classes for running the PDS-H and PDS-DS benchmarks."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import dataclasses
10
+ import importlib
11
+ import json
12
+ import os
13
+ import statistics
14
+ import sys
15
+ import textwrap
16
+ import time
17
+ import traceback
18
+ from collections import defaultdict
19
+ from datetime import datetime, timezone
20
+ from typing import TYPE_CHECKING, Any, Literal, assert_never
21
+
22
+ import nvtx
23
+
24
+ import polars as pl
25
+
26
+ try:
27
+ import pynvml
28
+ except ImportError:
29
+ pynvml = None
30
+
31
+ try:
32
+ from cudf_polars.dsl.translate import Translator
33
+ from cudf_polars.experimental.explain import explain_query
34
+ from cudf_polars.experimental.parallel import evaluate_streaming
35
+ from cudf_polars.testing.asserts import assert_gpu_result_equal
36
+ from cudf_polars.utils.config import ConfigOptions
37
+
38
+ CUDF_POLARS_AVAILABLE = True
39
+ except ImportError:
40
+ CUDF_POLARS_AVAILABLE = False
41
+
42
+ if TYPE_CHECKING:
43
+ from collections.abc import Callable, Sequence
44
+ from pathlib import Path
45
+
46
+
47
+ ExecutorType = Literal["in-memory", "streaming", "cpu"]
48
+
49
+
50
+ @dataclasses.dataclass
51
+ class Record:
52
+ """Results for a single run of a single PDS-H query."""
53
+
54
+ query: int
55
+ duration: float
56
+ shuffle_stats: dict[str, dict[str, int | float]] | None = None
57
+
58
+
59
+ @dataclasses.dataclass
60
+ class VersionInfo:
61
+ """Information about the commit of the software used to run the query."""
62
+
63
+ version: str
64
+ commit: str
65
+
66
+
67
+ @dataclasses.dataclass
68
+ class PackageVersions:
69
+ """Information about the versions of the software used to run the query."""
70
+
71
+ cudf_polars: str | VersionInfo
72
+ polars: str
73
+ python: str
74
+ rapidsmpf: str | VersionInfo | None
75
+
76
+ @classmethod
77
+ def collect(cls) -> PackageVersions:
78
+ """Collect the versions of the software used to run the query."""
79
+ packages = [
80
+ "cudf_polars",
81
+ "polars",
82
+ "rapidsmpf",
83
+ ]
84
+ versions: dict[str, str | VersionInfo | None] = {}
85
+ for name in packages:
86
+ try:
87
+ package = importlib.import_module(name)
88
+ except (AttributeError, ImportError): # noqa: PERF203
89
+ versions[name] = None
90
+ else:
91
+ if name in ("cudf_polars", "rapidsmpf"):
92
+ versions[name] = VersionInfo(
93
+ version=package.__version__,
94
+ commit=package.__git_commit__,
95
+ )
96
+ else:
97
+ versions[name] = package.__version__
98
+
99
+ versions["python"] = ".".join(str(v) for v in sys.version_info[:3])
100
+ # we manually ensure that only cudf-polars and rapidsmpf have a VersionInfo
101
+ return cls(**versions) # type: ignore[arg-type]
102
+
103
+
104
+ @dataclasses.dataclass
105
+ class GPUInfo:
106
+ """Information about a specific GPU."""
107
+
108
+ name: str
109
+ index: int
110
+ free_memory: int | None
111
+ used_memory: int | None
112
+ total_memory: int | None
113
+
114
+ @classmethod
115
+ def from_index(cls, index: int) -> GPUInfo:
116
+ """Create a GPUInfo from an index."""
117
+ pynvml.nvmlInit()
118
+ handle = pynvml.nvmlDeviceGetHandleByIndex(index)
119
+ try:
120
+ memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
121
+ return cls(
122
+ name=pynvml.nvmlDeviceGetName(handle),
123
+ index=index,
124
+ free_memory=memory.free,
125
+ used_memory=memory.used,
126
+ total_memory=memory.total,
127
+ )
128
+ except pynvml.NVMLError_NotSupported:
129
+ # Happens on systems without traditional GPU memory (e.g., Grace Hopper),
130
+ # where nvmlDeviceGetMemoryInfo is not supported.
131
+ # See: https://github.com/rapidsai/cudf/issues/19427
132
+ return cls(
133
+ name=pynvml.nvmlDeviceGetName(handle),
134
+ index=index,
135
+ free_memory=None,
136
+ used_memory=None,
137
+ total_memory=None,
138
+ )
139
+
140
+
141
+ @dataclasses.dataclass
142
+ class HardwareInfo:
143
+ """Information about the hardware used to run the query."""
144
+
145
+ gpus: list[GPUInfo]
146
+ # TODO: ucx
147
+
148
+ @classmethod
149
+ def collect(cls) -> HardwareInfo:
150
+ """Collect the hardware information."""
151
+ if pynvml is not None:
152
+ pynvml.nvmlInit()
153
+ gpus = [GPUInfo.from_index(i) for i in range(pynvml.nvmlDeviceGetCount())]
154
+ else:
155
+ # No GPUs -- probably running in CPU mode
156
+ gpus = []
157
+ return cls(gpus=gpus)
158
+
159
+
160
+ def _infer_scale_factor(name: str, path: str | Path, suffix: str) -> int | float:
161
+ if "pdsh" in name:
162
+ supplier = get_data(path, "supplier", suffix)
163
+ num_rows = supplier.select(pl.len()).collect().item(0, 0)
164
+ return num_rows / 10_000
165
+
166
+ elif "pdsds" in name:
167
+ # TODO: Keep a map of SF-row_count because of nonlinear scaling
168
+ # See: https://www.tpc.org/TPC_Documents_Current_Versions/pdf/TPC-DS_v4.0.0.pdf pg.46
169
+ customer = get_data(path, "promotion", suffix)
170
+ num_rows = customer.select(pl.len()).collect().item(0, 0)
171
+ return num_rows / 300
172
+
173
+ else:
174
+ raise ValueError(f"Invalid benchmark script name: '{name}'.")
175
+
176
+
177
+ @dataclasses.dataclass(kw_only=True)
178
+ class RunConfig:
179
+ """Results for a PDS-H or PDS-DS query run."""
180
+
181
+ queries: list[int]
182
+ suffix: str
183
+ executor: ExecutorType
184
+ scheduler: str
185
+ n_workers: int
186
+ versions: PackageVersions = dataclasses.field(
187
+ default_factory=PackageVersions.collect
188
+ )
189
+ records: dict[int, list[Record]] = dataclasses.field(default_factory=dict)
190
+ dataset_path: Path
191
+ scale_factor: int | float
192
+ shuffle: Literal["rapidsmpf", "tasks"] | None = None
193
+ gather_shuffle_stats: bool = False
194
+ broadcast_join_limit: int | None = None
195
+ blocksize: int | None = None
196
+ max_rows_per_partition: int | None = None
197
+ threads: int
198
+ iterations: int
199
+ timestamp: str = dataclasses.field(
200
+ default_factory=lambda: datetime.now(timezone.utc).isoformat()
201
+ )
202
+ hardware: HardwareInfo = dataclasses.field(default_factory=HardwareInfo.collect)
203
+ rmm_async: bool
204
+ rapidsmpf_oom_protection: bool
205
+ rapidsmpf_spill: bool
206
+ spill_device: float
207
+ query_set: str
208
+ stats_planning: bool
209
+
210
+ def __post_init__(self) -> None: # noqa: D105
211
+ if self.gather_shuffle_stats and self.shuffle != "rapidsmpf":
212
+ raise ValueError(
213
+ "gather_shuffle_stats is only supported when shuffle='rapidsmpf'."
214
+ )
215
+
216
+ @classmethod
217
+ def from_args(cls, args: argparse.Namespace) -> RunConfig:
218
+ """Create a RunConfig from command line arguments."""
219
+ executor: ExecutorType = args.executor
220
+ scheduler = args.scheduler
221
+
222
+ if executor == "in-memory" or executor == "cpu":
223
+ scheduler = None
224
+
225
+ path = args.path
226
+ name = args.query_set
227
+ scale_factor = args.scale
228
+
229
+ if scale_factor is None:
230
+ if "pdsds" in name:
231
+ raise ValueError(
232
+ "--scale is required for PDS-DS benchmarks.\n"
233
+ "TODO: This will be inferred once we maintain a map of scale factors to row counts."
234
+ )
235
+ if path is None:
236
+ raise ValueError(
237
+ "Must specify --root and --scale if --path is not specified."
238
+ )
239
+ # For PDS-H, infer scale factor based on row count
240
+ scale_factor = _infer_scale_factor(name, path, args.suffix)
241
+ if path is None:
242
+ path = f"{args.root}/scale-{scale_factor}"
243
+ try:
244
+ scale_factor = int(scale_factor)
245
+ except ValueError:
246
+ scale_factor = float(scale_factor)
247
+
248
+ if "pdsh" in name and args.scale is not None:
249
+ # Validate the user-supplied scale factor
250
+ sf_inf = _infer_scale_factor(name, path, args.suffix)
251
+ rel_error = abs((scale_factor - sf_inf) / sf_inf)
252
+ if rel_error > 0.01:
253
+ raise ValueError(
254
+ f"Specified scale factor is {args.scale}, "
255
+ f"but the inferred scale factor is {sf_inf}."
256
+ )
257
+
258
+ return cls(
259
+ queries=args.query,
260
+ executor=executor,
261
+ scheduler=scheduler,
262
+ n_workers=args.n_workers,
263
+ shuffle=args.shuffle,
264
+ gather_shuffle_stats=args.rapidsmpf_dask_statistics,
265
+ broadcast_join_limit=args.broadcast_join_limit,
266
+ dataset_path=path,
267
+ scale_factor=scale_factor,
268
+ blocksize=args.blocksize,
269
+ threads=args.threads,
270
+ iterations=args.iterations,
271
+ suffix=args.suffix,
272
+ rmm_async=args.rmm_async,
273
+ rapidsmpf_oom_protection=args.rapidsmpf_oom_protection,
274
+ spill_device=args.spill_device,
275
+ rapidsmpf_spill=args.rapidsmpf_spill,
276
+ max_rows_per_partition=args.max_rows_per_partition,
277
+ query_set=args.query_set,
278
+ stats_planning=args.stats_planning,
279
+ )
280
+
281
+ def serialize(self, engine: pl.GPUEngine | None) -> dict:
282
+ """Serialize the run config to a dictionary."""
283
+ result = dataclasses.asdict(self)
284
+
285
+ if engine is not None:
286
+ config_options = ConfigOptions.from_polars_engine(engine)
287
+ result["config_options"] = dataclasses.asdict(config_options)
288
+ return result
289
+
290
+ def summarize(self) -> None:
291
+ """Print a summary of the results."""
292
+ print("Iteration Summary")
293
+ print("=======================================")
294
+
295
+ for query, records in self.records.items():
296
+ print(f"query: {query}")
297
+ print(f"path: {self.dataset_path}")
298
+ print(f"scale_factor: {self.scale_factor}")
299
+ print(f"executor: {self.executor}")
300
+ if self.executor == "streaming":
301
+ print(f"scheduler: {self.scheduler}")
302
+ print(f"blocksize: {self.blocksize}")
303
+ print(f"shuffle_method: {self.shuffle}")
304
+ print(f"broadcast_join_limit: {self.broadcast_join_limit}")
305
+ print(f"stats_planning: {self.stats_planning}")
306
+ if self.scheduler == "distributed":
307
+ print(f"n_workers: {self.n_workers}")
308
+ print(f"threads: {self.threads}")
309
+ print(f"rmm_async: {self.rmm_async}")
310
+ print(f"rapidsmpf_oom_protection: {self.rapidsmpf_oom_protection}")
311
+ print(f"spill_device: {self.spill_device}")
312
+ print(f"rapidsmpf_spill: {self.rapidsmpf_spill}")
313
+ if len(records) > 0:
314
+ print(f"iterations: {self.iterations}")
315
+ print("---------------------------------------")
316
+ print(f"min time : {min(record.duration for record in records):0.4f}")
317
+ print(f"max time : {max(record.duration for record in records):0.4f}")
318
+ print(
319
+ f"mean time: {statistics.mean(record.duration for record in records):0.4f}"
320
+ )
321
+ print("=======================================")
322
+ total_mean_time = sum(
323
+ statistics.mean(record.duration for record in records)
324
+ for records in self.records.values()
325
+ if records
326
+ )
327
+ print(f"Total mean time across all queries: {total_mean_time:.4f} seconds")
328
+
329
+
330
+ def get_data(path: str | Path, table_name: str, suffix: str = "") -> pl.LazyFrame:
331
+ """Get table from dataset."""
332
+ return pl.scan_parquet(f"{path}/{table_name}{suffix}")
333
+
334
+
335
+ def get_executor_options(
336
+ run_config: RunConfig, benchmark: Any = None
337
+ ) -> dict[str, Any]:
338
+ """Generate executor_options for GPUEngine."""
339
+ executor_options: dict[str, Any] = {}
340
+
341
+ if run_config.blocksize:
342
+ executor_options["target_partition_size"] = run_config.blocksize
343
+ if run_config.max_rows_per_partition:
344
+ executor_options["max_rows_per_partition"] = run_config.max_rows_per_partition
345
+ if run_config.shuffle:
346
+ executor_options["shuffle_method"] = run_config.shuffle
347
+ if run_config.broadcast_join_limit:
348
+ executor_options["broadcast_join_limit"] = run_config.broadcast_join_limit
349
+ if run_config.rapidsmpf_spill:
350
+ executor_options["rapidsmpf_spill"] = run_config.rapidsmpf_spill
351
+ if run_config.scheduler == "distributed":
352
+ executor_options["scheduler"] = "distributed"
353
+ if run_config.stats_planning:
354
+ executor_options["stats_planning"] = {"use_reduction_planning": True}
355
+
356
+ if (
357
+ benchmark
358
+ and benchmark.__name__ == "PDSHQueries"
359
+ and run_config.executor == "streaming"
360
+ # Only use the unique_fraction config if stats_planning is disabled
361
+ and not run_config.stats_planning
362
+ ):
363
+ executor_options["unique_fraction"] = {
364
+ "c_custkey": 0.05,
365
+ "l_orderkey": 1.0,
366
+ "l_partkey": 0.1,
367
+ "o_custkey": 0.25,
368
+ }
369
+
370
+ return executor_options
371
+
372
+
373
+ def print_query_plan(
374
+ q_id: int,
375
+ q: pl.LazyFrame,
376
+ args: argparse.Namespace,
377
+ run_config: RunConfig,
378
+ engine: None | pl.GPUEngine = None,
379
+ ) -> None:
380
+ """Print the query plan."""
381
+ if run_config.executor == "cpu":
382
+ if args.explain_logical:
383
+ print(f"\nQuery {q_id} - Logical plan\n")
384
+ print(q.explain())
385
+ if args.explain:
386
+ print(f"\nQuery {q_id} - Physical plan\n")
387
+ print(q.show_graph(engine="streaming", plan_stage="physical"))
388
+ elif CUDF_POLARS_AVAILABLE:
389
+ assert isinstance(engine, pl.GPUEngine)
390
+ if args.explain_logical:
391
+ print(f"\nQuery {q_id} - Logical plan\n")
392
+ print(explain_query(q, engine, physical=False))
393
+ if args.explain:
394
+ print(f"\nQuery {q_id} - Physical plan\n")
395
+ print(explain_query(q, engine))
396
+ else:
397
+ raise RuntimeError(
398
+ "Cannot provide the logical or physical plan because cudf_polars is not installed."
399
+ )
400
+
401
+
402
+ def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): # type: ignore
403
+ """Initialize a Dask distributed cluster."""
404
+ if run_config.scheduler != "distributed":
405
+ return None
406
+
407
+ from dask_cuda import LocalCUDACluster
408
+ from distributed import Client
409
+
410
+ kwargs = {
411
+ "n_workers": run_config.n_workers,
412
+ "dashboard_address": ":8585",
413
+ "protocol": args.protocol,
414
+ "rmm_pool_size": args.rmm_pool_size,
415
+ "rmm_async": args.rmm_async,
416
+ "rmm_release_threshold": args.rmm_release_threshold,
417
+ "threads_per_worker": run_config.threads,
418
+ }
419
+
420
+ # Avoid UVM in distributed cluster
421
+ client = Client(LocalCUDACluster(**kwargs))
422
+ client.wait_for_workers(run_config.n_workers)
423
+
424
+ if run_config.shuffle != "tasks":
425
+ try:
426
+ from rapidsmpf.config import Options
427
+ from rapidsmpf.integrations.dask import bootstrap_dask_cluster
428
+
429
+ bootstrap_dask_cluster(
430
+ client,
431
+ options=Options(
432
+ {
433
+ "dask_spill_device": str(run_config.spill_device),
434
+ "dask_statistics": str(args.rapidsmpf_dask_statistics),
435
+ "dask_print_statistics": str(args.rapidsmpf_print_statistics),
436
+ "oom_protection": str(args.rapidsmpf_oom_protection),
437
+ }
438
+ ),
439
+ )
440
+ except ImportError as err:
441
+ if run_config.shuffle == "rapidsmpf":
442
+ raise ImportError(
443
+ "rapidsmpf is required for shuffle='rapidsmpf' but is not installed."
444
+ ) from err
445
+
446
+ return client
447
+
448
+
449
+ def execute_query(
450
+ q_id: int,
451
+ i: int,
452
+ q: pl.LazyFrame,
453
+ run_config: RunConfig,
454
+ args: argparse.Namespace,
455
+ engine: None | pl.GPUEngine = None,
456
+ ) -> pl.DataFrame:
457
+ """Execute a query with NVTX annotation."""
458
+ with nvtx.annotate(
459
+ message=f"Query {q_id} - Iteration {i}",
460
+ domain="cudf_polars",
461
+ color="green",
462
+ ):
463
+ if run_config.executor == "cpu":
464
+ return q.collect(engine="streaming")
465
+
466
+ elif CUDF_POLARS_AVAILABLE:
467
+ assert isinstance(engine, pl.GPUEngine)
468
+ if args.debug:
469
+ translator = Translator(q._ldf.visit(), engine)
470
+ ir = translator.translate_ir()
471
+ if run_config.executor == "in-memory":
472
+ return ir.evaluate(cache={}, timer=None).to_polars()
473
+ elif run_config.executor == "streaming":
474
+ return evaluate_streaming(ir, translator.config_options).to_polars()
475
+ assert_never(run_config.executor)
476
+ else:
477
+ return q.collect(engine=engine)
478
+
479
+ else:
480
+ raise RuntimeError("The requested engine is not supported.")
481
+
482
+
483
+ def _query_type(num_queries: int) -> Callable[[str | int], list[int]]:
484
+ def parse(query: str | int) -> list[int]:
485
+ if isinstance(query, int):
486
+ return [query]
487
+ if query == "all":
488
+ return list(range(1, num_queries + 1))
489
+
490
+ result: set[int] = set()
491
+ for part in query.split(","):
492
+ if "-" in part:
493
+ start, end = part.split("-")
494
+ result.update(range(int(start), int(end) + 1))
495
+ else:
496
+ result.add(int(part))
497
+ return sorted(result)
498
+
499
+ return parse
500
+
501
+
502
+ def parse_args(
503
+ args: Sequence[str] | None = None, num_queries: int = 22
504
+ ) -> argparse.Namespace:
505
+ """Parse command line arguments."""
506
+ parser = argparse.ArgumentParser(
507
+ prog="Cudf-Polars PDS-H Benchmarks",
508
+ description="Experimental streaming-executor benchmarks.",
509
+ formatter_class=argparse.RawTextHelpFormatter,
510
+ )
511
+ parser.add_argument(
512
+ "query",
513
+ type=_query_type(num_queries),
514
+ help=textwrap.dedent("""\
515
+ Query to run. One of the following:
516
+ - A single number (e.g. 11)
517
+ - A comma-separated list of query numbers (e.g. 1,3,7)
518
+ - A range of query number (e.g. 1-11,23-34)
519
+ - The string 'all' to run all queries (1 through 22)"""),
520
+ )
521
+ parser.add_argument(
522
+ "--path",
523
+ type=str,
524
+ default=os.environ.get("PDSH_DATASET_PATH"),
525
+ help=textwrap.dedent("""\
526
+ Path to the root directory of the PDS-H dataset.
527
+ Defaults to the PDSH_DATASET_PATH environment variable."""),
528
+ )
529
+ parser.add_argument(
530
+ "--root",
531
+ type=str,
532
+ default=os.environ.get("PDSH_DATASET_ROOT"),
533
+ help="Root PDS-H dataset directory (ignored if --path is used).",
534
+ )
535
+ parser.add_argument(
536
+ "--scale",
537
+ type=str,
538
+ default=None,
539
+ help="Dataset scale factor.",
540
+ )
541
+ parser.add_argument(
542
+ "--suffix",
543
+ type=str,
544
+ default=".parquet",
545
+ help=textwrap.dedent("""\
546
+ File suffix for input table files.
547
+ Default: .parquet"""),
548
+ )
549
+ parser.add_argument(
550
+ "-e",
551
+ "--executor",
552
+ default="streaming",
553
+ type=str,
554
+ choices=["in-memory", "streaming", "cpu"],
555
+ help=textwrap.dedent("""\
556
+ Query executor backend:
557
+ - in-memory : Evaluate query in GPU memory
558
+ - streaming : Partitioned evaluation (default)
559
+ - cpu : Use Polars CPU engine"""),
560
+ )
561
+ parser.add_argument(
562
+ "-s",
563
+ "--scheduler",
564
+ default="synchronous",
565
+ type=str,
566
+ choices=["synchronous", "distributed"],
567
+ help=textwrap.dedent("""\
568
+ Scheduler type to use with the 'streaming' executor.
569
+ - synchronous : Run locally in a single process
570
+ - distributed : Use Dask for multi-GPU execution"""),
571
+ )
572
+ parser.add_argument(
573
+ "--n-workers",
574
+ default=1,
575
+ type=int,
576
+ help="Number of Dask-CUDA workers (requires 'distributed' scheduler).",
577
+ )
578
+ parser.add_argument(
579
+ "--blocksize",
580
+ default=None,
581
+ type=int,
582
+ help="Target partition size, in bytes, for IO tasks.",
583
+ )
584
+ parser.add_argument(
585
+ "--max-rows-per-partition",
586
+ default=None,
587
+ type=int,
588
+ help="The maximum number of rows to process per partition.",
589
+ )
590
+ parser.add_argument(
591
+ "--iterations",
592
+ default=1,
593
+ type=int,
594
+ help="Number of times to run the same query.",
595
+ )
596
+ parser.add_argument(
597
+ "--debug",
598
+ default=False,
599
+ action="store_true",
600
+ help="Debug run.",
601
+ )
602
+ parser.add_argument(
603
+ "--protocol",
604
+ default="ucx",
605
+ type=str,
606
+ choices=["ucx"],
607
+ help="Communication protocol to use for Dask: ucx (uses ucxx)",
608
+ )
609
+ parser.add_argument(
610
+ "--shuffle",
611
+ default=None,
612
+ type=str,
613
+ choices=[None, "rapidsmpf", "tasks"],
614
+ help="Shuffle method to use for distributed execution.",
615
+ )
616
+ parser.add_argument(
617
+ "--broadcast-join-limit",
618
+ default=None,
619
+ type=int,
620
+ help="Set an explicit `broadcast_join_limit` option.",
621
+ )
622
+ parser.add_argument(
623
+ "--threads",
624
+ default=1,
625
+ type=int,
626
+ help="Number of threads to use on each GPU.",
627
+ )
628
+ parser.add_argument(
629
+ "--rmm-pool-size",
630
+ default=0.5,
631
+ type=float,
632
+ help=textwrap.dedent("""\
633
+ Fraction of total GPU memory to allocate for RMM pool.
634
+ Default: 0.5 (50%% of GPU memory)"""),
635
+ )
636
+ parser.add_argument(
637
+ "--rmm-release-threshold",
638
+ default=None,
639
+ type=float,
640
+ help=textwrap.dedent("""\
641
+ Passed to dask_cuda.LocalCUDACluster to control the release
642
+ threshold for RMM pool memory.
643
+ Default: None (no release threshold)"""),
644
+ )
645
+ parser.add_argument(
646
+ "--rmm-async",
647
+ action=argparse.BooleanOptionalAction,
648
+ default=False,
649
+ help="Use RMM async memory resource.",
650
+ )
651
+ parser.add_argument(
652
+ "--rapidsmpf-oom-protection",
653
+ action=argparse.BooleanOptionalAction,
654
+ default=False,
655
+ help="Use rapidsmpf CUDA managed memory-based OOM protection.",
656
+ )
657
+ parser.add_argument(
658
+ "--rapidsmpf-dask-statistics",
659
+ action=argparse.BooleanOptionalAction,
660
+ default=False,
661
+ help="Collect rapidsmpf shuffle statistics. The output will be stored in the 'shuffle_stats' field of each record.",
662
+ )
663
+ parser.add_argument(
664
+ "--rapidsmpf-print-statistics",
665
+ action=argparse.BooleanOptionalAction,
666
+ default=False,
667
+ help="Print rapidsmpf shuffle statistics on each Dask worker upon completion.",
668
+ )
669
+ parser.add_argument(
670
+ "--rapidsmpf-spill",
671
+ action=argparse.BooleanOptionalAction,
672
+ default=False,
673
+ help="Use rapidsmpf for general spilling.",
674
+ )
675
+ parser.add_argument(
676
+ "--spill-device",
677
+ default=0.5,
678
+ type=float,
679
+ help="Rapidsmpf device spill threshold.",
680
+ )
681
+ parser.add_argument(
682
+ "-o",
683
+ "--output",
684
+ type=argparse.FileType("at"),
685
+ default="pdsh_results.jsonl",
686
+ help="Output file path.",
687
+ )
688
+ parser.add_argument(
689
+ "--summarize",
690
+ action=argparse.BooleanOptionalAction,
691
+ help="Summarize the results.",
692
+ default=True,
693
+ )
694
+ parser.add_argument(
695
+ "--print-results",
696
+ action=argparse.BooleanOptionalAction,
697
+ help="Print the query results",
698
+ default=True,
699
+ )
700
+ parser.add_argument(
701
+ "--explain",
702
+ action=argparse.BooleanOptionalAction,
703
+ help="Print an outline of the physical plan",
704
+ default=False,
705
+ )
706
+ parser.add_argument(
707
+ "--explain-logical",
708
+ action=argparse.BooleanOptionalAction,
709
+ help="Print an outline of the logical plan",
710
+ default=False,
711
+ )
712
+ parser.add_argument(
713
+ "--validate",
714
+ action=argparse.BooleanOptionalAction,
715
+ default=False,
716
+ help="Validate the result against CPU execution.",
717
+ )
718
+ parser.add_argument(
719
+ "--baseline",
720
+ choices=["duckdb", "cpu"],
721
+ default="duckdb",
722
+ help="Which engine to use as the baseline for validation.",
723
+ )
724
+ parser.add_argument(
725
+ "--stats-planning",
726
+ action=argparse.BooleanOptionalAction,
727
+ default=False,
728
+ help="Enable statistics planning.",
729
+ )
730
+ return parser.parse_args(args)
731
+
732
+
733
+ def run_polars(
734
+ benchmark: Any,
735
+ options: Sequence[str] | None = None,
736
+ num_queries: int = 22,
737
+ ) -> None:
738
+ """Run the queries using the given benchmark and executor options."""
739
+ args = parse_args(options, num_queries=num_queries)
740
+ vars(args).update({"query_set": benchmark.name})
741
+ run_config = RunConfig.from_args(args)
742
+ validation_failures: list[int] = []
743
+ query_failures: list[tuple[int, int]] = []
744
+
745
+ client = initialize_dask_cluster(run_config, args) # type: ignore
746
+
747
+ records: defaultdict[int, list[Record]] = defaultdict(list)
748
+ engine: pl.GPUEngine | None = None
749
+
750
+ if run_config.executor != "cpu":
751
+ executor_options = get_executor_options(run_config, benchmark=benchmark)
752
+ engine = pl.GPUEngine(
753
+ raise_on_fail=True,
754
+ executor=run_config.executor,
755
+ executor_options=executor_options,
756
+ )
757
+
758
+ for q_id in run_config.queries:
759
+ try:
760
+ q = getattr(benchmark, f"q{q_id}")(run_config)
761
+ except AttributeError as err:
762
+ raise NotImplementedError(f"Query {q_id} not implemented.") from err
763
+
764
+ print_query_plan(q_id, q, args, run_config, engine)
765
+
766
+ records[q_id] = []
767
+
768
+ for i in range(args.iterations):
769
+ t0 = time.monotonic()
770
+
771
+ try:
772
+ result = execute_query(q_id, i, q, run_config, args, engine)
773
+ except Exception:
774
+ print(f"❌ query={q_id} iteration={i} failed!")
775
+ print(traceback.format_exc())
776
+ query_failures.append((q_id, i))
777
+ continue
778
+ if run_config.shuffle == "rapidsmpf" and run_config.gather_shuffle_stats:
779
+ from rapidsmpf.integrations.dask.shuffler import (
780
+ clear_shuffle_statistics,
781
+ gather_shuffle_statistics,
782
+ )
783
+
784
+ shuffle_stats = gather_shuffle_statistics(client) # type: ignore[arg-type]
785
+ clear_shuffle_statistics(client) # type: ignore[arg-type]
786
+ else:
787
+ shuffle_stats = None
788
+
789
+ if args.validate and run_config.executor != "cpu":
790
+ try:
791
+ assert_gpu_result_equal(
792
+ q,
793
+ engine=engine,
794
+ executor=run_config.executor,
795
+ check_exact=False,
796
+ )
797
+ print(f"✅ Query {q_id} passed validation!")
798
+ except AssertionError as e:
799
+ validation_failures.append(q_id)
800
+ print(f"❌ Query {q_id} failed validation!\n{e}")
801
+
802
+ t1 = time.monotonic()
803
+ record = Record(query=q_id, duration=t1 - t0, shuffle_stats=shuffle_stats)
804
+ if args.print_results:
805
+ print(result)
806
+
807
+ print(f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s")
808
+ records[q_id].append(record)
809
+
810
+ run_config = dataclasses.replace(run_config, records=dict(records))
811
+
812
+ if args.summarize:
813
+ run_config.summarize()
814
+
815
+ if client is not None:
816
+ client.close(timeout=60)
817
+
818
+ if args.validate and run_config.executor != "cpu":
819
+ print("\nValidation Summary")
820
+ print("==================")
821
+ if validation_failures:
822
+ print(
823
+ f"{len(validation_failures)} queries failed validation: {sorted(set(validation_failures))}"
824
+ )
825
+ else:
826
+ print("All validated queries passed.")
827
+
828
+ args.output.write(json.dumps(run_config.serialize(engine=engine)))
829
+ args.output.write("\n")
830
+
831
+ if query_failures or validation_failures:
832
+ sys.exit(1)