cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/callback.py +82 -65
  3. cudf_polars/containers/column.py +138 -7
  4. cudf_polars/containers/dataframe.py +26 -39
  5. cudf_polars/dsl/expr.py +3 -1
  6. cudf_polars/dsl/expressions/aggregation.py +27 -63
  7. cudf_polars/dsl/expressions/base.py +40 -72
  8. cudf_polars/dsl/expressions/binaryop.py +5 -41
  9. cudf_polars/dsl/expressions/boolean.py +25 -53
  10. cudf_polars/dsl/expressions/datetime.py +97 -17
  11. cudf_polars/dsl/expressions/literal.py +27 -33
  12. cudf_polars/dsl/expressions/rolling.py +110 -9
  13. cudf_polars/dsl/expressions/selection.py +8 -26
  14. cudf_polars/dsl/expressions/slicing.py +47 -0
  15. cudf_polars/dsl/expressions/sorting.py +5 -18
  16. cudf_polars/dsl/expressions/string.py +33 -36
  17. cudf_polars/dsl/expressions/ternary.py +3 -10
  18. cudf_polars/dsl/expressions/unary.py +35 -75
  19. cudf_polars/dsl/ir.py +749 -212
  20. cudf_polars/dsl/nodebase.py +8 -1
  21. cudf_polars/dsl/to_ast.py +5 -3
  22. cudf_polars/dsl/translate.py +319 -171
  23. cudf_polars/dsl/utils/__init__.py +8 -0
  24. cudf_polars/dsl/utils/aggregations.py +292 -0
  25. cudf_polars/dsl/utils/groupby.py +97 -0
  26. cudf_polars/dsl/utils/naming.py +34 -0
  27. cudf_polars/dsl/utils/replace.py +46 -0
  28. cudf_polars/dsl/utils/rolling.py +113 -0
  29. cudf_polars/dsl/utils/windows.py +186 -0
  30. cudf_polars/experimental/base.py +17 -19
  31. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  32. cudf_polars/experimental/benchmarks/pdsh.py +1279 -0
  33. cudf_polars/experimental/dask_registers.py +196 -0
  34. cudf_polars/experimental/distinct.py +174 -0
  35. cudf_polars/experimental/explain.py +127 -0
  36. cudf_polars/experimental/expressions.py +521 -0
  37. cudf_polars/experimental/groupby.py +288 -0
  38. cudf_polars/experimental/io.py +58 -29
  39. cudf_polars/experimental/join.py +353 -0
  40. cudf_polars/experimental/parallel.py +166 -93
  41. cudf_polars/experimental/repartition.py +69 -0
  42. cudf_polars/experimental/scheduler.py +155 -0
  43. cudf_polars/experimental/select.py +92 -7
  44. cudf_polars/experimental/shuffle.py +294 -0
  45. cudf_polars/experimental/sort.py +45 -0
  46. cudf_polars/experimental/spilling.py +151 -0
  47. cudf_polars/experimental/utils.py +100 -0
  48. cudf_polars/testing/asserts.py +146 -6
  49. cudf_polars/testing/io.py +72 -0
  50. cudf_polars/testing/plugin.py +78 -76
  51. cudf_polars/typing/__init__.py +59 -6
  52. cudf_polars/utils/config.py +353 -0
  53. cudf_polars/utils/conversion.py +40 -0
  54. cudf_polars/utils/dtypes.py +22 -5
  55. cudf_polars/utils/timer.py +39 -0
  56. cudf_polars/utils/versions.py +5 -4
  57. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/METADATA +10 -7
  58. cudf_polars_cu12-25.6.0.dist-info/RECORD +73 -0
  59. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/WHEEL +1 -1
  60. cudf_polars/experimental/dask_serialize.py +0 -59
  61. cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
  62. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info/licenses}/LICENSE +0 -0
  63. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1279 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Experimental PDS-H benchmarks.
