cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -8,27 +8,43 @@ from __future__ import annotations
8
8
  import argparse
9
9
  import dataclasses
10
10
  import importlib
11
+ import io
12
+ import itertools
11
13
  import json
14
+ import logging
12
15
  import os
13
16
  import statistics
14
17
  import sys
15
18
  import textwrap
16
19
  import time
17
20
  import traceback
21
+ import warnings
18
22
  from collections import defaultdict
19
23
  from datetime import datetime, timezone
24
+ from pathlib import Path
20
25
  from typing import TYPE_CHECKING, Any, Literal, assert_never
21
26
 
22
27
  import nvtx
23
28
 
24
29
  import polars as pl
25
30
 
31
+ import rmm.statistics
32
+
33
+ try:
34
+ import duckdb
35
+
36
+ duckdb_err = None
37
+ except ImportError as e:
38
+ duckdb = None
39
+ duckdb_err = e
40
+
26
41
  try:
27
42
  import pynvml
28
43
  except ImportError:
29
44
  pynvml = None
30
45
 
31
46
  try:
47
+ from cudf_polars.dsl.ir import IRExecutionContext
32
48
  from cudf_polars.dsl.translate import Translator
33
49
  from cudf_polars.experimental.explain import explain_query
34
50
  from cudf_polars.experimental.parallel import evaluate_streaming
@@ -41,7 +57,17 @@ except ImportError:
41
57
 
42
58
  if TYPE_CHECKING:
43
59
  from collections.abc import Callable, Sequence
44
- from pathlib import Path
60
+
61
+
62
+ try:
63
+ import structlog
64
+ import structlog.contextvars
65
+ import structlog.processors
66
+ import structlog.stdlib
67
+ except ImportError:
68
+ _HAS_STRUCTLOG = False
69
+ else:
70
+ _HAS_STRUCTLOG = True
45
71
 
46
72
 
47
73
  ExecutorType = Literal["in-memory", "streaming", "cpu"]
@@ -52,8 +78,28 @@ class Record:
52
78
  """Results for a single run of a single PDS-H query."""
53
79
 
54
80
  query: int
81
+ iteration: int
55
82
  duration: float
56
83
  shuffle_stats: dict[str, dict[str, int | float]] | None = None
84
+ traces: list[dict[str, Any]] | None = None
85
+
86
+ @classmethod
87
+ def new(
88
+ cls,
89
+ query: int,
90
+ iteration: int,
91
+ duration: float,
92
+ shuffle_stats: dict[str, dict[str, int | float]] | None = None,
93
+ traces: list[dict[str, Any]] | None = None,
94
+ ) -> Record:
95
+ """Create a Record from plain data."""
96
+ return cls(
97
+ query=query,
98
+ iteration=iteration,
99
+ duration=duration,
100
+ shuffle_stats=shuffle_stats,
101
+ traces=traces,
102
+ )
57
103
 
58
104
 
59
105
  @dataclasses.dataclass
@@ -181,7 +227,10 @@ class RunConfig:
181
227
  queries: list[int]
182
228
  suffix: str
183
229
  executor: ExecutorType
184
- scheduler: str
230
+ runtime: str
231
+ stream_policy: str | None
232
+ cluster: str
233
+ scheduler: str # Deprecated, kept for backward compatibility
185
234
  n_workers: int
