cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +60 -15
- cudf_polars/containers/column.py +137 -77
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +256 -114
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +33 -3
- cudf_polars/dsl/expressions/unary.py +126 -64
- cudf_polars/dsl/ir.py +1053 -350
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +307 -107
- cudf_polars/dsl/utils/aggregations.py +43 -30
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +55 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +792 -2
- cudf_polars/experimental/benchmarks/utils.py +596 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/distinct.py +2 -0
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +30 -15
- cudf_polars/experimental/groupby.py +25 -4
- cudf_polars/experimental/io.py +156 -124
- cudf_polars/experimental/join.py +53 -23
- cudf_polars/experimental/parallel.py +68 -19
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
- cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
- cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
- cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
- cudf_polars/experimental/rapidsmpf/core.py +488 -0
- cudf_polars/experimental/rapidsmpf/dask.py +172 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
- cudf_polars/experimental/rapidsmpf/io.py +696 -0
- cudf_polars/experimental/rapidsmpf/join.py +322 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
- cudf_polars/experimental/rapidsmpf/union.py +115 -0
- cudf_polars/experimental/rapidsmpf/utils.py +374 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +46 -12
- cudf_polars/experimental/sort.py +100 -26
- cudf_polars/experimental/spilling.py +1 -1
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +93 -17
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +473 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +5 -4
- cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
- cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
|
@@ -8,27 +8,43 @@ from __future__ import annotations
|
|
|
8
8
|
import argparse
|
|
9
9
|
import dataclasses
|
|
10
10
|
import importlib
|
|
11
|
+
import io
|
|
12
|
+
import itertools
|
|
11
13
|
import json
|
|
14
|
+
import logging
|
|
12
15
|
import os
|
|
13
16
|
import statistics
|
|
14
17
|
import sys
|
|
15
18
|
import textwrap
|
|
16
19
|
import time
|
|
17
20
|
import traceback
|
|
21
|
+
import warnings
|
|
18
22
|
from collections import defaultdict
|
|
19
23
|
from datetime import datetime, timezone
|
|
24
|
+
from pathlib import Path
|
|
20
25
|
from typing import TYPE_CHECKING, Any, Literal, assert_never
|
|
21
26
|
|
|
22
27
|
import nvtx
|
|
23
28
|
|
|
24
29
|
import polars as pl
|
|
25
30
|
|
|
31
|
+
import rmm.statistics
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
import duckdb
|
|
35
|
+
|
|
36
|
+
duckdb_err = None
|
|
37
|
+
except ImportError as e:
|
|
38
|
+
duckdb = None
|
|
39
|
+
duckdb_err = e
|
|
40
|
+
|
|
26
41
|
try:
|
|
27
42
|
import pynvml
|
|
28
43
|
except ImportError:
|
|
29
44
|
pynvml = None
|
|
30
45
|
|
|
31
46
|
try:
|
|
47
|
+
from cudf_polars.dsl.ir import IRExecutionContext
|
|
32
48
|
from cudf_polars.dsl.translate import Translator
|
|
33
49
|
from cudf_polars.experimental.explain import explain_query
|
|
34
50
|
from cudf_polars.experimental.parallel import evaluate_streaming
|
|
@@ -41,7 +57,17 @@ except ImportError:
|
|
|
41
57
|
|
|
42
58
|
if TYPE_CHECKING:
|
|
43
59
|
from collections.abc import Callable, Sequence
|
|
44
|
-
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
import structlog
|
|
64
|
+
import structlog.contextvars
|
|
65
|
+
import structlog.processors
|
|
66
|
+
import structlog.stdlib
|
|
67
|
+
except ImportError:
|
|
68
|
+
_HAS_STRUCTLOG = False
|
|
69
|
+
else:
|
|
70
|
+
_HAS_STRUCTLOG = True
|
|
45
71
|
|
|
46
72
|
|
|
47
73
|
ExecutorType = Literal["in-memory", "streaming", "cpu"]
|
|
@@ -52,8 +78,28 @@ class Record:
|
|
|
52
78
|
"""Results for a single run of a single PDS-H query."""
|
|
53
79
|
|
|
54
80
|
query: int
|
|
81
|
+
iteration: int
|
|
55
82
|
duration: float
|
|
56
83
|
shuffle_stats: dict[str, dict[str, int | float]] | None = None
|
|
84
|
+
traces: list[dict[str, Any]] | None = None
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def new(
|
|
88
|
+
cls,
|
|
89
|
+
query: int,
|
|
90
|
+
iteration: int,
|
|
91
|
+
duration: float,
|
|
92
|
+
shuffle_stats: dict[str, dict[str, int | float]] | None = None,
|
|
93
|
+
traces: list[dict[str, Any]] | None = None,
|
|
94
|
+
) -> Record:
|
|
95
|
+
"""Create a Record from plain data."""
|
|
96
|
+
return cls(
|
|
97
|
+
query=query,
|
|
98
|
+
iteration=iteration,
|
|
99
|
+
duration=duration,
|
|
100
|
+
shuffle_stats=shuffle_stats,
|
|
101
|
+
traces=traces,
|
|
102
|
+
)
|
|
57
103
|
|
|
58
104
|
|
|
59
105
|
@dataclasses.dataclass
|
|
@@ -181,7 +227,10 @@ class RunConfig:
|
|
|
181
227
|
queries: list[int]
|
|
182
228
|
suffix: str
|
|
183
229
|
executor: ExecutorType
|
|
184
|
-
|
|
230
|
+
runtime: str
|
|
231
|
+
stream_policy: str | None
|
|
232
|
+
cluster: str
|
|
233
|
+
scheduler: str # Deprecated, kept for backward compatibility
|
|
185
234
|
n_workers: int
|
|
186
235
|
versions: PackageVersions = dataclasses.field(
|
|
187
236
|
default_factory=PackageVersions.collect
|
|
@@ -205,7 +254,10 @@ class RunConfig:
|
|
|
205
254
|
rapidsmpf_spill: bool
|
|
206
255
|
spill_device: float
|
|
207
256
|
query_set: str
|
|
257
|
+
collect_traces: bool = False
|
|
208
258
|
stats_planning: bool
|
|
259
|
+
max_io_threads: int
|
|
260
|
+
native_parquet: bool
|
|
209
261
|
|
|
210
262
|
def __post_init__(self) -> None: # noqa: D105
|
|
211
263
|
if self.gather_shuffle_stats and self.shuffle != "rapidsmpf":
|
|
@@ -217,10 +269,38 @@ class RunConfig:
|
|
|
217
269
|
def from_args(cls, args: argparse.Namespace) -> RunConfig:
|
|
218
270
|
"""Create a RunConfig from command line arguments."""
|
|
219
271
|
executor: ExecutorType = args.executor
|
|
272
|
+
cluster = args.cluster
|
|
220
273
|
scheduler = args.scheduler
|
|
274
|
+
runtime = args.runtime
|
|
275
|
+
stream_policy = args.stream_policy
|
|
276
|
+
|
|
277
|
+
# Handle "auto" stream policy
|
|
278
|
+
if stream_policy == "auto":
|
|
279
|
+
stream_policy = None
|
|
221
280
|
|
|
281
|
+
# Deal with deprecated scheduler argument
|
|
282
|
+
# and non-streaming executors
|
|
222
283
|
if executor == "in-memory" or executor == "cpu":
|
|
284
|
+
cluster = None
|
|
223
285
|
scheduler = None
|
|
286
|
+
elif scheduler is not None:
|
|
287
|
+
if cluster is not None:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
"Cannot specify both -s/--scheduler and -c/--cluster. "
|
|
290
|
+
"Please use -c/--cluster only."
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
warnings.warn(
|
|
294
|
+
"The -s/--scheduler argument is deprecated. Use -c/--cluster instead.",
|
|
295
|
+
FutureWarning,
|
|
296
|
+
stacklevel=2,
|
|
297
|
+
)
|
|
298
|
+
cluster = "single" if scheduler == "synchronous" else "distributed"
|
|
299
|
+
elif cluster is not None:
|
|
300
|
+
scheduler = "synchronous" if cluster == "single" else "distributed"
|
|
301
|
+
else:
|
|
302
|
+
cluster = "single"
|
|
303
|
+
scheduler = "synchronous"
|
|
224
304
|
|
|
225
305
|
path = args.path
|
|
226
306
|
name = args.query_set
|
|
@@ -240,12 +320,25 @@ class RunConfig:
|
|
|
240
320
|
scale_factor = _infer_scale_factor(name, path, args.suffix)
|
|
241
321
|
if path is None:
|
|
242
322
|
path = f"{args.root}/scale-{scale_factor}"
|
|
323
|
+
|
|
324
|
+
scale_factor = float(scale_factor)
|
|
243
325
|
try:
|
|
244
|
-
|
|
326
|
+
scale_factor_int = int(scale_factor)
|
|
245
327
|
except ValueError:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
328
|
+
pass
|
|
329
|
+
else:
|
|
330
|
+
if scale_factor_int == scale_factor:
|
|
331
|
+
scale_factor = scale_factor_int
|
|
332
|
+
|
|
333
|
+
skip_scale_factor_inference = (
|
|
334
|
+
"LIBCUDF_IO_REROUTE_LOCAL_DIR_PATTERN" in os.environ
|
|
335
|
+
) and ("LIBCUDF_IO_REROUTE_REMOTE_DIR_PATTERN" in os.environ)
|
|
336
|
+
|
|
337
|
+
if (
|
|
338
|
+
"pdsh" in name
|
|
339
|
+
and args.scale is not None
|
|
340
|
+
and skip_scale_factor_inference is False
|
|
341
|
+
):
|
|
249
342
|
# Validate the user-supplied scale factor
|
|
250
343
|
sf_inf = _infer_scale_factor(name, path, args.suffix)
|
|
251
344
|
rel_error = abs((scale_factor - sf_inf) / sf_inf)
|
|
@@ -258,7 +351,10 @@ class RunConfig:
|
|
|
258
351
|
return cls(
|
|
259
352
|
queries=args.query,
|
|
260
353
|
executor=executor,
|
|
354
|
+
cluster=cluster,
|
|
261
355
|
scheduler=scheduler,
|
|
356
|
+
runtime=runtime,
|
|
357
|
+
stream_policy=stream_policy,
|
|
262
358
|
n_workers=args.n_workers,
|
|
263
359
|
shuffle=args.shuffle,
|
|
264
360
|
gather_shuffle_stats=args.rapidsmpf_dask_statistics,
|
|
@@ -275,7 +371,10 @@ class RunConfig:
|
|
|
275
371
|
rapidsmpf_spill=args.rapidsmpf_spill,
|
|
276
372
|
max_rows_per_partition=args.max_rows_per_partition,
|
|
277
373
|
query_set=args.query_set,
|
|
374
|
+
collect_traces=args.collect_traces,
|
|
278
375
|
stats_planning=args.stats_planning,
|
|
376
|
+
max_io_threads=args.max_io_threads,
|
|
377
|
+
native_parquet=args.native_parquet,
|
|
279
378
|
)
|
|
280
379
|
|
|
281
380
|
def serialize(self, engine: pl.GPUEngine | None) -> dict:
|
|
@@ -297,13 +396,17 @@ class RunConfig:
|
|
|
297
396
|
print(f"path: {self.dataset_path}")
|
|
298
397
|
print(f"scale_factor: {self.scale_factor}")
|
|
299
398
|
print(f"executor: {self.executor}")
|
|
399
|
+
print(f"stream_policy: {self.stream_policy}")
|
|
300
400
|
if self.executor == "streaming":
|
|
301
|
-
print(f"
|
|
401
|
+
print(f"runtime: {self.runtime}")
|
|
402
|
+
print(f"cluster: {self.cluster}")
|
|
302
403
|
print(f"blocksize: {self.blocksize}")
|
|
303
404
|
print(f"shuffle_method: {self.shuffle}")
|
|
304
405
|
print(f"broadcast_join_limit: {self.broadcast_join_limit}")
|
|
305
406
|
print(f"stats_planning: {self.stats_planning}")
|
|
306
|
-
if self.
|
|
407
|
+
if self.runtime == "rapidsmpf":
|
|
408
|
+
print(f"native_parquet: {self.native_parquet}")
|
|
409
|
+
if self.cluster == "distributed":
|
|
307
410
|
print(f"n_workers: {self.n_workers}")
|
|
308
411
|
print(f"threads: {self.threads}")
|
|
309
412
|
print(f"rmm_async: {self.rmm_async}")
|
|
@@ -338,20 +441,31 @@ def get_executor_options(
|
|
|
338
441
|
"""Generate executor_options for GPUEngine."""
|
|
339
442
|
executor_options: dict[str, Any] = {}
|
|
340
443
|
|
|
341
|
-
if run_config.
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
444
|
+
if run_config.executor == "streaming":
|
|
445
|
+
if run_config.blocksize:
|
|
446
|
+
executor_options["target_partition_size"] = run_config.blocksize
|
|
447
|
+
if run_config.max_rows_per_partition:
|
|
448
|
+
executor_options["max_rows_per_partition"] = (
|
|
449
|
+
run_config.max_rows_per_partition
|
|
450
|
+
)
|
|
451
|
+
if run_config.shuffle:
|
|
452
|
+
executor_options["shuffle_method"] = run_config.shuffle
|
|
453
|
+
if run_config.broadcast_join_limit:
|
|
454
|
+
executor_options["broadcast_join_limit"] = run_config.broadcast_join_limit
|
|
455
|
+
if run_config.rapidsmpf_spill:
|
|
456
|
+
executor_options["rapidsmpf_spill"] = run_config.rapidsmpf_spill
|
|
457
|
+
if run_config.cluster == "distributed":
|
|
458
|
+
executor_options["cluster"] = "distributed"
|
|
459
|
+
executor_options["stats_planning"] = {
|
|
460
|
+
"use_reduction_planning": run_config.stats_planning,
|
|
461
|
+
"use_sampling": (
|
|
462
|
+
# Always allow row-group sampling for rapidsmpf runtime
|
|
463
|
+
run_config.stats_planning or run_config.runtime == "rapidsmpf"
|
|
464
|
+
),
|
|
465
|
+
}
|
|
466
|
+
executor_options["client_device_threshold"] = run_config.spill_device
|
|
467
|
+
executor_options["runtime"] = run_config.runtime
|
|
468
|
+
executor_options["max_io_threads"] = run_config.max_io_threads
|
|
355
469
|
|
|
356
470
|
if (
|
|
357
471
|
benchmark
|
|
@@ -390,7 +504,7 @@ def print_query_plan(
|
|
|
390
504
|
if args.explain_logical:
|
|
391
505
|
print(f"\nQuery {q_id} - Logical plan\n")
|
|
392
506
|
print(explain_query(q, engine, physical=False))
|
|
393
|
-
if args.explain:
|
|
507
|
+
if args.explain and run_config.executor == "streaming":
|
|
394
508
|
print(f"\nQuery {q_id} - Physical plan\n")
|
|
395
509
|
print(explain_query(q, engine))
|
|
396
510
|
else:
|
|
@@ -399,9 +513,9 @@ def print_query_plan(
|
|
|
399
513
|
)
|
|
400
514
|
|
|
401
515
|
|
|
402
|
-
def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): # type: ignore
|
|
516
|
+
def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): # type: ignore[no-untyped-def]
|
|
403
517
|
"""Initialize a Dask distributed cluster."""
|
|
404
|
-
if run_config.
|
|
518
|
+
if run_config.cluster != "distributed":
|
|
405
519
|
return None
|
|
406
520
|
|
|
407
521
|
from dask_cuda import LocalCUDACluster
|
|
@@ -437,6 +551,10 @@ def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): #
|
|
|
437
551
|
}
|
|
438
552
|
),
|
|
439
553
|
)
|
|
554
|
+
# Setting this globally makes the peak statistics not meaningful
|
|
555
|
+
# across queries / iterations. But doing it per query isn't worth
|
|
556
|
+
# the effort right now.
|
|
557
|
+
client.run(rmm.statistics.enable_statistics)
|
|
440
558
|
except ImportError as err:
|
|
441
559
|
if run_config.shuffle == "rapidsmpf":
|
|
442
560
|
raise ImportError(
|
|
@@ -468,10 +586,18 @@ def execute_query(
|
|
|
468
586
|
if args.debug:
|
|
469
587
|
translator = Translator(q._ldf.visit(), engine)
|
|
470
588
|
ir = translator.translate_ir()
|
|
589
|
+
context = IRExecutionContext.from_config_options(
|
|
590
|
+
translator.config_options
|
|
591
|
+
)
|
|
471
592
|
if run_config.executor == "in-memory":
|
|
472
|
-
return ir.evaluate(
|
|
593
|
+
return ir.evaluate(
|
|
594
|
+
cache={}, timer=None, context=context
|
|
595
|
+
).to_polars()
|
|
473
596
|
elif run_config.executor == "streaming":
|
|
474
|
-
return evaluate_streaming(
|
|
597
|
+
return evaluate_streaming(
|
|
598
|
+
ir,
|
|
599
|
+
translator.config_options,
|
|
600
|
+
)
|
|
475
601
|
assert_never(run_config.executor)
|
|
476
602
|
else:
|
|
477
603
|
return q.collect(engine=engine)
|
|
@@ -558,22 +684,51 @@ def parse_args(
|
|
|
558
684
|
- streaming : Partitioned evaluation (default)
|
|
559
685
|
- cpu : Use Polars CPU engine"""),
|
|
560
686
|
)
|
|
687
|
+
parser.add_argument(
|
|
688
|
+
"-c",
|
|
689
|
+
"--cluster",
|
|
690
|
+
default=None,
|
|
691
|
+
type=str,
|
|
692
|
+
choices=["single", "distributed"],
|
|
693
|
+
help=textwrap.dedent("""\
|
|
694
|
+
Cluster type to use with the 'streaming' executor.
|
|
695
|
+
- single : Run locally in a single process
|
|
696
|
+
- distributed : Use Dask for multi-GPU execution"""),
|
|
697
|
+
)
|
|
561
698
|
parser.add_argument(
|
|
562
699
|
"-s",
|
|
563
700
|
"--scheduler",
|
|
564
|
-
default=
|
|
701
|
+
default=None,
|
|
565
702
|
type=str,
|
|
566
703
|
choices=["synchronous", "distributed"],
|
|
567
704
|
help=textwrap.dedent("""\
|
|
705
|
+
*Deprecated*: Use --cluster instead.
|
|
706
|
+
|
|
568
707
|
Scheduler type to use with the 'streaming' executor.
|
|
569
708
|
- synchronous : Run locally in a single process
|
|
570
709
|
- distributed : Use Dask for multi-GPU execution"""),
|
|
571
710
|
)
|
|
711
|
+
parser.add_argument(
|
|
712
|
+
"--runtime",
|
|
713
|
+
type=str,
|
|
714
|
+
choices=["tasks", "rapidsmpf"],
|
|
715
|
+
default="tasks",
|
|
716
|
+
help="Runtime to use for the streaming executor (tasks or rapidsmpf).",
|
|
717
|
+
)
|
|
718
|
+
parser.add_argument(
|
|
719
|
+
"--stream-policy",
|
|
720
|
+
type=str,
|
|
721
|
+
choices=["auto", "default", "new", "pool"],
|
|
722
|
+
default="auto",
|
|
723
|
+
help=textwrap.dedent("""\
|
|
724
|
+
CUDA stream policy (auto, default, new, pool).
|
|
725
|
+
Default: auto (use the default policy for the runtime)"""),
|
|
726
|
+
)
|
|
572
727
|
parser.add_argument(
|
|
573
728
|
"--n-workers",
|
|
574
729
|
default=1,
|
|
575
730
|
type=int,
|
|
576
|
-
help="Number of Dask-CUDA workers (requires 'distributed'
|
|
731
|
+
help="Number of Dask-CUDA workers (requires 'distributed' cluster).",
|
|
577
732
|
)
|
|
578
733
|
parser.add_argument(
|
|
579
734
|
"--blocksize",
|
|
@@ -627,11 +782,12 @@ def parse_args(
|
|
|
627
782
|
)
|
|
628
783
|
parser.add_argument(
|
|
629
784
|
"--rmm-pool-size",
|
|
630
|
-
default=
|
|
785
|
+
default=None,
|
|
631
786
|
type=float,
|
|
632
787
|
help=textwrap.dedent("""\
|
|
633
788
|
Fraction of total GPU memory to allocate for RMM pool.
|
|
634
|
-
Default: 0.5 (50%% of GPU memory)
|
|
789
|
+
Default: 0.5 (50%% of GPU memory) when --no-rmm-async,
|
|
790
|
+
None when --rmm-async"""),
|
|
635
791
|
)
|
|
636
792
|
parser.add_argument(
|
|
637
793
|
"--rmm-release-threshold",
|
|
@@ -646,7 +802,7 @@ def parse_args(
|
|
|
646
802
|
"--rmm-async",
|
|
647
803
|
action=argparse.BooleanOptionalAction,
|
|
648
804
|
default=False,
|
|
649
|
-
help="Use RMM async memory resource.",
|
|
805
|
+
help="Use RMM async memory resource. Note: only affects distributed cluster!",
|
|
650
806
|
)
|
|
651
807
|
parser.add_argument(
|
|
652
808
|
"--rapidsmpf-oom-protection",
|
|
@@ -721,13 +877,40 @@ def parse_args(
|
|
|
721
877
|
default="duckdb",
|
|
722
878
|
help="Which engine to use as the baseline for validation.",
|
|
723
879
|
)
|
|
880
|
+
|
|
881
|
+
parser.add_argument(
|
|
882
|
+
"--collect-traces",
|
|
883
|
+
action=argparse.BooleanOptionalAction,
|
|
884
|
+
default=False,
|
|
885
|
+
help="Collect data tracing cudf-polars execution.",
|
|
886
|
+
)
|
|
887
|
+
|
|
724
888
|
parser.add_argument(
|
|
725
889
|
"--stats-planning",
|
|
726
890
|
action=argparse.BooleanOptionalAction,
|
|
727
891
|
default=False,
|
|
728
892
|
help="Enable statistics planning.",
|
|
729
893
|
)
|
|
730
|
-
|
|
894
|
+
parser.add_argument(
|
|
895
|
+
"--max-io-threads",
|
|
896
|
+
default=2,
|
|
897
|
+
type=int,
|
|
898
|
+
help="Maximum number of IO threads for rapidsmpf runtime.",
|
|
899
|
+
)
|
|
900
|
+
parser.add_argument(
|
|
901
|
+
"--native-parquet",
|
|
902
|
+
action=argparse.BooleanOptionalAction,
|
|
903
|
+
default=True,
|
|
904
|
+
help="Use C++ read_parquet nodes for the rapidsmpf runtime.",
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
parsed_args = parser.parse_args(args)
|
|
908
|
+
|
|
909
|
+
if parsed_args.rmm_pool_size is None and not parsed_args.rmm_async:
|
|
910
|
+
# The default rmm pool size depends on the rmm_async flag
|
|
911
|
+
parsed_args.rmm_pool_size = 0.5
|
|
912
|
+
|
|
913
|
+
return parsed_args
|
|
731
914
|
|
|
732
915
|
|
|
733
916
|
def run_polars(
|
|
@@ -742,17 +925,28 @@ def run_polars(
|
|
|
742
925
|
validation_failures: list[int] = []
|
|
743
926
|
query_failures: list[tuple[int, int]] = []
|
|
744
927
|
|
|
745
|
-
client = initialize_dask_cluster(run_config, args)
|
|
928
|
+
client = initialize_dask_cluster(run_config, args)
|
|
746
929
|
|
|
747
930
|
records: defaultdict[int, list[Record]] = defaultdict(list)
|
|
748
931
|
engine: pl.GPUEngine | None = None
|
|
749
932
|
|
|
750
933
|
if run_config.executor != "cpu":
|
|
751
934
|
executor_options = get_executor_options(run_config, benchmark=benchmark)
|
|
935
|
+
if run_config.runtime == "rapidsmpf":
|
|
936
|
+
parquet_options = {
|
|
937
|
+
"use_rapidsmpf_native": run_config.native_parquet,
|
|
938
|
+
}
|
|
939
|
+
else:
|
|
940
|
+
parquet_options = {}
|
|
752
941
|
engine = pl.GPUEngine(
|
|
753
942
|
raise_on_fail=True,
|
|
943
|
+
memory_resource=rmm.mr.CudaAsyncMemoryResource()
|
|
944
|
+
if run_config.rmm_async
|
|
945
|
+
else None,
|
|
946
|
+
cuda_stream_policy=run_config.stream_policy,
|
|
754
947
|
executor=run_config.executor,
|
|
755
948
|
executor_options=executor_options,
|
|
949
|
+
parquet_options=parquet_options,
|
|
756
950
|
)
|
|
757
951
|
|
|
758
952
|
for q_id in run_config.queries:
|
|
@@ -764,8 +958,12 @@ def run_polars(
|
|
|
764
958
|
print_query_plan(q_id, q, args, run_config, engine)
|
|
765
959
|
|
|
766
960
|
records[q_id] = []
|
|
767
|
-
|
|
768
961
|
for i in range(args.iterations):
|
|
962
|
+
if _HAS_STRUCTLOG and run_config.collect_traces:
|
|
963
|
+
setup_logging(q_id, i)
|
|
964
|
+
if client is not None:
|
|
965
|
+
client.run(setup_logging, q_id, i)
|
|
966
|
+
|
|
769
967
|
t0 = time.monotonic()
|
|
770
968
|
|
|
771
969
|
try:
|
|
@@ -781,8 +979,8 @@ def run_polars(
|
|
|
781
979
|
gather_shuffle_statistics,
|
|
782
980
|
)
|
|
783
981
|
|
|
784
|
-
shuffle_stats = gather_shuffle_statistics(client)
|
|
785
|
-
clear_shuffle_statistics(client)
|
|
982
|
+
shuffle_stats = gather_shuffle_statistics(client)
|
|
983
|
+
clear_shuffle_statistics(client)
|
|
786
984
|
else:
|
|
787
985
|
shuffle_stats = None
|
|
788
986
|
|
|
@@ -800,15 +998,65 @@ def run_polars(
|
|
|
800
998
|
print(f"❌ Query {q_id} failed validation!\n{e}")
|
|
801
999
|
|
|
802
1000
|
t1 = time.monotonic()
|
|
803
|
-
record = Record(
|
|
1001
|
+
record = Record(
|
|
1002
|
+
query=q_id, iteration=i, duration=t1 - t0, shuffle_stats=shuffle_stats
|
|
1003
|
+
)
|
|
804
1004
|
if args.print_results:
|
|
805
1005
|
print(result)
|
|
806
1006
|
|
|
807
|
-
print(
|
|
1007
|
+
print(
|
|
1008
|
+
f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s",
|
|
1009
|
+
flush=True,
|
|
1010
|
+
)
|
|
808
1011
|
records[q_id].append(record)
|
|
809
1012
|
|
|
810
1013
|
run_config = dataclasses.replace(run_config, records=dict(records))
|
|
811
1014
|
|
|
1015
|
+
# consolidate logs
|
|
1016
|
+
if _HAS_STRUCTLOG and run_config.collect_traces:
|
|
1017
|
+
|
|
1018
|
+
def gather_logs() -> str:
|
|
1019
|
+
logger = logging.getLogger()
|
|
1020
|
+
return logger.handlers[0].stream.getvalue() # type: ignore[attr-defined]
|
|
1021
|
+
|
|
1022
|
+
if client is not None:
|
|
1023
|
+
all_logs = "\n".join(client.run(gather_logs).values())
|
|
1024
|
+
else:
|
|
1025
|
+
all_logs = gather_logs()
|
|
1026
|
+
|
|
1027
|
+
parsed_logs = [json.loads(log) for log in all_logs.splitlines() if log]
|
|
1028
|
+
# Some other log records can end up in here. Filter those out.
|
|
1029
|
+
parsed_logs = [log for log in parsed_logs if log["event"] == "Execute IR"]
|
|
1030
|
+
# Now we want to augment the existing Records with the trace data.
|
|
1031
|
+
|
|
1032
|
+
def group_key(x: dict) -> int:
|
|
1033
|
+
return x["query_id"]
|
|
1034
|
+
|
|
1035
|
+
def sort_key(x: dict) -> tuple[int, int]:
|
|
1036
|
+
return x["query_id"], x["iteration"]
|
|
1037
|
+
|
|
1038
|
+
grouped = itertools.groupby(
|
|
1039
|
+
sorted(parsed_logs, key=sort_key),
|
|
1040
|
+
key=group_key,
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
for query_id, run_logs_group in grouped:
|
|
1044
|
+
run_logs = list(run_logs_group)
|
|
1045
|
+
by_iteration = [
|
|
1046
|
+
list(x)
|
|
1047
|
+
for _, x in itertools.groupby(run_logs, key=lambda x: x["iteration"])
|
|
1048
|
+
]
|
|
1049
|
+
run_records = run_config.records[query_id]
|
|
1050
|
+
assert len(by_iteration) == len(run_records) # same number of iterations
|
|
1051
|
+
all_traces = [list(iteration) for iteration in by_iteration]
|
|
1052
|
+
|
|
1053
|
+
new_records = [
|
|
1054
|
+
dataclasses.replace(record, traces=traces)
|
|
1055
|
+
for record, traces in zip(run_records, all_traces, strict=True)
|
|
1056
|
+
]
|
|
1057
|
+
|
|
1058
|
+
run_config.records[query_id] = new_records
|
|
1059
|
+
|
|
812
1060
|
if args.summarize:
|
|
813
1061
|
run_config.summarize()
|
|
814
1062
|
|
|
@@ -830,3 +1078,312 @@ def run_polars(
|
|
|
830
1078
|
|
|
831
1079
|
if query_failures or validation_failures:
|
|
832
1080
|
sys.exit(1)
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
def setup_logging(query_id: int, iteration: int) -> None: # noqa: D103
|
|
1084
|
+
import cudf_polars.dsl.tracing
|
|
1085
|
+
|
|
1086
|
+
if not cudf_polars.dsl.tracing.LOG_TRACES:
|
|
1087
|
+
msg = (
|
|
1088
|
+
"Tracing requested via --collect-traces, but tracking is not enabled. "
|
|
1089
|
+
"Verify that 'CUDF_POLARS_LOG_TRACES' is set and structlog is installed."
|
|
1090
|
+
)
|
|
1091
|
+
raise RuntimeError(msg)
|
|
1092
|
+
|
|
1093
|
+
if _HAS_STRUCTLOG:
|
|
1094
|
+
# structlog uses contextvars to propagate context down to where log records
|
|
1095
|
+
# are emitted. Ideally, we'd just set the contextvars here using
|
|
1096
|
+
# structlog.bind_contextvars; for the distributed cluster we would need
|
|
1097
|
+
# to use something like client.run to set the contextvars on the worker.
|
|
1098
|
+
# However, there's an unfortunate conflict between structlog's use of
|
|
1099
|
+
# context vars and how Dask Workers actually execute tasks, such that
|
|
1100
|
+
# the contextvars set via `client.run` aren't visible to the actual
|
|
1101
|
+
# tasks.
|
|
1102
|
+
#
|
|
1103
|
+
# So instead we make a new logger each time we need a new context,
|
|
1104
|
+
# i.e. for each query/iteration pair.
|
|
1105
|
+
|
|
1106
|
+
def make_injector(
|
|
1107
|
+
query_id: int, iteration: int
|
|
1108
|
+
) -> Callable[[logging.Logger, str, dict[str, Any]], dict[str, Any]]:
|
|
1109
|
+
def inject(
|
|
1110
|
+
logger: Any, method_name: Any, event_dict: Any
|
|
1111
|
+
) -> dict[str, Any]:
|
|
1112
|
+
event_dict["query_id"] = query_id
|
|
1113
|
+
event_dict["iteration"] = iteration
|
|
1114
|
+
return event_dict
|
|
1115
|
+
|
|
1116
|
+
return inject
|
|
1117
|
+
|
|
1118
|
+
shared_processors = [
|
|
1119
|
+
structlog.contextvars.merge_contextvars,
|
|
1120
|
+
make_injector(query_id, iteration),
|
|
1121
|
+
structlog.processors.add_log_level,
|
|
1122
|
+
structlog.processors.CallsiteParameterAdder(
|
|
1123
|
+
parameters=[
|
|
1124
|
+
structlog.processors.CallsiteParameter.PROCESS,
|
|
1125
|
+
structlog.processors.CallsiteParameter.THREAD,
|
|
1126
|
+
],
|
|
1127
|
+
),
|
|
1128
|
+
structlog.processors.StackInfoRenderer(),
|
|
1129
|
+
structlog.dev.set_exc_info,
|
|
1130
|
+
structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M:%S.%f", utc=False),
|
|
1131
|
+
]
|
|
1132
|
+
|
|
1133
|
+
# For logging to a file
|
|
1134
|
+
json_renderer = structlog.processors.JSONRenderer()
|
|
1135
|
+
|
|
1136
|
+
stream = io.StringIO()
|
|
1137
|
+
json_file_handler = logging.StreamHandler(stream)
|
|
1138
|
+
json_file_handler.setFormatter(
|
|
1139
|
+
structlog.stdlib.ProcessorFormatter(
|
|
1140
|
+
processor=json_renderer,
|
|
1141
|
+
foreign_pre_chain=shared_processors,
|
|
1142
|
+
)
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
logging.basicConfig(level=logging.INFO, handlers=[json_file_handler])
|
|
1146
|
+
|
|
1147
|
+
structlog.configure(
|
|
1148
|
+
processors=[
|
|
1149
|
+
*shared_processors,
|
|
1150
|
+
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
|
1151
|
+
],
|
|
1152
|
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
|
1153
|
+
wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
|
|
1154
|
+
cache_logger_on_first_use=True,
|
|
1155
|
+
)
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
PDSDS_TABLE_NAMES: list[str] = [
|
|
1159
|
+
"call_center",
|
|
1160
|
+
"catalog_page",
|
|
1161
|
+
"catalog_returns",
|
|
1162
|
+
"catalog_sales",
|
|
1163
|
+
"customer",
|
|
1164
|
+
"customer_address",
|
|
1165
|
+
"customer_demographics",
|
|
1166
|
+
"date_dim",
|
|
1167
|
+
"household_demographics",
|
|
1168
|
+
"income_band",
|
|
1169
|
+
"inventory",
|
|
1170
|
+
"item",
|
|
1171
|
+
"promotion",
|
|
1172
|
+
"reason",
|
|
1173
|
+
"ship_mode",
|
|
1174
|
+
"store",
|
|
1175
|
+
"store_returns",
|
|
1176
|
+
"store_sales",
|
|
1177
|
+
"time_dim",
|
|
1178
|
+
"warehouse",
|
|
1179
|
+
"web_page",
|
|
1180
|
+
"web_returns",
|
|
1181
|
+
"web_sales",
|
|
1182
|
+
"web_site",
|
|
1183
|
+
]
|
|
1184
|
+
|
|
1185
|
+
PDSH_TABLE_NAMES: list[str] = [
|
|
1186
|
+
"customer",
|
|
1187
|
+
"lineitem",
|
|
1188
|
+
"nation",
|
|
1189
|
+
"orders",
|
|
1190
|
+
"part",
|
|
1191
|
+
"partsupp",
|
|
1192
|
+
"region",
|
|
1193
|
+
"supplier",
|
|
1194
|
+
]
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
def print_duckdb_plan(
|
|
1198
|
+
q_id: int,
|
|
1199
|
+
sql: str,
|
|
1200
|
+
dataset_path: Path,
|
|
1201
|
+
suffix: str,
|
|
1202
|
+
query_set: str,
|
|
1203
|
+
args: argparse.Namespace,
|
|
1204
|
+
) -> None:
|
|
1205
|
+
"""Print DuckDB query plan using EXPLAIN."""
|
|
1206
|
+
if duckdb is None:
|
|
1207
|
+
raise ImportError(duckdb_err)
|
|
1208
|
+
|
|
1209
|
+
if query_set == "pdsds":
|
|
1210
|
+
tbl_names = PDSDS_TABLE_NAMES
|
|
1211
|
+
else:
|
|
1212
|
+
tbl_names = PDSH_TABLE_NAMES
|
|
1213
|
+
|
|
1214
|
+
with duckdb.connect() as conn:
|
|
1215
|
+
for name in tbl_names:
|
|
1216
|
+
pattern = (Path(dataset_path) / name).as_posix() + suffix
|
|
1217
|
+
conn.execute(
|
|
1218
|
+
f"CREATE OR REPLACE VIEW {name} AS "
|
|
1219
|
+
f"SELECT * FROM parquet_scan('{pattern}');"
|
|
1220
|
+
)
|
|
1221
|
+
|
|
1222
|
+
if args.explain_logical and args.explain:
|
|
1223
|
+
conn.execute("PRAGMA explain_output = 'all';")
|
|
1224
|
+
elif args.explain_logical:
|
|
1225
|
+
conn.execute("PRAGMA explain_output = 'optimized_only';")
|
|
1226
|
+
else:
|
|
1227
|
+
conn.execute("PRAGMA explain_output = 'physical_only';")
|
|
1228
|
+
|
|
1229
|
+
print(f"\nDuckDB Query {q_id} - Plan\n")
|
|
1230
|
+
|
|
1231
|
+
plan_rows = conn.execute(f"EXPLAIN {sql}").fetchall()
|
|
1232
|
+
for _, line in plan_rows:
|
|
1233
|
+
print(line)
|
|
1234
|
+
|
|
1235
|
+
|
|
1236
|
+
def execute_duckdb_query(
|
|
1237
|
+
query: str,
|
|
1238
|
+
dataset_path: Path,
|
|
1239
|
+
*,
|
|
1240
|
+
suffix: str = ".parquet",
|
|
1241
|
+
query_set: str = "pdsh",
|
|
1242
|
+
) -> pl.DataFrame:
|
|
1243
|
+
"""Execute a query with DuckDB."""
|
|
1244
|
+
if duckdb is None:
|
|
1245
|
+
raise ImportError(duckdb_err)
|
|
1246
|
+
if query_set == "pdsds":
|
|
1247
|
+
tbl_names = PDSDS_TABLE_NAMES
|
|
1248
|
+
else:
|
|
1249
|
+
tbl_names = PDSH_TABLE_NAMES
|
|
1250
|
+
with duckdb.connect() as conn:
|
|
1251
|
+
for name in tbl_names:
|
|
1252
|
+
pattern = (Path(dataset_path) / name).as_posix() + suffix
|
|
1253
|
+
conn.execute(
|
|
1254
|
+
f"CREATE OR REPLACE VIEW {name} AS "
|
|
1255
|
+
f"SELECT * FROM parquet_scan('{pattern}');"
|
|
1256
|
+
)
|
|
1257
|
+
return conn.execute(query).pl()
|
|
1258
|
+
|
|
1259
|
+
|
|
1260
|
+
def run_duckdb(
|
|
1261
|
+
duckdb_queries_cls: Any, options: Sequence[str] | None = None, *, num_queries: int
|
|
1262
|
+
) -> None:
|
|
1263
|
+
"""Run the benchmark with DuckDB."""
|
|
1264
|
+
args = parse_args(options, num_queries=num_queries)
|
|
1265
|
+
vars(args).update({"query_set": duckdb_queries_cls.name})
|
|
1266
|
+
run_config = RunConfig.from_args(args)
|
|
1267
|
+
records: defaultdict[int, list[Record]] = defaultdict(list)
|
|
1268
|
+
|
|
1269
|
+
for q_id in run_config.queries:
|
|
1270
|
+
try:
|
|
1271
|
+
get_q = getattr(duckdb_queries_cls, f"q{q_id}")
|
|
1272
|
+
except AttributeError as err:
|
|
1273
|
+
raise NotImplementedError(f"Query {q_id} not implemented.") from err
|
|
1274
|
+
|
|
1275
|
+
sql = get_q(run_config)
|
|
1276
|
+
|
|
1277
|
+
if args.explain or args.explain_logical:
|
|
1278
|
+
print_duckdb_plan(
|
|
1279
|
+
q_id=q_id,
|
|
1280
|
+
sql=sql,
|
|
1281
|
+
dataset_path=run_config.dataset_path,
|
|
1282
|
+
suffix=run_config.suffix,
|
|
1283
|
+
query_set=duckdb_queries_cls.name,
|
|
1284
|
+
args=args,
|
|
1285
|
+
)
|
|
1286
|
+
|
|
1287
|
+
print(f"DuckDB Executing: {q_id}")
|
|
1288
|
+
records[q_id] = []
|
|
1289
|
+
|
|
1290
|
+
for i in range(args.iterations):
|
|
1291
|
+
t0 = time.time()
|
|
1292
|
+
result = execute_duckdb_query(
|
|
1293
|
+
sql,
|
|
1294
|
+
run_config.dataset_path,
|
|
1295
|
+
suffix=run_config.suffix,
|
|
1296
|
+
query_set=duckdb_queries_cls.name,
|
|
1297
|
+
)
|
|
1298
|
+
t1 = time.time()
|
|
1299
|
+
record = Record(query=q_id, iteration=i, duration=t1 - t0)
|
|
1300
|
+
if args.print_results:
|
|
1301
|
+
print(result)
|
|
1302
|
+
print(f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s")
|
|
1303
|
+
records[q_id].append(record)
|
|
1304
|
+
|
|
1305
|
+
run_config = dataclasses.replace(run_config, records=dict(records))
|
|
1306
|
+
if args.summarize:
|
|
1307
|
+
run_config.summarize()
|
|
1308
|
+
|
|
1309
|
+
|
|
1310
|
+
def run_validate(
|
|
1311
|
+
polars_queries_cls: Any,
|
|
1312
|
+
duckdb_queries_cls: Any,
|
|
1313
|
+
options: Sequence[str] | None = None,
|
|
1314
|
+
*,
|
|
1315
|
+
num_queries: int,
|
|
1316
|
+
check_dtypes: bool,
|
|
1317
|
+
check_column_order: bool,
|
|
1318
|
+
) -> None:
|
|
1319
|
+
"""Validate Polars CPU/GPU vs DuckDB."""
|
|
1320
|
+
from polars.testing import assert_frame_equal
|
|
1321
|
+
|
|
1322
|
+
args = parse_args(options, num_queries=num_queries)
|
|
1323
|
+
vars(args).update({"query_set": polars_queries_cls.name})
|
|
1324
|
+
run_config = RunConfig.from_args(args)
|
|
1325
|
+
|
|
1326
|
+
baseline = args.baseline
|
|
1327
|
+
if baseline not in {"duckdb", "cpu"}:
|
|
1328
|
+
raise ValueError("Baseline must be one of: 'duckdb', 'cpu'")
|
|
1329
|
+
|
|
1330
|
+
failures: list[int] = []
|
|
1331
|
+
|
|
1332
|
+
engine: pl.GPUEngine | None = None
|
|
1333
|
+
if run_config.executor != "cpu":
|
|
1334
|
+
engine = pl.GPUEngine(
|
|
1335
|
+
raise_on_fail=True,
|
|
1336
|
+
executor=run_config.executor,
|
|
1337
|
+
executor_options=get_executor_options(run_config, polars_queries_cls),
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
for q_id in run_config.queries:
|
|
1341
|
+
print(f"\nValidating Query {q_id}")
|
|
1342
|
+
try:
|
|
1343
|
+
get_pl = getattr(polars_queries_cls, f"q{q_id}")
|
|
1344
|
+
get_ddb = getattr(duckdb_queries_cls, f"q{q_id}")
|
|
1345
|
+
except AttributeError as err:
|
|
1346
|
+
raise NotImplementedError(f"Query {q_id} not implemented.") from err
|
|
1347
|
+
|
|
1348
|
+
polars_query = get_pl(run_config)
|
|
1349
|
+
if baseline == "duckdb":
|
|
1350
|
+
base_sql = get_ddb(run_config)
|
|
1351
|
+
base_result = execute_duckdb_query(
|
|
1352
|
+
base_sql,
|
|
1353
|
+
run_config.dataset_path,
|
|
1354
|
+
query_set=duckdb_queries_cls.name,
|
|
1355
|
+
)
|
|
1356
|
+
else:
|
|
1357
|
+
base_result = polars_query.collect(engine="streaming")
|
|
1358
|
+
|
|
1359
|
+
if run_config.executor == "cpu":
|
|
1360
|
+
test_result = polars_query.collect(engine="streaming")
|
|
1361
|
+
else:
|
|
1362
|
+
try:
|
|
1363
|
+
test_result = polars_query.collect(engine=engine)
|
|
1364
|
+
except Exception as e:
|
|
1365
|
+
failures.append(q_id)
|
|
1366
|
+
print(f"❌ Query {q_id} failed validation: GPU execution failed.\n{e}")
|
|
1367
|
+
continue
|
|
1368
|
+
|
|
1369
|
+
try:
|
|
1370
|
+
assert_frame_equal(
|
|
1371
|
+
base_result,
|
|
1372
|
+
test_result,
|
|
1373
|
+
check_dtypes=check_dtypes,
|
|
1374
|
+
check_column_order=check_column_order,
|
|
1375
|
+
)
|
|
1376
|
+
print(f"✅ Query {q_id} passed validation.")
|
|
1377
|
+
except AssertionError as e:
|
|
1378
|
+
failures.append(q_id)
|
|
1379
|
+
print(f"❌ Query {q_id} failed validation:\n{e}")
|
|
1380
|
+
if args.print_results:
|
|
1381
|
+
print("Baseline Result:\n", base_result)
|
|
1382
|
+
print("Test Result:\n", test_result)
|
|
1383
|
+
|
|
1384
|
+
if failures:
|
|
1385
|
+
print("\nValidation Summary:")
|
|
1386
|
+
print("===================")
|
|
1387
|
+
print(f"{len(failures)} query(s) failed: {failures}")
|
|
1388
|
+
else:
|
|
1389
|
+
print("\nAll queries passed validation.")
|