6
+
7
+ Based on https://github.com/pola-rs/polars-benchmark.
8
+
9
+ WARNING: This is an experimental (and unofficial)
10
+ benchmark script. It is not intended for public use
11
+ and may be modified or removed at any time.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import dataclasses
18
+ import importlib
19
+ import json
20
+ import os
21
+ import sys
22
+ import time
23
+ from collections import defaultdict
24
+ from datetime import date, datetime, timezone
25
+ from typing import TYPE_CHECKING, Any
26
+
27
+ import numpy as np
28
+
29
+ import polars as pl
30
+
31
+ try:
32
+ import pynvml
33
+ except ImportError:
34
+ pynvml = None
35
+
36
+ try:
37
+ from cudf_polars.dsl.translate import Translator
38
+ from cudf_polars.experimental.explain import explain_query
39
+ from cudf_polars.experimental.parallel import evaluate_streaming
40
+
41
+ CUDF_POLARS_AVAILABLE = True
42
+ except ImportError:
43
+ CUDF_POLARS_AVAILABLE = False
44
+
45
+ if TYPE_CHECKING:
46
+ import pathlib
47
+
48
+
49
+ # Without this setting, the first IO task to run
50
+ # on each worker takes ~15 sec extra
51
+ os.environ["KVIKIO_COMPAT_MODE"] = os.environ.get("KVIKIO_COMPAT_MODE", "on")
52
+ os.environ["KVIKIO_NTHREADS"] = os.environ.get("KVIKIO_NTHREADS", "8")
53
+
54
+
55
+ @dataclasses.dataclass
56
+ class Record:
57
+ """Results for a single run of a single PDS-H query."""
58
+
59
+ query: int
60
+ duration: float
61
+
62
+
63
+ @dataclasses.dataclass
64
+ class PackageVersions:
65
+ """Information about the versions of the software used to run the query."""
66
+
67
+ cudf_polars: str
68
+ polars: str
69
+ python: str
70
+ rapidsmpf: str | None
71
+
72
+ @classmethod
73
+ def collect(cls) -> PackageVersions:
74
+ """Collect the versions of the software used to run the query."""
75
+ packages = [
76
+ "cudf_polars",
77
+ "polars",
78
+ "rapidsmpf",
79
+ ]
80
+ versions = {}
81
+ for name in packages:
82
+ try:
83
+ package = importlib.import_module(name)
84
+ versions[name] = package.__version__
85
+ except (AttributeError, ImportError): # noqa: PERF203
86
+ versions[name] = None
87
+ versions["python"] = ".".join(str(v) for v in sys.version_info[:3])
88
+ return cls(**versions)
89
+
90
+
91
+ @dataclasses.dataclass
92
+ class GPUInfo:
93
+ """Information about a specific GPU."""
94
+
95
+ name: str
96
+ index: int
97
+ free_memory: int
98
+ used_memory: int
99
+ total_memory: int
100
+
101
+ @classmethod
102
+ def from_index(cls, index: int) -> GPUInfo:
103
+ """Create a GPUInfo from an index."""
104
+ pynvml.nvmlInit()
105
+ handle = pynvml.nvmlDeviceGetHandleByIndex(index)
106
+ memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
107
+ return cls(
108
+ name=pynvml.nvmlDeviceGetName(handle),
109
+ index=index,
110
+ free_memory=memory.free,
111
+ used_memory=memory.used,
112
+ total_memory=memory.total,
113
+ )
114
+
115
+
116
+ @dataclasses.dataclass
117
+ class HardwareInfo:
118
+ """Information about the hardware used to run the query."""
119
+
120
+ gpus: list[GPUInfo]
121
+ # TODO: ucx
122
+
123
+ @classmethod
124
+ def collect(cls) -> HardwareInfo:
125
+ """Collect the hardware information."""
126
+ if pynvml is not None:
127
+ pynvml.nvmlInit()
128
+ gpus = [GPUInfo.from_index(i) for i in range(pynvml.nvmlDeviceGetCount())]
129
+ else:
130
+ # No GPUs -- probably running in CPU mode
131
+ gpus = []
132
+ return cls(gpus=gpus)
133
+
134
+
135
+ @dataclasses.dataclass(kw_only=True)
136
+ class RunConfig:
137
+ """Results for a PDS-H query run."""
138
+
139
+ queries: list[int]
140
+ suffix: str
141
+ executor: str
142
+ scheduler: str
143
+ n_workers: int
144
+ versions: PackageVersions = dataclasses.field(
145
+ default_factory=PackageVersions.collect
146
+ )
147
+ records: dict[int, list[Record]] = dataclasses.field(default_factory=dict)
148
+ dataset_path: pathlib.Path
149
+ shuffle: str | None = None
150
+ broadcast_join_limit: int | None = None
151
+ blocksize: int | None = None
152
+ threads: int
153
+ iterations: int
154
+ timestamp: str = dataclasses.field(
155
+ default_factory=lambda: datetime.now(timezone.utc).isoformat()
156
+ )
157
+ hardware: HardwareInfo = dataclasses.field(default_factory=HardwareInfo.collect)
158
+ rapidsmpf_spill: bool
159
+ spill_device: float
160
+
161
+ @classmethod
162
+ def from_args(cls, args: argparse.Namespace) -> RunConfig:
163
+ """Create a RunConfig from command line arguments."""
164
+ executor = args.executor
165
+ scheduler = args.scheduler
166
+
167
+ if executor == "in-memory" or executor == "cpu":
168
+ scheduler = None
169
+
170
+ return cls(
171
+ queries=args.query,
172
+ executor=executor,
173
+ scheduler=scheduler,
174
+ n_workers=args.n_workers,
175
+ shuffle=args.shuffle,
176
+ broadcast_join_limit=args.broadcast_join_limit,
177
+ dataset_path=args.path,
178
+ blocksize=args.blocksize,
179
+ threads=args.threads,
180
+ iterations=args.iterations,
181
+ suffix=args.suffix,
182
+ spill_device=args.spill_device,
183
+ rapidsmpf_spill=args.rapidsmpf_spill,
184
+ )
185
+
186
+ def serialize(self) -> dict:
187
+ """Serialize the run config to a dictionary."""
188
+ return dataclasses.asdict(self)
189
+
190
+ def summarize(self) -> None:
191
+ """Print a summary of the results."""
192
+ print("Iteration Summary")
193
+ print("=======================================")
194
+
195
+ for query, records in self.records.items():
196
+ print(f"query: {query}")
197
+ print(f"path: {self.dataset_path}")
198
+ print(f"executor: {self.executor}")
199
+ if self.executor == "streaming":
200
+ print(f"scheduler: {self.scheduler}")
201
+ print(f"blocksize: {self.blocksize}")
202
+ print(f"shuffle_method: {self.shuffle}")
203
+ print(f"broadcast_join_limit: {self.broadcast_join_limit}")
204
+ if self.scheduler == "distributed":
205
+ print(f"n_workers: {self.n_workers}")
206
+ print(f"threads: {self.threads}")
207
+ print(f"spill_device: {self.spill_device}")
208
+ print(f"rapidsmpf_spill: {self.rapidsmpf_spill}")
209
+ if len(records) > 0:
210
+ print(f"iterations: {self.iterations}")
211
+ print("---------------------------------------")
212
+ print(f"min time : {min([record.duration for record in records]):0.4f}")
213
+ print(f"max time : {max(record.duration for record in records):0.4f}")
214
+ print(
215
+ f"mean time: {np.mean([record.duration for record in records]):0.4f}"
216
+ )
217
+ print("=======================================")
218
+
219
+
220
+ def get_data(
221
+ path: str | pathlib.Path, table_name: str, suffix: str = ""
222
+ ) -> pl.LazyFrame:
223
+ """Get table from dataset."""
224
+ return pl.scan_parquet(f"{path}/{table_name}{suffix}")
225
+
226
+
227
+ class PDSHQueries:
228
+ """PDS-H query definitions."""
229
+
230
+ @staticmethod
231
+ def q0(run_config: RunConfig) -> pl.LazyFrame:
232
+ """Query 0."""
233
+ return pl.LazyFrame()
234
+
235
+ @staticmethod
236
+ def q1(run_config: RunConfig) -> pl.LazyFrame:
237
+ """Query 1."""
238
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
239
+
240
+ var1 = date(1998, 9, 2)
241
+
242
+ return (
243
+ lineitem.filter(pl.col("l_shipdate") <= var1)
244
+ .group_by("l_returnflag", "l_linestatus")
245
+ .agg(
246
+ pl.sum("l_quantity").alias("sum_qty"),
247
+ pl.sum("l_extendedprice").alias("sum_base_price"),
248
+ (pl.col("l_extendedprice") * (1.0 - pl.col("l_discount")))
249
+ .sum()
250
+ .alias("sum_disc_price"),
251
+ (
252
+ pl.col("l_extendedprice")
253
+ * (1.0 - pl.col("l_discount"))
254
+ * (1.0 + pl.col("l_tax"))
255
+ )
256
+ .sum()
257
+ .alias("sum_charge"),
258
+ pl.mean("l_quantity").alias("avg_qty"),
259
+ pl.mean("l_extendedprice").alias("avg_price"),
260
+ pl.mean("l_discount").alias("avg_disc"),
261
+ pl.len().alias("count_order"),
262
+ )
263
+ .sort("l_returnflag", "l_linestatus")
264
+ )
265
+
266
+ @staticmethod
267
+ def q2(run_config: RunConfig) -> pl.LazyFrame:
268
+ """Query 2."""
269
+ nation = get_data(run_config.dataset_path, "nation", run_config.suffix)
270
+ part = get_data(run_config.dataset_path, "part", run_config.suffix)
271
+ partsupp = get_data(run_config.dataset_path, "partsupp", run_config.suffix)
272
+ region = get_data(run_config.dataset_path, "region", run_config.suffix)
273
+ supplier = get_data(run_config.dataset_path, "supplier", run_config.suffix)
274
+
275
+ var1 = 15
276
+ var2 = "BRASS"
277
+ var3 = "EUROPE"
278
+
279
+ q1 = (
280
+ part.join(partsupp, left_on="p_partkey", right_on="ps_partkey")
281
+ .join(supplier, left_on="ps_suppkey", right_on="s_suppkey")
282
+ .join(nation, left_on="s_nationkey", right_on="n_nationkey")
283
+ .join(region, left_on="n_regionkey", right_on="r_regionkey")
284
+ .filter(pl.col("p_size") == var1)
285
+ .filter(pl.col("p_type").str.ends_with(var2))
286
+ .filter(pl.col("r_name") == var3)
287
+ )
288
+
289
+ return (
290
+ q1.group_by("p_partkey")
291
+ .agg(pl.min("ps_supplycost"))
292
+ .join(q1, on=["p_partkey", "ps_supplycost"])
293
+ .select(
294
+ "s_acctbal",
295
+ "s_name",
296
+ "n_name",
297
+ "p_partkey",
298
+ "p_mfgr",
299
+ "s_address",
300
+ "s_phone",
301
+ "s_comment",
302
+ )
303
+ .sort(
304
+ by=["s_acctbal", "n_name", "s_name", "p_partkey"],
305
+ descending=[True, False, False, False],
306
+ )
307
+ .head(100)
308
+ )
309
+
310
+ @staticmethod
311
+ def q3(run_config: RunConfig) -> pl.LazyFrame:
312
+ """Query 3."""
313
+ customer = get_data(run_config.dataset_path, "customer", run_config.suffix)
314
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
315
+ orders = get_data(run_config.dataset_path, "orders", run_config.suffix)
316
+
317
+ var1 = "BUILDING"
318
+ var2 = date(1995, 3, 15)
319
+
320
+ return (
321
+ customer.filter(pl.col("c_mktsegment") == var1)
322
+ .join(orders, left_on="c_custkey", right_on="o_custkey")
323
+ .join(lineitem, left_on="o_orderkey", right_on="l_orderkey")
324
+ .filter(pl.col("o_orderdate") < var2)
325
+ .filter(pl.col("l_shipdate") > var2)
326
+ .with_columns(
327
+ (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias(
328
+ "revenue"
329
+ )
330
+ )
331
+ .group_by("o_orderkey", "o_orderdate", "o_shippriority")
332
+ .agg(pl.sum("revenue"))
333
+ .select(
334
+ pl.col("o_orderkey").alias("l_orderkey"),
335
+ "revenue",
336
+ "o_orderdate",
337
+ "o_shippriority",
338
+ )
339
+ .sort(by=["revenue", "o_orderdate"], descending=[True, False])
340
+ .head(10)
341
+ )
342
+
343
+ @staticmethod
344
+ def q4(run_config: RunConfig) -> pl.LazyFrame:
345
+ """Query 4."""
346
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
347
+ orders = get_data(run_config.dataset_path, "orders", run_config.suffix)
348
+
349
+ var1 = date(1993, 7, 1)
350
+ var2 = date(1993, 10, 1)
351
+
352
+ return (
353
+ # SQL exists translates to semi join in Polars API
354
+ orders.join(
355
+ (lineitem.filter(pl.col("l_commitdate") < pl.col("l_receiptdate"))),
356
+ left_on="o_orderkey",
357
+ right_on="l_orderkey",
358
+ how="semi",
359
+ )
360
+ .filter(pl.col("o_orderdate").is_between(var1, var2, closed="left"))
361
+ .group_by("o_orderpriority")
362
+ .agg(pl.len().alias("order_count"))
363
+ .sort("o_orderpriority")
364
+ )
365
+
366
+ @staticmethod
367
+ def q5(run_config: RunConfig) -> pl.LazyFrame:
368
+ """Query 5."""
369
+ path = run_config.dataset_path
370
+ suffix = run_config.suffix
371
+ customer = get_data(path, "customer", suffix)
372
+ lineitem = get_data(path, "lineitem", suffix)
373
+ nation = get_data(path, "nation", suffix)
374
+ orders = get_data(path, "orders", suffix)
375
+ region = get_data(path, "region", suffix)
376
+ supplier = get_data(path, "supplier", suffix)
377
+
378
+ var1 = "ASIA"
379
+ var2 = date(1994, 1, 1)
380
+ var3 = date(1995, 1, 1)
381
+
382
+ return (
383
+ region.join(nation, left_on="r_regionkey", right_on="n_regionkey")
384
+ .join(customer, left_on="n_nationkey", right_on="c_nationkey")
385
+ .join(orders, left_on="c_custkey", right_on="o_custkey")
386
+ .join(lineitem, left_on="o_orderkey", right_on="l_orderkey")
387
+ .join(
388
+ supplier,
389
+ left_on=["l_suppkey", "n_nationkey"],
390
+ right_on=["s_suppkey", "s_nationkey"],
391
+ )
392
+ .filter(pl.col("r_name") == var1)
393
+ .filter(pl.col("o_orderdate").is_between(var2, var3, closed="left"))
394
+ .with_columns(
395
+ (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias(
396
+ "revenue"
397
+ )
398
+ )
399
+ .group_by("n_name")
400
+ .agg(pl.sum("revenue"))
401
+ .sort(by="revenue", descending=True)
402
+ )
403
+
404
+ @staticmethod
405
+ def q6(run_config: RunConfig) -> pl.LazyFrame:
406
+ """Query 6."""
407
+ path = run_config.dataset_path
408
+ suffix = run_config.suffix
409
+ lineitem = get_data(path, "lineitem", suffix)
410
+
411
+ var1 = date(1994, 1, 1)
412
+ var2 = date(1995, 1, 1)
413
+ var3 = 0.05
414
+ var4 = 0.07
415
+ var5 = 24
416
+
417
+ return (
418
+ lineitem.filter(pl.col("l_shipdate").is_between(var1, var2, closed="left"))
419
+ .filter(pl.col("l_discount").is_between(var3, var4))
420
+ .filter(pl.col("l_quantity") < var5)
421
+ .with_columns(
422
+ (pl.col("l_extendedprice") * pl.col("l_discount")).alias("revenue")
423
+ )
424
+ .select(pl.sum("revenue"))
425
+ )
426
+
427
+ @staticmethod
428
+ def q7(run_config: RunConfig) -> pl.LazyFrame:
429
+ """Query 7."""
430
+ customer = get_data(run_config.dataset_path, "customer", run_config.suffix)
431
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
432
+ nation = get_data(run_config.dataset_path, "nation", run_config.suffix)
433
+ orders = get_data(run_config.dataset_path, "orders", run_config.suffix)
434
+ supplier = get_data(run_config.dataset_path, "supplier", run_config.suffix)
435
+
436
+ var1 = "FRANCE"
437
+ var2 = "GERMANY"
438
+ var3 = date(1995, 1, 1)
439
+ var4 = date(1996, 12, 31)
440
+
441
+ n1 = nation.filter(pl.col("n_name") == var1)
442
+ n2 = nation.filter(pl.col("n_name") == var2)
443
+
444
+ q1 = (
445
+ customer.join(n1, left_on="c_nationkey", right_on="n_nationkey")
446
+ .join(orders, left_on="c_custkey", right_on="o_custkey")
447
+ .rename({"n_name": "cust_nation"})
448
+ .join(lineitem, left_on="o_orderkey", right_on="l_orderkey")
449
+ .join(supplier, left_on="l_suppkey", right_on="s_suppkey")
450
+ .join(n2, left_on="s_nationkey", right_on="n_nationkey")
451
+ .rename({"n_name": "supp_nation"})
452
+ )
453
+
454
+ q2 = (
455
+ customer.join(n2, left_on="c_nationkey", right_on="n_nationkey")
456
+ .join(orders, left_on="c_custkey", right_on="o_custkey")
457
+ .rename({"n_name": "cust_nation"})
458
+ .join(lineitem, left_on="o_orderkey", right_on="l_orderkey")
459
+ .join(supplier, left_on="l_suppkey", right_on="s_suppkey")
460
+ .join(n1, left_on="s_nationkey", right_on="n_nationkey")
461
+ .rename({"n_name": "supp_nation"})
462
+ )
463
+
464
+ return (
465
+ pl.concat([q1, q2])
466
+ .filter(pl.col("l_shipdate").is_between(var3, var4))
467
+ .with_columns(
468
+ (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias(
469
+ "volume"
470
+ ),
471
+ pl.col("l_shipdate").dt.year().alias("l_year"),
472
+ )
473
+ .group_by("supp_nation", "cust_nation", "l_year")
474
+ .agg(pl.sum("volume").alias("revenue"))
475
+ .sort(by=["supp_nation", "cust_nation", "l_year"])
476
+ )
477
+
478
+ @staticmethod
479
+ def q8(run_config: RunConfig) -> pl.LazyFrame:
480
+ """Query 8."""
481
+ customer = get_data(run_config.dataset_path, "customer", run_config.suffix)
482
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
483
+ nation = get_data(run_config.dataset_path, "nation", run_config.suffix)
484
+ orders = get_data(run_config.dataset_path, "orders", run_config.suffix)
485
+ part = get_data(run_config.dataset_path, "part", run_config.suffix)
486
+ region = get_data(run_config.dataset_path, "region", run_config.suffix)
487
+ supplier = get_data(run_config.dataset_path, "supplier", run_config.suffix)
488
+
489
+ var1 = "BRAZIL"
490
+ var2 = "AMERICA"
491
+ var3 = "ECONOMY ANODIZED STEEL"
492
+ var4 = date(1995, 1, 1)
493
+ var5 = date(1996, 12, 31)
494
+
495
+ n1 = nation.select("n_nationkey", "n_regionkey")
496
+ n2 = nation.select("n_nationkey", "n_name")
497
+
498
+ return (
499
+ part.join(lineitem, left_on="p_partkey", right_on="l_partkey")
500
+ .join(supplier, left_on="l_suppkey", right_on="s_suppkey")
501
+ .join(orders, left_on="l_orderkey", right_on="o_orderkey")
502
+ .join(customer, left_on="o_custkey", right_on="c_custkey")
503
+ .join(n1, left_on="c_nationkey", right_on="n_nationkey")
504
+ .join(region, left_on="n_regionkey", right_on="r_regionkey")
505
+ .filter(pl.col("r_name") == var2)
506
+ .join(n2, left_on="s_nationkey", right_on="n_nationkey")
507
+ .filter(pl.col("o_orderdate").is_between(var4, var5))
508
+ .filter(pl.col("p_type") == var3)
509
+ .select(
510
+ pl.col("o_orderdate").dt.year().alias("o_year"),
511
+ (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias(
512
+ "volume"
513
+ ),
514
+ pl.col("n_name").alias("nation"),
515
+ )
516
+ .with_columns(
517
+ pl.when(pl.col("nation") == var1)
518
+ .then(pl.col("volume"))
519
+ .otherwise(0)
520
+ .alias("_tmp")
521
+ )
522
+ .group_by("o_year")
523
+ .agg((pl.sum("_tmp") / pl.sum("volume")).round(2).alias("mkt_share"))
524
+ .sort("o_year")
525
+ )
526
+
527
+ @staticmethod
528
+ def q9(run_config: RunConfig) -> pl.LazyFrame:
529
+ """Query 9."""
530
+ path = run_config.dataset_path
531
+ suffix = run_config.suffix
532
+ lineitem = get_data(path, "lineitem", suffix)
533
+ nation = get_data(path, "nation", suffix)
534
+ orders = get_data(path, "orders", suffix)
535
+ part = get_data(path, "part", suffix)
536
+ partsupp = get_data(path, "partsupp", suffix)
537
+ supplier = get_data(path, "supplier", suffix)
538
+
539
+ return (
540
+ part.join(partsupp, left_on="p_partkey", right_on="ps_partkey")
541
+ .join(supplier, left_on="ps_suppkey", right_on="s_suppkey")
542
+ .join(
543
+ lineitem,
544
+ left_on=["p_partkey", "ps_suppkey"],
545
+ right_on=["l_partkey", "l_suppkey"],
546
+ )
547
+ .join(orders, left_on="l_orderkey", right_on="o_orderkey")
548
+ .join(nation, left_on="s_nationkey", right_on="n_nationkey")
549
+ .filter(pl.col("p_name").str.contains("green"))
550
+ .select(
551
+ pl.col("n_name").alias("nation"),
552
+ pl.col("o_orderdate").dt.year().alias("o_year"),
553
+ (
554
+ pl.col("l_extendedprice") * (1 - pl.col("l_discount"))
555
+ - pl.col("ps_supplycost") * pl.col("l_quantity")
556
+ ).alias("amount"),
557
+ )
558
+ .group_by("nation", "o_year")
559
+ .agg(pl.sum("amount").round(2).alias("sum_profit"))
560
+ .sort(by=["nation", "o_year"], descending=[False, True])
561
+ )
562
+
563
+ @staticmethod
564
+ def q10(run_config: RunConfig) -> pl.LazyFrame:
565
+ """Query 10."""
566
+ path = run_config.dataset_path
567
+ suffix = run_config.suffix
568
+ customer = get_data(path, "customer", suffix)
569
+ lineitem = get_data(path, "lineitem", suffix)
570
+ nation = get_data(path, "nation", suffix)
571
+ orders = get_data(path, "orders", suffix)
572
+
573
+ var1 = date(1993, 10, 1)
574
+ var2 = date(1994, 1, 1)
575
+
576
+ return (
577
+ customer.join(orders, left_on="c_custkey", right_on="o_custkey")
578
+ .join(lineitem, left_on="o_orderkey", right_on="l_orderkey")
579
+ .join(nation, left_on="c_nationkey", right_on="n_nationkey")
580
+ .filter(pl.col("o_orderdate").is_between(var1, var2, closed="left"))
581
+ .filter(pl.col("l_returnflag") == "R")
582
+ .group_by(
583
+ "c_custkey",
584
+ "c_name",
585
+ "c_acctbal",
586
+ "c_phone",
587
+ "n_name",
588
+ "c_address",
589
+ "c_comment",
590
+ )
591
+ .agg(
592
+ (pl.col("l_extendedprice") * (1 - pl.col("l_discount")))
593
+ .sum()
594
+ .round(2)
595
+ .alias("revenue")
596
+ )
597
+ .select(
598
+ "c_custkey",
599
+ "c_name",
600
+ "revenue",
601
+ "c_acctbal",
602
+ "n_name",
603
+ "c_address",
604
+ "c_phone",
605
+ "c_comment",
606
+ )
607
+ .sort(by="revenue", descending=True)
608
+ .head(20)
609
+ )
610
+
611
+ @staticmethod
612
+ def q11(run_config: RunConfig) -> pl.LazyFrame:
613
+ """Query 11."""
614
+ nation = get_data(run_config.dataset_path, "nation", run_config.suffix)
615
+ partsupp = get_data(run_config.dataset_path, "partsupp", run_config.suffix)
616
+ supplier = get_data(run_config.dataset_path, "supplier", run_config.suffix)
617
+
618
+ var1 = "GERMANY"
619
+ var2 = 0.0001
620
+
621
+ q1 = (
622
+ partsupp.join(supplier, left_on="ps_suppkey", right_on="s_suppkey")
623
+ .join(nation, left_on="s_nationkey", right_on="n_nationkey")
624
+ .filter(pl.col("n_name") == var1)
625
+ )
626
+ q2 = q1.select(
627
+ (pl.col("ps_supplycost") * pl.col("ps_availqty"))
628
+ .sum()
629
+ .round(2)
630
+ .alias("tmp")
631
+ * var2
632
+ )
633
+
634
+ return (
635
+ q1.group_by("ps_partkey")
636
+ .agg(
637
+ (pl.col("ps_supplycost") * pl.col("ps_availqty"))
638
+ .sum()
639
+ .round(2)
640
+ .alias("value")
641
+ )
642
+ .join(q2, how="cross")
643
+ .filter(pl.col("value") > pl.col("tmp"))
644
+ .select("ps_partkey", "value")
645
+ .sort("value", descending=True)
646
+ )
647
+
648
+ @staticmethod
649
+ def q12(run_config: RunConfig) -> pl.LazyFrame:
650
+ """Query 12."""
651
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
652
+ orders = get_data(run_config.dataset_path, "orders", run_config.suffix)
653
+
654
+ var1 = "MAIL"
655
+ var2 = "SHIP"
656
+ var3 = date(1994, 1, 1)
657
+ var4 = date(1995, 1, 1)
658
+
659
+ return (
660
+ orders.join(lineitem, left_on="o_orderkey", right_on="l_orderkey")
661
+ .filter(pl.col("l_shipmode").is_in([var1, var2]))
662
+ .filter(pl.col("l_commitdate") < pl.col("l_receiptdate"))
663
+ .filter(pl.col("l_shipdate") < pl.col("l_commitdate"))
664
+ .filter(pl.col("l_receiptdate").is_between(var3, var4, closed="left"))
665
+ .with_columns(
666
+ pl.when(pl.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"]))
667
+ .then(1)
668
+ .otherwise(0)
669
+ .alias("high_line_count"),
670
+ pl.when(pl.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"]).not_())
671
+ .then(1)
672
+ .otherwise(0)
673
+ .alias("low_line_count"),
674
+ )
675
+ .group_by("l_shipmode")
676
+ .agg(pl.col("high_line_count").sum(), pl.col("low_line_count").sum())
677
+ .sort("l_shipmode")
678
+ )
679
+
680
+ @staticmethod
681
+ def q13(run_config: RunConfig) -> pl.LazyFrame:
682
+ """Query 13."""
683
+ customer = get_data(run_config.dataset_path, "customer", run_config.suffix)
684
+ orders = get_data(run_config.dataset_path, "orders", run_config.suffix)
685
+
686
+ var1 = "special"
687
+ var2 = "requests"
688
+
689
+ orders = orders.filter(
690
+ pl.col("o_comment").str.contains(f"{var1}.*{var2}").not_()
691
+ )
692
+ return (
693
+ customer.join(orders, left_on="c_custkey", right_on="o_custkey", how="left")
694
+ .group_by("c_custkey")
695
+ .agg(pl.col("o_orderkey").count().alias("c_count"))
696
+ .group_by("c_count")
697
+ .len()
698
+ .select(pl.col("c_count"), pl.col("len").alias("custdist"))
699
+ .sort(by=["custdist", "c_count"], descending=[True, True])
700
+ )
701
+
702
+ @staticmethod
703
+ def q14(run_config: RunConfig) -> pl.LazyFrame:
704
+ """Query 14."""
705
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
706
+ part = get_data(run_config.dataset_path, "part", run_config.suffix)
707
+
708
+ var1 = date(1995, 9, 1)
709
+ var2 = date(1995, 10, 1)
710
+
711
+ return (
712
+ lineitem.join(part, left_on="l_partkey", right_on="p_partkey")
713
+ .filter(pl.col("l_shipdate").is_between(var1, var2, closed="left"))
714
+ .select(
715
+ (
716
+ 100.00
717
+ * pl.when(pl.col("p_type").str.contains("PROMO*"))
718
+ .then(pl.col("l_extendedprice") * (1 - pl.col("l_discount")))
719
+ .otherwise(0)
720
+ .sum()
721
+ / (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).sum()
722
+ )
723
+ .round(2)
724
+ .alias("promo_revenue")
725
+ )
726
+ )
727
+
728
+ @staticmethod
729
+ def q15(run_config: RunConfig) -> pl.LazyFrame:
730
+ """Query 15."""
731
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
732
+ supplier = get_data(run_config.dataset_path, "supplier", run_config.suffix)
733
+
734
+ var1 = date(1996, 1, 1)
735
+ var2 = date(1996, 4, 1)
736
+
737
+ revenue = (
738
+ lineitem.filter(pl.col("l_shipdate").is_between(var1, var2, closed="left"))
739
+ .group_by("l_suppkey")
740
+ .agg(
741
+ (pl.col("l_extendedprice") * (1 - pl.col("l_discount")))
742
+ .sum()
743
+ .alias("total_revenue")
744
+ )
745
+ .select(pl.col("l_suppkey").alias("supplier_no"), pl.col("total_revenue"))
746
+ )
747
+
748
+ return (
749
+ supplier.join(revenue, left_on="s_suppkey", right_on="supplier_no")
750
+ .filter(pl.col("total_revenue") == pl.col("total_revenue").max())
751
+ .with_columns(pl.col("total_revenue").round(2))
752
+ .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue")
753
+ .sort("s_suppkey")
754
+ )
755
+
756
+ @staticmethod
757
+ def q16(run_config: RunConfig) -> pl.LazyFrame:
758
+ """Query 16."""
759
+ part = get_data(run_config.dataset_path, "part", run_config.suffix)
760
+ partsupp = get_data(run_config.dataset_path, "partsupp", run_config.suffix)
761
+ supplier = get_data(run_config.dataset_path, "supplier", run_config.suffix)
762
+
763
+ var1 = "Brand#45"
764
+
765
+ supplier = supplier.filter(
766
+ pl.col("s_comment").str.contains(".*Customer.*Complaints.*")
767
+ ).select(pl.col("s_suppkey"), pl.col("s_suppkey").alias("ps_suppkey"))
768
+
769
+ return (
770
+ part.join(partsupp, left_on="p_partkey", right_on="ps_partkey")
771
+ .filter(pl.col("p_brand") != var1)
772
+ .filter(pl.col("p_type").str.contains("MEDIUM POLISHED*").not_())
773
+ .filter(pl.col("p_size").is_in([49, 14, 23, 45, 19, 3, 36, 9]))
774
+ .join(supplier, left_on="ps_suppkey", right_on="s_suppkey", how="left")
775
+ .filter(pl.col("ps_suppkey_right").is_null())
776
+ .group_by("p_brand", "p_type", "p_size")
777
+ .agg(pl.col("ps_suppkey").n_unique().alias("supplier_cnt"))
778
+ .sort(
779
+ by=["supplier_cnt", "p_brand", "p_type", "p_size"],
780
+ descending=[True, False, False, False],
781
+ )
782
+ )
783
+
784
+ @staticmethod
785
+ def q17(run_config: RunConfig) -> pl.LazyFrame:
786
+ """Query 17."""
787
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
788
+ part = get_data(run_config.dataset_path, "part", run_config.suffix)
789
+
790
+ var1 = "Brand#23"
791
+ var2 = "MED BOX"
792
+
793
+ q1 = (
794
+ part.filter(pl.col("p_brand") == var1)
795
+ .filter(pl.col("p_container") == var2)
796
+ .join(lineitem, how="left", left_on="p_partkey", right_on="l_partkey")
797
+ )
798
+
799
+ return (
800
+ q1.group_by("p_partkey")
801
+ .agg((0.2 * pl.col("l_quantity").mean()).alias("avg_quantity"))
802
+ .select(pl.col("p_partkey").alias("key"), pl.col("avg_quantity"))
803
+ .join(q1, left_on="key", right_on="p_partkey")
804
+ .filter(pl.col("l_quantity") < pl.col("avg_quantity"))
805
+ .select(
806
+ (pl.col("l_extendedprice").sum() / 7.0).round(2).alias("avg_yearly")
807
+ )
808
+ )
809
+
810
+ @staticmethod
811
+ def q18(run_config: RunConfig) -> pl.LazyFrame:
812
+ """Query 18."""
813
+ path = run_config.dataset_path
814
+ suffix = run_config.suffix
815
+ customer = get_data(path, "customer", suffix)
816
+ lineitem = get_data(path, "lineitem", suffix)
817
+ orders = get_data(path, "orders", suffix)
818
+
819
+ var1 = 300
820
+
821
+ q1 = (
822
+ lineitem.group_by("l_orderkey")
823
+ .agg(pl.col("l_quantity").sum().alias("sum_quantity"))
824
+ .filter(pl.col("sum_quantity") > var1)
825
+ )
826
+
827
+ return (
828
+ orders.join(q1, left_on="o_orderkey", right_on="l_orderkey", how="semi")
829
+ .join(lineitem, left_on="o_orderkey", right_on="l_orderkey")
830
+ .join(customer, left_on="o_custkey", right_on="c_custkey")
831
+ .group_by(
832
+ "c_name", "o_custkey", "o_orderkey", "o_orderdate", "o_totalprice"
833
+ )
834
+ .agg(pl.col("l_quantity").sum().alias("col6"))
835
+ .select(
836
+ pl.col("c_name"),
837
+ pl.col("o_custkey").alias("c_custkey"),
838
+ pl.col("o_orderkey"),
839
+ pl.col("o_orderdate").alias("o_orderdat"),
840
+ pl.col("o_totalprice"),
841
+ pl.col("col6"),
842
+ )
843
+ .sort(by=["o_totalprice", "o_orderdat"], descending=[True, False])
844
+ .head(100)
845
+ )
846
+
847
+ @staticmethod
848
+ def q19(run_config: RunConfig) -> pl.LazyFrame:
849
+ """Query 19."""
850
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
851
+ part = get_data(run_config.dataset_path, "part", run_config.suffix)
852
+
853
+ return (
854
+ part.join(lineitem, left_on="p_partkey", right_on="l_partkey")
855
+ .filter(pl.col("l_shipmode").is_in(["AIR", "AIR REG"]))
856
+ .filter(pl.col("l_shipinstruct") == "DELIVER IN PERSON")
857
+ .filter(
858
+ (
859
+ (pl.col("p_brand") == "Brand#12")
860
+ & pl.col("p_container").is_in(
861
+ ["SM CASE", "SM BOX", "SM PACK", "SM PKG"]
862
+ )
863
+ & (pl.col("l_quantity").is_between(1, 11))
864
+ & (pl.col("p_size").is_between(1, 5))
865
+ )
866
+ | (
867
+ (pl.col("p_brand") == "Brand#23")
868
+ & pl.col("p_container").is_in(
869
+ ["MED BAG", "MED BOX", "MED PKG", "MED PACK"]
870
+ )
871
+ & (pl.col("l_quantity").is_between(10, 20))
872
+ & (pl.col("p_size").is_between(1, 10))
873
+ )
874
+ | (
875
+ (pl.col("p_brand") == "Brand#34")
876
+ & pl.col("p_container").is_in(
877
+ ["LG CASE", "LG BOX", "LG PACK", "LG PKG"]
878
+ )
879
+ & (pl.col("l_quantity").is_between(20, 30))
880
+ & (pl.col("p_size").is_between(1, 15))
881
+ )
882
+ )
883
+ .select(
884
+ (pl.col("l_extendedprice") * (1 - pl.col("l_discount")))
885
+ .sum()
886
+ .round(2)
887
+ .alias("revenue")
888
+ )
889
+ )
890
+
891
+ @staticmethod
892
+ def q20(run_config: RunConfig) -> pl.LazyFrame:
893
+ """Query 20."""
894
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
895
+ nation = get_data(run_config.dataset_path, "nation", run_config.suffix)
896
+ part = get_data(run_config.dataset_path, "part", run_config.suffix)
897
+ partsupp = get_data(run_config.dataset_path, "partsupp", run_config.suffix)
898
+ supplier = get_data(run_config.dataset_path, "supplier", run_config.suffix)
899
+
900
+ var1 = date(1994, 1, 1)
901
+ var2 = date(1995, 1, 1)
902
+ var3 = "CANADA"
903
+ var4 = "forest"
904
+
905
+ q1 = (
906
+ lineitem.filter(pl.col("l_shipdate").is_between(var1, var2, closed="left"))
907
+ .group_by("l_partkey", "l_suppkey")
908
+ .agg((pl.col("l_quantity").sum() * 0.5).alias("sum_quantity"))
909
+ )
910
+ q2 = nation.filter(pl.col("n_name") == var3)
911
+ q3 = supplier.join(q2, left_on="s_nationkey", right_on="n_nationkey")
912
+
913
+ return (
914
+ part.filter(pl.col("p_name").str.starts_with(var4))
915
+ .select(pl.col("p_partkey").unique())
916
+ .join(partsupp, left_on="p_partkey", right_on="ps_partkey")
917
+ .join(
918
+ q1,
919
+ left_on=["ps_suppkey", "p_partkey"],
920
+ right_on=["l_suppkey", "l_partkey"],
921
+ )
922
+ .filter(pl.col("ps_availqty") > pl.col("sum_quantity"))
923
+ .select(pl.col("ps_suppkey").unique())
924
+ .join(q3, left_on="ps_suppkey", right_on="s_suppkey")
925
+ .select("s_name", "s_address")
926
+ .sort("s_name")
927
+ )
928
+
929
+ @staticmethod
930
+ def q21(run_config: RunConfig) -> pl.LazyFrame:
931
+ """Query 21."""
932
+ lineitem = get_data(run_config.dataset_path, "lineitem", run_config.suffix)
933
+ nation = get_data(run_config.dataset_path, "nation", run_config.suffix)
934
+ orders = get_data(run_config.dataset_path, "orders", run_config.suffix)
935
+ supplier = get_data(run_config.dataset_path, "supplier", run_config.suffix)
936
+
937
+ var1 = "SAUDI ARABIA"
938
+
939
+ q1 = (
940
+ lineitem.group_by("l_orderkey")
941
+ .agg(pl.col("l_suppkey").len().alias("n_supp_by_order"))
942
+ .filter(pl.col("n_supp_by_order") > 1)
943
+ .join(
944
+ lineitem.filter(pl.col("l_receiptdate") > pl.col("l_commitdate")),
945
+ on="l_orderkey",
946
+ )
947
+ )
948
+
949
+ return (
950
+ q1.group_by("l_orderkey")
951
+ .agg(pl.col("l_suppkey").len().alias("n_supp_by_order"))
952
+ .join(q1, on="l_orderkey")
953
+ .join(supplier, left_on="l_suppkey", right_on="s_suppkey")
954
+ .join(nation, left_on="s_nationkey", right_on="n_nationkey")
955
+ .join(orders, left_on="l_orderkey", right_on="o_orderkey")
956
+ .filter(pl.col("n_supp_by_order") == 1)
957
+ .filter(pl.col("n_name") == var1)
958
+ .filter(pl.col("o_orderstatus") == "F")
959
+ .group_by("s_name")
960
+ .agg(pl.len().alias("numwait"))
961
+ .sort(by=["numwait", "s_name"], descending=[True, False])
962
+ .head(100)
963
+ )
964
+
965
+ @staticmethod
966
+ def q22(run_config: RunConfig) -> pl.LazyFrame:
967
+ """Query 22."""
968
+ customer = get_data(run_config.dataset_path, "customer", run_config.suffix)
969
+ orders = get_data(run_config.dataset_path, "orders", run_config.suffix)
970
+
971
+ q1 = (
972
+ customer.with_columns(pl.col("c_phone").str.slice(0, 2).alias("cntrycode"))
973
+ .filter(pl.col("cntrycode").str.contains("13|31|23|29|30|18|17"))
974
+ .select("c_acctbal", "c_custkey", "cntrycode")
975
+ )
976
+
977
+ q2 = q1.filter(pl.col("c_acctbal") > 0.0).select(
978
+ pl.col("c_acctbal").mean().alias("avg_acctbal")
979
+ )
980
+
981
+ q3 = orders.select(pl.col("o_custkey").unique()).with_columns(
982
+ pl.col("o_custkey").alias("c_custkey")
983
+ )
984
+
985
+ return (
986
+ q1.join(q3, on="c_custkey", how="left")
987
+ .filter(pl.col("o_custkey").is_null())
988
+ .join(q2, how="cross")
989
+ .filter(pl.col("c_acctbal") > pl.col("avg_acctbal"))
990
+ .group_by("cntrycode")
991
+ .agg(
992
+ pl.col("c_acctbal").count().alias("numcust"),
993
+ pl.col("c_acctbal").sum().round(2).alias("totacctbal"),
994
+ )
995
+ .sort("cntrycode")
996
+ )
997
+
998
+
999
+ def _query_type(query: int | str) -> list[int]:
1000
+ if isinstance(query, int):
1001
+ return [query]
1002
+ elif query == "all":
1003
+ return list(range(1, 23))
1004
+ else:
1005
+ return [int(q) for q in query.split(",")]
1006
+
1007
+
1008
+ parser = argparse.ArgumentParser(
1009
+ prog="Cudf-Polars PDS-H Benchmarks",
1010
+ description="Experimental streaming-executor benchmarks.",
1011
+ )
1012
+ parser.add_argument(
1013
+ "query",
1014
+ type=_query_type,
1015
+ help="Query number.",
1016
+ )
1017
+ parser.add_argument(
1018
+ "--path",
1019
+ type=str,
1020
+ default=os.environ.get("PDSH_DATASET_PATH"),
1021
+ help="Root PDS-H dataset directory path.",
1022
+ )
1023
+ parser.add_argument(
1024
+ "--suffix",
1025
+ type=str,
1026
+ default=".parquet",
1027
+ help="Table file suffix.",
1028
+ )
1029
+ parser.add_argument(
1030
+ "-e",
1031
+ "--executor",
1032
+ default="streaming",
1033
+ type=str,
1034
+ choices=["in-memory", "streaming", "cpu"],
1035
+ help="Executor.",
1036
+ )
1037
+ parser.add_argument(
1038
+ "-s",
1039
+ "--scheduler",
1040
+ default="synchronous",
1041
+ type=str,
1042
+ choices=["synchronous", "distributed"],
1043
+ help="Scheduler to use with the 'streaming' executor.",
1044
+ )
1045
+ parser.add_argument(
1046
+ "--n-workers",
1047
+ default=1,
1048
+ type=int,
1049
+ help="Number of Dask-CUDA workers (requires 'distributed' scheduler).",
1050
+ )
1051
+ parser.add_argument(
1052
+ "--blocksize",
1053
+ default=None,
1054
+ type=int,
1055
+ help="Approx. partition size.",
1056
+ )
1057
+ parser.add_argument(
1058
+ "--iterations",
1059
+ default=1,
1060
+ type=int,
1061
+ help="Number of times to run the same query.",
1062
+ )
1063
+ parser.add_argument(
1064
+ "--debug",
1065
+ default=False,
1066
+ action="store_true",
1067
+ help="Debug run.",
1068
+ )
1069
+ parser.add_argument(
1070
+ "--shuffle",
1071
+ default=None,
1072
+ type=str,
1073
+ choices=[None, "rapidsmpf", "tasks"],
1074
+ help="Shuffle method to use for distributed execution.",
1075
+ )
1076
+ parser.add_argument(
1077
+ "--broadcast-join-limit",
1078
+ default=None,
1079
+ type=int,
1080
+ help="Set an explicit `broadcast_join_limit` option.",
1081
+ )
1082
+ parser.add_argument(
1083
+ "--threads",
1084
+ default=1,
1085
+ type=int,
1086
+ help="Number of threads to use on each GPU.",
1087
+ )
1088
+ parser.add_argument(
1089
+ "--rmm-pool-size",
1090
+ default=0.5,
1091
+ type=float,
1092
+ help="RMM pool size (fractional).",
1093
+ )
1094
+ parser.add_argument(
1095
+ "--rmm-async",
1096
+ action=argparse.BooleanOptionalAction,
1097
+ default=False,
1098
+ help="Use RMM async memory resource.",
1099
+ )
1100
+ parser.add_argument(
1101
+ "--rapidsmpf-spill",
1102
+ action=argparse.BooleanOptionalAction,
1103
+ default=False,
1104
+ help="Use rapidsmpf for general spilling.",
1105
+ )
1106
+ parser.add_argument(
1107
+ "--spill-device",
1108
+ default=0.5,
1109
+ type=float,
1110
+ help="Rapdsimpf device spill threshold.",
1111
+ )
1112
+ parser.add_argument(
1113
+ "-o",
1114
+ "--output",
1115
+ type=argparse.FileType("at"),
1116
+ default="pdsh_results.jsonl",
1117
+ help="Output file path.",
1118
+ )
1119
+ parser.add_argument(
1120
+ "--summarize",
1121
+ action=argparse.BooleanOptionalAction,
1122
+ help="Summarize the results.",
1123
+ default=True,
1124
+ )
1125
+ parser.add_argument(
1126
+ "--print-results",
1127
+ action=argparse.BooleanOptionalAction,
1128
+ help="Print the query results",
1129
+ default=True,
1130
+ )
1131
+ parser.add_argument(
1132
+ "--explain",
1133
+ action=argparse.BooleanOptionalAction,
1134
+ help="Print an outline of the physical plan",
1135
+ default=False,
1136
+ )
1137
+ parser.add_argument(
1138
+ "--explain-logical",
1139
+ action=argparse.BooleanOptionalAction,
1140
+ help="Print an outline of the logical plan",
1141
+ default=False,
1142
+ )
1143
+ args = parser.parse_args()
1144
+
1145
+
1146
+ def run(args: argparse.Namespace) -> None:
1147
+ """Run the benchmark."""
1148
+ client = None
1149
+ run_config = RunConfig.from_args(args)
1150
+
1151
+ if run_config.scheduler == "distributed":
1152
+ from dask_cuda import LocalCUDACluster
1153
+ from distributed import Client
1154
+
1155
+ kwargs = {
1156
+ "n_workers": run_config.n_workers,
1157
+ "dashboard_address": ":8585",
1158
+ "protocol": "ucxx",
1159
+ "rmm_pool_size": args.rmm_pool_size,
1160
+ "rmm_async": args.rmm_async,
1161
+ "threads_per_worker": run_config.threads,
1162
+ }
1163
+
1164
+ # Avoid UVM in distributed cluster
1165
+ client = Client(LocalCUDACluster(**kwargs))
1166
+ client.wait_for_workers(run_config.n_workers)
1167
+ if run_config.shuffle != "tasks":
1168
+ try:
1169
+ from rapidsmpf.integrations.dask import bootstrap_dask_cluster
1170
+
1171
+ bootstrap_dask_cluster(client, spill_device=run_config.spill_device)
1172
+ except ImportError as err:
1173
+ if run_config.shuffle == "rapidsmpf":
1174
+ raise ImportError from err
1175
+
1176
+ records: defaultdict[int, list[Record]] = defaultdict(list)
1177
+ engine: pl.GPUEngine | None = None
1178
+
1179
+ if run_config.executor == "cpu":
1180
+ engine = None
1181
+ else:
1182
+ executor_options: dict[str, Any] = {}
1183
+ if run_config.executor == "streaming":
1184
+ executor_options = {
1185
+ "cardinality_factor": {
1186
+ "c_custkey": 0.05, # Q10
1187
+ "l_orderkey": 1.0, # Q18
1188
+ "l_partkey": 0.1, # Q20
1189
+ "o_custkey": 0.25, # Q22
1190
+ },
1191
+ }
1192
+ if run_config.blocksize:
1193
+ executor_options["target_partition_size"] = run_config.blocksize
1194
+ if run_config.shuffle:
1195
+ executor_options["shuffle_method"] = run_config.shuffle
1196
+ if run_config.broadcast_join_limit:
1197
+ executor_options["broadcast_join_limit"] = (
1198
+ run_config.broadcast_join_limit
1199
+ )
1200
+ if run_config.rapidsmpf_spill:
1201
+ executor_options["rapidsmpf_spill"] = run_config.rapidsmpf_spill
1202
+ if run_config.scheduler == "distributed":
1203
+ executor_options["scheduler"] = "distributed"
1204
+
1205
+ engine = pl.GPUEngine(
1206
+ raise_on_fail=True,
1207
+ executor=run_config.executor,
1208
+ executor_options=executor_options,
1209
+ )
1210
+
1211
+ for q_id in run_config.queries:
1212
+ try:
1213
+ q = getattr(PDSHQueries, f"q{q_id}")(run_config)
1214
+ except AttributeError as err:
1215
+ raise NotImplementedError(f"Query {q_id} not implemented.") from err
1216
+
1217
+ if run_config.executor == "cpu":
1218
+ if args.explain_logical:
1219
+ print(f"\nQuery {q_id} - Logical plan\n")
1220
+ print(q.explain())
1221
+ elif CUDF_POLARS_AVAILABLE:
1222
+ assert isinstance(engine, pl.GPUEngine)
1223
+ if args.explain_logical:
1224
+ print(f"\nQuery {q_id} - Logical plan\n")
1225
+ print(explain_query(q, engine, physical=False))
1226
+ elif args.explain:
1227
+ print(f"\nQuery {q_id} - Physical plan\n")
1228
+ print(explain_query(q, engine))
1229
+ else:
1230
+ raise RuntimeError(
1231
+ "Cannot provide the logical or physical plan because cudf_polars is not installed."
1232
+ )
1233
+
1234
+ records[q_id] = []
1235
+
1236
+ for _ in range(args.iterations):
1237
+ t0 = time.monotonic()
1238
+
1239
+ if run_config.executor == "cpu":
1240
+ result = q.collect(new_streaming=True)
1241
+ elif CUDF_POLARS_AVAILABLE:
1242
+ assert isinstance(engine, pl.GPUEngine)
1243
+ if args.debug:
1244
+ translator = Translator(q._ldf.visit(), engine)
1245
+ ir = translator.translate_ir()
1246
+ if run_config.executor == "in-memory":
1247
+ result = ir.evaluate(cache={}, timer=None).to_polars()
1248
+ elif run_config.executor == "streaming":
1249
+ result = evaluate_streaming(
1250
+ ir, translator.config_options
1251
+ ).to_polars()
1252
+ else:
1253
+ result = q.collect(engine=engine)
1254
+ else:
1255
+ raise RuntimeError(
1256
+ "Cannot provide debug information because cudf_polars is not installed."
1257
+ )
1258
+
1259
+ t1 = time.monotonic()
1260
+ record = Record(query=q_id, duration=t1 - t0)
1261
+ if args.print_results:
1262
+ print(result)
1263
+ print(f"Ran query={q_id} in {record.duration:0.4f}s", flush=True)
1264
+ records[q_id].append(record)
1265
+
1266
+ run_config = dataclasses.replace(run_config, records=dict(records))
1267
+
1268
+ if args.summarize:
1269
+ run_config.summarize()
1270
+
1271
+ if client is not None:
1272
+ client.close(timeout=60)
1273
+
1274
+ args.output.write(json.dumps(run_config.serialize()))
1275
+ args.output.write("\n")
1276
+
1277
+
1278
+ if __name__ == "__main__":
1279
+ run(args)