186
235
  versions: PackageVersions = dataclasses.field(
187
236
  default_factory=PackageVersions.collect
@@ -205,7 +254,10 @@ class RunConfig:
205
254
  rapidsmpf_spill: bool
206
255
  spill_device: float
207
256
  query_set: str
257
+ collect_traces: bool = False
208
258
  stats_planning: bool
259
+ max_io_threads: int
260
+ native_parquet: bool
209
261
 
210
262
  def __post_init__(self) -> None: # noqa: D105
211
263
  if self.gather_shuffle_stats and self.shuffle != "rapidsmpf":
@@ -217,10 +269,38 @@ class RunConfig:
217
269
  def from_args(cls, args: argparse.Namespace) -> RunConfig:
218
270
  """Create a RunConfig from command line arguments."""
219
271
  executor: ExecutorType = args.executor
272
+ cluster = args.cluster
220
273
  scheduler = args.scheduler
274
+ runtime = args.runtime
275
+ stream_policy = args.stream_policy
276
+
277
+ # Handle "auto" stream policy
278
+ if stream_policy == "auto":
279
+ stream_policy = None
221
280
 
281
+ # Deal with deprecated scheduler argument
282
+ # and non-streaming executors
222
283
  if executor == "in-memory" or executor == "cpu":
284
+ cluster = None
223
285
  scheduler = None
286
+ elif scheduler is not None:
287
+ if cluster is not None:
288
+ raise ValueError(
289
+ "Cannot specify both -s/--scheduler and -c/--cluster. "
290
+ "Please use -c/--cluster only."
291
+ )
292
+ else:
293
+ warnings.warn(
294
+ "The -s/--scheduler argument is deprecated. Use -c/--cluster instead.",
295
+ FutureWarning,
296
+ stacklevel=2,
297
+ )
298
+ cluster = "single" if scheduler == "synchronous" else "distributed"
299
+ elif cluster is not None:
300
+ scheduler = "synchronous" if cluster == "single" else "distributed"
301
+ else:
302
+ cluster = "single"
303
+ scheduler = "synchronous"
224
304
 
225
305
  path = args.path
226
306
  name = args.query_set
@@ -240,12 +320,25 @@ class RunConfig:
240
320
  scale_factor = _infer_scale_factor(name, path, args.suffix)
241
321
  if path is None:
242
322
  path = f"{args.root}/scale-{scale_factor}"
323
+
324
+ scale_factor = float(scale_factor)
243
325
  try:
244
- scale_factor = int(scale_factor)
326
+ scale_factor_int = int(scale_factor)
245
327
  except ValueError:
246
- scale_factor = float(scale_factor)
247
-
248
- if "pdsh" in name and args.scale is not None:
328
+ pass
329
+ else:
330
+ if scale_factor_int == scale_factor:
331
+ scale_factor = scale_factor_int
332
+
333
+ skip_scale_factor_inference = (
334
+ "LIBCUDF_IO_REROUTE_LOCAL_DIR_PATTERN" in os.environ
335
+ ) and ("LIBCUDF_IO_REROUTE_REMOTE_DIR_PATTERN" in os.environ)
336
+
337
+ if (
338
+ "pdsh" in name
339
+ and args.scale is not None
340
+ and skip_scale_factor_inference is False
341
+ ):
249
342
  # Validate the user-supplied scale factor
250
343
  sf_inf = _infer_scale_factor(name, path, args.suffix)
251
344
  rel_error = abs((scale_factor - sf_inf) / sf_inf)
@@ -258,7 +351,10 @@ class RunConfig:
258
351
  return cls(
259
352
  queries=args.query,
260
353
  executor=executor,
354
+ cluster=cluster,
261
355
  scheduler=scheduler,
356
+ runtime=runtime,
357
+ stream_policy=stream_policy,
262
358
  n_workers=args.n_workers,
263
359
  shuffle=args.shuffle,
264
360
  gather_shuffle_stats=args.rapidsmpf_dask_statistics,
@@ -275,7 +371,10 @@ class RunConfig:
275
371
  rapidsmpf_spill=args.rapidsmpf_spill,
276
372
  max_rows_per_partition=args.max_rows_per_partition,
277
373
  query_set=args.query_set,
374
+ collect_traces=args.collect_traces,
278
375
  stats_planning=args.stats_planning,
376
+ max_io_threads=args.max_io_threads,
377
+ native_parquet=args.native_parquet,
279
378
  )
280
379
 
281
380
  def serialize(self, engine: pl.GPUEngine | None) -> dict:
@@ -297,13 +396,17 @@ class RunConfig:
297
396
  print(f"path: {self.dataset_path}")
298
397
  print(f"scale_factor: {self.scale_factor}")
299
398
  print(f"executor: {self.executor}")
399
+ print(f"stream_policy: {self.stream_policy}")
300
400
  if self.executor == "streaming":
301
- print(f"scheduler: {self.scheduler}")
401
+ print(f"runtime: {self.runtime}")
402
+ print(f"cluster: {self.cluster}")
302
403
  print(f"blocksize: {self.blocksize}")
303
404
  print(f"shuffle_method: {self.shuffle}")
304
405
  print(f"broadcast_join_limit: {self.broadcast_join_limit}")
305
406
  print(f"stats_planning: {self.stats_planning}")
306
- if self.scheduler == "distributed":
407
+ if self.runtime == "rapidsmpf":
408
+ print(f"native_parquet: {self.native_parquet}")
409
+ if self.cluster == "distributed":
307
410
  print(f"n_workers: {self.n_workers}")
308
411
  print(f"threads: {self.threads}")
309
412
  print(f"rmm_async: {self.rmm_async}")
@@ -338,20 +441,31 @@ def get_executor_options(
338
441
  """Generate executor_options for GPUEngine."""
339
442
  executor_options: dict[str, Any] = {}
340
443
 
341
- if run_config.blocksize:
342
- executor_options["target_partition_size"] = run_config.blocksize
343
- if run_config.max_rows_per_partition:
344
- executor_options["max_rows_per_partition"] = run_config.max_rows_per_partition
345
- if run_config.shuffle:
346
- executor_options["shuffle_method"] = run_config.shuffle
347
- if run_config.broadcast_join_limit:
348
- executor_options["broadcast_join_limit"] = run_config.broadcast_join_limit
349
- if run_config.rapidsmpf_spill:
350
- executor_options["rapidsmpf_spill"] = run_config.rapidsmpf_spill
351
- if run_config.scheduler == "distributed":
352
- executor_options["scheduler"] = "distributed"
353
- if run_config.stats_planning:
354
- executor_options["stats_planning"] = {"use_reduction_planning": True}
444
+ if run_config.executor == "streaming":
445
+ if run_config.blocksize:
446
+ executor_options["target_partition_size"] = run_config.blocksize
447
+ if run_config.max_rows_per_partition:
448
+ executor_options["max_rows_per_partition"] = (
449
+ run_config.max_rows_per_partition
450
+ )
451
+ if run_config.shuffle:
452
+ executor_options["shuffle_method"] = run_config.shuffle
453
+ if run_config.broadcast_join_limit:
454
+ executor_options["broadcast_join_limit"] = run_config.broadcast_join_limit
455
+ if run_config.rapidsmpf_spill:
456
+ executor_options["rapidsmpf_spill"] = run_config.rapidsmpf_spill
457
+ if run_config.cluster == "distributed":
458
+ executor_options["cluster"] = "distributed"
459
+ executor_options["stats_planning"] = {
460
+ "use_reduction_planning": run_config.stats_planning,
461
+ "use_sampling": (
462
+ # Always allow row-group sampling for rapidsmpf runtime
463
+ run_config.stats_planning or run_config.runtime == "rapidsmpf"
464
+ ),
465
+ }
466
+ executor_options["client_device_threshold"] = run_config.spill_device
467
+ executor_options["runtime"] = run_config.runtime
468
+ executor_options["max_io_threads"] = run_config.max_io_threads
355
469
 
356
470
  if (
357
471
  benchmark
@@ -390,7 +504,7 @@ def print_query_plan(
390
504
  if args.explain_logical:
391
505
  print(f"\nQuery {q_id} - Logical plan\n")
392
506
  print(explain_query(q, engine, physical=False))
393
- if args.explain:
507
+ if args.explain and run_config.executor == "streaming":
394
508
  print(f"\nQuery {q_id} - Physical plan\n")
395
509
  print(explain_query(q, engine))
396
510
  else:
@@ -399,9 +513,9 @@ def print_query_plan(
399
513
  )
400
514
 
401
515
 
402
- def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): # type: ignore
516
+ def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): # type: ignore[no-untyped-def]
403
517
  """Initialize a Dask distributed cluster."""
404
- if run_config.scheduler != "distributed":
518
+ if run_config.cluster != "distributed":
405
519
  return None
406
520
 
407
521
  from dask_cuda import LocalCUDACluster
@@ -437,6 +551,10 @@ def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): #
437
551
  }
438
552
  ),
439
553
  )
554
+ # Setting this globally makes the peak statistics not meaningful
555
+ # across queries / iterations. But doing it per query isn't worth
556
+ # the effort right now.
557
+ client.run(rmm.statistics.enable_statistics)
440
558
  except ImportError as err:
441
559
  if run_config.shuffle == "rapidsmpf":
442
560
  raise ImportError(
@@ -468,10 +586,18 @@ def execute_query(
468
586
  if args.debug:
469
587
  translator = Translator(q._ldf.visit(), engine)
470
588
  ir = translator.translate_ir()
589
+ context = IRExecutionContext.from_config_options(
590
+ translator.config_options
591
+ )
471
592
  if run_config.executor == "in-memory":
472
- return ir.evaluate(cache={}, timer=None).to_polars()
593
+ return ir.evaluate(
594
+ cache={}, timer=None, context=context
595
+ ).to_polars()
473
596
  elif run_config.executor == "streaming":
474
- return evaluate_streaming(ir, translator.config_options).to_polars()
597
+ return evaluate_streaming(
598
+ ir,
599
+ translator.config_options,
600
+ )
475
601
  assert_never(run_config.executor)
476
602
  else:
477
603
  return q.collect(engine=engine)
@@ -558,22 +684,51 @@ def parse_args(
558
684
  - streaming : Partitioned evaluation (default)
559
685
  - cpu : Use Polars CPU engine"""),
560
686
  )
687
+ parser.add_argument(
688
+ "-c",
689
+ "--cluster",
690
+ default=None,
691
+ type=str,
692
+ choices=["single", "distributed"],
693
+ help=textwrap.dedent("""\
694
+ Cluster type to use with the 'streaming' executor.
695
+ - single : Run locally in a single process
696
+ - distributed : Use Dask for multi-GPU execution"""),
697
+ )
561
698
  parser.add_argument(
562
699
  "-s",
563
700
  "--scheduler",
564
- default="synchronous",
701
+ default=None,
565
702
  type=str,
566
703
  choices=["synchronous", "distributed"],
567
704
  help=textwrap.dedent("""\
705
+ *Deprecated*: Use --cluster instead.
706
+
568
707
  Scheduler type to use with the 'streaming' executor.
569
708
  - synchronous : Run locally in a single process
570
709
  - distributed : Use Dask for multi-GPU execution"""),
571
710
  )
711
+ parser.add_argument(
712
+ "--runtime",
713
+ type=str,
714
+ choices=["tasks", "rapidsmpf"],
715
+ default="tasks",
716
+ help="Runtime to use for the streaming executor (tasks or rapidsmpf).",
717
+ )
718
+ parser.add_argument(
719
+ "--stream-policy",
720
+ type=str,
721
+ choices=["auto", "default", "new", "pool"],
722
+ default="auto",
723
+ help=textwrap.dedent("""\
724
+ CUDA stream policy (auto, default, new, pool).
725
+ Default: auto (use the default policy for the runtime)"""),
726
+ )
572
727
  parser.add_argument(
573
728
  "--n-workers",
574
729
  default=1,
575
730
  type=int,
576
- help="Number of Dask-CUDA workers (requires 'distributed' scheduler).",
731
+ help="Number of Dask-CUDA workers (requires 'distributed' cluster).",
577
732
  )
578
733
  parser.add_argument(
579
734
  "--blocksize",
@@ -627,11 +782,12 @@ def parse_args(
627
782
  )
628
783
  parser.add_argument(
629
784
  "--rmm-pool-size",
630
- default=0.5,
785
+ default=None,
631
786
  type=float,
632
787
  help=textwrap.dedent("""\
633
788
  Fraction of total GPU memory to allocate for RMM pool.
634
- Default: 0.5 (50%% of GPU memory)"""),
789
+ Default: 0.5 (50%% of GPU memory) when --no-rmm-async,
790
+ None when --rmm-async"""),
635
791
  )
636
792
  parser.add_argument(
637
793
  "--rmm-release-threshold",
@@ -646,7 +802,7 @@ def parse_args(
646
802
  "--rmm-async",
647
803
  action=argparse.BooleanOptionalAction,
648
804
  default=False,
649
- help="Use RMM async memory resource.",
805
+ help="Use RMM async memory resource. Note: only affects distributed cluster!",
650
806
  )
651
807
  parser.add_argument(
652
808
  "--rapidsmpf-oom-protection",
@@ -721,13 +877,40 @@ def parse_args(
721
877
  default="duckdb",
722
878
  help="Which engine to use as the baseline for validation.",
723
879
  )
880
+
881
+ parser.add_argument(
882
+ "--collect-traces",
883
+ action=argparse.BooleanOptionalAction,
884
+ default=False,
885
+ help="Collect data tracing cudf-polars execution.",
886
+ )
887
+
724
888
  parser.add_argument(
725
889
  "--stats-planning",
726
890
  action=argparse.BooleanOptionalAction,
727
891
  default=False,
728
892
  help="Enable statistics planning.",
729
893
  )
730
- return parser.parse_args(args)
894
+ parser.add_argument(
895
+ "--max-io-threads",
896
+ default=2,
897
+ type=int,
898
+ help="Maximum number of IO threads for rapidsmpf runtime.",
899
+ )
900
+ parser.add_argument(
901
+ "--native-parquet",
902
+ action=argparse.BooleanOptionalAction,
903
+ default=True,
904
+ help="Use C++ read_parquet nodes for the rapidsmpf runtime.",
905
+ )
906
+
907
+ parsed_args = parser.parse_args(args)
908
+
909
+ if parsed_args.rmm_pool_size is None and not parsed_args.rmm_async:
910
+ # The default rmm pool size depends on the rmm_async flag
911
+ parsed_args.rmm_pool_size = 0.5
912
+
913
+ return parsed_args
731
914
 
732
915
 
733
916
  def run_polars(
@@ -742,17 +925,28 @@ def run_polars(
742
925
  validation_failures: list[int] = []
743
926
  query_failures: list[tuple[int, int]] = []
744
927
 
745
- client = initialize_dask_cluster(run_config, args) # type: ignore
928
+ client = initialize_dask_cluster(run_config, args)
746
929
 
747
930
  records: defaultdict[int, list[Record]] = defaultdict(list)
748
931
  engine: pl.GPUEngine | None = None
749
932
 
750
933
  if run_config.executor != "cpu":
751
934
  executor_options = get_executor_options(run_config, benchmark=benchmark)
935
+ if run_config.runtime == "rapidsmpf":
936
+ parquet_options = {
937
+ "use_rapidsmpf_native": run_config.native_parquet,
938
+ }
939
+ else:
940
+ parquet_options = {}
752
941
  engine = pl.GPUEngine(
753
942
  raise_on_fail=True,
943
+ memory_resource=rmm.mr.CudaAsyncMemoryResource()
944
+ if run_config.rmm_async
945
+ else None,
946
+ cuda_stream_policy=run_config.stream_policy,
754
947
  executor=run_config.executor,
755
948
  executor_options=executor_options,
949
+ parquet_options=parquet_options,
756
950
  )
757
951
 
758
952
  for q_id in run_config.queries:
@@ -764,8 +958,12 @@ def run_polars(
764
958
  print_query_plan(q_id, q, args, run_config, engine)
765
959
 
766
960
  records[q_id] = []
767
-
768
961
  for i in range(args.iterations):
962
+ if _HAS_STRUCTLOG and run_config.collect_traces:
963
+ setup_logging(q_id, i)
964
+ if client is not None:
965
+ client.run(setup_logging, q_id, i)
966
+
769
967
  t0 = time.monotonic()
770
968
 
771
969
  try:
@@ -781,8 +979,8 @@ def run_polars(
781
979
  gather_shuffle_statistics,
782
980
  )
783
981
 
784
- shuffle_stats = gather_shuffle_statistics(client) # type: ignore[arg-type]
785
- clear_shuffle_statistics(client) # type: ignore[arg-type]
982
+ shuffle_stats = gather_shuffle_statistics(client)
983
+ clear_shuffle_statistics(client)
786
984
  else:
787
985
  shuffle_stats = None
788
986
 
@@ -800,15 +998,65 @@ def run_polars(
800
998
  print(f"❌ Query {q_id} failed validation!\n{e}")
801
999
 
802
1000
  t1 = time.monotonic()
803
- record = Record(query=q_id, duration=t1 - t0, shuffle_stats=shuffle_stats)
1001
+ record = Record(
1002
+ query=q_id, iteration=i, duration=t1 - t0, shuffle_stats=shuffle_stats
1003
+ )
804
1004
  if args.print_results:
805
1005
  print(result)
806
1006
 
807
- print(f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s")
1007
+ print(
1008
+ f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s",
1009
+ flush=True,
1010
+ )
808
1011
  records[q_id].append(record)
809
1012
 
810
1013
  run_config = dataclasses.replace(run_config, records=dict(records))
811
1014
 
1015
+ # consolidate logs
1016
+ if _HAS_STRUCTLOG and run_config.collect_traces:
1017
+
1018
+ def gather_logs() -> str:
1019
+ logger = logging.getLogger()
1020
+ return logger.handlers[0].stream.getvalue() # type: ignore[attr-defined]
1021
+
1022
+ if client is not None:
1023
+ all_logs = "\n".join(client.run(gather_logs).values())
1024
+ else:
1025
+ all_logs = gather_logs()
1026
+
1027
+ parsed_logs = [json.loads(log) for log in all_logs.splitlines() if log]
1028
+ # Some other log records can end up in here. Filter those out.
1029
+ parsed_logs = [log for log in parsed_logs if log["event"] == "Execute IR"]
1030
+ # Now we want to augment the existing Records with the trace data.
1031
+
1032
+ def group_key(x: dict) -> int:
1033
+ return x["query_id"]
1034
+
1035
+ def sort_key(x: dict) -> tuple[int, int]:
1036
+ return x["query_id"], x["iteration"]
1037
+
1038
+ grouped = itertools.groupby(
1039
+ sorted(parsed_logs, key=sort_key),
1040
+ key=group_key,
1041
+ )
1042
+
1043
+ for query_id, run_logs_group in grouped:
1044
+ run_logs = list(run_logs_group)
1045
+ by_iteration = [
1046
+ list(x)
1047
+ for _, x in itertools.groupby(run_logs, key=lambda x: x["iteration"])
1048
+ ]
1049
+ run_records = run_config.records[query_id]
1050
+ assert len(by_iteration) == len(run_records) # same number of iterations
1051
+ all_traces = [list(iteration) for iteration in by_iteration]
1052
+
1053
+ new_records = [
1054
+ dataclasses.replace(record, traces=traces)
1055
+ for record, traces in zip(run_records, all_traces, strict=True)
1056
+ ]
1057
+
1058
+ run_config.records[query_id] = new_records
1059
+
812
1060
  if args.summarize:
813
1061
  run_config.summarize()
814
1062
 
@@ -830,3 +1078,312 @@ def run_polars(
830
1078
 
831
1079
  if query_failures or validation_failures:
832
1080
  sys.exit(1)
1081
+
1082
+
1083
+ def setup_logging(query_id: int, iteration: int) -> None: # noqa: D103
1084
+ import cudf_polars.dsl.tracing
1085
+
1086
+ if not cudf_polars.dsl.tracing.LOG_TRACES:
1087
+ msg = (
1088
+ "Tracing requested via --collect-traces, but tracking is not enabled. "
1089
+ "Verify that 'CUDF_POLARS_LOG_TRACES' is set and structlog is installed."
1090
+ )
1091
+ raise RuntimeError(msg)
1092
+
1093
+ if _HAS_STRUCTLOG:
1094
+ # structlog uses contextvars to propagate context down to where log records
1095
+ # are emitted. Ideally, we'd just set the contextvars here using
1096
+ # structlog.bind_contextvars; for the distributed cluster we would need
1097
+ # to use something like client.run to set the contextvars on the worker.
1098
+ # However, there's an unfortunate conflict between structlog's use of
1099
+ # context vars and how Dask Workers actually execute tasks, such that
1100
+ # the contextvars set via `client.run` aren't visible to the actual
1101
+ # tasks.
1102
+ #
1103
+ # So instead we make a new logger each time we need a new context,
1104
+ # i.e. for each query/iteration pair.
1105
+
1106
+ def make_injector(
1107
+ query_id: int, iteration: int
1108
+ ) -> Callable[[logging.Logger, str, dict[str, Any]], dict[str, Any]]:
1109
+ def inject(
1110
+ logger: Any, method_name: Any, event_dict: Any
1111
+ ) -> dict[str, Any]:
1112
+ event_dict["query_id"] = query_id
1113
+ event_dict["iteration"] = iteration
1114
+ return event_dict
1115
+
1116
+ return inject
1117
+
1118
+ shared_processors = [
1119
+ structlog.contextvars.merge_contextvars,
1120
+ make_injector(query_id, iteration),
1121
+ structlog.processors.add_log_level,
1122
+ structlog.processors.CallsiteParameterAdder(
1123
+ parameters=[
1124
+ structlog.processors.CallsiteParameter.PROCESS,
1125
+ structlog.processors.CallsiteParameter.THREAD,
1126
+ ],
1127
+ ),
1128
+ structlog.processors.StackInfoRenderer(),
1129
+ structlog.dev.set_exc_info,
1130
+ structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M:%S.%f", utc=False),
1131
+ ]
1132
+
1133
+ # For logging to a file
1134
+ json_renderer = structlog.processors.JSONRenderer()
1135
+
1136
+ stream = io.StringIO()
1137
+ json_file_handler = logging.StreamHandler(stream)
1138
+ json_file_handler.setFormatter(
1139
+ structlog.stdlib.ProcessorFormatter(
1140
+ processor=json_renderer,
1141
+ foreign_pre_chain=shared_processors,
1142
+ )
1143
+ )
1144
+
1145
+ logging.basicConfig(level=logging.INFO, handlers=[json_file_handler])
1146
+
1147
+ structlog.configure(
1148
+ processors=[
1149
+ *shared_processors,
1150
+ structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
1151
+ ],
1152
+ logger_factory=structlog.stdlib.LoggerFactory(),
1153
+ wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
1154
+ cache_logger_on_first_use=True,
1155
+ )
1156
+
1157
+
1158
+ PDSDS_TABLE_NAMES: list[str] = [
1159
+ "call_center",
1160
+ "catalog_page",
1161
+ "catalog_returns",
1162
+ "catalog_sales",
1163
+ "customer",
1164
+ "customer_address",
1165
+ "customer_demographics",
1166
+ "date_dim",
1167
+ "household_demographics",
1168
+ "income_band",
1169
+ "inventory",
1170
+ "item",
1171
+ "promotion",
1172
+ "reason",
1173
+ "ship_mode",
1174
+ "store",
1175
+ "store_returns",
1176
+ "store_sales",
1177
+ "time_dim",
1178
+ "warehouse",
1179
+ "web_page",
1180
+ "web_returns",
1181
+ "web_sales",
1182
+ "web_site",
1183
+ ]
1184
+
1185
+ PDSH_TABLE_NAMES: list[str] = [
1186
+ "customer",
1187
+ "lineitem",
1188
+ "nation",
1189
+ "orders",
1190
+ "part",
1191
+ "partsupp",
1192
+ "region",
1193
+ "supplier",
1194
+ ]
1195
+
1196
+
1197
+ def print_duckdb_plan(
1198
+ q_id: int,
1199
+ sql: str,
1200
+ dataset_path: Path,
1201
+ suffix: str,
1202
+ query_set: str,
1203
+ args: argparse.Namespace,
1204
+ ) -> None:
1205
+ """Print DuckDB query plan using EXPLAIN."""
1206
+ if duckdb is None:
1207
+ raise ImportError(duckdb_err)
1208
+
1209
+ if query_set == "pdsds":
1210
+ tbl_names = PDSDS_TABLE_NAMES
1211
+ else:
1212
+ tbl_names = PDSH_TABLE_NAMES
1213
+
1214
+ with duckdb.connect() as conn:
1215
+ for name in tbl_names:
1216
+ pattern = (Path(dataset_path) / name).as_posix() + suffix
1217
+ conn.execute(
1218
+ f"CREATE OR REPLACE VIEW {name} AS "
1219
+ f"SELECT * FROM parquet_scan('{pattern}');"
1220
+ )
1221
+
1222
+ if args.explain_logical and args.explain:
1223
+ conn.execute("PRAGMA explain_output = 'all';")
1224
+ elif args.explain_logical:
1225
+ conn.execute("PRAGMA explain_output = 'optimized_only';")
1226
+ else:
1227
+ conn.execute("PRAGMA explain_output = 'physical_only';")
1228
+
1229
+ print(f"\nDuckDB Query {q_id} - Plan\n")
1230
+
1231
+ plan_rows = conn.execute(f"EXPLAIN {sql}").fetchall()
1232
+ for _, line in plan_rows:
1233
+ print(line)
1234
+
1235
+
1236
+ def execute_duckdb_query(
1237
+ query: str,
1238
+ dataset_path: Path,
1239
+ *,
1240
+ suffix: str = ".parquet",
1241
+ query_set: str = "pdsh",
1242
+ ) -> pl.DataFrame:
1243
+ """Execute a query with DuckDB."""
1244
+ if duckdb is None:
1245
+ raise ImportError(duckdb_err)
1246
+ if query_set == "pdsds":
1247
+ tbl_names = PDSDS_TABLE_NAMES
1248
+ else:
1249
+ tbl_names = PDSH_TABLE_NAMES
1250
+ with duckdb.connect() as conn:
1251
+ for name in tbl_names:
1252
+ pattern = (Path(dataset_path) / name).as_posix() + suffix
1253
+ conn.execute(
1254
+ f"CREATE OR REPLACE VIEW {name} AS "
1255
+ f"SELECT * FROM parquet_scan('{pattern}');"
1256
+ )
1257
+ return conn.execute(query).pl()
1258
+
1259
+
1260
+ def run_duckdb(
1261
+ duckdb_queries_cls: Any, options: Sequence[str] | None = None, *, num_queries: int
1262
+ ) -> None:
1263
+ """Run the benchmark with DuckDB."""
1264
+ args = parse_args(options, num_queries=num_queries)
1265
+ vars(args).update({"query_set": duckdb_queries_cls.name})
1266
+ run_config = RunConfig.from_args(args)
1267
+ records: defaultdict[int, list[Record]] = defaultdict(list)
1268
+
1269
+ for q_id in run_config.queries:
1270
+ try:
1271
+ get_q = getattr(duckdb_queries_cls, f"q{q_id}")
1272
+ except AttributeError as err:
1273
+ raise NotImplementedError(f"Query {q_id} not implemented.") from err
1274
+
1275
+ sql = get_q(run_config)
1276
+
1277
+ if args.explain or args.explain_logical:
1278
+ print_duckdb_plan(
1279
+ q_id=q_id,
1280
+ sql=sql,
1281
+ dataset_path=run_config.dataset_path,
1282
+ suffix=run_config.suffix,
1283
+ query_set=duckdb_queries_cls.name,
1284
+ args=args,
1285
+ )
1286
+
1287
+ print(f"DuckDB Executing: {q_id}")
1288
+ records[q_id] = []
1289
+
1290
+ for i in range(args.iterations):
1291
+ t0 = time.time()
1292
+ result = execute_duckdb_query(
1293
+ sql,
1294
+ run_config.dataset_path,
1295
+ suffix=run_config.suffix,
1296
+ query_set=duckdb_queries_cls.name,
1297
+ )
1298
+ t1 = time.time()
1299
+ record = Record(query=q_id, iteration=i, duration=t1 - t0)
1300
+ if args.print_results:
1301
+ print(result)
1302
+ print(f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s")
1303
+ records[q_id].append(record)
1304
+
1305
+ run_config = dataclasses.replace(run_config, records=dict(records))
1306
+ if args.summarize:
1307
+ run_config.summarize()
1308
+
1309
+
1310
+ def run_validate(
1311
+ polars_queries_cls: Any,
1312
+ duckdb_queries_cls: Any,
1313
+ options: Sequence[str] | None = None,
1314
+ *,
1315
+ num_queries: int,
1316
+ check_dtypes: bool,
1317
+ check_column_order: bool,
1318
+ ) -> None:
1319
+ """Validate Polars CPU/GPU vs DuckDB."""
1320
+ from polars.testing import assert_frame_equal
1321
+
1322
+ args = parse_args(options, num_queries=num_queries)
1323
+ vars(args).update({"query_set": polars_queries_cls.name})
1324
+ run_config = RunConfig.from_args(args)
1325
+
1326
+ baseline = args.baseline
1327
+ if baseline not in {"duckdb", "cpu"}:
1328
+ raise ValueError("Baseline must be one of: 'duckdb', 'cpu'")
1329
+
1330
+ failures: list[int] = []
1331
+
1332
+ engine: pl.GPUEngine | None = None
1333
+ if run_config.executor != "cpu":
1334
+ engine = pl.GPUEngine(
1335
+ raise_on_fail=True,
1336
+ executor=run_config.executor,
1337
+ executor_options=get_executor_options(run_config, polars_queries_cls),
1338
+ )
1339
+
1340
+ for q_id in run_config.queries:
1341
+ print(f"\nValidating Query {q_id}")
1342
+ try:
1343
+ get_pl = getattr(polars_queries_cls, f"q{q_id}")
1344
+ get_ddb = getattr(duckdb_queries_cls, f"q{q_id}")
1345
+ except AttributeError as err:
1346
+ raise NotImplementedError(f"Query {q_id} not implemented.") from err
1347
+
1348
+ polars_query = get_pl(run_config)
1349
+ if baseline == "duckdb":
1350
+ base_sql = get_ddb(run_config)
1351
+ base_result = execute_duckdb_query(
1352
+ base_sql,
1353
+ run_config.dataset_path,
1354
+ query_set=duckdb_queries_cls.name,
1355
+ )
1356
+ else:
1357
+ base_result = polars_query.collect(engine="streaming")
1358
+
1359
+ if run_config.executor == "cpu":
1360
+ test_result = polars_query.collect(engine="streaming")
1361
+ else:
1362
+ try:
1363
+ test_result = polars_query.collect(engine=engine)
1364
+ except Exception as e:
1365
+ failures.append(q_id)
1366
+ print(f"❌ Query {q_id} failed validation: GPU execution failed.\n{e}")
1367
+ continue
1368
+
1369
+ try:
1370
+ assert_frame_equal(
1371
+ base_result,
1372
+ test_result,
1373
+ check_dtypes=check_dtypes,
1374
+ check_column_order=check_column_order,
1375
+ )
1376
+ print(f"✅ Query {q_id} passed validation.")
1377
+ except AssertionError as e:
1378
+ failures.append(q_id)
1379
+ print(f"❌ Query {q_id} failed validation:\n{e}")
1380
+ if args.print_results:
1381
+ print("Baseline Result:\n", base_result)
1382
+ print("Test Result:\n", test_result)
1383
+
1384
+ if failures:
1385
+ print("\nValidation Summary:")
1386
+ print("===================")
1387
+ print(f"{len(failures)} query(s) failed: {failures}")
1388
+ else:
1389
+ print("\nAll queries passed validation.")